commit 74c4dbf7813fa00dd4ab6987b2ffd48da4e57d29 Author: Arthur Meyre Date: Fri Oct 21 09:46:35 2022 +0200 feat(tfhe): new tfhe-rs package, initial commit diff --git a/.config/nextest.toml b/.config/nextest.toml new file mode 100644 index 000000000..84d656c09 --- /dev/null +++ b/.config/nextest.toml @@ -0,0 +1,17 @@ +[profile.ci] +# Print out output for failing tests as soon as they fail, and also at the end +# of the run (for easy scrollability). +failure-output = "final" +fail-fast = false +retries = 0 +slow-timeout = "5m" + + +[[profile.ci.overrides]] +filter = 'test(/^.*param_message_1_carry_[567]$/) or test(/^.*param_message_4_carry_4$/)' +retries = 3 + +[[profile.ci.overrides]] +filter = 'test(/^.*param_message_[23]_carry_[23]$/)' +retries = 1 + diff --git a/.gitbook.yaml b/.gitbook.yaml new file mode 100644 index 000000000..6da46390a --- /dev/null +++ b/.gitbook.yaml @@ -0,0 +1 @@ +root: ./tfhe/docs diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 000000000..73c9bf27d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: Bug report +about: Report a problem with concrete +title: '' +labels: triage_required +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behaviour +1. first step +2. second step +3. etc + +**Expected behaviour** +A clear and concise description of what you expected to happen. + +**Evidence** +If applicable, add material to help explain your problem (e.g. screenshots, logs, artifacts, etc.). + +**Configuration(please complete the following information):** + - OS: [e.g. Ubuntu 20.04] + + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 000000000..123993571 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for concrete +title: '' +labels: feature_request +assignees: '' + +--- + +**What is the problem you want to solve and can not with the current version?** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/workflows/aws_tfhe_tests.yml b/.github/workflows/aws_tfhe_tests.yml new file mode 100644 index 000000000..61d0f1e08 --- /dev/null +++ b/.github/workflows/aws_tfhe_tests.yml @@ -0,0 +1,102 @@ +name: AWS Tests on CPU + +env: + CARGO_TERM_COLOR: always + ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + RUSTFLAGS: "-C target-cpu=native" + +on: + # Allows you to run this workflow manually from the Actions tab as an alternative. + workflow_dispatch: + # All the inputs are provided by Slab + inputs: + instance_id: + description: "AWS instance ID" + type: string + instance_image_id: + description: "AWS instance AMI ID" + type: string + instance_type: + description: "AWS instance product type" + type: string + runner_name: + description: "Action runner name" + type: string + +jobs: + shortint-tests: + concurrency: + group: ${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }} + cancel-in-progress: true + runs-on: ${{ github.event.inputs.runner_name }} + steps: + # Step used for log purpose. + - name: Instance configuration used + run: | + echo "ID: ${{ github.event.inputs.instance_id }}" + echo "AMI: ${{ github.event.inputs.instance_image_id }}" + echo "Type: ${{ github.event.inputs.instance_type }}" + + - uses: actions/checkout@v2 + + - name: Set up home + run: | + echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}" + + - name: Install latest stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + default: true + + - name: Run core tests + run: | + make test_core_crypto + + - name: Run C API tests + run: | + make test_c_api + + - name: Run user docs tests + run: | + make test_user_doc + + - name: Install AWS CLI + run: | + apt update + apt install -y awscli + + - name: Configure AWS credentials from Test account + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_IAM_ID }} + aws-secret-access-key: ${{ secrets.AWS_IAM_KEY }} + role-to-assume: concrete-lib-ci + aws-region: eu-west-3 + role-duration-seconds: 10800 + + - name: Download keys locally + run: aws s3 cp --recursive --no-progress s3://concrete-libs-keycache ./keys + + - name: Gen Keys if required + run: | + make gen_key_cache + + - name: Sync keys + run: aws s3 sync ./keys s3://concrete-libs-keycache + + - name: Run shortint tests + run: | + make test_shortint_ci + + - name: Slack Notification + if: ${{ always() }} + continue-on-error: true + uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7 + env: + SLACK_COLOR: ${{ job.status }} + SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }} + SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png + SLACK_MESSAGE: "Shortint tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})" + SLACK_USERNAME: ${{ secrets.BOT_USERNAME }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.github/workflows/aws_tfhe_tests_w_gpu.yml b/.github/workflows/aws_tfhe_tests_w_gpu.yml new file mode 100644 index 000000000..fb1ec848a --- /dev/null +++ b/.github/workflows/aws_tfhe_tests_w_gpu.yml @@ -0,0 +1,113 @@ +# Compile and test project on an AWS instance +name: AWS tests on GPU + +# This workflow is meant to be run via Zama CI bot Slab. +on: + workflow_dispatch: + inputs: + instance_id: + description: "AWS instance ID" + type: string + instance_image_id: + description: "AWS instance AMI ID" + type: string + instance_type: + description: "AWS EC2 instance product type" + type: string + runner_name: + description: "Action runner name" + type: string + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: "-C target-cpu=native" + ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + +jobs: + run-tests-linux: + concurrency: + group: ${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }} + cancel-in-progress: true + name: Test code in EC2 + runs-on: ${{ github.event.inputs.runner_name }} + strategy: + fail-fast: false + # explicit include-based build matrix, of known valid options + matrix: + include: + - os: ubuntu-20.04 + cuda: "11.8" + old_cuda: "11.1" + cuda_arch: "70" + gcc: 8 + env: + CUDA_PATH: /usr/local/cuda-${{ matrix.cuda }} + OLD_CUDA_PATH: /usr/local/cuda-${{ matrix.old_cuda }} + + steps: + - name: EC2 instance configuration used + run: | + echo "IDs: ${{ github.event.inputs.instance_id }}" + echo "AMI: ${{ github.event.inputs.instance_image_id }}" + echo "Type: ${{ github.event.inputs.instance_type }}" + - uses: actions/checkout@v2 + - name: Set up home + run: | + echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}" + - name: Export CUDA variables + run: | + echo "CUDA_PATH=$CUDA_PATH" >> "${GITHUB_ENV}" + echo "$CUDA_PATH/bin" >> "${GITHUB_PATH}" + echo "LD_LIBRARY_PATH=$CUDA_PATH/lib:$LD_LIBRARY_PATH" >> "${GITHUB_ENV}" + # Specify the correct host compilers + - name: Export gcc and g++ variables + run: | + echo "CC=/usr/bin/gcc-${{ matrix.gcc }}" >> "${GITHUB_ENV}" + echo "CXX=/usr/bin/g++-${{ matrix.gcc }}" >> "${GITHUB_ENV}" + echo "CUDAHOSTCXX=/usr/bin/g++-${{ matrix.gcc }}" >> "${GITHUB_ENV}" + echo "CUDACXX=$CUDA_PATH/bin/nvcc" >> "${GITHUB_ENV}" + echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}" + - name: Install latest stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + default: true + + - name: Cuda clippy + run: | + make clippy_cuda + + - name: Run core cuda tests + run: | + make test_core_crypto_cuda + + - name: Test tfhe-rs/boolean with cpu + run: | + make test_boolean + + - name: Test tfhe-rs/boolean with cuda backend with CUDA 11.8 + run: | + make test_boolean_cuda + + - name: Export variables for CUDA 11.1 + run: | + echo "CUDA_PATH=$OLD_CUDA_PATH" >> "${GITHUB_ENV}" + echo "LD_LIBRARY_PATH=$OLD_CUDA_PATH/lib:$LD_LIBRARY_PATH" >> "${GITHUB_ENV}" + echo "CUDACXX=$OLD_CUDA_PATH/bin/nvcc" >> "${GITHUB_ENV}" + + - name: Test tfhe-rs/boolean with cuda backend with CUDA 11.1 + run: | + cargo clean + make test_boolean_cuda + + - name: Slack Notification + if: ${{ always() }} + continue-on-error: true + uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7 + env: + SLACK_COLOR: ${{ job.status }} + SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }} + SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png + SLACK_MESSAGE: "(Slab ci-bot beta) AWS tests GPU finished with status ${{ job.status }}. (${{ env.ACTION_RUN_URL }})" + SLACK_USERNAME: ${{ secrets.BOT_USERNAME }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.github/workflows/cargo_build.yml b/.github/workflows/cargo_build.yml new file mode 100644 index 000000000..e36895921 --- /dev/null +++ b/.github/workflows/cargo_build.yml @@ -0,0 +1,69 @@ +name: Cargo Build + +on: + pull_request: + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: "-C target-cpu=native" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref }} + cancel-in-progress: true + +jobs: + cargo-builds: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + + - name: Get rust toolchain to use for checks and lints + id: toolchain + run: | + echo "rs-toolchain=$(make rs_toolchain)" >> "${GITHUB_OUTPUT}" + + - name: Check format + run: | + make check_fmt + + - name: Build doc + run: | + make doc + + - name: Clippy boolean + run: | + make clippy_boolean + + - name: Build Release boolean + run: | + make build_boolean + + - name: Clippy shortint + run: | + make clippy_shortint + + - name: Build Release shortint + run: | + make build_shortint + + - name: Clippy shortint and boolean + run: | + make clippy + + - name: Build Release shortint and boolean + run: | + make build_boolean_and_shortint + + - name: C API Clippy + run: | + make clippy_c_api + + - name: Build Release c_api + run: | + make build_c_api diff --git a/.github/workflows/check_commit.yml b/.github/workflows/check_commit.yml new file mode 100644 index 000000000..5da948bb0 --- /dev/null +++ b/.github/workflows/check_commit.yml @@ -0,0 +1,33 @@ +# Check commit and PR compliance +name: Check commit and PR compliance +on: + pull_request: + branches: + - main + - dev +jobs: + check-commit-pr: + name: Check commit and PR + runs-on: ubuntu-latest + steps: + - name: Check first line + uses: gsactions/commit-message-checker@v1 + with: + pattern: '^((feat|fix|chore|refactor|style|test|docs|doc)\(\w+\)\:) .+$' + flags: "gs" + error: 'Your first line has to contain a commit type and scope like "feat(my_feature): msg".' + excludeDescription: "true" # optional: this excludes the description body of a pull request + excludeTitle: "true" # optional: this excludes the title of a pull request + checkAllCommitMessages: "true" # optional: this checks all commits associated with a pull request + accessToken: ${{ secrets.GITHUB_TOKEN }} # github access token is only required if checkAllCommitMessages is true + + - name: Check line length + uses: gsactions/commit-message-checker@v1 + with: + pattern: '(^.{0,74}$\r?\n?){0,20}' + flags: "gm" + error: "The maximum line length of 74 characters is exceeded." + excludeDescription: "true" # optional: this excludes the description body of a pull request + excludeTitle: "true" # optional: this excludes the title of a pull request + checkAllCommitMessages: "true" # optional: this checks all commits associated with a pull request + accessToken: ${{ secrets.GITHUB_TOKEN }} # github access token is only required if checkAllCommitMessages is true diff --git a/.github/workflows/m1_tests.yml b/.github/workflows/m1_tests.yml new file mode 100644 index 000000000..e1dc7ad31 --- /dev/null +++ b/.github/workflows/m1_tests.yml @@ -0,0 +1,128 @@ +name: Tests on M1 CPU + +on: + workflow_dispatch: + pull_request: + types: [labeled] + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: "-C target-cpu=native" + ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref }} + cancel-in-progress: true + +jobs: + cargo-builds: + if: "github.event_name != 'pull_request' || contains(github.event.label.name, 'm1_test')" + runs-on: ["self-hosted", "m1mac"] + + steps: + - uses: actions/checkout@v2 + + - name: Install latest stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + default: true + + - name: Build doc + run: | + make doc + + - name: Clippy boolean + run: | + make clippy_boolean + + - name: Build Release boolean + run: | + make build_boolean + + - name: Clippy shortint + run: | + make clippy_shortint + + - name: Build Release shortint + run: | + make build_shortint + + - name: Clippy shortint and boolean + run: | + make clippy + + - name: Build Release shortint and boolean + run: | + make build_boolean_and_shortint + + - name: C API Clippy + run: | + make clippy_c_api + + - name: Build Release c_api + run: | + make build_c_api + + - name: Test tfhe-rs/boolean with cpu + run: | + make test_boolean + + - name: Run core tests + run: | + make test_core_crypto + + - name: Run C API tests + run: | + make test_c_api + + - name: Run user docs tests + run: | + make test_user_doc + + - name: Configure AWS credentials from Test account + uses: aws-actions/configure-aws-credentials@v1 + with: + aws-access-key-id: ${{ secrets.AWS_IAM_ID }} + aws-secret-access-key: ${{ secrets.AWS_IAM_KEY }} + role-to-assume: concrete-lib-ci + aws-region: eu-west-3 + role-duration-seconds: 10800 + + - name: Download keys locally + run: aws s3 cp --recursive --no-progress s3://concrete-libs-keycache ./keys + + - name: Gen Keys if required + run: | + make gen_key_cache + + - name: Sync keys + run: aws s3 sync ./keys s3://concrete-libs-keycache + + - name: Run shortint tests + run: | + make test_shortint_ci + + remove_label: + name: Remove m1_test label + runs-on: ubuntu-latest + needs: + - cargo-builds + if: ${{ always() }} + steps: + - uses: actions-ecosystem/action-remove-labels@v1 + with: + labels: m1_test + github_token: ${{ secrets.GITHUB_TOKEN }} + + - name: Slack Notification + if: ${{ always() }} + continue-on-error: true + uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7 + env: + SLACK_COLOR: ${{ needs.cargo-builds.result }} + SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }} + SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png + SLACK_MESSAGE: "M1 tests finished with status: ${{ needs.cargo-builds.result }}. (${{ env.ACTION_RUN_URL }})" + SLACK_USERNAME: ${{ secrets.BOT_USERNAME }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..2c01b436e --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +target/ +.idea/ +.vscode/ + +# Path we use for internal-keycache during tests +keys/ + +**/Cargo.lock +**/*.bin diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000..403b36a51 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,9 @@ +[workspace] +resolver = "2" +members = ["tfhe"] + +[profile.bench] +lto = "fat" + +[profile.release] +lto = "fat" diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..a05bfc3f6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,33 @@ +BSD 3-Clause Clear License + +Copyright © 2022 ZAMA. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation and/or other +materials provided with the distribution. + +3. Neither the name of ZAMA nor the names of its contributors may be used to endorse +or promote products derived from this software without specific prior written permission. + +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*. +THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive, +free and non-commercial license on all patents filed in its name relating to the open-source +code (the "Patents") for the sole purpose of evaluation, development, research, prototyping +and experimentation. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..c4721e1a4 --- /dev/null +++ b/Makefile @@ -0,0 +1,151 @@ +SHELL:=$(shell /usr/bin/env which bash) +RS_CHECK_TOOLCHAIN:=$(shell cat toolchain.txt) +CARGO_RS_CHECK_TOOLCHAIN:=+$(RS_CHECK_TOOLCHAIN) +TARGET_ARCH_FEATURE:=$(shell ./scripts/get_arch_feature.sh) +RS_BUILD_TOOLCHAIN:=$(shell \ + ( (echo $(TARGET_ARCH_FEATURE) | grep -q x86) && echo stable) || echo $(RS_CHECK_TOOLCHAIN)) +CARGO_RS_BUILD_TOOLCHAIN:=+$(RS_BUILD_TOOLCHAIN) +# This is done to avoid forgetting it, we still precise the RUSTFLAGS in the commands to be able to +# copy paste the command in the termianl and change them if required without forgetting the flags +export RUSTFLAGS:=-C target-cpu=native + +.PHONY: rs_check_toolchain # Echo the rust toolchain used for checks +rs_check_toolchain: + @echo $(RS_CHECK_TOOLCHAIN) + +.PHONY: rs_build_toolchain # Echo the rust toolchain used for builds +rs_build_toolchain: + @echo $(RS_BUILD_TOOLCHAIN) + +.PHONY: install_rs_check_toolchain # Install the toolchain used for checks +install_rs_check_toolchain: + @rustup toolchain list | grep -q "$(RS_CHECK_TOOLCHAIN)" || \ + rustup toolchain install --profile default "$(RS_CHECK_TOOLCHAIN)" || \ + echo "Unable to install $(RS_CHECK_TOOLCHAIN) toolchain, check your rustup installation. \ + Rustup can be downloaded at https://rustup.rs/" + +.PHONY: install_rs_build_toolchain # Install the toolchain used for builds +install_rs_build_toolchain: + @rustup toolchain list | grep -q "$(RS_BUILD_TOOLCHAIN)" || \ + rustup toolchain install --profile default "$(RS_BUILD_TOOLCHAIN)" || \ + echo "Unable to install $(RS_BUILD_TOOLCHAIN) toolchain, check your rustup installation. \ + Rustup can be downloaded at https://rustup.rs/" + +.PHONY: install_cargo_nextest # Install cargo nextest used for shortint tests +install_cargo_nextest: install_rs_build_toolchain + @cargo nextest --version > /dev/null 2>&1 || \ + cargo $(CARGO_RS_BUILD_TOOLCHAIN) install cargo-nextest --locked || \ + echo "Unable to install cargo nextest, unknown error." + +.PHONY: fmt # Format rust code +fmt: install_rs_check_toolchain + cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt + +.PHONT: check_fmt # Check rust code format +check_fmt: install_rs_check_toolchain + cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt --check + +.PHONY: clippy_boolean # Run clippy lints enabling the boolean features +clippy_boolean: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ + --features=$(TARGET_ARCH_FEATURE),boolean \ + -p tfhe -- --no-deps -D warnings + +.PHONY: clippy_shortint # Run clippy lints enabling the shortint features +clippy_shortint: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ + --features=$(TARGET_ARCH_FEATURE),shortint \ + -p tfhe -- --no-deps -D warnings + +.PHONY: clippy # Run clippy lints enabling the boolean, shortint +clippy: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint \ + -p tfhe -- --no-deps -D warnings + +.PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API +clippy_c_api: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ + --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api \ + -p tfhe -- --no-deps -D warnings + +.PHONY: clippy_cuda # Run clippy lints enabling the boolean, shortint, cuda and c API features +clippy_cuda: install_rs_check_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ + --features=$(TARGET_ARCH_FEATURE),cuda,boolean-c-api,shortint-c-api \ + -p tfhe -- --no-deps -D warnings + +.PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests +gen_key_cache: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) run --release \ + --example generates_test_keys \ + --features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe + +.PHONY: build_boolean # Build with boolean enabled +build_boolean: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release \ + --features=$(TARGET_ARCH_FEATURE),boolean -p tfhe + +.PHONY: build_shortint # Build with shortint enabled +build_shortint: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release \ + --features=$(TARGET_ARCH_FEATURE),shortint -p tfhe + +.PHONY: build_boolean_and_shortint # Build with boolean and shortint enabled +build_boolean_and_shortint: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint -p tfhe + +.PHONY: build_c_api # Build the C API for boolean and shortint +build_c_api: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) build --release + --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api -p tfhe + +.PHONY: test_core_crypto # Run the tests of the core_crypto module +test_core_crypto: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ + --features=$(TARGET_ARCH_FEATURE) -p tfhe -- core_crypto:: + +.PHONY: test_core_crypto_cuda # Run the tests of the core_crypto module with cuda enabled +test_core_crypto_cuda: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ + --features=$(TARGET_ARCH_FEATURE),cuda -p tfhe -- core_crypto::backends::cuda:: + +.PHONY: test_boolean # Run the tests of the boolean module +test_boolean: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ + --features=$(TARGET_ARCH_FEATURE),boolean -p tfhe -- boolean:: + +.PHONY: test_boolean_cuda # Run the tests of the boolean module with cuda enabled +test_boolean_cuda: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ + --features=$(TARGET_ARCH_FEATURE),boolean,cuda -p tfhe -- boolean:: + +.PHONY: test_c_api # Run the tests for the C API +test_c_api: install_rs_build_toolchain + ./scripts/c_api_tests.sh $(CARGO_RS_BUILD_TOOLCHAIN) + +.PHONY: test_shortint_ci # Run the tests for shortint ci +test_shortint_ci: install_rs_build_toolchain install_cargo_nextest + ./scripts/shortint-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN) + +.PHONY: test_shortint # Run all the tests for shortint +test_shortint: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \ + --features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe -- shortint:: + +.PHONY: test_user_doc # Run tests from the .md documentation +test_user_doc: install_rs_build_toolchain + RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release --doc \ + --features=$(TARGET_ARCH_FEATURE),shortint,boolean,internal-keycache -p tfhe \ + -- test_user_docs:: + +.PHONY: doc # Build rust doc +doc: install_rs_check_toolchain + RUSTDOCFLAGS="--html-in-header katex-header.html" \ + cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint --no-deps + +.PHONY: help # Generate list of targets with descriptions +help: + @grep '^.PHONY: .* #' Makefile | sed 's/\.PHONY: \(.*\) # \(.*\)/\1\t\2/' | expand -t30 | sort diff --git a/README.md b/README.md new file mode 100644 index 000000000..4a26217ac --- /dev/null +++ b/README.md @@ -0,0 +1,133 @@ +

+ + +

+

+ + + + + + + + + + + + + + + + + + + + +

+ +**TFHE-rs** is a pure Rust implementation of TFHE for boolean and small integer +arithmetics over encrypted data. It includes: + - a **Rust** API + - a **C** API + - and a **client-side WASM** API + +**TFHE-rs** is meant for developers and researchers who want full control over +what they can do with TFHE, while not having to worry about the low level +implementation. The goal is to have a stable, simple, high-performance and +production-ready library for all the advanced features of TFHE. + +## Getting Started + +To use `TFHE-rs` in your project, you first need to add it as a dependency in your `Cargo.toml`: + +```toml +tfhe = { version = "0.1.0", features = [ "boolean","shortint","x86_64-unix" ] } +``` + +Here is a full example evaluating a Boolean circuit: + +```rust +use tfhe::boolean::prelude::*; + +fn main() { +// We generate a set of client/server keys, using the default parameters: + let (mut client_key, mut server_key) = gen_keys(); + +// We use the client secret key to encrypt two messages: + let ct_1 = client_key.encrypt(true); + let ct_2 = client_key.encrypt(false); + +// We use the server public key to execute a boolean circuit: +// if ((NOT ct_2) NAND (ct_1 AND ct_2)) then (NOT ct_2) else (ct_1 AND ct_2) + let ct_3 = server_key.not(&ct_2); + let ct_4 = server_key.and(&ct_1, &ct_2); + let ct_5 = server_key.nand(&ct_3, &ct_4); + let ct_6 = server_key.mux(&ct_5, &ct_3, &ct_4); + +// We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_6); + assert_eq!(output, true); +} +``` + +Another example of how the library can be used with shortints: + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(Parameters::default()); + + let msg1 = 1; + let msg2 = 0; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to execute an integer circuit: + let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + assert_eq!(output, (msg1 + msg2) % modulus as u64); +} +``` + +## Contributing + +There are two ways to contribute to TFHE-rs: + +- you can open issues to report bugs or typos and to suggest new ideas +- you can ask to become an official contributor by emailing [hello@zama.ai](mailto:hello@zama.ai). +(becoming an approved contributor involves signing our Contributor License Agreement (CLA)) + +Only approved contributors can send pull requests, so please make sure to get in touch before you do! + +## Credits + +This library uses several dependencies and we would like to thank the contributors of those +libraries. + +## License + +This software is distributed under the BSD-3-Clause-Clear license. If you have any questions, +please contact us at `hello@zama.ai`. + +## Disclaimers + +### Security Estimation + +Security estimations are done using the +[Lattice Estimator](https://github.com/malb/lattice-estimator) +with `red_cost_model = reduction.RC.BDGL16`. + +When a new update is published in the Lattice Estimator, we update parameters accordingly. + +### Side-Channel Attacks + +Mitigation for side channel attacks have not yet been implemented in TFHE-rs, +and will be released in upcoming versions. diff --git a/ci/slab.toml b/ci/slab.toml new file mode 100644 index 000000000..bb30a0313 --- /dev/null +++ b/ci/slab.toml @@ -0,0 +1,21 @@ +[profile.cpu-big] +region = "eu-west-3" +image_id = "ami-04deffe45b5b236fd" +instance_type = "c5a.8xlarge" + +[profile.gpu] +region = "us-east-1" +image_id = "ami-0ae662beb44082155" +instance_type = "p3.2xlarge" +subnet_id = "subnet-8123c9e7" +security_group = "sg-0466d33ced960ba35" + +[command.cpu_test] +workflow = "aws_tfhe_tests.yml" +profile = "cpu-big" +check_run_name = "Shortint CPU AWS Tests" + +[command.gpu_test] +workflow = "aws_tfhe_tests_w_gpu.yml" +profile = "gpu" +check_run_name = "AWS tests GPU (Slab)" diff --git a/katex-header.html b/katex-header.html new file mode 100644 index 000000000..be4a7271b --- /dev/null +++ b/katex-header.html @@ -0,0 +1,15 @@ + + + + diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 000000000..a1fed167f --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,5 @@ +unstable_features = true +imports_granularity="Module" +format_code_in_doc_comments = true +wrap_comments = true +comment_width = 100 diff --git a/scripts/c_api_tests.sh b/scripts/c_api_tests.sh new file mode 100755 index 000000000..a37c2ae32 --- /dev/null +++ b/scripts/c_api_tests.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +set -e + +CURR_DIR="$(dirname "$0")" +ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" +REPO_ROOT="${CURR_DIR}/.." +TFHE_BUILD_DIR="${REPO_ROOT}/tfhe/build/" + +mkdir -p "${TFHE_BUILD_DIR}" + +cd "${TFHE_BUILD_DIR}" + +cmake .. -DCMAKE_BUILD_TYPE=RELEASE + +RUSTFLAGS="-C target-cpu=native" cargo ${1:+"${1}"} build \ +--release --features="${ARCH_FEATURE}",boolean-c-api,shortint-c-api -p tfhe + +make -j +make "test" diff --git a/scripts/get_arch_feature.sh b/scripts/get_arch_feature.sh new file mode 100755 index 000000000..41b3a44a7 --- /dev/null +++ b/scripts/get_arch_feature.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash + +set -e + +ARCH_FEATURE=x86_64 + +IS_AARCH64="$( (uname -a | grep -c arm64) || true)" + +if [[ "${IS_AARCH64}" != "0" ]]; then + ARCH_FEATURE=aarch64 +fi + +UNAME="$(uname)" + +if [[ "${UNAME}" == "Linux" || "${UNAME}" == "Darwin" ]]; then + ARCH_FEATURE="${ARCH_FEATURE}-unix" +fi + +echo "${ARCH_FEATURE}" diff --git a/scripts/shortint-tests.sh b/scripts/shortint-tests.sh new file mode 100755 index 000000000..392ac4a31 --- /dev/null +++ b/scripts/shortint-tests.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +set -e + +CURR_DIR="$(dirname "$0")" +ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" + +nproc_bin=nproc + +# macOS detects CPUs differently +if [[ $(uname) == "Darwin" ]]; then + nproc_bin="sysctl -n hw.logicalcpu" +fi + +n_threads="$(${nproc_bin})" + +if uname -a | grep "arm64"; then + if [[ $(uname) == "Darwin" ]]; then + # Keys are 4.7 gigs at max, CI M1 macs only has 8 gigs of RAM + n_threads=1 + fi +else + # Keys are 4.7 gigs at max, test machine has 32 gigs of RAM + n_threads=6 +fi + +filter_expression=''\ +'('\ +' test(/^shortint::server_key::.*_param_message_1_carry_1$/)'\ +'or test(/^shortint::server_key::.*_param_message_1_carry_2$/)'\ +'or test(/^shortint::server_key::.*_param_message_1_carry_3$/)'\ +'or test(/^shortint::server_key::.*_param_message_1_carry_4$/)'\ +'or test(/^shortint::server_key::.*_param_message_1_carry_5$/)'\ +'or test(/^shortint::server_key::.*_param_message_1_carry_6$/)'\ +'or test(/^shortint::server_key::.*_param_message_2_carry_2$/)'\ +'or test(/^shortint::server_key::.*_param_message_3_carry_3$/)'\ +'or test(/^shortint::server_key::.*_param_message_4_carry_4$/)'\ +')'\ +'and not test(~smart_add_and_mul)' # This test is too slow + +export RUSTFLAGS="-C target-cpu=native" + +# Run tests only no examples or benches +cargo ${1:+"${1}"} nextest run \ + --tests \ + --release \ + --package tfhe \ + --profile ci \ + --features="${ARCH_FEATURE}",shortint,internal-keycache \ + --test-threads "${n_threads}" \ + -E "${filter_expression}" + +cargo ${1:+"${1}"} test \ + --release \ + --package tfhe \ + --features="${ARCH_FEATURE}",shortint,internal-keycache \ + --doc \ + shortint:: + +echo "Test ran in $SECONDS seconds" diff --git a/tfhe/.gitignore b/tfhe/.gitignore new file mode 100644 index 000000000..d16386367 --- /dev/null +++ b/tfhe/.gitignore @@ -0,0 +1 @@ +build/ \ No newline at end of file diff --git a/tfhe/CMakeLists.txt b/tfhe/CMakeLists.txt new file mode 100644 index 000000000..4ccb1f5b9 --- /dev/null +++ b/tfhe/CMakeLists.txt @@ -0,0 +1,6 @@ +# tfhe/CMakeLists.txt +cmake_minimum_required(VERSION 3.16) +project(tfhe-c-api C) +set(SOURCE c_api_tests/*.c) +enable_testing() +add_subdirectory(c_api_tests) diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml new file mode 100644 index 000000000..8a69c275e --- /dev/null +++ b/tfhe/Cargo.toml @@ -0,0 +1,180 @@ +[package] +name = "tfhe" +version = "0.1.0" +edition = "2021" +readme = "../README.md" +keywords = ["fully", "homomorphic", "encryption", "fhe", "cryptography"] +homepage = "https://zama.ai/" +documentation = "https://docs.zama.ai/tfhe-rs" +repository = "https://github.com/zama-ai/tfhe-rs" +license = "BSD-3-Clause-Clear" +description = "Concrete is a fully homomorphic encryption (FHE) library that implements Zama's variant of TFHE." +build = "build.rs" +exclude = ["/docs/", "/c_api_tests/", "/CMakeLists.txt"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dev-dependencies] +rand = "0.7" +kolmogorov_smirnov = "1.1.0" +paste = "1.0.7" +lazy_static = { version = "1.4.0" } +criterion = "0.3.5" +doc-comment = "0.3.3" +# Used in user documentation +bincode = "1.3.3" +fs2 = { version = "0.4.3"} + +[build-dependencies] +cbindgen = { version = "0.24.3", optional = true } + +[dependencies] +concrete-csprng = { version = "0.2.1" } +concrete-cuda = { version = "0.1.1", optional = true } +lazy_static = { version = "1.4.0", optional = true } +serde = { version = "1.0", optional = true } +rayon = { version = "1.5.0", optional = true } +bincode = { version = "1.3.3", optional = true } +concrete-fft = { version = "0.1", optional = true } +aligned-vec = "0.5" +dyn-stack = { version = "0.8", optional = true } +once_cell = "1.13" +paste = "1.0.7" +fs2 = { version = "0.4.3", optional = true } + +# wasm deps +wasm-bindgen = { version = "0.2.63", features = [ + "serde-serialize", +], optional = true } +js-sys = { version = "0.3", optional = true } +console_error_panic_hook = { version = "0.1.7", optional = true } +serde-wasm-bindgen = { version = "0.4", optional = true } +getrandom = { version = "0.2.8", optional = true } + +[features] +boolean = ["minimal_core_crypto_features"] +shortint = ["minimal_core_crypto_features"] +internal-keycache = ["lazy_static", "fs2"] + +__c_api = ["cbindgen", "minimal_core_crypto_features"] +boolean-c-api = ["boolean", "__c_api"] +shortint-c-api = ["shortint", "__c_api"] + +__wasm_api = [ + "wasm-bindgen", + "js-sys", + "console_error_panic_hook", + "serde-wasm-bindgen", + "getrandom", + "getrandom/js", +] +boolean-client-js-wasm-api = ["boolean", "__wasm_api"] +shortint-client-js-wasm-api = ["shortint", "__wasm_api"] + +cuda = ["backend_cuda"] +nightly-avx512 = ["backend_fft_nightly_avx512"] + +# A pure-rust CPU backend. +backend_default = ["concrete-csprng/generator_soft"] + +# An accelerated backend, using the `concrete-fft` library. +backend_fft = ["concrete-fft", "dyn-stack"] +backend_fft_serialization = [ + "bincode", + "concrete-fft/serde", + "aligned-vec/serde", + "__commons_serialization", +] +backend_fft_nightly_avx512 = ["concrete-fft/nightly"] + +# Enables the parallel engine in default backend. +backend_default_parallel = ["__commons_parallel"] + +# Enable the x86_64 specific accelerated implementation of the random generator for the default +# backend +backend_default_generator_x86_64_aesni = [ + "concrete-csprng/generator_x86_64_aesni", +] + +# Enable the aarch64 specific accelerated implementation of the random generator for the default +# backend +backend_default_generator_aarch64_aes = [ + "concrete-csprng/generator_aarch64_aes", +] + +# Enable the serialization engine in the default backend. +backend_default_serialization = ["bincode", "__commons_serialization"] + +# A GPU backend, relying on Cuda acceleration +backend_cuda = ["concrete-cuda"] + +# Private features +__profiling = [] +__private_docs = [] +__commons_parallel = ["rayon", "concrete-csprng/parallel"] +__commons_serialization = ["serde", "serde/derive"] + +seeder_unix = ["concrete-csprng/seeder_unix"] +seeder_x86_64_rdseed = ["concrete-csprng/seeder_x86_64_rdseed"] + +minimal_core_crypto_features = [ + "backend_default", + "backend_default_parallel", + "backend_default_serialization", + "backend_fft", + "backend_fft_serialization", +] + +# These target_arch features enable a set of public features for concrete-core if users want a known +# good/working configuration for concrete-core. +# For a target_arch that does not yet have such a feature, one can still enable features manually or +# create a feature for said target_arch to make its use simpler. +x86_64 = [ + "minimal_core_crypto_features", + "backend_default_generator_x86_64_aesni", + "seeder_x86_64_rdseed", +] +x86_64-unix = ["x86_64", "seeder_unix"] + +# CUDA builds are Unix only at the moment +x86_64-unix-cuda = ["x86_64-unix", "cuda"] + +aarch64 = [ + "minimal_core_crypto_features", + "backend_default_generator_aarch64_aes", +] +aarch64-unix = ["aarch64", "seeder_unix"] + +[package.metadata.docs.rs] +# TODO: manage builds for docs.rs based on their documentation https://docs.rs/about +features = ["x86_64-unix", "boolean", "shortint"] +rustdoc-args = ["--html-in-header", "katex-header.html"] + +########### +# # +# Benches # +# # +########### + +[[bench]] +name = "boolean-bench" +path = "benches/boolean/bench.rs" +harness = false +required-features = ["boolean", "internal-keycache"] + +[[bench]] +name = "shortint-bench" +path = "benches/shortint/bench.rs" +harness = false +required-features = ["shortint", "internal-keycache"] + +[[example]] +name = "generates_test_keys" +required-features = ["shortint", "internal-keycache"] + +[[example]] +name = "micro_bench_and" +required-features = ["boolean"] + +[lib] +crate-type = ["lib", "staticlib", "cdylib"] diff --git a/tfhe/LICENSE b/tfhe/LICENSE new file mode 100644 index 000000000..e26e594c2 --- /dev/null +++ b/tfhe/LICENSE @@ -0,0 +1,32 @@ +BSD 3-Clause Clear License + +Copyright © 2022 ZAMA. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this +list of conditions and the following disclaimer in the documentation and/or other +materials provided with the distribution. + +3. Neither the name of ZAMA nor the names of its contributors may be used to endorse +or promote products derived from this software without specific prior written permission. +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*. +THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL +ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS +OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive, +free and non-commercial license on all patents filed in its name relating to the open-source +code (the "Patents") for the sole purpose of evaluation, development, research, prototyping +and experimentation. diff --git a/tfhe/benches/boolean/bench.rs b/tfhe/benches/boolean/bench.rs new file mode 100644 index 000000000..468a573d1 --- /dev/null +++ b/tfhe/benches/boolean/bench.rs @@ -0,0 +1,60 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use tfhe::boolean::client_key::ClientKey; +use tfhe::boolean::parameters::{BooleanParameters, DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS}; +use tfhe::boolean::prelude::BinaryBooleanGates; +use tfhe::boolean::server_key::ServerKey; + +criterion_group!( + gates_benches, + bench_default_parameters, + bench_tfhe_lib_parameters +); + +criterion_main!(gates_benches); + +// Put all `bench_function` in one place +// so the keygen is only run once per parameters saving time. +fn bench_gates(c: &mut Criterion, params: BooleanParameters, parameter_name: &str) { + let cks = ClientKey::new(¶ms); + let sks = ServerKey::new(&cks); + + let ct1 = cks.encrypt(true); + let ct2 = cks.encrypt(false); + let ct3 = cks.encrypt(true); + + let id = format!("AND gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.and(&ct1, &ct2)))); + + let id = format!("NAND gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.nand(&ct1, &ct2)))); + + let id = format!("OR gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.or(&ct1, &ct2)))); + + let id = format!("XOR gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.xor(&ct1, &ct2)))); + + let id = format!("XNOR gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.xnor(&ct1, &ct2)))); + + let id = format!("NOT gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.not(&ct1)))); + + let id = format!("MUX gate {}", parameter_name); + c.bench_function(&id, |b| b.iter(|| black_box(sks.mux(&ct1, &ct2, &ct3)))); +} + +#[cfg(not(feature = "cuda"))] +fn bench_default_parameters(c: &mut Criterion) { + bench_gates(c, DEFAULT_PARAMETERS, "DEFAULT_PARAMETERS"); +} + +#[cfg(feature = "cuda")] +fn bench_default_parameters(_: &mut Criterion) { + let _ = DEFAULT_PARAMETERS; // to avoid unused import warnings + println!("DEFAULT_PARAMETERS not benched as they are not compatible with the cuda feature."); +} + +fn bench_tfhe_lib_parameters(c: &mut Criterion) { + bench_gates(c, TFHE_LIB_PARAMETERS, "TFHE_LIB_PARAMETERS"); +} diff --git a/tfhe/benches/shortint/bench.rs b/tfhe/benches/shortint/bench.rs new file mode 100644 index 000000000..6844aa573 --- /dev/null +++ b/tfhe/benches/shortint/bench.rs @@ -0,0 +1,232 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use tfhe::shortint::parameters::*; +use tfhe::shortint::{Ciphertext, Parameters, ServerKey}; + +use rand::Rng; +use tfhe::shortint::keycache::KEY_CACHE; + +use tfhe::shortint::keycache::KEY_CACHE_WOPBS; +use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_4_NORM2_6; + +macro_rules! named_param { + ($param:ident) => { + (stringify!($param), $param) + }; +} + +const SERVER_KEY_BENCH_PARAMS: [(&str, Parameters); 4] = [ + named_param!(PARAM_MESSAGE_1_CARRY_1), + named_param!(PARAM_MESSAGE_2_CARRY_2), + named_param!(PARAM_MESSAGE_3_CARRY_3), + named_param!(PARAM_MESSAGE_4_CARRY_4), +]; + +fn bench_server_key_binary_function(c: &mut Criterion, bench_name: &str, binary_op: F) +where + F: Fn(&ServerKey, &mut Ciphertext, &mut Ciphertext), +{ + let mut bench_group = c.benchmark_group(bench_name); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let modulus = 1_u64 << cks.parameters.message_modulus.0; + + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ct_0 = cks.encrypt(clear_0); + let mut ct_1 = cks.encrypt(clear_1); + + let bench_id = format!("{}::{}", bench_name, param_name); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + binary_op(sks, &mut ct_0, &mut ct_1); + }) + }); + } + + bench_group.finish() +} + +fn bench_server_key_binary_scalar_function(c: &mut Criterion, bench_name: &str, binary_op: F) +where + F: Fn(&ServerKey, &mut Ciphertext, u8), +{ + let mut bench_group = c.benchmark_group(bench_name); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let modulus = 1_u64 << cks.parameters.message_modulus.0; + + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let mut ct_0 = cks.encrypt(clear_0); + + let bench_id = format!("{}::{}", bench_name, param_name); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + binary_op(sks, &mut ct_0, clear_1 as u8); + }) + }); + } + + bench_group.finish() +} + +fn carry_extract(c: &mut Criterion) { + let mut bench_group = c.benchmark_group("carry_extract"); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let modulus = 1_u64 << cks.parameters.message_modulus.0; + + let clear_0 = rng.gen::() % modulus; + + let ct_0 = cks.encrypt(clear_0); + + let bench_id = format!("ServerKey::carry_extract::{}", param_name); + bench_group.bench_function(&bench_id, |b| { + b.iter(|| { + sks.carry_extract(&ct_0); + }) + }); + } + + bench_group.finish() +} + +fn programmable_bootstrapping(c: &mut Criterion) { + let mut bench_group = c.benchmark_group("programmable_bootstrap"); + + for (param_name, param) in SERVER_KEY_BENCH_PARAMS { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + let acc = sks.generate_accumulator(|x| x); + + let clear_0 = rng.gen::() % modulus; + + let ctxt = cks.encrypt(clear_0); + + let id = format!("ServerKey::programmable_bootstrap::{}", param_name); + + bench_group.bench_function(&id, |b| { + b.iter(|| { + sks.keyswitch_programmable_bootstrap(&ctxt, &acc); + }) + }); + } + + bench_group.finish(); +} + +fn bench_wopbs_param_message_8_norm2_5(c: &mut Criterion) { + let mut bench_group = c.benchmark_group("programmable_bootstrap"); + + let param = WOPBS_PARAM_MESSAGE_4_NORM2_6; + + let keys = KEY_CACHE_WOPBS.get_from_param((param, param)); + let (cks, _, wopbs_key) = (keys.client_key(), keys.server_key(), keys.wopbs_key()); + + let mut rng = rand::thread_rng(); + + let clear = rng.gen::() % param.message_modulus.0; + let mut ct = cks.encrypt_without_padding(clear as u64); + let vec_lut = wopbs_key.generate_lut_native_crt(&ct, |x| x); + + let id = format!("Shortint WOPBS: {:?}", param); + + bench_group.bench_function(&id, |b| { + b.iter(|| { + wopbs_key.programmable_bootstrapping_native_crt(&mut ct, &vec_lut); + }) + }); + + bench_group.finish(); +} + +macro_rules! define_server_key_bench_fn ( + ($server_key_method:ident) => { + fn $server_key_method(c: &mut Criterion) { + bench_server_key_binary_function( + c, + concat!("ServerKey::", stringify!($server_key_method)), + |server_key, lhs, rhs| { + server_key.$server_key_method(lhs, rhs); + }) + } + } +); + +macro_rules! define_server_key_scalar_bench_fn ( + ($server_key_method:ident) => { + fn $server_key_method(c: &mut Criterion) { + bench_server_key_binary_scalar_function( + c, + concat!("ServerKey::", stringify!($server_key_method)), + |server_key, lhs, rhs| { + server_key.$server_key_method(lhs, rhs); + }) + } + } +); + +define_server_key_bench_fn!(unchecked_add); +define_server_key_bench_fn!(unchecked_sub); +define_server_key_bench_fn!(unchecked_mul_lsb); +define_server_key_bench_fn!(unchecked_mul_msb); +define_server_key_bench_fn!(smart_bitand); +define_server_key_bench_fn!(smart_bitor); +define_server_key_bench_fn!(smart_bitxor); +define_server_key_bench_fn!(smart_add); +define_server_key_bench_fn!(smart_sub); +define_server_key_bench_fn!(smart_mul_lsb); + +define_server_key_scalar_bench_fn!(unchecked_scalar_add); +define_server_key_scalar_bench_fn!(unchecked_scalar_mul); + +criterion_group!( + arithmetic_operation, + unchecked_add, + unchecked_sub, + unchecked_mul_lsb, + unchecked_mul_msb, + smart_bitand, + smart_bitor, + smart_bitxor, + smart_add, + smart_sub, + smart_mul_lsb, + carry_extract, + // programmable_bootstrapping, + // multivalue_programmable_bootstrapping + //bench_two_block_pbs + //wopbs_v0_norm2_2, + bench_wopbs_param_message_8_norm2_5, + programmable_bootstrapping +); + +criterion_group!( + arithmetic_scalar_operation, + unchecked_scalar_add, + unchecked_scalar_mul, +); + +criterion_main!(arithmetic_operation,); // arithmetic_scalar_operation,); diff --git a/tfhe/build.rs b/tfhe/build.rs new file mode 100644 index 000000000..4725c1e51 --- /dev/null +++ b/tfhe/build.rs @@ -0,0 +1,35 @@ +// tfhe/build.rs + +#[cfg(feature = "__c_api")] +fn gen_c_api() { + use std::env; + use std::path::PathBuf; + + /// Find the location of the `target/` directory. Note that this may be + /// overridden by `cmake`, so we also need to check the `CARGO_TARGET_DIR` + /// variable. + fn target_dir() -> PathBuf { + if let Ok(target) = env::var("CARGO_TARGET_DIR") { + PathBuf::from(target) + } else { + PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join("../target/release") + } + } + + extern crate cbindgen; + let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + let package_name = env::var("CARGO_PKG_NAME").unwrap(); + let output_file = target_dir() + .join(format!("{package_name}.h")) + .display() + .to_string(); + + cbindgen::generate(crate_dir) + .unwrap() + .write_to_file(output_file); +} + +fn main() { + #[cfg(feature = "__c_api")] + gen_c_api() +} diff --git a/tfhe/c_api_tests/.clang-format b/tfhe/c_api_tests/.clang-format new file mode 100644 index 000000000..3b0e7ff30 --- /dev/null +++ b/tfhe/c_api_tests/.clang-format @@ -0,0 +1 @@ +ColumnLimit: 100 diff --git a/tfhe/c_api_tests/CMakeLists.txt b/tfhe/c_api_tests/CMakeLists.txt new file mode 100644 index 000000000..21cd966ca --- /dev/null +++ b/tfhe/c_api_tests/CMakeLists.txt @@ -0,0 +1,37 @@ +project(tfhe-c-api-tests) + +cmake_minimum_required(VERSION 3.16) + +set(TFHE_C_API_RELEASE "${CMAKE_CURRENT_SOURCE_DIR}/../../target/release/") + +include_directories(${TFHE_C_API_RELEASE}) +add_library(Tfhe STATIC IMPORTED) +set_target_properties(Tfhe PROPERTIES IMPORTED_LOCATION ${TFHE_C_API_RELEASE}/libtfhe.a) + +if(APPLE) + find_library(SECURITY_FRAMEWORK Security) + if (NOT SECURITY_FRAMEWORK) + message(FATAL_ERROR "Security framework not found") + endif() +endif() + +file(GLOB TEST_CASES test_*.c) +foreach (testsourcefile ${TEST_CASES}) + get_filename_component(testname ${testsourcefile} NAME_WLE) + get_filename_component(groupname ${testsourcefile} DIRECTORY) + add_executable(${testname} ${testsourcefile}) + add_test( + NAME ${testname} + COMMAND ${testname} + WORKING_DIRECTORY ${CMAKE_BINARY_DIR}/Testing + ) + target_include_directories(${testname} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + target_link_libraries(${testname} LINK_PUBLIC Tfhe m pthread dl) + if(APPLE) + target_link_libraries(${testname} LINK_PUBLIC ${SECURITY_FRAMEWORK}) + endif() + target_compile_options(${testname} PRIVATE -Werror) + # Enabled asserts even in release mode + add_definitions(-UNDEBUG) +endforeach (testsourcefile ${TEST_CASES}) + diff --git a/tfhe/c_api_tests/test_boolean_keygen.c b/tfhe/c_api_tests/test_boolean_keygen.c new file mode 100644 index 000000000..a8894747a --- /dev/null +++ b/tfhe/c_api_tests/test_boolean_keygen.c @@ -0,0 +1,123 @@ +#include "tfhe.h" +#include +#include +#include +#include +#include + +void test_default_keygen_w_serde(void) { + BooleanClientKey *cks = NULL; + BooleanServerKey *sks = NULL; + BooleanCiphertext *ct = NULL; + Buffer ct_ser_buffer = {.pointer = NULL, .length = 0}; + BooleanCiphertext *deser_ct = NULL; + + int gen_keys_ok = boolean_gen_keys_with_default_parameters(&cks, &sks); + assert(gen_keys_ok == 0); + + int encrypt_ok = boolean_client_key_encrypt(cks, true, &ct); + assert(encrypt_ok == 0); + + int ser_ok = boolean_serialize_ciphertext(ct, &ct_ser_buffer); + assert(ser_ok == 0); + + BufferView deser_view = {.pointer = ct_ser_buffer.pointer, .length = ct_ser_buffer.length}; + + int deser_ok = boolean_deserialize_ciphertext(deser_view, &deser_ct); + assert(deser_ok == 0); + + assert(deser_view.length == ct_ser_buffer.length); + for (size_t idx = 0; idx < deser_view.length; ++idx) { + assert(deser_view.pointer[idx] == ct_ser_buffer.pointer[idx]); + } + + bool result = false; + int decrypt_ok = boolean_client_key_decrypt(cks, deser_ct, &result); + assert(decrypt_ok == 0); + + assert(result == true); + + destroy_boolean_client_key(cks); + destroy_boolean_server_key(sks); + destroy_boolean_ciphertext(ct); + destroy_boolean_ciphertext(deser_ct); + destroy_buffer(&ct_ser_buffer); +} + +void test_predefined_keygen_w_serde(void) { + BooleanClientKey *cks = NULL; + BooleanServerKey *sks = NULL; + + int gen_keys_ok = boolean_gen_keys_with_predefined_parameters_set( + BOOLEAN_PARAMETERS_SET_DEFAULT_PARAMETERS, &cks, &sks); + + assert(gen_keys_ok == 0); + + destroy_boolean_client_key(cks); + destroy_boolean_server_key(sks); + + gen_keys_ok = boolean_gen_keys_with_predefined_parameters_set( + BOOLEAN_PARAMETERS_SET_THFE_LIB_PARAMETERS, &cks, &sks); + + assert(gen_keys_ok == 0); + + destroy_boolean_client_key(cks); + destroy_boolean_server_key(sks); +} + +void test_custom_keygen(void) { + BooleanClientKey *cks = NULL; + BooleanServerKey *sks = NULL; + BooleanParameters *params = NULL; + + int params_ok = boolean_create_parameters(10, 1, 1024, 10e-100, 10e-100, 3, 1, 4, 2, ¶ms); + assert(params_ok == 0); + + int gen_keys_ok = boolean_gen_keys_with_parameters(params, &cks, &sks); + + assert(gen_keys_ok == 0); + + destroy_boolean_parameters(params); + destroy_boolean_client_key(cks); + destroy_boolean_server_key(sks); +} + +void test_public_keygen(void) { + BooleanClientKey *cks = NULL; + BooleanPublicKey *pks = NULL; + BooleanParameters *params = NULL; + BooleanCiphertext *ct = NULL; + + int get_params_ok = boolean_get_parameters(BOOLEAN_PARAMETERS_SET_DEFAULT_PARAMETERS, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = boolean_gen_client_key(params, &cks); + assert(gen_keys_ok == 0); + + int gen_pks = boolean_gen_public_key(cks, &pks); + assert(gen_pks == 0); + + bool msg = true; + + int encrypt_ok = boolean_public_key_encrypt(pks, msg, &ct); + assert(encrypt_ok == 0); + + bool result = false; + int decrypt_ok = boolean_client_key_decrypt(cks, ct, &result); + assert(decrypt_ok == 0); + + assert(result == true); + + destroy_boolean_parameters(params); + destroy_boolean_client_key(cks); + destroy_boolean_public_key(pks); + destroy_boolean_ciphertext(ct); +} + +int main(void) { + test_default_keygen_w_serde(); + test_predefined_keygen_w_serde(); + test_custom_keygen(); + test_public_keygen(); + return EXIT_SUCCESS; +} diff --git a/tfhe/c_api_tests/test_boolean_server_key.c b/tfhe/c_api_tests/test_boolean_server_key.c new file mode 100644 index 000000000..34ee4272e --- /dev/null +++ b/tfhe/c_api_tests/test_boolean_server_key.c @@ -0,0 +1,403 @@ +#include "tfhe.h" +#include +#include +#include +#include +#include + +void test_binary_boolean_function(BooleanClientKey *cks, BooleanServerKey *sks, + bool (*c_fun)(bool, bool), + int (*api_fun)(const BooleanServerKey *, + const BooleanCiphertext *, + const BooleanCiphertext *, BooleanCiphertext **)) { + for (int idx_left_trivial = 0; idx_left_trivial < 2; ++idx_left_trivial) { + for (int idx_right_trivial = 0; idx_right_trivial < 2; ++idx_right_trivial) { + for (int idx_left = 0; idx_left < 2; ++idx_left) { + for (int idx_right = 0; idx_right < 2; ++idx_right) { + BooleanCiphertext *ct_left = NULL; + BooleanCiphertext *ct_right = NULL; + BooleanCiphertext *ct_result = NULL; + + bool left = (bool)idx_left; + bool right = (bool)idx_right; + bool left_trivial = (bool)idx_left_trivial; + bool right_trivial = (bool)idx_right_trivial; + + bool expected = c_fun(left, right); + + if (left_trivial) { + int encrypt_left_ok = boolean_trivial_encrypt(left, &ct_left); + assert(encrypt_left_ok == 0); + } else { + int encrypt_left_ok = boolean_client_key_encrypt(cks, left, &ct_left); + assert(encrypt_left_ok == 0); + } + + if (right_trivial) { + int encrypt_left_ok = boolean_trivial_encrypt(right, &ct_right); + assert(encrypt_left_ok == 0); + } else { + int encrypt_right_ok = boolean_client_key_encrypt(cks, right, &ct_right); + assert(encrypt_right_ok == 0); + } + + int api_call_ok = api_fun(sks, ct_left, ct_right, &ct_result); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_left); + destroy_boolean_ciphertext(ct_right); + destroy_boolean_ciphertext(ct_result); + } + } + } + } +} + +void test_binary_boolean_function_assign( + BooleanClientKey *cks, BooleanServerKey *sks, bool (*c_fun)(bool, bool), + int (*api_fun)(const BooleanServerKey *, BooleanCiphertext *, const BooleanCiphertext *)) { + for (int idx_left_trivial = 0; idx_left_trivial < 2; ++idx_left_trivial) { + for (int idx_right_trivial = 0; idx_right_trivial < 2; ++idx_right_trivial) { + for (int idx_left = 0; idx_left < 2; ++idx_left) { + for (int idx_right = 0; idx_right < 2; ++idx_right) { + BooleanCiphertext *ct_left_and_result = NULL; + BooleanCiphertext *ct_right = NULL; + + bool left = (bool)idx_left; + bool right = (bool)idx_right; + bool left_trivial = (bool)idx_left_trivial; + bool right_trivial = (bool)idx_right_trivial; + + bool expected = c_fun(left, right); + + if (left_trivial) { + int encrypt_left_ok = boolean_trivial_encrypt(left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + } else { + int encrypt_left_ok = boolean_client_key_encrypt(cks, left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + } + + if (right_trivial) { + int encrypt_left_ok = boolean_trivial_encrypt(right, &ct_right); + assert(encrypt_left_ok == 0); + } else { + int encrypt_right_ok = boolean_client_key_encrypt(cks, right, &ct_right); + assert(encrypt_right_ok == 0); + } + + int api_call_ok = api_fun(sks, ct_left_and_result, ct_right); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_left_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_left_and_result); + destroy_boolean_ciphertext(ct_right); + } + } + } + } +} + +void test_binary_boolean_function_scalar(BooleanClientKey *cks, BooleanServerKey *sks, + bool (*c_fun)(bool, bool), + int (*api_fun)(const BooleanServerKey *, + const BooleanCiphertext *, bool, + BooleanCiphertext **)) { + for (int idx_left = 0; idx_left < 2; ++idx_left) { + for (int idx_right = 0; idx_right < 2; ++idx_right) { + BooleanCiphertext *ct_left = NULL; + BooleanCiphertext *ct_result = NULL; + + bool left = (bool)idx_left; + bool right = (bool)idx_right; + + bool expected = c_fun(left, right); + + int encrypt_left_ok = boolean_client_key_encrypt(cks, left, &ct_left); + assert(encrypt_left_ok == 0); + + int api_call_ok = api_fun(sks, ct_left, right, &ct_result); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_left); + destroy_boolean_ciphertext(ct_result); + } + } +} + +void test_binary_boolean_function_scalar_assign(BooleanClientKey *cks, BooleanServerKey *sks, + bool (*c_fun)(bool, bool), + int (*api_fun)(const BooleanServerKey *, + BooleanCiphertext *, bool)) { + for (int idx_left = 0; idx_left < 2; ++idx_left) { + for (int idx_right = 0; idx_right < 2; ++idx_right) { + BooleanCiphertext *ct_left_and_result = NULL; + + bool left = (bool)idx_left; + bool right = (bool)idx_right; + + bool expected = c_fun(left, right); + + int encrypt_left_ok = boolean_client_key_encrypt(cks, left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + + int api_call_ok = api_fun(sks, ct_left_and_result, right); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_left_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_left_and_result); + } + } +} + +void test_not(BooleanClientKey *cks, BooleanServerKey *sks) { + for (int idx_in_trivial = 0; idx_in_trivial < 2; ++idx_in_trivial) { + for (int idx_in = 0; idx_in < 2; ++idx_in) { + BooleanCiphertext *ct_in = NULL; + BooleanCiphertext *ct_result = NULL; + + bool in = (bool)idx_in; + bool in_trivial = (bool)idx_in_trivial; + + bool expected = !in; + + if (in_trivial) { + int encrypt_in_ok = boolean_trivial_encrypt(in, &ct_in); + assert(encrypt_in_ok == 0); + } else { + int encrypt_in_ok = boolean_client_key_encrypt(cks, in, &ct_in); + assert(encrypt_in_ok == 0); + } + + int api_call_ok = boolean_server_key_not(sks, ct_in, &ct_result); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_in); + destroy_boolean_ciphertext(ct_result); + } + } +} + +void test_not_assign(BooleanClientKey *cks, BooleanServerKey *sks) { + for (int idx_in_trivial = 0; idx_in_trivial < 2; ++idx_in_trivial) { + for (int idx_in = 0; idx_in < 2; ++idx_in) { + BooleanCiphertext *ct_in_and_result = NULL; + + bool in = (bool)idx_in; + bool in_trivial = (bool)idx_in_trivial; + + bool expected = !in; + + if (in_trivial) { + int encrypt_in_ok = boolean_trivial_encrypt(in, &ct_in_and_result); + assert(encrypt_in_ok == 0); + } else { + int encrypt_in_ok = boolean_client_key_encrypt(cks, in, &ct_in_and_result); + assert(encrypt_in_ok == 0); + } + + int api_call_ok = boolean_server_key_not_assign(sks, ct_in_and_result); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_in_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_in_and_result); + } + } +} + +void test_mux(BooleanClientKey *cks, BooleanServerKey *sks) { + for (int idx_cond_trivial = 0; idx_cond_trivial < 2; ++idx_cond_trivial) { + for (int idx_then_trivial = 0; idx_then_trivial < 2; ++idx_then_trivial) { + for (int idx_else_trivial = 0; idx_else_trivial < 2; ++idx_else_trivial) { + for (int idx_condition = 0; idx_condition < 2; ++idx_condition) { + for (int idx_then = 0; idx_then < 2; ++idx_then) { + for (int idx_else = 0; idx_else < 2; ++idx_else) { + BooleanCiphertext *ct_cond = NULL; + BooleanCiphertext *ct_then = NULL; + BooleanCiphertext *ct_else = NULL; + BooleanCiphertext *ct_result = NULL; + + bool cond = (bool)idx_else; + bool then = (bool)idx_then; + bool else_ = (bool)idx_else; + bool cond_trivial = (bool)idx_cond_trivial; + bool then_trivial = (bool)idx_then_trivial; + bool else_trivial = (bool)idx_else_trivial; + + bool expected = else_; + if (cond) { + expected = then; + } + + if (cond_trivial) { + int encrypt_cond_ok = boolean_trivial_encrypt(cond, &ct_cond); + assert(encrypt_cond_ok == 0); + } else { + int encrypt_cond_ok = boolean_client_key_encrypt(cks, cond, &ct_cond); + assert(encrypt_cond_ok == 0); + } + if (then_trivial) { + int encrypt_then_ok = boolean_trivial_encrypt(then, &ct_then); + assert(encrypt_then_ok == 0); + } else { + int encrypt_then_ok = boolean_client_key_encrypt(cks, then, &ct_then); + assert(encrypt_then_ok == 0); + } + if (else_trivial) { + int encrypt_else_ok = boolean_trivial_encrypt(else_, &ct_else); + assert(encrypt_else_ok == 0); + } else { + int encrypt_else_ok = boolean_client_key_encrypt(cks, else_, &ct_else); + assert(encrypt_else_ok == 0); + } + + int api_call_ok = boolean_server_key_mux(sks, ct_cond, ct_then, ct_else, &ct_result); + assert(api_call_ok == 0); + + bool decrypted_result = false; + + int decrypt_ok = boolean_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_boolean_ciphertext(ct_cond); + destroy_boolean_ciphertext(ct_then); + destroy_boolean_ciphertext(ct_else); + destroy_boolean_ciphertext(ct_result); + } + } + } + } + } + } +} + +bool c_and(bool left, bool right) { return left && right; } + +bool c_nand(bool left, bool right) { return !c_and(left, right); } + +bool c_or(bool left, bool right) { return left || right; } + +bool c_nor(bool left, bool right) { return !c_or(left, right); } + +bool c_xor(bool left, bool right) { return left != right; } + +bool c_xnor(bool left, bool right) { return !c_xor(left, right); } + +void test_server_key(void) { + BooleanClientKey *cks = NULL; + BooleanServerKey *sks = NULL; + Buffer cks_ser_buffer = {.pointer = NULL, .length = 0}; + BooleanClientKey *deser_cks = NULL; + Buffer sks_ser_buffer = {.pointer = NULL, .length = 0}; + BooleanServerKey *deser_sks = NULL; + + int gen_keys_ok = boolean_gen_keys_with_default_parameters(&cks, &sks); + assert(gen_keys_ok == 0); + + int ser_cks_ok = boolean_serialize_client_key(cks, &cks_ser_buffer); + assert(ser_cks_ok == 0); + + BufferView deser_view = {.pointer = cks_ser_buffer.pointer, .length = cks_ser_buffer.length}; + + int deser_cks_ok = boolean_deserialize_client_key(deser_view, &deser_cks); + assert(deser_cks_ok == 0); + + int ser_sks_ok = boolean_serialize_server_key(sks, &sks_ser_buffer); + assert(ser_sks_ok == 0); + + deser_view.pointer = sks_ser_buffer.pointer; + deser_view.length = sks_ser_buffer.length; + + int deser_sks_ok = boolean_deserialize_server_key(deser_view, &deser_sks); + assert(deser_sks_ok == 0); + + test_binary_boolean_function(deser_cks, deser_sks, c_and, boolean_server_key_and); + test_binary_boolean_function(deser_cks, deser_sks, c_nand, boolean_server_key_nand); + test_binary_boolean_function(deser_cks, deser_sks, c_or, boolean_server_key_or); + test_binary_boolean_function(deser_cks, deser_sks, c_nor, boolean_server_key_nor); + test_binary_boolean_function(deser_cks, deser_sks, c_xor, boolean_server_key_xor); + test_binary_boolean_function(deser_cks, deser_sks, c_xnor, boolean_server_key_xnor); + test_not(deser_cks, deser_sks); + test_mux(deser_cks, deser_sks); + + test_binary_boolean_function_assign(deser_cks, deser_sks, c_and, boolean_server_key_and_assign); + test_binary_boolean_function_assign(deser_cks, deser_sks, c_nand, boolean_server_key_nand_assign); + test_binary_boolean_function_assign(deser_cks, deser_sks, c_or, boolean_server_key_or_assign); + test_binary_boolean_function_assign(deser_cks, deser_sks, c_nor, boolean_server_key_nor_assign); + test_binary_boolean_function_assign(deser_cks, deser_sks, c_xor, boolean_server_key_xor_assign); + test_binary_boolean_function_assign(deser_cks, deser_sks, c_xnor, boolean_server_key_xnor_assign); + test_not_assign(deser_cks, deser_sks); + + test_binary_boolean_function_scalar(deser_cks, deser_sks, c_and, boolean_server_key_and_scalar); + test_binary_boolean_function_scalar(deser_cks, deser_sks, c_nand, boolean_server_key_nand_scalar); + test_binary_boolean_function_scalar(deser_cks, deser_sks, c_or, boolean_server_key_or_scalar); + test_binary_boolean_function_scalar(deser_cks, deser_sks, c_nor, boolean_server_key_nor_scalar); + test_binary_boolean_function_scalar(deser_cks, deser_sks, c_xor, boolean_server_key_xor_scalar); + test_binary_boolean_function_scalar(deser_cks, deser_sks, c_xnor, boolean_server_key_xnor_scalar); + + test_binary_boolean_function_scalar_assign(deser_cks, deser_sks, c_and, + boolean_server_key_and_scalar_assign); + test_binary_boolean_function_scalar_assign(deser_cks, deser_sks, c_nand, + boolean_server_key_nand_scalar_assign); + test_binary_boolean_function_scalar_assign(deser_cks, deser_sks, c_or, + boolean_server_key_or_scalar_assign); + test_binary_boolean_function_scalar_assign(deser_cks, deser_sks, c_nor, + boolean_server_key_nor_scalar_assign); + test_binary_boolean_function_scalar_assign(deser_cks, deser_sks, c_xor, + boolean_server_key_xor_scalar_assign); + test_binary_boolean_function_scalar_assign(deser_cks, deser_sks, c_xnor, + boolean_server_key_xnor_scalar_assign); + + destroy_boolean_client_key(cks); + destroy_boolean_server_key(sks); + destroy_boolean_client_key(deser_cks); + destroy_boolean_server_key(deser_sks); + destroy_buffer(&cks_ser_buffer); + destroy_buffer(&sks_ser_buffer); +} + +int main(void) { + test_server_key(); + return EXIT_SUCCESS; +} diff --git a/tfhe/c_api_tests/test_micro_bench_and.c b/tfhe/c_api_tests/test_micro_bench_and.c new file mode 100644 index 000000000..d2586ee5b --- /dev/null +++ b/tfhe/c_api_tests/test_micro_bench_and.c @@ -0,0 +1,51 @@ +#include "tfhe.h" +#include +#include +#include +#include +#include +#include + +void micro_bench_and() { + BooleanClientKey *cks = NULL; + BooleanServerKey *sks = NULL; + + // int gen_keys_ok = boolean_gen_keys_with_default_parameters(&cks, &sks); + // assert(gen_keys_ok == 0); + + int gen_keys_ok = boolean_gen_keys_with_predefined_parameters_set( + BOOLEAN_PARAMETERS_SET_THFE_LIB_PARAMETERS, &cks, &sks); + assert(gen_keys_ok == 0); + + int num_loops = 10000; + + BooleanCiphertext *ct_left = NULL; + BooleanCiphertext *ct_right = NULL; + + int encrypt_left_ok = boolean_client_key_encrypt(cks, false, &ct_left); + assert(encrypt_left_ok == 0); + int encrypt_right_ok = boolean_client_key_encrypt(cks, true, &ct_right); + assert(encrypt_right_ok == 0); + + clock_t start = clock(); + + for (int idx_loops = 0; idx_loops < num_loops; ++idx_loops) { + BooleanCiphertext *ct_result = NULL; + boolean_server_key_and(sks, ct_left, ct_right, &ct_result); + destroy_boolean_ciphertext(ct_result); + } + + clock_t stop = clock(); + double elapsed_ms = (double)((stop - start) * 1000) / CLOCKS_PER_SEC; + double mean_ms = elapsed_ms / num_loops; + + printf("%g ms, mean %g ms\n", elapsed_ms, mean_ms); + + destroy_boolean_client_key(cks); + destroy_boolean_server_key(sks); +} + +int main(void) { + micro_bench_and(); + return EXIT_SUCCESS; +} diff --git a/tfhe/c_api_tests/test_shortint_keygen.c b/tfhe/c_api_tests/test_shortint_keygen.c new file mode 100644 index 000000000..f95e4c377 --- /dev/null +++ b/tfhe/c_api_tests/test_shortint_keygen.c @@ -0,0 +1,112 @@ +#include "tfhe.h" +#include +#include +#include +#include +#include + +void test_predefined_keygen_w_serde(void) { + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintParameters *params = NULL; + ShortintCiphertext *ct = NULL; + Buffer ct_ser_buffer = {.pointer = NULL, .length = 0}; + ShortintCiphertext *deser_ct = NULL; + + int get_params_ok = shortint_get_parameters(2, 2, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + assert(gen_keys_ok == 0); + + int encrypt_ok = shortint_client_key_encrypt(cks, 3, &ct); + assert(encrypt_ok == 0); + + int ser_ok = shortint_serialize_ciphertext(ct, &ct_ser_buffer); + assert(ser_ok == 0); + + BufferView deser_view = {.pointer = ct_ser_buffer.pointer, .length = ct_ser_buffer.length}; + + int deser_ok = shortint_deserialize_ciphertext(deser_view, &deser_ct); + assert(deser_ok == 0); + + assert(deser_view.length == ct_ser_buffer.length); + for (size_t idx = 0; idx < deser_view.length; ++idx) { + assert(deser_view.pointer[idx] == ct_ser_buffer.pointer[idx]); + } + + uint64_t result = -1; + int decrypt_ok = shortint_client_key_decrypt(cks, deser_ct, &result); + assert(decrypt_ok == 0); + + assert(result == 3); + + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_parameters(params); + destroy_shortint_ciphertext(ct); + destroy_shortint_ciphertext(deser_ct); + destroy_buffer(&ct_ser_buffer); +} + +void test_custom_keygen(void) { + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintParameters *params = NULL; + + int params_ok = shortint_create_parameters(10, 1, 1024, 10e-100, 10e-100, 2, 3, 2, 3, 2, 3, + 10e-100, 2, 3, 2, 2, ¶ms); + assert(params_ok == 0); + + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + + assert(gen_keys_ok == 0); + + destroy_shortint_parameters(params); + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); +} + +void test_public_keygen(void) { + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintPublicKey *pks = NULL; + ShortintParameters *params = NULL; + ShortintCiphertext *ct = NULL; + + int get_params_ok = shortint_get_parameters(2, 2, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = shortint_gen_client_key(params, &cks); + assert(gen_keys_ok == 0); + + int gen_pks = shortint_gen_public_key(cks, &pks); + assert(gen_pks == 0); + + int gen_sks = shortint_gen_server_key(cks, &sks); + assert(gen_sks == 0); + + uint64_t msg = 2; + + int encrypt_ok = shortint_public_key_encrypt(pks, sks, msg, &ct); + assert(encrypt_ok == 0); + + uint64_t result = -1; + int decrypt_ok = shortint_client_key_decrypt(cks, ct, &result); + assert(decrypt_ok == 0); + + assert(result == 2); + + destroy_shortint_parameters(params); + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_public_key(pks); + destroy_shortint_ciphertext(ct); +} + +int main(void) { + test_predefined_keygen_w_serde(); + test_custom_keygen(); + test_public_keygen(); + return EXIT_SUCCESS; +} diff --git a/tfhe/c_api_tests/test_shortint_pbs.c b/tfhe/c_api_tests/test_shortint_pbs.c new file mode 100644 index 000000000..e0716b5b5 --- /dev/null +++ b/tfhe/c_api_tests/test_shortint_pbs.c @@ -0,0 +1,197 @@ +#include "tfhe.h" +#include +#include +#include +#include +#include + +uint64_t double_accumulator_2_bits_message(uint64_t in) { return (in * 2) % 4; } + +uint64_t get_max_value_of_accumulator_generator(uint64_t (*accumulator_func)(uint64_t), + size_t message_bits) { + uint64_t max_value = 0; + for (size_t idx = 0; idx < (1 << message_bits); ++idx) { + uint64_t acc_value = accumulator_func((uint64_t)idx); + max_value = acc_value > max_value ? acc_value : max_value; + } + + return max_value; +} + +uint64_t product_accumulator_2_bits_encrypted_mul(uint64_t left, uint64_t right) { + return (left * right) % 4; +} + +uint64_t get_max_value_of_bivariate_accumulator_generator(uint64_t (*accumulator_func)(uint64_t, + uint64_t), + size_t message_bits_left, + size_t message_bits_right) { + uint64_t max_value = 0; + for (size_t idx_left = 0; idx_left < (1 << message_bits_left); ++idx_left) { + for (size_t idx_right = 0; idx_right < (1 << message_bits_right); ++idx_right) { + uint64_t acc_value = accumulator_func((uint64_t)idx_left, (uint64_t)idx_right); + max_value = acc_value > max_value ? acc_value : max_value; + } + } + + return max_value; +} + +void test_shortint_pbs_2_bits_message(void) { + ShortintPBSAccumulator *accumulator = NULL; + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintParameters *params = NULL; + + int get_params_ok = shortint_get_parameters(2, 2, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + assert(gen_keys_ok == 0); + + int gen_acc_ok = shortint_server_key_generate_pbs_accumulator( + sks, double_accumulator_2_bits_message, &accumulator); + assert(gen_acc_ok == 0); + + for (int in_idx = 0; in_idx < 4; ++in_idx) { + ShortintCiphertext *ct = NULL; + ShortintCiphertext *ct_out = NULL; + + uint64_t in_val = (uint64_t)in_idx; + + int encrypt_ok = shortint_client_key_encrypt(cks, in_val, &ct); + assert(encrypt_ok == 0); + + size_t degree = -1; + int get_degree_ok = shortint_ciphertext_get_degree(ct, °ree); + assert(get_degree_ok == 0); + + assert(degree == 3); + + int pbs_ok = shortint_server_key_programmable_bootstrap(sks, accumulator, ct, &ct_out); + assert(pbs_ok == 0); + + size_t degree_to_set = + (size_t)get_max_value_of_accumulator_generator(double_accumulator_2_bits_message, 2); + + int set_degree_ok = shortint_ciphertext_set_degree(ct_out, degree_to_set); + assert(set_degree_ok == 0); + + degree = -1; + get_degree_ok = shortint_ciphertext_get_degree(ct_out, °ree); + assert(get_degree_ok == 0); + + assert(degree == degree_to_set); + + uint64_t result_non_assign = -1; + int decrypt_non_assign_ok = shortint_client_key_decrypt(cks, ct_out, &result_non_assign); + assert(decrypt_non_assign_ok == 0); + + assert(result_non_assign == double_accumulator_2_bits_message(in_val)); + + int pbs_assign_ok = shortint_server_key_programmable_bootstrap_assign(sks, accumulator, ct_out); + assert(pbs_assign_ok == 0); + + degree_to_set = + (size_t)get_max_value_of_accumulator_generator(double_accumulator_2_bits_message, 2); + + set_degree_ok = shortint_ciphertext_set_degree(ct_out, degree_to_set); + assert(set_degree_ok == 0); + + uint64_t result_assign = -1; + int decrypt_assign_ok = shortint_client_key_decrypt(cks, ct_out, &result_assign); + assert(decrypt_assign_ok == 0); + + assert(result_assign == double_accumulator_2_bits_message(result_non_assign)); + + destroy_shortint_ciphertext(ct); + destroy_shortint_ciphertext(ct_out); + } + + destroy_shortint_pbs_accumulator(accumulator); + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_parameters(params); +} + +void test_shortint_bivariate_pbs_2_bits_message(void) { + ShortintBivariatePBSAccumulator *accumulator = NULL; + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintParameters *params = NULL; + + int get_params_ok = shortint_get_parameters(2, 2, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + assert(gen_keys_ok == 0); + + int gen_acc_ok = shortint_server_key_generate_bivariate_pbs_accumulator( + sks, product_accumulator_2_bits_encrypted_mul, &accumulator); + assert(gen_acc_ok == 0); + + for (int left_idx = 0; left_idx < 4; ++left_idx) { + for (int right_idx = 0; right_idx < 4; ++right_idx) { + ShortintCiphertext *ct_left = NULL; + ShortintCiphertext *ct_right = NULL; + ShortintCiphertext *ct_out = NULL; + + uint64_t left_val = (uint64_t)left_idx; + uint64_t right_val = (uint64_t)right_idx; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, left_val, &ct_left); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks, right_val, &ct_right); + assert(encrypt_right_ok == 0); + + int pbs_ok = shortint_server_key_bivariate_programmable_bootstrap(sks, accumulator, ct_left, + ct_right, &ct_out); + assert(pbs_ok == 0); + + size_t degree_to_set = (size_t)get_max_value_of_bivariate_accumulator_generator( + product_accumulator_2_bits_encrypted_mul, 2, 2); + + int set_degree_ok = shortint_ciphertext_set_degree(ct_right, degree_to_set); + assert(set_degree_ok == 0); + + uint64_t result_non_assign = -1; + int decrypt_non_assign_ok = shortint_client_key_decrypt(cks, ct_out, &result_non_assign); + assert(decrypt_non_assign_ok == 0); + + assert(result_non_assign == product_accumulator_2_bits_encrypted_mul(left_val, right_val)); + + int pbs_assign_ok = shortint_server_key_bivariate_programmable_bootstrap_assign( + sks, accumulator, ct_out, ct_right); + assert(pbs_assign_ok == 0); + + degree_to_set = + (size_t)get_max_value_of_accumulator_generator(double_accumulator_2_bits_message, 2); + + set_degree_ok = shortint_ciphertext_set_degree(ct_out, degree_to_set); + assert(set_degree_ok == 0); + + uint64_t result_assign = -1; + int decrypt_assign_ok = shortint_client_key_decrypt(cks, ct_out, &result_assign); + assert(decrypt_assign_ok == 0); + + assert(result_assign == + product_accumulator_2_bits_encrypted_mul(result_non_assign, right_val)); + + destroy_shortint_ciphertext(ct_left); + destroy_shortint_ciphertext(ct_right); + destroy_shortint_ciphertext(ct_out); + } + } + + destroy_shortint_bivariate_pbs_accumulator(accumulator); + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_parameters(params); +} + +int main(void) { + test_shortint_pbs_2_bits_message(); + test_shortint_bivariate_pbs_2_bits_message(); + return EXIT_SUCCESS; +} diff --git a/tfhe/c_api_tests/test_shortint_server_key.c b/tfhe/c_api_tests/test_shortint_server_key.c new file mode 100644 index 000000000..da3f3b5c5 --- /dev/null +++ b/tfhe/c_api_tests/test_shortint_server_key.c @@ -0,0 +1,557 @@ +#include "tfhe.h" +#include +#include +#include +#include +#include + +void test_shortint_unary_op(const ShortintClientKey *cks, const ShortintServerKey *sks, + const uint32_t message_bits, const uint32_t carry_bits, + uint64_t (*c_fun)(uint64_t), + int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, + ShortintCiphertext **)) { + + int message_max = 1 << message_bits; + + for (int val_in = 0; val_in < message_max; ++val_in) { + ShortintCiphertext *ct_in = NULL; + ShortintCiphertext *ct_result = NULL; + + uint64_t in = (uint64_t)val_in; + + uint64_t expected = c_fun(in) % message_max; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, in, &ct_in); + assert(encrypt_left_ok == 0); + + int api_call_ok = api_fun(sks, ct_in, &ct_result); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_shortint_ciphertext(ct_in); + destroy_shortint_ciphertext(ct_result); + } +} + +void test_shortint_unary_op_assign(const ShortintClientKey *cks, const ShortintServerKey *sks, + const uint32_t message_bits, const uint32_t carry_bits, + uint64_t (*c_fun)(uint64_t), + int (*api_fun)(const ShortintServerKey *, + ShortintCiphertext *)) { + + int message_max = 1 << message_bits; + + for (int in = 0; in < message_max; ++in) { + ShortintCiphertext *ct_in_and_result = NULL; + + uint64_t in = (uint64_t)in; + + uint64_t expected = c_fun(in) % message_max; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, in, &ct_in_and_result); + assert(encrypt_left_ok == 0); + + int api_call_ok = api_fun(sks, ct_in_and_result); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks, ct_in_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_shortint_ciphertext(ct_in_and_result); + } +} + +void test_shortint_binary_op(const ShortintClientKey *cks, const ShortintServerKey *sks, + const uint32_t message_bits, const uint32_t carry_bits, + uint64_t (*c_fun)(uint64_t, uint64_t), + int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, + ShortintCiphertext *, ShortintCiphertext **)) { + + int message_max = 1 << message_bits; + + for (int val_left = 0; val_left < message_max; ++val_left) { + for (int val_right = 0; val_right < message_max; ++val_right) { + ShortintCiphertext *ct_left = NULL; + ShortintCiphertext *ct_right = NULL; + ShortintCiphertext *ct_result = NULL; + + uint64_t left = (uint64_t)val_left; + uint64_t right = (uint64_t)val_right; + + uint64_t expected = c_fun(left, right) % message_max; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, left, &ct_left); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks, right, &ct_right); + assert(encrypt_right_ok == 0); + + int api_call_ok = api_fun(sks, ct_left, ct_right, &ct_result); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_shortint_ciphertext(ct_left); + destroy_shortint_ciphertext(ct_right); + destroy_shortint_ciphertext(ct_result); + } + } +} + +void test_shortint_binary_op_assign(const ShortintClientKey *cks, const ShortintServerKey *sks, + const uint32_t message_bits, const uint32_t carry_bits, + uint64_t (*c_fun)(uint64_t, uint64_t), + int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, + ShortintCiphertext *)) { + + int message_max = 1 << message_bits; + + for (int val_left = 0; val_left < message_max; ++val_left) { + for (int val_right = 0; val_right < message_max; ++val_right) { + ShortintCiphertext *ct_left_and_result = NULL; + ShortintCiphertext *ct_right = NULL; + + uint64_t left = (uint64_t)val_left; + uint64_t right = (uint64_t)val_right; + + uint64_t expected = c_fun(left, right) % message_max; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + + int encrypt_right_ok = shortint_client_key_encrypt(cks, right, &ct_right); + assert(encrypt_right_ok == 0); + + int api_call_ok = api_fun(sks, ct_left_and_result, ct_right); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks, ct_left_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_shortint_ciphertext(ct_left_and_result); + destroy_shortint_ciphertext(ct_right); + } + } +} + +void test_shortint_binary_scalar_op( + const ShortintClientKey *cks, const ShortintServerKey *sks, const uint32_t message_bits, + const uint32_t carry_bits, uint64_t (*c_fun)(uint64_t, uint8_t), + int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, uint8_t, ShortintCiphertext **), + uint8_t forbidden_scalar_values[], size_t forbidden_scalar_values_len) { + + int message_max = 1 << message_bits; + + for (int val_left = 0; val_left < message_max; ++val_left) { + for (int val_right = 0; val_right < message_max; ++val_right) { + ShortintCiphertext *ct_left = NULL; + ShortintCiphertext *ct_result = NULL; + + uint64_t left = (uint64_t)val_left; + uint8_t scalar_right = (uint8_t)val_right; + + if (forbidden_scalar_values != NULL) { + bool found_forbidden_value = false; + for (int idx = 0; idx < forbidden_scalar_values_len; ++idx) { + if (forbidden_scalar_values[idx] == scalar_right) { + found_forbidden_value = true; + break; + } + } + + if (found_forbidden_value) { + continue; + } + } + + uint64_t expected = c_fun(left, scalar_right) % message_max; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, left, &ct_left); + assert(encrypt_left_ok == 0); + + int api_call_ok = api_fun(sks, ct_left, scalar_right, &ct_result); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks, ct_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_shortint_ciphertext(ct_left); + destroy_shortint_ciphertext(ct_result); + } + } +} + +void test_shortint_binary_scalar_op_assign( + const ShortintClientKey *cks, const ShortintServerKey *sks, const uint32_t message_bits, + const uint32_t carry_bits, uint64_t (*c_fun)(uint64_t, uint8_t), + int (*api_fun)(const ShortintServerKey *, ShortintCiphertext *, uint8_t), + uint8_t forbidden_scalar_values[], size_t forbidden_scalar_values_len) { + + int message_max = 1 << message_bits; + + for (int val_left = 0; val_left < message_max; ++val_left) { + for (int val_right = 0; val_right < message_max; ++val_right) { + ShortintCiphertext *ct_left_and_result = NULL; + + uint64_t left = (uint64_t)val_left; + uint8_t scalar_right = (uint8_t)val_right; + + if (forbidden_scalar_values != NULL) { + bool found_forbidden_value = false; + for (int idx = 0; idx < forbidden_scalar_values_len; ++idx) { + if (forbidden_scalar_values[idx] == scalar_right) { + found_forbidden_value = true; + break; + } + } + + if (found_forbidden_value) { + continue; + } + } + + uint64_t expected = c_fun(left, scalar_right) % message_max; + + int encrypt_left_ok = shortint_client_key_encrypt(cks, left, &ct_left_and_result); + assert(encrypt_left_ok == 0); + + int api_call_ok = api_fun(sks, ct_left_and_result, scalar_right); + assert(api_call_ok == 0); + + uint64_t decrypted_result = -1; + + int decrypt_ok = shortint_client_key_decrypt(cks, ct_left_and_result, &decrypted_result); + assert(decrypt_ok == 0); + + assert(decrypted_result == expected); + + destroy_shortint_ciphertext(ct_left_and_result); + } + } +} + +uint64_t add(uint64_t left, uint64_t right) { return left + right; } +uint64_t sub(uint64_t left, uint64_t right) { return left - right; } +uint64_t mul(uint64_t left, uint64_t right) { return left * right; } +uint64_t neg(uint64_t in) { return -in; } + +uint64_t homomorphic_div(uint64_t left, uint64_t right) { + if (right != 0) { + return left / right; + } else { + // Special value chosen in the shortint implementation in case of a division by 0 + return 0; + } +} + +uint64_t bitand(uint64_t left, uint64_t right) { return left & right; } +uint64_t bitxor(uint64_t left, uint64_t right) { return left ^ right; } +uint64_t bitor (uint64_t left, uint64_t right) { return left | right; } + +uint64_t greater(uint64_t left, uint64_t right) { return (uint64_t)(left > right); } +uint64_t greater_or_equal(uint64_t left, uint64_t right) { return (uint64_t)(left >= right); } +uint64_t less(uint64_t left, uint64_t right) { return (uint64_t)(left < right); } +uint64_t less_or_equal(uint64_t left, uint64_t right) { return (uint64_t)(left <= right); } +uint64_t equal(uint64_t left, uint64_t right) { return (uint64_t)(left == right); } +uint64_t not_equal(uint64_t left, uint64_t right) { return (uint64_t)(left != right); } + +uint64_t scalar_greater(uint64_t left, uint8_t right) { return (uint64_t)(left > right); } +uint64_t scalar_greater_or_equal(uint64_t left, uint8_t right) { return (uint64_t)(left >= right); } +uint64_t scalar_less(uint64_t left, uint8_t right) { return (uint64_t)(left < right); } +uint64_t scalar_less_or_equal(uint64_t left, uint8_t right) { return (uint64_t)(left <= right); } +uint64_t scalar_equal(uint64_t left, uint8_t right) { return (uint64_t)(left == right); } +uint64_t scalar_not_equal(uint64_t left, uint8_t right) { return (uint64_t)(left != right); } + +uint64_t scalar_add(uint64_t left, uint8_t right) { return left + right; } +uint64_t scalar_sub(uint64_t left, uint8_t right) { return left - right; } +uint64_t scalar_mul(uint64_t left, uint8_t right) { return left * right; } +uint64_t scalar_div(uint64_t left, uint8_t right) { return left / right; } +uint64_t scalar_mod(uint64_t left, uint8_t right) { return left % right; } + +uint64_t left_shift(uint64_t left, uint8_t right) { return left << right; } +uint64_t right_shift(uint64_t left, uint8_t right) { return left >> right; } + +void test_server_key(void) { + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + Buffer cks_ser_buffer = {.pointer = NULL, .length = 0}; + ShortintClientKey *deser_cks = NULL; + Buffer sks_ser_buffer = {.pointer = NULL, .length = 0}; + ShortintServerKey *deser_sks = NULL; + ShortintParameters *params = NULL; + + const uint32_t message_bits = 2; + const uint32_t carry_bits = 2; + + int get_params_ok = shortint_get_parameters(message_bits, carry_bits, ¶ms); + assert(get_params_ok == 0); + + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + assert(gen_keys_ok == 0); + + int ser_cks_ok = shortint_serialize_client_key(cks, &cks_ser_buffer); + assert(ser_cks_ok == 0); + + BufferView deser_view = {.pointer = cks_ser_buffer.pointer, .length = cks_ser_buffer.length}; + + int deser_cks_ok = shortint_deserialize_client_key(deser_view, &deser_cks); + assert(deser_cks_ok == 0); + + int ser_sks_ok = shortint_serialize_server_key(sks, &sks_ser_buffer); + assert(ser_sks_ok == 0); + + deser_view.pointer = sks_ser_buffer.pointer; + deser_view.length = sks_ser_buffer.length; + + int deser_sks_ok = shortint_deserialize_server_key(deser_view, &deser_sks); + assert(deser_sks_ok == 0); + + printf("add\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, add, + shortint_server_key_smart_add); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, add, + shortint_server_key_unchecked_add); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, add, + shortint_server_key_smart_add_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, add, + shortint_server_key_unchecked_add_assign); + + printf("sub\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, sub, + shortint_server_key_smart_sub); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, sub, + shortint_server_key_unchecked_sub); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, sub, + shortint_server_key_smart_sub_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, sub, + shortint_server_key_unchecked_sub_assign); + + printf("mul\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, mul, + shortint_server_key_smart_mul); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, mul, + shortint_server_key_unchecked_mul); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, mul, + shortint_server_key_smart_mul_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, mul, + shortint_server_key_unchecked_mul_assign); + + printf("left_shift\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, left_shift, + shortint_server_key_smart_scalar_left_shift, NULL, 0); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, left_shift, + shortint_server_key_unchecked_scalar_left_shift, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, left_shift, + shortint_server_key_smart_scalar_left_shift_assign, NULL, + 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, left_shift, + shortint_server_key_unchecked_scalar_left_shift_assign, + NULL, 0); + + printf("right_shift\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, right_shift, + shortint_server_key_smart_scalar_right_shift, NULL, 0); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, right_shift, + shortint_server_key_unchecked_scalar_right_shift, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, right_shift, + shortint_server_key_smart_scalar_right_shift_assign, NULL, + 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, right_shift, + shortint_server_key_unchecked_scalar_right_shift_assign, + NULL, 0); + + printf("scalar_add\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_add, + shortint_server_key_smart_scalar_add, NULL, 0); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_add, + shortint_server_key_unchecked_scalar_add, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_add, + shortint_server_key_smart_scalar_add_assign, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_add, + shortint_server_key_unchecked_scalar_add_assign, NULL, 0); + + printf("scalar_sub\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_sub, + shortint_server_key_smart_scalar_sub, NULL, 0); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_sub, + shortint_server_key_unchecked_scalar_sub, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_sub, + shortint_server_key_smart_scalar_sub_assign, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_sub, + shortint_server_key_unchecked_scalar_sub_assign, NULL, 0); + + printf("scalar_mul\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_mul, + shortint_server_key_smart_scalar_mul, NULL, 0); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_mul, + shortint_server_key_unchecked_scalar_mul, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_mul, + shortint_server_key_smart_scalar_mul_assign, NULL, 0); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_mul, + shortint_server_key_unchecked_scalar_mul_assign, NULL, 0); + + printf("bitand\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, bitand, + shortint_server_key_smart_bitand); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, bitand, + shortint_server_key_unchecked_bitand); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, bitand, + shortint_server_key_smart_bitand_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, bitand, + shortint_server_key_unchecked_bitand_assign); + + printf("bitxor\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, bitxor, + shortint_server_key_smart_bitxor); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, bitxor, + shortint_server_key_unchecked_bitxor); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, bitxor, + shortint_server_key_smart_bitxor_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, bitxor, + shortint_server_key_unchecked_bitxor_assign); + + printf("bitor\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, bitor, + shortint_server_key_smart_bitor); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, bitor, + shortint_server_key_unchecked_bitor); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, bitor, + shortint_server_key_smart_bitor_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, bitor, + shortint_server_key_unchecked_bitor_assign); + + printf("greater\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, greater, + shortint_server_key_smart_greater); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, greater, + shortint_server_key_unchecked_greater); + + printf("greater_or_equal\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, greater_or_equal, + shortint_server_key_smart_greater_or_equal); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, greater_or_equal, + shortint_server_key_unchecked_greater_or_equal); + + printf("less\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, less, + shortint_server_key_smart_less); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, less, + shortint_server_key_unchecked_less); + + printf("less_or_equal\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, less_or_equal, + shortint_server_key_smart_less_or_equal); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, less_or_equal, + shortint_server_key_unchecked_less_or_equal); + + printf("equal\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, equal, + shortint_server_key_smart_equal); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, equal, + shortint_server_key_unchecked_equal); + + printf("not_equal\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, not_equal, + shortint_server_key_smart_not_equal); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, not_equal, + shortint_server_key_unchecked_not_equal); + + printf("scalar_greater\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_greater, + shortint_server_key_smart_scalar_greater, NULL, 0); + + printf("scalar_greater_or_equal\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, + scalar_greater_or_equal, + shortint_server_key_smart_scalar_greater_or_equal, NULL, 0); + + printf("scalar_less\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_less, + shortint_server_key_smart_scalar_less, NULL, 0); + + printf("scalar_less_or_equal\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, + scalar_less_or_equal, + shortint_server_key_smart_scalar_less_or_equal, NULL, 0); + + printf("scalar_equal\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_equal, + shortint_server_key_smart_scalar_equal, NULL, 0); + + printf("scalar_not_equal\n"); + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_not_equal, + shortint_server_key_smart_scalar_not_equal, NULL, 0); + + printf("neg\n"); + test_shortint_unary_op(deser_cks, deser_sks, message_bits, carry_bits, neg, + shortint_server_key_smart_neg); + test_shortint_unary_op(deser_cks, deser_sks, message_bits, carry_bits, neg, + shortint_server_key_unchecked_neg); + test_shortint_unary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, neg, + shortint_server_key_smart_neg_assign); + test_shortint_unary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, neg, + shortint_server_key_unchecked_neg_assign); + + printf("div\n"); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, homomorphic_div, + shortint_server_key_smart_div); + test_shortint_binary_op(deser_cks, deser_sks, message_bits, carry_bits, homomorphic_div, + shortint_server_key_unchecked_div); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, homomorphic_div, + shortint_server_key_smart_div_assign); + test_shortint_binary_op_assign(deser_cks, deser_sks, message_bits, carry_bits, homomorphic_div, + shortint_server_key_unchecked_div_assign); + + printf("scalar_div\n"); + uint8_t forbidden_scalar_div_values[1] = {0}; + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_div, + shortint_server_key_unchecked_scalar_div, + forbidden_scalar_div_values, 1); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_div, + shortint_server_key_unchecked_scalar_div_assign, + forbidden_scalar_div_values, 1); + printf("scalar_mod\n"); + uint8_t forbidden_scalar_mod_values[1] = {0}; + test_shortint_binary_scalar_op(deser_cks, deser_sks, message_bits, carry_bits, scalar_mod, + shortint_server_key_unchecked_scalar_mod, + forbidden_scalar_mod_values, 1); + test_shortint_binary_scalar_op_assign(deser_cks, deser_sks, message_bits, carry_bits, scalar_mod, + shortint_server_key_unchecked_scalar_mod_assign, + forbidden_scalar_mod_values, 1); + + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_client_key(deser_cks); + destroy_shortint_server_key(deser_sks); + destroy_shortint_parameters(params); + destroy_buffer(&cks_ser_buffer); + destroy_buffer(&sks_ser_buffer); +} + +int main(void) { + test_server_key(); + return EXIT_SUCCESS; +} diff --git a/tfhe/cbindgen.toml b/tfhe/cbindgen.toml new file mode 100644 index 000000000..deefe789a --- /dev/null +++ b/tfhe/cbindgen.toml @@ -0,0 +1,129 @@ +# This is a template cbindgen.toml file with all of the default values. +# Some values are commented out because their absence is the real default. +# +# See https://github.com/eqrion/cbindgen/blob/master/docs.md#cbindgentoml +# for detailed documentation of every option here. + + +language = "C" + + +############## Options for Wrapping the Contents of the Header ################# + +header = "// Copyright © 2022 ZAMA.\n// All rights reserved." +# trailer = "/* Text to put at the end of the generated file */" +include_guard = "TFHE_RS_C_API_H" +# pragma_once = true +autogen_warning = "// Warning, this file is autogenerated by cbindgen. Do not modify this manually." +include_version = false +#namespace = "tfhe_rs_c_api" +namespaces = [] +using_namespaces = [] +sys_includes = [] +includes = [] +no_includes = false +cpp_compat = true +after_includes = "" + + +############################ Code Style Options ################################ + +braces = "SameLine" +line_length = 100 +tab_width = 2 +documentation = false +documentation_style = "auto" +line_endings = "LF" # also "CR", "CRLF", "Native" + + +############################# Codegen Options ################################## + +style = "both" +sort_by = "Name" # default for `fn.sort_by` and `const.sort_by` +usize_is_size_t = true + + +[defines] +# "target_os = freebsd" = "DEFINE_FREEBSD" +# "feature = serde" = "DEFINE_SERDE" + + +[export] +include = ["BooleanParametersSet"] +exclude = [] +#prefix = "CAPI_" +item_types = [] +renaming_overrides_prefixing = false + + +[export.rename] + + +[export.body] + + +[export.mangle] + + +[fn] +rename_args = "None" +# must_use = "MUST_USE_FUNC" +# no_return = "NO_RETURN" +# prefix = "START_FUNC" +# postfix = "END_FUNC" +args = "auto" +sort_by = "Name" + + +[struct] +rename_fields = "None" +# must_use = "MUST_USE_STRUCT" +derive_constructor = false +derive_eq = false +derive_neq = false +derive_lt = false +derive_lte = false +derive_gt = false +derive_gte = false + + +[enum] +rename_variants = "None" +# must_use = "MUST_USE_ENUM" +add_sentinel = false +prefix_with_name = false +derive_helper_methods = false +derive_const_casts = false +derive_mut_casts = false +# cast_assert_name = "ASSERT" +derive_tagged_enum_destructor = false +derive_tagged_enum_copy_constructor = false +enum_class = true +private_default_tagged_enum_constructor = false + + +[const] +allow_static_const = true +allow_constexpr = false +sort_by = "Name" + + +[macro_expansion] +bitflags = false + + +############## Options for How Your Rust library Should Be Parsed ############## + +[parse] +parse_deps = true +include = ["tfhe"] +exclude = [] +clean = false +extra_bindings = [] + + +[parse.expand] +crates = [] +all_features = false +default_features = true +features = [] diff --git a/tfhe/docs/Booleans/operations.md b/tfhe/docs/Booleans/operations.md new file mode 100644 index 000000000..ee53c77c5 --- /dev/null +++ b/tfhe/docs/Booleans/operations.md @@ -0,0 +1,90 @@ +# Operations and Examples + +In thfe::boolean, the available operations are mainly related to their equivalent Boolean gates, +i.e., AND, OR,... In what follows, an example of a unary gate (NOT) and one about a binary gate +(XOR). The last one is about the ternary MUX gate are detailed, which gives the possibility to +homomorphically compute conditional statements of the form ``If..Then..Else``. + +## The NOT unary gate + +```rust +use tfhe::boolean::prelude::*; + +fn main() { +// We generate a set of client/server keys, using the default parameters: + let (mut client_key, mut server_key) = gen_keys(); + +// We use the client secret key to encrypt a message: + let ct_1 = client_key.encrypt(true); + +// We use the server public key to execute the NOT gate: + let ct_not = server_key.not(&ct_1); + +// We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_not); + assert_eq!(output, false); +} +``` + + +## Binary gates + +```rust +use tfhe::boolean::prelude::*; + +fn main() { +// We generate a set of client/server keys, using the default parameters: + let (mut client_key, mut server_key) = gen_keys(); + +// We use the client secret key to encrypt a message: + let ct_1 = client_key.encrypt(true); + let ct_2 = client_key.encrypt(false); + +// We use the server public key to execute the XOR gate: + let ct_xor = server_key.xor(&ct_1, &ct_2); + +// We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_xor); + assert_eq!(output, true^false); +} +``` + + +## The MUX ternary gate +Let ``ct_1, ct_2, ct_3`` be three Boolean +ciphertexts. Then, the MUX gate (abbreviation of MUtipleXer) is equivalent to the operation: +```r +if ct_1 { + return ct_2 +} else { + return ct_3 +} +``` + +This example show how to use the MUX ternary gate. + +```rust +use tfhe::boolean::prelude::*; + +fn main() { +// We generate a set of client/server keys, using the default parameters: + let (mut client_key, mut server_key) = gen_keys(); + + let bool1 = true; + let bool2 = false; + let bool3 = true; + +// We use the client secret key to encrypt a message: + let ct_1 = client_key.encrypt(true); + let ct_2 = client_key.encrypt(false); + let ct_3 = client_key.encrypt(false); + + +// We use the server public key to execute the NOT gate: + let ct_xor = server_key.mux(&ct_1, &ct_2, &ct_3); + +// We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_xor); + assert_eq!(output, if bool1 {bool2} else {bool3}); +} +``` diff --git a/tfhe/docs/Booleans/parameters.md b/tfhe/docs/Booleans/parameters.md new file mode 100644 index 000000000..20a3abfa0 --- /dev/null +++ b/tfhe/docs/Booleans/parameters.md @@ -0,0 +1,54 @@ +# Cryptographic parameters + +## Default parameters + +The TFHE cryptographic scheme relies on a variant of [Regev cryptosystem](https://cims.nyu.edu/~regev/papers/lwesurvey.pdf), and is based on a problem so hard to solve, that is even post-quantum resistant. + +In practice, you need to tune some cryptographic parameters, in order to ensure the correctness of the result, and the security of the computation. + +To make it simpler, **we provide two sets of parameters**, which ensure correct computations for a certain probability with the standard security of 128 bits. There exists an error probability due the probabilistic nature of the encryption, which requires adding randomness (called noise) following a Gaussian distribution. If this noise is too large, the decryption will not give a correct result. There is a trade-off between efficiency and correctness: generally, using a less efficient parameter set (in terms of computation time) leads to a smaller risk of having an error during homomorphic evaluation. + +In the two proposed sets of parameters, the only difference lies into this probability error. +The default parameter set ensures a probability error of at most $$2^{-40}$$ when computing a +programmable bootstrapping (i.e., any gates but the `not`). The other one is closer to the error +probability claimed into the original [TFHE paper](https://eprint.iacr.org/2018/421), +namely $$2^{-165}$$, but up to date regarding security requirements. + +The following array summarizes this: + +| Parameter set | Error probability | +|:-------------------:|:-----------------:| +| DEFAULT_PARAMETERS | $$ 2^{-40} $$ | +| TFHE_LIB_PARAMETERS | $$ 2^{-165} $$ | + + +## User-defined parameters + + +Note that if you desire, you can also create your own set of parameters. +This is an `unsafe` operation as failing to properly fix the parameters will potentially result with an incorrect and/or insecure computation: + +```rust + +use tfhe::boolean::prelude::*; + +fn main() { +// WARNING: might be insecure and/or incorrect +// You can create your own set of parameters + let parameters = unsafe { + BooleanParameters::new( + LweDimension(586), + GlweDimension(2), + PolynomialSize(512), + StandardDev(0.00008976167396834998), + StandardDev(0.00000002989040792967434), + DecompositionBaseLog(8), + DecompositionLevelCount(2), + DecompositionBaseLog(2), + DecompositionLevelCount(5), + ) + }; +} +``` + + diff --git a/tfhe/docs/Booleans/serialization.md b/tfhe/docs/Booleans/serialization.md new file mode 100644 index 000000000..a927702a0 --- /dev/null +++ b/tfhe/docs/Booleans/serialization.md @@ -0,0 +1,57 @@ +# Save and Load Keys From Files + +Since the `ServerKey` and `ClientKey` types both implement the `Serialize` and +`Deserialize` traits, you are free to use any serializer that suits you to save and load the +keys to disk. + +Here is an example using the `bincode` serialization library, which serializes to a +binary format: + +```rust +use std::fs::File; +use std::io::{Write, Read}; +use tfhe::boolean::prelude::*; + +fn main() { +// We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(); + +// We serialize the keys to bytes: + let encoded_server_key: Vec = bincode::serialize(&server_key).unwrap(); + let encoded_client_key: Vec = bincode::serialize(&client_key).unwrap(); + + let server_key_file = "/tmp/ser_example_server_key.bin"; + let client_key_file = "/tmp/ser_example_client_key.bin"; + +// We write the keys to files: + let mut file = File::create(server_key_file) + .expect("failed to create server key file"); + file.write_all(encoded_server_key.as_slice()).expect("failed to write key to file"); + let mut file = File::create(client_key_file) + .expect("failed to create client key file"); + file.write_all(encoded_client_key.as_slice()).expect("failed to write key to file"); + +// We retrieve the keys: + let mut file = File::open(server_key_file) + .expect("failed to open server key file"); + let mut encoded_server_key: Vec = Vec::new(); + file.read_to_end(&mut encoded_server_key).expect("failed to read the key"); + + let mut file = File::open(client_key_file) + .expect("failed to open client key file"); + let mut encoded_client_key: Vec = Vec::new(); + file.read_to_end(&mut encoded_client_key).expect("failed to read the key"); + +// We deserialize the keys: + let loaded_server_key: ServerKey = bincode::deserialize(&encoded_server_key[..]) + .expect("failed to deserialize"); + let loaded_client_key: ClientKey = bincode::deserialize(&encoded_client_key[..]) + .expect("failed to deserialize"); + + + let ct_1 = client_key.encrypt(false); + +// We check for equality: + assert_eq!(false, loaded_client_key.decrypt(&ct_1)); +} +``` diff --git a/tfhe/docs/Booleans/tutorial.md b/tfhe/docs/Booleans/tutorial.md new file mode 100644 index 000000000..a7a9a3d7b --- /dev/null +++ b/tfhe/docs/Booleans/tutorial.md @@ -0,0 +1,246 @@ +# Tutorial: a first boolean circuit + +This library is meant to be used both on the **server side** and on the **client side**. +The usual use case would follow those steps: + +1. On the **client side**, generate the `client` and `server keys`. +2. Send the `server key` to the **server**. +3. Then any number of times: + + On the **client side**, *encryption* of the input data with the `client key`. + + Transmit the encrypted input to the **server**. + + On the **server side**, *homomorphic computation* with the `server key`. + + Transmit the encrypted output to the **client**. + + On the **client side**, *decryption* of the output data with `client key`. + +## 1. Setup + +In the first step, the client creates two keys: the `client key` and the `server key`, +with the +`concrete_boolean::gen_keys` function: +```rust +use tfhe::boolean::prelude::*; + +fn main() { + +// We generate the client key and the server key, +// using the default parameters: + let (client_key, server_key): (ClientKey, ServerKey) = gen_keys(); +} +``` + +In more details: + ++ The `client_key` is of type `ClientKey`. It is **secret**, and must **never** be transmitted. + This key will only be used to encrypt and decrypt data. ++ The `server_key` is of type `ServerKey`. It is a **public key**, and can be shared with any + party. + This key has to be sent to the server because it is required for the homomorphic computation. + +Note that both the `client_key` and `server_key` implement the `Serialize` and `Deserialize` traits. +This way you can use any compatible serializer to store/send the data. For instance, to store +the `server_key` in a binary file, you can use the `bincode` library: +```rust +use std::fs::File; +use std::io::{Write, Read}; +use tfhe::boolean::prelude::*; + +fn main() { + +//---------------------------- CLIENT SIDE ---------------------------- + +// We generate a client key and a server key, using the default parameters: + let (client_key, server_key) = gen_keys(); + +// We serialize the server key to bytes, and store them in a file: + let encoded: Vec = bincode::serialize(&server_key).unwrap(); + + let server_key_file = "/tmp/tutorial_server_key.bin"; + +// We write the server key to a file: + let mut file = File::create(server_key_file) + .expect("failed to create server key file"); + file.write_all(encoded.as_slice()).expect("failed to write key to file"); + +// ... +// We send the key to server side +// ... + + +//---------------------------- SERVER SIDE ---------------------------- + +// We read the file: + let mut file = File::open(server_key_file) + .expect("failed to open server key file"); + let mut encoded: Vec = Vec::new(); + file.read_to_end(&mut encoded).expect("failed to read key"); + +// We deserialize the server key: + let key: ServerKey = bincode::deserialize(&encoded[..]) + .expect("failed to deserialize"); +} +``` + +## 2. Encrypting Inputs + +Once the server key is available on the **server side**, it is possible to perform some +homomorphic computations. +The client simply needs to encrypt some data and send it to the server. +Again, the `Ciphertext` type implements the `Serialize` and +the `Deserialize` traits, so that any serializer and communication tool suiting your use case +can be +used: +```rust +use tfhe::boolean::prelude::*; + +fn main() { + // Don't consider the following line; you should follow the procedure above. + let (mut client_key, _) = gen_keys(); + +//---------------------------- SERVER SIDE + +// We use the client key to encrypt the messages: + let ct_1 = client_key.encrypt(true); + let ct_2 = client_key.encrypt(false); + +// We serialize the ciphertexts: + let encoded_1: Vec = bincode::serialize(&ct_1).unwrap(); + let encoded_2: Vec = bincode::serialize(&ct_2).unwrap(); + +// ... +// And we send them to the server somehow +// ... +} +``` + +## 2bis. Encrypting Inputs using public key + +Once the server key is available on the **server side**, it is possible to perform some +homomorphic computations. +The client simply needs to encrypt some data and send it to the server. +Again, the `Ciphertext` type implements the `Serialize` and +the `Deserialize` traits, so that any serializer and communication tool suiting your use case +can be +used: +```rust +use tfhe::boolean::prelude::*; + +fn main() { + // Don't consider the following line; you should follow the procedure above. + let (client_key, _) = gen_keys(); + let public_key = PublicKey::new(&client_key); + +//---------------------------- SERVER SIDE + +// We use the public key to encrypt the messages: + let ct_1 = public_key.encrypt(true); + let ct_2 = public_key.encrypt(false); + +// We serialize the ciphertexts: + let encoded_1: Vec = bincode::serialize(&ct_1).unwrap(); + let encoded_2: Vec = bincode::serialize(&ct_2).unwrap(); + +// ... +// And we send them to the server somehow +// ... +} +``` + + +## Executing a Boolean Circuit + +Once the encrypted inputs are on the **server side**, the `server_key` can be used to +homomorphically execute the desired boolean circuit: + +```rust +use std::fs::File; +use std::io::{Write, Read}; +use tfhe::boolean::prelude::*; + +fn main() { + // Don't consider the following lines; you should follow the procedure above. + let (mut client_key, mut server_key) = gen_keys(); + let ct_1 = client_key.encrypt(true); + let ct_2 = client_key.encrypt(false); + let encoded_1: Vec = bincode::serialize(&ct_1).unwrap(); + let encoded_2: Vec = bincode::serialize(&ct_2).unwrap(); + +//---------------------------- ON SERVER SIDE ---------------------------- + +// We deserialize the ciphertexts: + let ct_1: Ciphertext = bincode::deserialize(&encoded_1[..]) + .expect("failed to deserialize"); + let ct_2: Ciphertext = bincode::deserialize(&encoded_2[..]) + .expect("failed to deserialize"); + +// We use the server key to execute the boolean circuit: +// if ((NOT ct_2) NAND (ct_1 AND ct_2)) then (NOT ct_2) else (ct_1 AND ct_2) + let ct_3 = server_key.not(&ct_2); + let ct_4 = server_key.and(&ct_1, &ct_2); + let ct_5 = server_key.nand(&ct_3, &ct_4); + let ct_6 = server_key.mux(&ct_5, &ct_3, &ct_4); + +// Then we serialize the output of the circuit: + let encoded_output: Vec = bincode::serialize(&ct_6) + .expect("failed to serialize output"); + +// ... +// And we send the output to the client +// ... +} +``` + +## Decrypting the output + +Once the encrypted output is on the client side, the `client_key` can be used to +decrypt it: + +```rust +use std::fs::File; +use std::io::{Write, Read}; +use tfhe::boolean::prelude::*; + +fn main() { + // Don't consider the following lines; you should follow the procedure above. + let (mut client_key, mut server_key) = gen_keys(); + let ct_6 = client_key.encrypt(true); + let encoded_output: Vec = bincode::serialize(&ct_6).unwrap(); + +//---------------------------- ON CLIENT SIDE + +// We deserialize the output ciphertext: + let output: Ciphertext = bincode::deserialize(&encoded_output[..]) + .expect("failed to deserialize"); + +// Finally, we decrypt the output: + let output = client_key.decrypt(&output); + +// And check that the result is the expected one: + assert_eq!(output, true); +} +``` + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tfhe/docs/README.md b/tfhe/docs/README.md new file mode 100644 index 000000000..84d67b239 --- /dev/null +++ b/tfhe/docs/README.md @@ -0,0 +1,38 @@ +# What is TFHE-rs? + +⭐️ [Star the repo on Github](https://github.com/zama-ai/tfhe-rs) | 🗣 [Community support forum ](https://community.zama.ai)| 📁 [Contribute to the project](https://docs.zama.ai/tfhe-rs/developers/contributing) + +![](_static/docs\_home.jpg) + +TFHE-rs is a pure Rust implementation of TFHE for boolean and small integer arithmetics over encrypted data. It includes a Rust and C API, as well as a client-side WASM API. + +TFHE-rs is meant for developers and researchers who want full control over what they can do with TFHE, while not having to worry about the low level implementation. + +The goal is to have a stable, simple, high-performance, and production-ready library for all the advanced features of TFHE. + +### Key Cryptographic concepts + +TFHE-rs library implements Zama’s variant of Fully Homomorphic Encryption over the Torus (TFHE). TFHE is based on Learning With Errors (LWE), a well studied cryptographic primitive believed to be secure even against quantum computers. + +In cryptography, a raw value is called a message (also sometimes called a cleartext), an encoded message is called a plaintext and an encrypted plaintext is called a ciphertext. + +The idea of homomorphic encryption is that you can compute on ciphertexts while not knowing messages encrypted in them. A scheme is said to be _fully homomorphic_, meaning any program can be evaluated with it, if at least two of the following operations are supported \($$x$$is a plaintext and $$E[x]$$ is the +corresponding ciphertext\): + +* homomorphic univariate function evaluation: $$f(E[x]) = E[f(x)]$$ +* homomorphic addition: $$E[x] + E[y] = E[x + y]$$ +* homomorphic multiplication: $$E[x] * E[y] = E[x * y]$$ + +Zama's variant of TFHE is fully homomorphic and deals with fixed-precision numbers as messages. It implements all needed homomorphic operations, such as addition and function evaluation via **Programmable Bootstrapping**. You can read more about Zama's TFHE variant in the [preliminary whitepaper](https://whitepaper.zama.ai/). + +Using FHE in a Rust program with TFHE-rs consists in: + +* generating a client key and a server key using secure parameters: + * client key encrypts/decrypts data and must be kept secret + * server key is used to perform operations on encrypted data and could be + public (also called evaluation key) +* encrypting plaintexts using the client key to produce ciphertexts +* operating homomorphically on ciphertexts with the server key +* decrypting the resulting ciphertexts into plaintexts using the client key + +If you would like to know more about the problems that FHE solves, we suggest you review our [6 minute introduction to homomorphic encryption](https://6min.zama.ai/). diff --git a/tfhe/docs/SUMMARY.md b/tfhe/docs/SUMMARY.md new file mode 100644 index 000000000..c720f53a2 --- /dev/null +++ b/tfhe/docs/SUMMARY.md @@ -0,0 +1,33 @@ +# Table of contents + +* [What is TFHE-rs?](README.md) + +## Getting Started + +* [Installation](getting\_started/installation.md) +* [Quick Start](getting\_started/quick\_start.md) +* [Supported Operations](getting\_started/operations.md) +* [Benchmarks](getting\_started/benchmarks.md) +* [Security and Cryptography](getting\_started/security\_and\_cryptography.md) + +## Booleans +* [Tutorial](Booleans/tutorial.md) +* [Operations](Booleans/operations.md) +* [Cryptographic Parameters](Booleans/parameters.md) +* [Serialization/Deserialization](Booleans/serialization.md) + +## Shortint +* [Tutorial](shortint/tutorial.md) +* [Operations](shortint/operations.md) +* [Cryptographic Parameters](shortint/parameters.md) +* [Serialization/Deserialization](shortint/serialization.md) + +## C API +* [Tutorial](c_api/tutorial.md) + +## Developers +* [Contributing](dev/contributing.md) + +## API references +* [docs.rs](https://docs.rs/tfhe/) + diff --git a/tfhe/docs/_static/ciphertext-representation.svg b/tfhe/docs/_static/ciphertext-representation.svg new file mode 100644 index 000000000..b1c4f1cc8 --- /dev/null +++ b/tfhe/docs/_static/ciphertext-representation.svg @@ -0,0 +1,16 @@ + + + + + + + noisecarrymessageLSBMSBCiphertext \ No newline at end of file diff --git a/tfhe/docs/_static/docs_home.jpg b/tfhe/docs/_static/docs_home.jpg new file mode 100644 index 000000000..c0b4da36f Binary files /dev/null and b/tfhe/docs/_static/docs_home.jpg differ diff --git a/tfhe/docs/_static/fig6.png b/tfhe/docs/_static/fig6.png new file mode 100644 index 000000000..7cc40ff89 Binary files /dev/null and b/tfhe/docs/_static/fig6.png differ diff --git a/tfhe/docs/_static/fig7.png b/tfhe/docs/_static/fig7.png new file mode 100644 index 000000000..927ed00b2 Binary files /dev/null and b/tfhe/docs/_static/fig7.png differ diff --git a/tfhe/docs/_static/fig8.png b/tfhe/docs/_static/fig8.png new file mode 100644 index 000000000..87abe8d71 Binary files /dev/null and b/tfhe/docs/_static/fig8.png differ diff --git a/tfhe/docs/_static/lwe.png b/tfhe/docs/_static/lwe.png new file mode 100644 index 000000000..ca47cd911 Binary files /dev/null and b/tfhe/docs/_static/lwe.png differ diff --git a/tfhe/docs/c_api/tutorial.md b/tfhe/docs/c_api/tutorial.md new file mode 100644 index 000000000..79785f926 --- /dev/null +++ b/tfhe/docs/c_api/tutorial.md @@ -0,0 +1,174 @@ +# Tutorial: using the C API + +Welcome to this `TFHE-rs` C API tutorial! + +This library exposes a C binding to the `TFHE-rs` primitives to implement _Fully Homomorphic Encryption_ (FHE) programs. + +# First steps using `TFHE-rs` C API + +## Setting-up `TFHE-rs` C API for use in a C program. + + `TFHE-rs` C API can be built on a Unix x86_64 machine using the following command: + +```shell +RUSTFLAGS="-C target-cpu=native" cargo build --release --features=x86_64-unix,boolean-c-api,shortint-c-api -p tfhe +``` + +All features are opt-in, but for simplicity here, the C API is enabled for booleans and shortints. + +The `tfhe.h` header as well as the static (.a) and dynamic (.so) `libtfhe` binaries can then be found in "${REPO_ROOT}/target/release/" + +The build system needs to be set-up so that the C or C++ program links against `TFHE-rs` C API +binaries. + +Here is a minimal CMakeLists.txt allowing to do just that: + +```cmake +project(my-project) + +cmake_minimum_required(VERSION 3.16) + +set(TFHE_C_API "/path/to/tfhe-rs/binaries/and/header") + +include_directories(${TFHE_C_API}) +add_library(tfhe STATIC IMPORTED) +set_target_properties(tfhe PROPERTIES IMPORTED_LOCATION ${TFHE_C_API}/libtfhe.a) + +if(APPLE) + find_library(SECURITY_FRAMEWORK Security) + if (NOT SECURITY_FRAMEWORK) + message(FATAL_ERROR "Security framework not found") + endif() +endif() + +set(EXECUTABLE_NAME my-executable) +add_executable(${EXECUTABLE_NAME} main.c) +target_include_directories(${EXECUTABLE_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_libraries(${EXECUTABLE_NAME} LINK_PUBLIC tfhe m pthread dl) +if(APPLE) + target_link_libraries(${EXECUTABLE_NAME} LINK_PUBLIC ${SECURITY_FRAMEWORK}) +endif() +target_compile_options(${EXECUTABLE_NAME} PRIVATE -Werror) +``` + +## Commented code of a PBS doubling a 2 bits encrypted message using `TFHE-rs C API` + +The steps required to perform the mutiplication by 2 of a 2 bits ciphertext +using a PBS are detailed. +This is NOT the most efficient way of doing this operation, +but it allows to show the management required to run a PBS manually using the C API. + +WARNING: The following example does not have proper memory management in the error case to make it easier to fit the code on this page. + +To run the example below, the above CMakeLists.txt and main.c files need to be in the same +directory. The commands to run are: +```shell +# /!\ Be sure to update CMakeLists.txt to give the absolute path to the compiled tfhe library +$ ls +CMakeLists.txt main.c +$ mkdir build && cd build +$ cmake .. -DCMAKE_BUILD_TYPE=RELEASE +... +$ make +... +$ ./my-executable +Result: 2 +$ +``` + +```c +#include "tfhe.h" +#include +#include +#include + +uint64_t double_accumulator_2_bits_message(uint64_t in) { return (in * 2) % 4; } + +uint64_t get_max_value_of_accumulator_generator(uint64_t (*accumulator_func)(uint64_t), + size_t message_bits) +{ + uint64_t max_value = 0; + for (size_t idx = 0; idx < (1 << message_bits); ++idx) + { + uint64_t acc_value = accumulator_func((uint64_t)idx); + max_value = acc_value > max_value ? acc_value : max_value; + } + + return max_value; +} + +int main(void) +{ + ShortintPBSAccumulator *accumulator = NULL; + ShortintClientKey *cks = NULL; + ShortintServerKey *sks = NULL; + ShortintParameters *params = NULL; + + // Get the parameters for 2 bits messages with 2 bits of carry + int get_params_ok = shortint_get_parameters(2, 2, ¶ms); + assert(get_params_ok == 0); + + // Generate the keys with the parameters + int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks); + assert(gen_keys_ok == 0); + + // Generate the accumulator for the PBS + int gen_acc_ok = shortint_server_key_generate_pbs_accumulator( + sks, double_accumulator_2_bits_message, &accumulator); + assert(gen_acc_ok == 0); + + ShortintCiphertext *ct = NULL; + ShortintCiphertext *ct_out = NULL; + + // We will compute 1 * 2 using a PBS, it's not the recommended way to perform a multiplication, + // but it shows how to manage a PBS manually in the C API + uint64_t in_val = 1; + + // Encrypt the input value + int encrypt_ok = shortint_client_key_encrypt(cks, in_val, &ct); + assert(encrypt_ok == 0); + + // Check the degree is set to the maximum value that can be encrypted on 2 bits, i.e. 3 + // This check is not required and is just added to show, the degree information can be retrieved + // in the C APi + size_t degree = -1; + int get_degree_ok = shortint_ciphertext_get_degree(ct, °ree); + assert(get_degree_ok == 0); + + assert(degree == 3); + + // Apply the PBS on our encrypted input + int pbs_ok = shortint_server_key_programmable_bootstrap(sks, accumulator, ct, &ct_out); + assert(pbs_ok == 0); + + // Set the degree to keep consistency for potential further computations + // Note: This is only required for the PBS + size_t degree_to_set = + (size_t)get_max_value_of_accumulator_generator(double_accumulator_2_bits_message, 2); + + int set_degree_ok = shortint_ciphertext_set_degree(ct_out, degree_to_set); + assert(set_degree_ok == 0); + + // Decrypt the result + uint64_t result = -1; + int decrypt_non_assign_ok = shortint_client_key_decrypt(cks, ct_out, &result); + assert(decrypt_non_assign_ok == 0); + + // Check the result is what we expect i.e. 2 + assert(result == double_accumulator_2_bits_message(in_val)); + printf("Result: %ld\n", result); + + // Destroy entities from the C API + destroy_shortint_ciphertext(ct); + destroy_shortint_ciphertext(ct_out); + destroy_shortint_pbs_accumulator(accumulator); + destroy_shortint_client_key(cks); + destroy_shortint_server_key(sks); + destroy_shortint_parameters(params); + return EXIT_SUCCESS; +} +``` + +# Audience + +Programmers wishing to use `TFHE-rs` but unable to use Rust (for various reasons) can use these bindings in their language of choice as long as it can interface with C code to bring `TFHE-rs` functionalities to said language. diff --git a/tfhe/docs/dev/contributing.md b/tfhe/docs/dev/contributing.md new file mode 100644 index 000000000..8a720d9df --- /dev/null +++ b/tfhe/docs/dev/contributing.md @@ -0,0 +1,6 @@ +# Contribute + +There are two ways to contribute to **TFHE-rs**: + +* you can open issues to report bugs and typos and to suggest ideas +* you can ask to become an official contributor by emailing hello@zama.ai. Only approved contributors can end pull requests, so please make sure to get in touch before you do! diff --git a/tfhe/docs/getting_started/benchmarks.md b/tfhe/docs/getting_started/benchmarks.md new file mode 100644 index 000000000..7d8803bc5 --- /dev/null +++ b/tfhe/docs/getting_started/benchmarks.md @@ -0,0 +1,44 @@ +# Benchmarks + +Due to their nature, homomorphic operations are obviously slower than their clear equivalent. In what follows, some timings are exposed for the basic operations. For completeness, some benchmarks of other libraries are also given. + +All the benchmarks had been launched on an AWS m6i.metal with the following specifications: +Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz and 512GB of RAM. + +## Booleans + +This measures the execution time of a single binary boolean gate. + +### thfe.rs::booleans + +| Parameter set | concrete-fft | concrete-fft + avx512 | +| --- | --- | --- | +| DEFAULT_PARAMETERS | 8.8ms | 6.8ms | +| TFHE_LIB_PARAMETERS | 13.6ms | 10.9ms | + +### tfhe-lib + +| Parameter set | fftw | spqlios-fma| +| --- | --- | --- | +| default_128bit_gate_bootstrapping_parameters | 28.9ms | 15.7ms | + +### OpenFHE + +| Parameter set | GINX | GINX (Intel HEXL) | +| --- | --- | --- | +| STD_128 | 172ms | 78ms | +| MEDIUM | 113ms | 50.2ms | + +## Shortints +This measures the execution time for some operations and some parameter sets of shortints. + +### thfe.rs::shortint +This uses the concrete-fft + avx512 configuration. + + +| Parameter set | unchecked_add | unchecked_mul_lsb | keyswitch_programmable_bootstrap | +| --- | --- | --- | --- | +| PARAM_MESSAGE_1_CARRY_1 | 337 ns | 10.1 ms | 9.91 ms | +| PARAM_MESSAGE_2_CARRY_2 | 407 ns | 21.7 ms | 21.4 ms | +| PARAM_MESSAGE_3_CARRY_3 | 3.06 µs | 161 ms | 159 ms | +| PARAM_MESSAGE_4_CARRY_4 | 11.7 µs | 1.03 s | 956 ms | diff --git a/tfhe/docs/getting_started/installation.md b/tfhe/docs/getting_started/installation.md new file mode 100644 index 000000000..5aa27a2bf --- /dev/null +++ b/tfhe/docs/getting_started/installation.md @@ -0,0 +1,80 @@ +# Installation + +## Importing into your project + +To use `TFHE-rs` in your project, you first need to add it as a dependency in your `Cargo.toml`: + + +```toml +tfhe = { version = "0.1.0", features = [ "boolean", "shortint", "x86_64-unix" ] } +``` + +## Choosing your features + +`TFHE-rs` exposes different `cargo features` to customize the types and features used. + +### Kinds. + +This crate exposes two kinds of data types. Each kind is enabled by activating its corresponding feature in the TOML line. Each kind may have multiple types: + +| Kind | Features | Type(s) | +| --------- | ------------- |------------------------------------------| +| Booleans | `boolean` | Booleans | +| ShortInts | `shortint` | Short unsigned integers | + + +### Serialization. + +The different data types and keys exposed by the crate can be serialized / deserialized. + +More information can be found [here](../Booleans/serialization.md) for Booleans and [here](../shortint/serialization.md) for shortint. + +## Supported platforms + +TFHE-rs is supported on Linux (x86, aarch64), macOS (x86, aarch64) and Windows (x86 with `RDSEED` +instruction). + +| OS | x86 | aarch64 | +| --------- | ------------- |------------------| +| Linux | `x86_64-unix` | `aarch64-unix`* | +| macOS | `x86_64-unix` | `aarch64-unix`* | +| Windows | `x86_64` | Unsupported | + +{% hint style="info" %} +Users who have ARM devices can use `TFHE-rs` by compiling using the +`nightly` toolchain. +{% endhint %} + + +### Using TFHE-rs with nightly toolchain + +First, install the needed Rust toolchain: + +```shell +rustup toolchain install nightly +``` + +Then, you can either: + +* Manually specify the toolchain to use in each of the cargo commands: + +For example: + +```shell +cargo +nightly build +cargo +nightly test +``` + +* Or override the toolchain to use for the current project: + +```shell +rustup override set nightly +# cargo will use the `nightly` toolchain. +cargo build +``` + +To check the toolchain that Cargo will use by default, you can use the following command: + +```shell +rustup show +``` diff --git a/tfhe/docs/getting_started/operations.md b/tfhe/docs/getting_started/operations.md new file mode 100644 index 000000000..7667ecf1a --- /dev/null +++ b/tfhe/docs/getting_started/operations.md @@ -0,0 +1,53 @@ +# Supported Operations + +## Boolean + +The list of supported operations by the homomorphic booleans is: + +|Operation Name | type | +| ------ | ------ | +| `not` | Unary | +| `and` | Binary | +| `or` | Binary | +| `xor` | Binary | +| `nor` | Binary | +| `xnor` | Binary | +| `cmux` | Ternary | + + +A walk-through using homomorphic Booleans can be found [here](../Booleans/tutorial.md). + + +## ShortInt + +In TFHE-rs, the shortints represent short unsigned integers encoded over 8 bits maximum. A complete homomorphic arithmetic is provided, along with the possibility to compute univariate and bi-variate functions. Some operations are only available for integers up to 4 bits. More technical details can be found [here](../shortint/operations.md). + + +The list of supported operations is: + +| Operation name | Type | +|--------------- | ------ | +| Negation | Unary | +| Addition | Binary | +| Subtraction | Binary | +| Multiplication | Binary | +| Division* | Binary | +| Modular reduction | Binary | +| Comparisons | Binary | +| Left/Right Shift | Binary | +| And | Binary | +| Or | Binary | +| Xor | Binary | +| Exact Function Evaluation | Unary/Binary | + +{% hint style="info" %} +\* The division operation implements a subtlety: since data is encrypted, it might be possible to compute a division by 0. In this case, the division is tweaked so that dividing by 0 returns 0. +{% endhint %} + +A walk-through example can be found [here](../shortint/tutorial.md) and more examples and +explanations can be found [here](../shortint/operations.md) + + + + + diff --git a/tfhe/docs/getting_started/quick_start.md b/tfhe/docs/getting_started/quick_start.md new file mode 100644 index 000000000..3f0b52bb8 --- /dev/null +++ b/tfhe/docs/getting_started/quick_start.md @@ -0,0 +1,80 @@ +# Quick start + +This library makes it possible to execute **homomorphic operations over encrypted data**, where the data are either Booleans or short integers (named shortints in the rest of this documentation). +It allows one to execute a circuit on an **untrusted server** because both circuit inputs and outputs are kept **private**. +Data are indeed encrypted on the client side, before being sent to the server. On the server side every computation is performed on ciphertexts. + +The server however has to know the circuit to be evaluated. At the end of the computation, the server returns the encryption of the result to the user. She can then decrypt it with her `secret key`. + + +## General method to write homomorphic circuit program + +The overall process to write an homomorphic program is the same for both Boolean and short integers types. +In a nutshell, the basic steps for using the TFHE-rs library are the following: +- Choose a data type (Boolean or shortint) +- Import the library +- Create client and server keys +- Encrypt data with the client key +- Compute over encrypted data using the server key +- Decrypt data with the client key + + +### Boolean example + +Here is an example to illustrate how the library can be used to evaluate a Boolean circuit: + +```rust +use tfhe::boolean::prelude::*; + +fn main() { +// We generate a set of client/server keys, using the default parameters: + let (mut client_key, mut server_key) = gen_keys(); + +// We use the client secret key to encrypt two messages: + let ct_1 = client_key.encrypt(true); + let ct_2 = client_key.encrypt(false); + +// We use the server public key to execute a boolean circuit: +// if ((NOT ct_2) NAND (ct_1 AND ct_2)) then (NOT ct_2) else (ct_1 AND ct_2) + let ct_3 = server_key.not(&ct_2); + let ct_4 = server_key.and(&ct_1, &ct_2); + let ct_5 = server_key.nand(&ct_3, &ct_4); + let ct_6 = server_key.mux(&ct_5, &ct_3, &ct_4); + +// We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_6); + assert_eq!(output, true); +} +``` + +### Shortint example + +and here is a full example using shortints: + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(Parameters::default()); + + let msg1 = 1; + let msg2 = 0; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to execute an integer circuit: + let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + assert_eq!(output, (msg1 + msg2) % modulus as u64); +} +``` + + +The library is pretty simple to use, and can evaluate **homomorphic circuits of arbitrary length**. The description of the algorithms can be found in the [TFHE](https://doi.org/10.1007/s00145-019-09319-x) paper (also available as [ePrint 2018/421](https://ia.cr/2018/421)). diff --git a/tfhe/docs/getting_started/security_and_cryptography.md b/tfhe/docs/getting_started/security_and_cryptography.md new file mode 100644 index 000000000..a6a3aa151 --- /dev/null +++ b/tfhe/docs/getting_started/security_and_cryptography.md @@ -0,0 +1,121 @@ +# Cryptography & Security + +# TFHE + +TFHE-rs is a cryptographic library dedicated to Fully Homomorphic Encryption. As its name +suggests, it is based on the TFHE scheme. + +It is interesting to understand some basics about TFHE, +in order to apprehend where the limitations are coming from both +in terms of precision (number of bits used to represent the plaintext values) +and execution time (why TFHE operations are slower than native operations). + +# LWE Ciphertexts + +Although there are many kinds of ciphertexts in TFHE, +all the encrypted values in TFHE-rs are mainly stored as LWE ciphertexts. + +The security of TFHE relies on the LWE problem which stands for Learning With Errors. +The problem is believed to be secure against quantum attacks. + +An LWE Ciphertext is a collection of 32-bits or 64-bits unsigned integers. +Before encrypting a message in an LWE ciphertext, one needs to first encode it as a plaintext. +This is done by shifting the message to the most significant bits of the unsigned integer type used. + +Then, a little random value called noise is added to the least significant bits. +This noise (also called error for Learning With Errors) is crucial to the security of the ciphertext. + +$$ plaintext = (\Delta * m) + e $$ + +![](../_static/lwe.png) + +To go from a **plaintext** to a **ciphertext** one needs to encrypt the plaintext using a secret key. + +An LWE secret key is a list of `n` random integers: $$S = (s_0, ..., s_n)$$. +$$n$$ is called the $$LweDimension$$ + +A LWE ciphertext, is composed of two parts: +- The mask $$(a_0, ..., a_{n-1})$$ +- The body $$b$$ + +The mask of a _fresh_ ciphertext (one that is the result of an encryption +and not an operation such as ciphertext addition) is a list of `n` uniformly random values. + +The body is computed as follows: + +$$ b = (\sum_{i = 0}^{n-1}{a_i * s_i}) + plaintext $$ + +Now that the encryption scheme is defined, to illustrate why it is slower to compute over encrypted data, +let us show the example of the addition between ciphertexts. + +To add two ciphertexts, we must add their $mask$ and $body$ as done below. + +$$ +ct_0 = (a_{0}, ..., a_{n}, b) \\ +ct_1 = (a_{1}^{'}, ..., a_{n}^{'}, b^{'}) \\ + +ct_{2} = ct_0 + ct_1 \\ +ct_{2} = (a_{0} + a_{0}^{'}, ..., a_{n} + a_{n}^{'}, b + b^{'})\\ + +b + b^{'} = (\sum_{i = 0}^{n-1}{a_i * s_i}) + plaintext + (\sum_{i = 0}^{n-1}{a_i^{'} * s_i}) + plaintext^{'}\\ + +b + b^{'} = (\sum_{i = 0}^{n-1}{(a_i + a_i^{'})* s_i}) + \Delta m + \Delta m^{'} + e + e^{'}\\ +$$ + +To add ciphertexts, it is sufficient to add their masks and bodies. +Instead of just adding 2 integers, one needs to add $$n + 1$$ elements. +The addition is an intuitive example to show the slowdown of FHE computation compared to plaintext +computation but other operations are far more expensive +(e.g., the computation of a lookup table using the Programmable Bootstrapping) + +# Ciphertexts Operations + +## Understanding noise and padding + +In FHE, there are two types of operations that can be applied to ciphertexts: + +* **leveled operations**, which increase the noise in the ciphertext +* **bootstrapped operations**, which reduce the noise in the ciphertext + +In FHE, the noise must be tracked and managed in order to guarantee the correctness of the computation. + +Bootstrapping operations are used across the computation to decrease the noise in the ciphertexts, preventing it from tampering the message. The rest of the operations are called leveled because they do not need bootstrapping operations and thus are most of the time really fast. + +The following sections explain the concept of noise and padding in ciphertexts. + +### Noise + +For it to be secure, LWE requires random noise to be added to the message at encryption time. + +In TFHE, this random noise is drawn from a Centered Normal Distribution parameterized by a standard deviation. This standard deviation is a security parameter. +With all other security parameters set, the larger the standard deviation is, the more secure the encryption is. + +In `TFHE-rs`, the noise is encoded in the least significant bits of the plaintexts. Each leveled computation will increase the noise, and thus if too many computations are done, the noise will eventually overflow onto the significant data bits of the message and lead to an incorrect result. + +The figure below illustrates this problem in case of an addition, where an extra bit of noise is incurred as a result. + +![Noise overtaking on the plaintexts after homomorphic addition. Most Significant bits are on the left.](../_static/fig7.png) + +TFHE-rs offers the possibility to automatically manage the noise, by performing bootstrapping operations to reset the noise when needed. + + +### Padding + +Since encoded values have a fixed precision, operating on them can sometime produce results that are outside the original interval. To avoid losing precision or wrapping around the interval, TFHE-rs uses additional bits by defining bits of **padding** on the most significant bits. + +As an example, consider adding two ciphertexts. Adding two values could en up outside the range of either ciphertexts, and thus necessitate a carry, which would then be carried onto the first padding bit. In the figure below, each plaintext over 32 bits has one bit of padding on its left \(i.e., the most significant bit\). After the addition, the padding bit is no longer available, as it has been used in order for the carry. This is referred to as **consuming** bits of padding. Since no padding is left, there is no guarantee that additional additions would yield correct results. + +![](../_static/fig6.png) + +If you would like to know more about TFHE, you can find more information in our [TFHE Deep Dive](https://www.zama.ai/post/tfhe-deep-dive-part-1). + +## Security + +By default, the cryptographic parameters provided by `TFHE-rs` ensure at least 128 bits of security. +The security has been evaluated using the latest versions of the Lattice Estimator ([repository](https://github.com/malb/lattice-estimator)) with `red_cost_model = reduction.RC.BDGL16`. + +For all sets of parameters, the error probability when computing a univariate function over one ciphertext is $$2^{-40}$$. Note that univariate functions might be performed when arithmetic functions are computed (for instance, the multiplication of two ciphertexts). + +## Public key encryption + +In public key encryption, the public key consists in providing a given number of message encrypting the value 0. By setting the number of encryptions of 0 in the public key at $$m = \lceil (n+1) \log(q) \rceil + \lambda$$, where $$n$$ is the LWE dimension, $$q$$ is the ciphertext modulus and $$\lambda$$ is the number of security bits. In a nutshell, this construction is secure due to the left-over-hash lemma, which is essentially related to the impossibility of breaking the underlying multiple subset sum problem. By using this formula, this guarantees both a high density subset sum and an exponentially large number of possible associated random vectors per LWE sample (a,b) diff --git a/tfhe/docs/shortint/operations.md b/tfhe/docs/shortint/operations.md new file mode 100644 index 000000000..778fecd0c --- /dev/null +++ b/tfhe/docs/shortint/operations.md @@ -0,0 +1,378 @@ +# How Shortint are represented + +In `shortint`, the encrypted data is stored in an LWE ciphertext. + +Conceptually, the message stored in an LWE ciphertext, is divided into +a **carry buffer** and a **message buffer**. + +![](../\_static/ciphertext-representation.svg) + +The message buffer is the space where the actual message is stored. This represents the modulus of the input messages (denoted by `MessageModulus` in the code). When doing computations on a ciphertext, the encrypted message can overflow the message modulus: the exceeding information is stored in the carry buffer. The size of the carry buffer is defined by another modulus, called `CarryModulus`. + +Together, the message modulus and the carry modulus form the plaintext space that is available in a ciphertext. This space cannot be overflowed, otherwise the computation may result in incorrect outputs. + +In order to ensure the computation correctness, we keep track of the maximum value encrypted in a +ciphertext via an associated attribute called the **degree**. When the degree reaches a defined threshold, the carry buffer may be emptied to resume safely the computations. Therefore, in `shortint` the carry modulus is mainly considered as a means to do more computations. + +# Types of operations + +The operations available via a `ServerKey` may come in different variants: + + - operations that take their inputs as encrypted values. + - scalar operations take at least one non-encrypted value as input. + +For example, the addition has both variants: + + - `ServerKey::unchecked_add` which takes two encrypted values and adds them. + - `ServerKey::unchecked_scalar_add` which takes an encrypted value and a clear value (the + so-called scalar) and adds them. + +Each operation may come in different 'flavors': + + - `unchecked`: Always does the operation, without checking if the result may exceed the capacity of + the plaintext space. Using this operations might have an impact on the correctness of the + following operations; + - `checked`: Checks are done before computing the operation, returning an error if operation + cannot be done safely; + - `smart`: Always does the operation, if the operation cannot be computed safely, the smart operation + will clear the carry modulus to make the operation possible. + +Not all operations have these 3 flavors, as some of them are implemented in a way +that the operation is always possible without ever exceeding the plaintext space capacity. + + +# How to use operation types + +Let's try to do a circuit evaluation using the different flavours of operations we already introduced. +For a very small circuit, the `unchecked` flavour may be enough to do the computation correctly. +Otherwise, the `checked` and `smart` are the best options. + +As an example, let's do a scalar multiplication, a subtraction and a multiplication. + + +```rust +use tfhe::shortint::prelude::*; + + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 3; + let msg2 = 3; + let scalar = 4; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the client key to encrypt two messages: + let mut ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + server_key.unchecked_scalar_mul_assign(&mut ct_1, scalar); + server_key.unchecked_sub_assign(&mut ct_1, &ct_2); + server_key.unchecked_mul_lsb_assign(&mut ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_1); + println!("expected {}, found {}", ((msg1 * scalar as u64 - msg2) * msg2) % modulus as u64, output); +} +``` + +During this computation the carry buffer has been overflowed and as all the operations were `unchecked` the output +may be incorrect. + +If we redo this same circuit but using the `checked` flavour, a panic will occur. + +```rust +use tfhe::shortint::prelude::*; + +use std::error::Error; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 3; + let msg2 = 3; + let scalar = 4; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the client key to encrypt two messages: + let mut ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + let mut ops = || -> Result<(), Box> { + server_key.checked_scalar_mul_assign(&mut ct_1, scalar)?; + server_key.checked_sub_assign(&mut ct_1, &ct_2)?; + server_key.checked_mul_lsb_assign(&mut ct_1, &ct_2)?; + Ok(()) + }; + + match ops() { + Ok(_) => (), + Err(e) => { + println!("correctness of operations is not guaranteed due to error: {}", e); + return; + }, + } + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_1); + assert_eq!(output, ((msg1 * scalar as u64 - msg2) * msg2) % modulus as u64); +} +``` + +Therefore, the `checked` flavour permits to manually manage the overflow of the carry buffer +by raising an error if the correctness is not guaranteed. + +Lastly, using the `smart` flavour will output the correct result all the time. However, the +computation may be slower +as the carry buffer may be cleaned during the computations. + +```rust +use tfhe::shortint::prelude::*; + + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 3; + let msg2 = 3; + let scalar = 4; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the client key to encrypt two messages: + let mut ct_1 = client_key.encrypt(msg1); + let mut ct_2 = client_key.encrypt(msg2); + + server_key.smart_scalar_mul_assign(&mut ct_1, scalar); + server_key.smart_sub_assign(&mut ct_1, &mut ct_2); + server_key.smart_mul_lsb_assign(&mut ct_1, &mut ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_1); + assert_eq!(output, ((msg1 * scalar as u64 - msg2) * msg2) % modulus as u64); +} +``` +#List of available operations +{% hint style="warning" %} + +Currently, certain operations can only be used if the parameter set chosen is compatible with the +bivariate programmable bootstrapping, meaning the carry buffer is larger or equal than the +message buffer. These operations are marked with a star (*). + +{% endhint %} + + +The list of implemented operations for shortints is: +- addition between two ciphertexts +- addition between a ciphertext and an unencrypted scalar +- comparisons `<`, `<=`, `>`, `>=`, `==` between a ciphertext and an unencrypted scalar +- division of a ciphertext by an unencrypted scalar +- LSB multiplication between two ciphertexts returning the result truncated to fit in the `message buffer` +- multiplication of a ciphertext by an unencrypted scalar +- bitwise shift `<<`, `>>` +- subtraction of a ciphertext by another ciphertext +- subtraction of a ciphertext by an unencrypted scalar +- negation of a ciphertext +- bitwise and, or and xor (*) +- comparisons `<`, `<=`, `>`, `>=`, `==` between two ciphertexts (*) +- division between two ciphertexts (*) +- MSB multiplication between two ciphertexts returning the part overflowing the `message buffer` (*) + +In what follows, some simple code examples are given. + +## Public key encryption +TFHE-rs supports both private and public key encryption methods. Note that the only difference +between both lies into the encryption step: in this case, the encryption method is called using +`public_key` instead of `client_key`. + +Here a small example on how to use public encryption: +```rust +use tfhe::boolean::prelude::*; + +fn main() { + // Generate the client key and the server key: + let (cks, mut sks) = gen_keys(); + let pks = PublicKey::new(&cks); + // Encryption of one message: + let ct = pks.encrypt(true); + // Decryption: + let dec = cks.decrypt(&ct); + assert_eq!(true, dec); +} +``` + + + +In what follows, all examples are related to private key encryption. + +## Arithmetic operations +Classical arithmetic operations are supported by shortints: + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys to compute over Z/2^2Z, with 2 carry bits + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 2; + let msg2 = 1; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the private client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to execute an integer circuit: + let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + assert_eq!(output, (msg1 + msg2) % modulus as u64); +} +``` + + + +### Bitwise operations + +Short homomorphic integer types support some bitwise operations. + +A simple example on how to use these operations: +```rust + +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys to compute over Z/2^2Z, with 2 carry bits + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 2; + let msg2 = 1; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the private client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to homomorphically compute a bitwise AND: + let ct_3 = server_key.unchecked_bitand(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + assert_eq!(output, (msg1 & msg2) % modulus as u64); +} +``` + +### Comparisons + +Short homomorphic integer types support comparison operations. + +A simple example on how to use these operations: + +```rust + +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys to compute over Z/2^2Z, with 2 carry bits + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 2; + let msg2 = 1; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the private client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to execute an integer circuit: + let ct_3 = server_key.unchecked_greater_or_equal(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + assert_eq!(output, (msg1 >= msg2) as u64 % modulus as u64); +} +``` + +### Univariate function evaluations + +A simple example on how to use this operation to homomorphically compute +the hamming weight (i.e., the number of bit equals to one) of an encrypted +number. + +```rust + +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys to compute over Z/2^2Z, with 2 carry bits + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 3; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the private client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + + //define the accumulator as the + let acc = server_key.generate_accumulator(|n| n.count_ones().into()); + + // add the two ciphertexts + let ct_res = server_key.keyswitch_programmable_bootstrap(&ct_1, &acc); + + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_res); + assert_eq!(output, msg1.count_ones() as u64); +} +``` + +### Bi-variate function evaluations + +Using the shortint types offers the possibility to evaluate bi-variate functions, i.e., +functions that takes two ciphertexts as input. This requires to choose a parameter set +such that the carry buffer size is at least as large as the message one i.e., +PARAM_MESSAGE_X_CARRY_Y with X <= Y. + +In what follows, a simple code example: + +```rust + +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys to compute over Z/2^2Z, with 2 carry bits + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 3; + let msg2 = 2; + + let modulus = client_key.parameters.message_modulus.0 as u64; + + // We use the private client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // Compute the accumulator for the bivariate functions + let acc = server_key.generate_accumulator_bivariate(|x,y| (x.count_ones() + + y.count_ones()) as u64 % modulus ); + + let ct_res = server_key.keyswitch_programmable_bootstrap_bivariate(&ct_1, &ct_2, &acc); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_res); + assert_eq!(output, (msg1.count_ones() as u64 + msg2.count_ones() as u64) % modulus); +} +``` + + diff --git a/tfhe/docs/shortint/parameters.md b/tfhe/docs/shortint/parameters.md new file mode 100644 index 000000000..e3e852764 --- /dev/null +++ b/tfhe/docs/shortint/parameters.md @@ -0,0 +1,88 @@ +# Cryptographic parameters + +All parameter sets provides at least 128-bits of security according to the [Lattice-Estimator](https://github.com/malb/lattice-estimator), with an error probability equals to $$2^{-40}$$ when computing a programmable bootstrapping. This error probability is due to the randomness added at each encryption (see [here](../getting_started/security_and_cryptography.md) for more details about the encryption process). + + +## Parameters and message precision + +`shortint` comes with sets of parameters that permit to use the functionalities of the library securely and efficiently. Each parameter sets is associated to the message and carry precisions. Thus, each key pair is entangled to precision. + +The user is allowed to choose which set of parameters to use when creating the pair of keys. + +The difference between the parameter sets is the total amount of space dedicated to the plaintext and how it is split between the message buffer and the carry buffer. The syntax chosen for the name of a parameter is: +`PARAM_MESSAGE_{number of message bits}_CARRY_{number of carry bits}`. For example, the set of parameters for a message buffer of 5 bits and a carry buffer of 2 bits is `PARAM_MESSAGE_5_CARRY_2`. + +In what follows, there is an example where keys are generated to have messages encoded over 3 bits i.e., computations are done modulus $$2^3 = 8$$), with 3 bits of carry. + + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + + let msg1 = 3; + let msg2 = 7; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); +} +``` + +## Impact of parameters on the operations + +As shown [here](../getting_started/benchmarks.md), the choice of the parameter set impacts the operations available and their efficiency. + +### Generic bi-variate functions + +The computations of bi-variate functions is based on a trick *concatenating* two ciphertexts into one. In the case where the carry buffer is not at least as large as the message one, this trick is not working anymore. Then, many bi-variate operations, such as comparisons cannot be correctly computed anymore. The only exception concerns the multiplication. + +### Multiplication + +In the case of the multiplication, two algorithms are implemented: the first one relies on the bi-variate function trick, where the other one is based on the [quarter square method](https://en.wikipedia.org/wiki/Multiplication_algorithm#Quarter_square_multiplication). In order to correctly compute a multiplication, the only requirement is to have at least one bit of carry (i.e., using parameter sets PARAM_MESSAGE_X_CARRY_Y with Y>=1). This method is in general slower than using the other one. Note that using the `smart` version of the multiplication automatically chooses which algorithm is used depending on the chosen parameters. + +## User-defined parameter sets + +Beyond the predefined parameter sets, this is possible to define new parameter sets. +To do so, it is sufficient to use the function `unsecure_parameters()` or to manually fulfill the +`Parameter` structure fields. + +For instance: + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + let param = unsafe { + Parameters::new( + LweDimension(656), + GlweDimension(2), + PolynomialSize(512), + StandardDev(0.000034119201269311964), + StandardDev(0.00000004053919869756513), + DecompositionBaseLog(8), + DecompositionLevelCount(2), + DecompositionBaseLog(3), + DecompositionLevelCount(4), + StandardDev(0.00000000037411618952047216), + DecompositionBaseLog(15), + DecompositionLevelCount(1), + DecompositionLevelCount(0), + DecompositionBaseLog(0), + MessageModulus(4), + CarryModulus(1), + ) + }; +} +``` + + + + + + + + + diff --git a/tfhe/docs/shortint/serialization.md b/tfhe/docs/shortint/serialization.md new file mode 100644 index 000000000..25be5129c --- /dev/null +++ b/tfhe/docs/shortint/serialization.md @@ -0,0 +1,64 @@ +# Serialization / Deserialization + +As explained in the introduction, some types (`Serverkey`, `Ciphertext`) are meant to be shared with the server that does the computations. + +The easiest way to send these data to a server is to use the serialization and deserialization features. tfhe::shortint uses the [serde](https://crates.io/crates/serde) framework, serde's Serialize and Deserialize are implemented on tfhe::shortint's types. + +To be able to serialize our data, we need to pick a [data format](https://serde.rs/#data-formats), for our use case, [bincode](https://crates.io/crates/bincode) is a good choice, mainly because it is binary format. + + +```toml +# Cargo.toml + +[dependencies] +# ... +bincode = "1.3.3" +``` + + +```rust +// main.rs + +use bincode; +use std::io::Cursor; +use tfhe::shortint::prelude::*; + + +fn main() -> Result<(), Box> { + let (client_key, server_key) = gen_keys(Parameters::default()); + + let msg1 = 1; + let msg2 = 0; + + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + let mut serialized_data = Vec::new(); + bincode::serialize_into(&mut serialized_data, &server_key)?; + bincode::serialize_into(&mut serialized_data, &ct_1)?; + bincode::serialize_into(&mut serialized_data, &ct_2)?; + + // Simulate sending serialized data to a server and getting + // back the serialized result + let serialized_result = server_function(&serialized_data)?; + let result: Ciphertext = bincode::deserialize(&serialized_result)?; + + let output = client_key.decrypt(&result); + assert_eq!(output, msg1 + msg2); + Ok(()) +} + + +fn server_function(serialized_data: &[u8]) -> Result, Box> { + let mut serialized_data = Cursor::new(serialized_data); + let server_key: ServerKey = bincode::deserialize_from(&mut serialized_data)?; + let ct_1: Ciphertext = bincode::deserialize_from(&mut serialized_data)?; + let ct_2: Ciphertext = bincode::deserialize_from(&mut serialized_data)?; + + let result = server_key.unchecked_add(&ct_1, &ct_2); + + let serialized_result = bincode::serialize(&result)?; + + Ok(serialized_result) +} +``` diff --git a/tfhe/docs/shortint/tutorial.md b/tfhe/docs/shortint/tutorial.md new file mode 100644 index 000000000..5cc6156ac --- /dev/null +++ b/tfhe/docs/shortint/tutorial.md @@ -0,0 +1,96 @@ +# Tutorial: Writing an homomorphic circuit using shortints + +# 1. Key Generation + +`tfhe::shortint` provides 2 key types: + - `ClientKey` + - `ServerKey` + +The `ClientKey` is the key that encrypts and decrypts messages (integer values up to 8 bits here), thus this key is meant to be kept private and should never be shared. This key is created from parameter values that will dictate both the security and efficiency of computations. The parameters also set the maximum number of bits of message encrypted in a ciphertext. + +The `ServerKey` is the key that is used to actually do the FHE computations. It contains (among other things) a bootstrapping key and a keyswitching key. This key is created from a `ClientKey` that needs to be shared to the server, therefore it is not meant to be kept private. A user with a `ServerKey` can compute on the encrypted data sent by the owner of the associated `ClientKey`. + +To reflect that, computation/operation methods are tied to the `ServerKey` type. + + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(Parameters::default()); +} +``` + + +# 2. Encrypting values + +Once the keys have been generated, the client key is used to encrypt data: + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(Parameters::default()); + + let msg1 = 1; + let msg2 = 0; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); +} +``` + +# 2 bis. Encrypting values using a public key + +Once the keys have been generated, the client key is used to encrypt data: + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(Parameters::default()); + let public_key = PublicKey::new(&client_key); + + let msg1 = 1; + let msg2 = 0; + + // We use the client key to encrypt two messages: + let ct_1 = public_key.encrypt(&server_key, msg1); + let ct_2 = public_key.encrypt(&server_key, msg2); +} +``` + + +# 3. Computing and decrypting + +With our `server_key`, and encrypted values, we can now do an addition +and then decrypt the result. + +```rust +use tfhe::shortint::prelude::*; + +fn main() { + // We generate a set of client/server keys, using the default parameters: + let (client_key, server_key) = gen_keys(Parameters::default()); + + let msg1 = 1; + let msg2 = 0; + + let modulus = client_key.parameters.message_modulus.0; + + // We use the client key to encrypt two messages: + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + // We use the server public key to execute an integer circuit: + let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); + + // We use the client key to decrypt the output of the circuit: + let output = client_key.decrypt(&ct_3); + assert_eq!(output, (msg1 + msg2) % modulus as u64); +} +``` diff --git a/tfhe/examples/generates_test_keys.rs b/tfhe/examples/generates_test_keys.rs new file mode 100644 index 000000000..7753e3546 --- /dev/null +++ b/tfhe/examples/generates_test_keys.rs @@ -0,0 +1,31 @@ +use tfhe::shortint::keycache::{FileStorage, NamedParam, PersistentStorage}; + +use tfhe::shortint::parameters::ALL_PARAMETER_VEC; +use tfhe::shortint::{gen_keys, ClientKey, ServerKey}; + +fn client_server_keys() { + let file_storage = FileStorage::new("keys/shortint/client_server".to_string()); + + println!("Generating (ClientKey, ServerKey)"); + for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() { + println!( + "Generating [{} / {}] : {}", + i + 1, + ALL_PARAMETER_VEC.len(), + params.name() + ); + + let keys: Option<(ClientKey, ServerKey)> = file_storage.load(params); + + if keys.is_some() { + continue; + } + + let client_server_keys = gen_keys(params); + file_storage.store(params, &client_server_keys); + } +} + +fn main() { + client_server_keys() +} diff --git a/tfhe/examples/micro_bench_and.rs b/tfhe/examples/micro_bench_and.rs new file mode 100644 index 000000000..27e516fec --- /dev/null +++ b/tfhe/examples/micro_bench_and.rs @@ -0,0 +1,28 @@ +use tfhe::boolean::client_key::ClientKey; +use tfhe::boolean::parameters::TFHE_LIB_PARAMETERS; +use tfhe::boolean::prelude::BinaryBooleanGates; +use tfhe::boolean::server_key::ServerKey; + +fn main() { + // let (cks, sks) = gen_keys(); + let cks = ClientKey::new(&TFHE_LIB_PARAMETERS); + let sks = ServerKey::new(&cks); + + let left = false; + let right = true; + + let ct_left = cks.encrypt(left); + let ct_right = cks.encrypt(right); + + let start = std::time::Instant::now(); + + let num_loops: usize = 10000; + + for _ in 0..num_loops { + let _ = sks.and(&ct_left, &ct_right); + } + let elapsed = start.elapsed().as_millis() as f64; + let mean: f64 = elapsed / num_loops as f64; + + println!("{elapsed:?} ms, mean {mean:?} ms"); +} diff --git a/tfhe/src/boolean/ciphertext/mod.rs b/tfhe/src/boolean/ciphertext/mod.rs new file mode 100644 index 000000000..461f0a1f3 --- /dev/null +++ b/tfhe/src/boolean/ciphertext/mod.rs @@ -0,0 +1,60 @@ +//! An encryption of a boolean message. +//! +//! This module implements the ciphertext structure containing an encryption of a Boolean message. + +use crate::core_crypto::prelude::*; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A structure containing a ciphertext, meant to encrypt a Boolean message. +/// +/// It is used to evaluate a Boolean circuits homomorphically. +#[derive(Clone, Debug)] +pub enum Ciphertext { + Encrypted(LweCiphertext32), + Trivial(bool), +} + +#[derive(Serialize, Deserialize)] +enum SerializableCiphertext { + Encrypted(Vec), + Trivial(bool), +} + +impl Serialize for Ciphertext { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + match self { + Ciphertext::Encrypted(lwe) => { + let ciphertext = ser_eng.serialize(lwe).map_err(serde::ser::Error::custom)?; + SerializableCiphertext::Encrypted(ciphertext) + } + Ciphertext::Trivial(b) => SerializableCiphertext::Trivial(*b), + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Ciphertext { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = SerializableCiphertext::deserialize(deserializer)?; + + let mut de_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + Ok(match thing { + SerializableCiphertext::Encrypted(data) => { + let lwe = de_eng + .deserialize(data.as_slice()) + .map_err(serde::de::Error::custom)?; + Self::Encrypted(lwe) + } + SerializableCiphertext::Trivial(b) => Self::Trivial(b), + }) + } +} diff --git a/tfhe/src/boolean/client_key/mod.rs b/tfhe/src/boolean/client_key/mod.rs new file mode 100644 index 000000000..4cf2bf5d1 --- /dev/null +++ b/tfhe/src/boolean/client_key/mod.rs @@ -0,0 +1,189 @@ +//! The secret key of the client. +//! +//! This module implements the generation of the client' secret keys, together with the +//! encryption and decryption methods. + +use crate::boolean::ciphertext::Ciphertext; +use crate::boolean::engine::{CpuBooleanEngine, WithThreadLocalEngine}; +use crate::boolean::parameters::BooleanParameters; +use crate::core_crypto::prelude::*; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt::{Debug, Formatter}; + +/// A structure containing the client key, which must be kept secret. +/// +/// In more details, it contains: +/// * `lwe_secret_key` - an LWE secret key, used to encrypt the inputs and decrypt the outputs. +/// This secret key is also used in the generation of bootstrapping and key switching keys. +/// * `glwe_secret_key` - a GLWE secret key, used to generate the bootstrapping keys and key +/// switching keys. +/// * `parameters` - the cryptographic parameter set. +#[derive(Clone)] +pub struct ClientKey { + pub(crate) lwe_secret_key: LweSecretKey32, + pub(crate) glwe_secret_key: GlweSecretKey32, + pub(crate) parameters: BooleanParameters, +} + +impl PartialEq for ClientKey { + fn eq(&self, other: &Self) -> bool { + self.parameters == other.parameters + && self.lwe_secret_key == other.lwe_secret_key + && self.glwe_secret_key == other.glwe_secret_key + } +} + +impl Debug for ClientKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "ClientKey {{ ")?; + write!(f, "lwe_secret_key: {:?}, ", self.lwe_secret_key)?; + write!(f, "glwe_secret_key: {:?}, ", self.glwe_secret_key)?; + write!(f, "parameters: {:?}, ", self.parameters)?; + write!(f, "engine: CoreEngine, ")?; + write!(f, "}}")?; + Ok(()) + } +} + +impl ClientKey { + /// Encrypts a Boolean message using the client key. + /// + /// # Example + /// + /// ```rust + /// # #[cfg(not(feature = "cuda"))] + /// # fn main() { + /// use tfhe::boolean::prelude::*; + /// + /// // Generate the client key and the server key: + /// let (cks, mut sks) = gen_keys(); + /// + /// // Encryption of one message: + /// let ct = cks.encrypt(true); + /// + /// // Decryption: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(true, dec); + /// # } + /// # #[cfg(feature = "cuda")] + /// # fn main() {} + /// ``` + pub fn encrypt(&self, message: bool) -> Ciphertext { + CpuBooleanEngine::with_thread_local_mut(|engine| engine.encrypt(message, self)) + } + + /// Decrypts a ciphertext encrypting a Boolean message using the client key. + /// + /// # Example + /// + /// ```rust + /// # #[cfg(not(feature = "cuda"))] + /// # fn main() { + /// use tfhe::boolean::prelude::*; + /// + /// // Generate the client key and the server key: + /// let (cks, mut sks) = gen_keys(); + /// + /// // Encryption of one message: + /// let ct = cks.encrypt(true); + /// + /// // Decryption: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(true, dec); + /// # } + /// # #[cfg(feature = "cuda")] + /// # fn main() {} + /// ``` + pub fn decrypt(&self, ct: &Ciphertext) -> bool { + CpuBooleanEngine::with_thread_local_mut(|engine| engine.decrypt(ct, self)) + } + + /// Allocates and generates a client key. + /// + /// # Panic + /// + /// This will panic when the "cuda" feature is enabled and the parameters + /// uses a GlweDimension > 1 (as it is not yet supported by the cuda backend). + /// + /// # Example + /// + /// ```rust + /// # #[cfg(not(feature = "cuda"))] + /// # fn main() { + /// use tfhe::boolean::client_key::ClientKey; + /// use tfhe::boolean::parameters::TFHE_LIB_PARAMETERS; + /// use tfhe::boolean::prelude::*; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(&TFHE_LIB_PARAMETERS); + /// # } + /// # #[cfg(feature = "cuda")] + /// # fn main() { + /// use tfhe::boolean::client_key::ClientKey; + /// use tfhe::boolean::parameters::GPU_DEFAULT_PARAMETERS; + /// use tfhe::boolean::prelude::*; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(&GPU_DEFAULT_PARAMETERS);} + /// ``` + pub fn new(parameter_set: &BooleanParameters) -> ClientKey { + #[cfg(feature = "cuda")] + { + if parameter_set.glwe_dimension.0 > 1 { + panic!("the cuda backend does not support support GlweSize greater than one"); + } + } + CpuBooleanEngine::with_thread_local_mut(|engine| engine.create_client_key(*parameter_set)) + } +} + +#[derive(Serialize, Deserialize)] +struct SerializableClientKey { + lwe_secret_key: Vec, + glwe_secret_key: Vec, + parameters: BooleanParameters, +} + +impl Serialize for ClientKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let lwe_secret_key = ser_eng + .serialize(&self.lwe_secret_key) + .map_err(serde::ser::Error::custom)?; + let glwe_secret_key = ser_eng + .serialize(&self.glwe_secret_key) + .map_err(serde::ser::Error::custom)?; + + SerializableClientKey { + lwe_secret_key, + glwe_secret_key, + parameters: self.parameters, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ClientKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = + SerializableClientKey::deserialize(deserializer).map_err(serde::de::Error::custom)?; + let mut de_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + Ok(Self { + lwe_secret_key: de_eng + .deserialize(thing.lwe_secret_key.as_slice()) + .map_err(serde::de::Error::custom)?, + glwe_secret_key: de_eng + .deserialize(thing.glwe_secret_key.as_slice()) + .map_err(serde::de::Error::custom)?, + parameters: thing.parameters, + }) + } +} diff --git a/tfhe/src/boolean/engine/bootstrapping/cpu.rs b/tfhe/src/boolean/engine/bootstrapping/cpu.rs new file mode 100644 index 000000000..2c401c05c --- /dev/null +++ b/tfhe/src/boolean/engine/bootstrapping/cpu.rs @@ -0,0 +1,317 @@ +use crate::boolean::ciphertext::Ciphertext; +use crate::boolean::{ClientKey, PLAINTEXT_TRUE}; +use crate::core_crypto::prelude::*; +use crate::seeders::new_seeder; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::error::Error; + +use super::{BooleanServerKey, Bootstrapper}; + +/// Memory used as buffer for the bootstrap +/// +/// It contains contiguous chunk which is then sliced and converted +/// into core's View types. +#[derive(Default)] +struct Memory { + elements: Vec, +} + +impl Memory { + /// Returns a tuple with buffers that matches the server key. + /// + /// - The first element is the accumulator for bootstrap step. + /// - The second element is a lwe buffer where the result of the of the bootstrap should be + /// written + fn as_buffers( + &mut self, + engine: &mut DefaultEngine, + server_key: &CpuBootstrapKey, + ) -> (GlweCiphertextView32, LweCiphertextMutView32) { + let num_elem_in_accumulator = server_key + .bootstrapping_key + .glwe_dimension() + .to_glwe_size() + .0 + * server_key.bootstrapping_key.polynomial_size().0; + let num_elem_in_lwe = server_key + .bootstrapping_key + .output_lwe_dimension() + .to_lwe_size() + .0; + let total_elem_needed = num_elem_in_accumulator + num_elem_in_lwe; + + let all_elements = if self.elements.len() < total_elem_needed { + self.elements.resize(total_elem_needed, 0u32); + self.elements.as_mut_slice() + } else { + &mut self.elements[..total_elem_needed] + }; + + let (accumulator_elements, lwe_elements) = + all_elements.split_at_mut(num_elem_in_accumulator); + accumulator_elements + [num_elem_in_accumulator - server_key.bootstrapping_key.polynomial_size().0..] + .fill(PLAINTEXT_TRUE); + + let accumulator = engine + .create_glwe_ciphertext_from( + &*accumulator_elements, + server_key.bootstrapping_key.polynomial_size(), + ) + .unwrap(); + let lwe = engine.create_lwe_ciphertext_from(lwe_elements).unwrap(); + + (accumulator, lwe) + } +} + +/// A structure containing the server public key. +/// +/// This server key data lives on the CPU. +/// +/// The server key is generated by the client and is meant to be published: the client +/// sends it to the server so it can compute homomorphic Boolean circuits. +/// +/// In more details, it contains: +/// * `bootstrapping_key` - a public key, used to perform the bootstrapping operation. +/// * `key_switching_key` - a public key, used to perform the key-switching operation. +#[derive(Clone)] +pub struct CpuBootstrapKey { + pub(super) standard_bootstraping_key: LweBootstrapKey32, + pub(super) bootstrapping_key: FftFourierLweBootstrapKey32, + pub(super) key_switching_key: LweKeyswitchKey32, +} + +impl CpuBootstrapKey {} + +impl BooleanServerKey for CpuBootstrapKey { + fn lwe_size(&self) -> LweSize { + self.bootstrapping_key.input_lwe_dimension().to_lwe_size() + } +} + +/// Performs ciphertext bootstraps on the CPU +pub(crate) struct CpuBootstrapper { + memory: Memory, + engine: DefaultEngine, + fourier_engine: FftEngine, +} + +impl CpuBootstrapper { + pub(crate) fn new_server_key( + &mut self, + cks: &ClientKey, + ) -> Result> { + // convert into a variance for rlwe context + let var_rlwe = Variance(cks.parameters.glwe_modular_std_dev.get_variance()); + // creation of the bootstrapping key in the Fourier domain + + let standard_bootstraping_key: LweBootstrapKey32 = + self.engine.generate_new_lwe_bootstrap_key( + &cks.lwe_secret_key, + &cks.glwe_secret_key, + cks.parameters.pbs_base_log, + cks.parameters.pbs_level, + var_rlwe, + )?; + + let fourier_bsk = self + .fourier_engine + .convert_lwe_bootstrap_key(&standard_bootstraping_key)?; + + // Convert the GLWE secret key into an LWE secret key: + let big_lwe_secret_key = self + .engine + .transform_glwe_secret_key_to_lwe_secret_key(cks.glwe_secret_key.clone())?; + + // convert into a variance for lwe context + let var_lwe = Variance(cks.parameters.lwe_modular_std_dev.get_variance()); + // creation of the key switching key + let ksk = self.engine.generate_new_lwe_keyswitch_key( + &big_lwe_secret_key, + &cks.lwe_secret_key, + cks.parameters.ks_level, + cks.parameters.ks_base_log, + var_lwe, + )?; + + Ok(CpuBootstrapKey { + standard_bootstraping_key, + bootstrapping_key: fourier_bsk, + key_switching_key: ksk, + }) + } +} + +impl Default for CpuBootstrapper { + fn default() -> Self { + let engine = + DefaultEngine::new(new_seeder()).expect("Unexpectedly failed to create a core engine"); + + let fourier_engine = FftEngine::new(()).unwrap(); + Self { + memory: Default::default(), + engine, + fourier_engine, + } + } +} + +impl Bootstrapper for CpuBootstrapper { + type ServerKey = CpuBootstrapKey; + + fn bootstrap( + &mut self, + input: &LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result> { + let (accumulator, mut buffer_lwe_after_pbs) = + self.memory.as_buffers(&mut self.engine, server_key); + + // Need to be able to get view from &Lwe + let inner_data = self + .engine + .consume_retrieve_lwe_ciphertext(input.clone()) + .unwrap(); + let input = self + .engine + .create_lwe_ciphertext_from(inner_data.as_slice()) + .unwrap(); + + self.fourier_engine.discard_bootstrap_lwe_ciphertext( + &mut buffer_lwe_after_pbs, + &input, + &accumulator, + &server_key.bootstrapping_key, + )?; + + let data = self + .engine + .consume_retrieve_lwe_ciphertext(buffer_lwe_after_pbs) + .unwrap() + .to_vec(); + let ct = self.engine.create_lwe_ciphertext_from(data)?; + Ok(ct) + } + + fn keyswitch( + &mut self, + input: &LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result> { + // Allocate the output of the KS + let mut ct_ks = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0])?; + + self.engine.discard_keyswitch_lwe_ciphertext( + &mut ct_ks, + input, + &server_key.key_switching_key, + )?; + + Ok(ct_ks) + } + + fn bootstrap_keyswitch( + &mut self, + ciphertext: LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result> { + let (accumulator, mut buffer_lwe_after_pbs) = + self.memory.as_buffers(&mut self.engine, server_key); + + let mut inner_data = self + .engine + .consume_retrieve_lwe_ciphertext(ciphertext) + .unwrap(); + let input_view = self + .engine + .create_lwe_ciphertext_from(inner_data.as_slice())?; + + self.fourier_engine.discard_bootstrap_lwe_ciphertext( + &mut buffer_lwe_after_pbs, + &input_view, + &accumulator, + &server_key.bootstrapping_key, + )?; + + // Convert from a mut view to a view + let slice = self + .engine + .consume_retrieve_lwe_ciphertext(buffer_lwe_after_pbs) + .unwrap(); + let buffer_lwe_after_pbs = self.engine.create_lwe_ciphertext_from(&*slice)?; + + let mut output_view = self + .engine + .create_lwe_ciphertext_from(inner_data.as_mut_slice())?; + + // Compute a key switch to get back to input key + self.engine.discard_keyswitch_lwe_ciphertext( + &mut output_view, + &buffer_lwe_after_pbs, + &server_key.key_switching_key, + )?; + + let ciphertext = self.engine.create_lwe_ciphertext_from(inner_data)?; + + Ok(Ciphertext::Encrypted(ciphertext)) + } +} + +#[derive(Serialize, Deserialize)] +struct SerializableCpuServerKey { + pub standard_bootstraping_key: Vec, + pub key_switching_key: Vec, +} + +impl Serialize for CpuBootstrapKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let key_switching_key = ser_eng + .serialize(&self.key_switching_key) + .map_err(serde::ser::Error::custom)?; + let standard_bootstraping_key = ser_eng + .serialize(&self.standard_bootstraping_key) + .map_err(serde::ser::Error::custom)?; + + SerializableCpuServerKey { + key_switching_key, + standard_bootstraping_key, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for CpuBootstrapKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = SerializableCpuServerKey::deserialize(deserializer) + .map_err(serde::de::Error::custom)?; + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + let key_switching_key = ser_eng + .deserialize(thing.key_switching_key.as_slice()) + .map_err(serde::de::Error::custom)?; + let standard_bootstraping_key = ser_eng + .deserialize(thing.standard_bootstraping_key.as_slice()) + .map_err(serde::de::Error::custom)?; + let bootstrapping_key = FftEngine::new(()) + .unwrap() + .convert_lwe_bootstrap_key(&standard_bootstraping_key) + .map_err(serde::de::Error::custom)?; + + Ok(Self { + standard_bootstraping_key, + key_switching_key, + bootstrapping_key, + }) + } +} diff --git a/tfhe/src/boolean/engine/bootstrapping/cuda.rs b/tfhe/src/boolean/engine/bootstrapping/cuda.rs new file mode 100644 index 000000000..227b2a00e --- /dev/null +++ b/tfhe/src/boolean/engine/bootstrapping/cuda.rs @@ -0,0 +1,240 @@ +use super::{BooleanServerKey, Bootstrapper, CpuBootstrapKey}; +use crate::boolean::PLAINTEXT_TRUE; +use crate::core_crypto::prelude::*; +use crate::seeders::new_seeder; + +use std::collections::BTreeMap; + +use crate::boolean::ciphertext::Ciphertext; + +pub(crate) struct CudaBootstrapKey { + bootstrapping_key: CudaFourierLweBootstrapKey32, + key_switching_key: CudaLweKeyswitchKey32, +} + +impl BooleanServerKey for CudaBootstrapKey { + fn lwe_size(&self) -> LweSize { + self.bootstrapping_key.input_lwe_dimension().to_lwe_size() + } +} + +#[derive(PartialOrd, PartialEq, Ord, Eq)] +struct KeyId { + // Both of these are for the accumulator + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_dimension_after_bootstrap: LweDimension, +} + +#[derive(Default)] +struct CudaMemory { + cuda_buffers: BTreeMap, +} + +/// All the buffers needed to do a bootstrap or a keyswitch or bootstrap + keyswitch +struct CudaBuffers { + accumulator: CudaGlweCiphertext32, + // Its size is the one of a ciphertext after pbs + lwe_after_bootstrap: CudaLweCiphertext32, + // Its size is the one of a ciphertext after a keyswitch + // ie the size of a ciphertext before the bootstrap + lwe_after_keyswitch: CudaLweCiphertext32, +} + +impl CudaMemory { + /// Returns the buffers that matches the given key. + fn as_buffers_for_key( + &mut self, + cpu_engine: &mut DefaultEngine, + cuda_engine: &mut CudaEngine, + server_key: &CudaBootstrapKey, + ) -> &mut CudaBuffers { + let key_id = KeyId { + glwe_dimension: server_key.bootstrapping_key.glwe_dimension(), + polynomial_size: server_key.bootstrapping_key.polynomial_size(), + lwe_dimension_after_bootstrap: server_key.bootstrapping_key.output_lwe_dimension(), + }; + + self.cuda_buffers.entry(key_id).or_insert_with(|| { + let output_lwe_size = server_key + .bootstrapping_key + .output_lwe_dimension() + .to_lwe_size(); + let output_ciphertext = cpu_engine + .create_lwe_ciphertext_from(vec![0u32; output_lwe_size.0]) + .unwrap(); + let cuda_lwe_after_bootstrap = cuda_engine + .convert_lwe_ciphertext(&output_ciphertext) + .unwrap(); + + let num_elements = server_key + .bootstrapping_key + .glwe_dimension() + .to_glwe_size() + .0 + * server_key.bootstrapping_key.polynomial_size().0; + let mut acc = vec![0u32; num_elements]; + acc[num_elements - server_key.bootstrapping_key.polynomial_size().0..] + .fill(PLAINTEXT_TRUE); + let accumulator = cpu_engine + .create_glwe_ciphertext_from(acc, server_key.bootstrapping_key.polynomial_size()) + .unwrap(); + let cuda_accumulator = cuda_engine.convert_glwe_ciphertext(&accumulator).unwrap(); + + let lwe_size_after_keyswitch = server_key + .key_switching_key + .output_lwe_dimension() + .to_lwe_size(); + let output_ciphertext = cpu_engine + .create_lwe_ciphertext_from(vec![0u32; lwe_size_after_keyswitch.0]) + .unwrap(); + let cuda_lwe_after_keyswitch = cuda_engine + .convert_lwe_ciphertext(&output_ciphertext) + .unwrap(); + + CudaBuffers { + accumulator: cuda_accumulator, + lwe_after_bootstrap: cuda_lwe_after_bootstrap, + lwe_after_keyswitch: cuda_lwe_after_keyswitch, + } + }) + } +} + +pub(crate) struct CudaBootstrapper { + cuda_engine: CudaEngine, + cpu_engine: DefaultEngine, + memory: CudaMemory, +} + +impl Default for CudaBootstrapper { + fn default() -> Self { + Self { + cuda_engine: CudaEngine::new(()).unwrap(), + // Secret does not matter, we won't generate keys or ciphertext. + cpu_engine: DefaultEngine::new(new_seeder()).unwrap(), + memory: Default::default(), + } + } +} + +impl CudaBootstrapper { + pub(crate) fn new_serverk_key( + &mut self, + server_key: &CpuBootstrapKey, + ) -> Result> { + let bootstrapping_key = self + .cuda_engine + .convert_lwe_bootstrap_key(&server_key.standard_bootstraping_key)?; + + let key_switching_key = self + .cuda_engine + .convert_lwe_keyswitch_key(&server_key.key_switching_key)?; + + Ok(CudaBootstrapKey { + bootstrapping_key, + key_switching_key, + }) + } +} + +impl Bootstrapper for CudaBootstrapper { + type ServerKey = CudaBootstrapKey; + + fn bootstrap( + &mut self, + input: &LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result> { + let cuda_buffers = + self.memory + .as_buffers_for_key(&mut self.cpu_engine, &mut self.cuda_engine, server_key); + + // The output size of keyswitch is the one of regular boolean ciphertext + // so we can use lwe_after_keyswitch + self.cuda_engine + .discard_convert_lwe_ciphertext(&mut cuda_buffers.lwe_after_keyswitch, input)?; + + self.cuda_engine.discard_bootstrap_lwe_ciphertext( + &mut cuda_buffers.lwe_after_bootstrap, + &cuda_buffers.lwe_after_keyswitch, + &cuda_buffers.accumulator, + &server_key.bootstrapping_key, + )?; + + let output_ciphertext = self + .cuda_engine + .convert_lwe_ciphertext(&cuda_buffers.lwe_after_bootstrap)?; + Ok(output_ciphertext) + } + + fn keyswitch( + &mut self, + input: &LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result> { + let cuda_buffers = + self.memory + .as_buffers_for_key(&mut self.cpu_engine, &mut self.cuda_engine, server_key); + + // The input of the function we implement must be a ciphertext that result of a bootstrap + // so we can discard convert in the lwe ciphertext after bootstrap + self.cuda_engine + .discard_convert_lwe_ciphertext(&mut cuda_buffers.lwe_after_bootstrap, input)?; + + self.cuda_engine.discard_keyswitch_lwe_ciphertext( + &mut cuda_buffers.lwe_after_keyswitch, + &cuda_buffers.lwe_after_bootstrap, + &server_key.key_switching_key, + )?; + + let output_ciphertext = self + .cuda_engine + .convert_lwe_ciphertext(&cuda_buffers.lwe_after_keyswitch)?; + Ok(output_ciphertext) + } + + fn bootstrap_keyswitch( + &mut self, + ciphertext: LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result> { + // We re-implement instead of calling our bootstrap and then keyswitch fn + // to avoid one extra conversion / copy cpu <-> gpu + + let cuda_buffers = + self.memory + .as_buffers_for_key(&mut self.cpu_engine, &mut self.cuda_engine, server_key); + + // The output size of keyswitch is the one of regular boolean ciphertext + // so we can use it + self.cuda_engine + .discard_convert_lwe_ciphertext(&mut cuda_buffers.lwe_after_keyswitch, &ciphertext)?; + + self.cuda_engine.discard_bootstrap_lwe_ciphertext( + &mut cuda_buffers.lwe_after_bootstrap, + &cuda_buffers.lwe_after_keyswitch, + &cuda_buffers.accumulator, + &server_key.bootstrapping_key, + )?; + + self.cuda_engine.discard_keyswitch_lwe_ciphertext( + &mut cuda_buffers.lwe_after_keyswitch, + &cuda_buffers.lwe_after_bootstrap, + &server_key.key_switching_key, + )?; + + // We write the result from gpu to cpu avoiding an extra allocation + let mut data = self + .cpu_engine + .consume_retrieve_lwe_ciphertext(ciphertext)?; + let mut view = self + .cpu_engine + .create_lwe_ciphertext_from(data.as_mut_slice())?; + self.cuda_engine + .discard_convert_lwe_ciphertext(&mut view, &cuda_buffers.lwe_after_keyswitch)?; + let output_ciphertext = self.cpu_engine.create_lwe_ciphertext_from(data)?; + + Ok(Ciphertext::Encrypted(output_ciphertext)) + } +} diff --git a/tfhe/src/boolean/engine/bootstrapping/mod.rs b/tfhe/src/boolean/engine/bootstrapping/mod.rs new file mode 100644 index 000000000..89d1fe61d --- /dev/null +++ b/tfhe/src/boolean/engine/bootstrapping/mod.rs @@ -0,0 +1,48 @@ +use crate::boolean::ciphertext::Ciphertext; +use crate::core_crypto::prelude::{LweCiphertext32, LweSize}; +mod cpu; +#[cfg(feature = "cuda")] +mod cuda; + +#[cfg(feature = "cuda")] +pub(crate) use cuda::{CudaBootstrapKey, CudaBootstrapper}; + +pub(crate) use cpu::{CpuBootstrapKey, CpuBootstrapper}; + +pub trait BooleanServerKey { + /// The LweSize of the Ciphertexts that this key can bootstrap + fn lwe_size(&self) -> LweSize; +} + +/// Trait for types which implement the bootstrapping + key switching +/// of a ciphertext. +/// +/// Meant to be implemented for different hardware (CPU, GPU) or for other bootstrapping +/// technics. +pub(crate) trait Bootstrapper: Default { + type ServerKey: BooleanServerKey; + + /// Shall return the result of the bootstrapping of the + /// input ciphertext or an error if any + fn bootstrap( + &mut self, + input: &LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result>; + + /// Shall return the result of the key switching of the + /// input ciphertext or an error if any + fn keyswitch( + &mut self, + input: &LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result>; + + /// Shall do the bootstrapping and key switching of the ciphertext. + /// The result is returned as a new value. + fn bootstrap_keyswitch( + &mut self, + ciphertext: LweCiphertext32, + server_key: &Self::ServerKey, + ) -> Result>; +} diff --git a/tfhe/src/boolean/engine/mod.rs b/tfhe/src/boolean/engine/mod.rs new file mode 100644 index 000000000..3d07d5ca0 --- /dev/null +++ b/tfhe/src/boolean/engine/mod.rs @@ -0,0 +1,936 @@ +use crate::boolean::ciphertext::Ciphertext; +use crate::boolean::parameters::BooleanParameters; +use crate::boolean::{ClientKey, PublicKey, PLAINTEXT_FALSE, PLAINTEXT_TRUE}; +use crate::core_crypto::prelude::*; +use bootstrapping::{BooleanServerKey, Bootstrapper, CpuBootstrapper}; +use std::cell::RefCell; +pub mod bootstrapping; +use crate::boolean::engine::bootstrapping::CpuBootstrapKey; +use crate::core_crypto::backends::default::engines::ActivatedRandomGenerator; +use crate::core_crypto::commons::crypto::secret::generators::DeterministicSeeder; +use crate::seeders::new_seeder; + +#[cfg(feature = "cuda")] +use bootstrapping::{CudaBootstrapKey, CudaBootstrapper}; + +pub(crate) trait BinaryGatesEngine { + fn and(&mut self, ct_left: L, ct_right: R, server_key: &K) -> Ciphertext; + fn nand(&mut self, ct_left: L, ct_right: R, server_key: &K) -> Ciphertext; + fn nor(&mut self, ct_left: L, ct_right: R, server_key: &K) -> Ciphertext; + fn or(&mut self, ct_left: L, ct_right: R, server_key: &K) -> Ciphertext; + fn xor(&mut self, ct_left: L, ct_right: R, server_key: &K) -> Ciphertext; + fn xnor(&mut self, ct_left: L, ct_right: R, server_key: &K) -> Ciphertext; +} + +pub(crate) trait BinaryGatesAssignEngine { + fn and_assign(&mut self, ct_left: L, ct_right: R, server_key: &K); + fn nand_assign(&mut self, ct_left: L, ct_right: R, server_key: &K); + fn nor_assign(&mut self, ct_left: L, ct_right: R, server_key: &K); + fn or_assign(&mut self, ct_left: L, ct_right: R, server_key: &K); + fn xor_assign(&mut self, ct_left: L, ct_right: R, server_key: &K); + fn xnor_assign(&mut self, ct_left: L, ct_right: R, server_key: &K); +} + +/// Trait to be able to acces thread_local +/// engines in a generic way +pub(crate) trait WithThreadLocalEngine { + fn with_thread_local_mut(func: F) -> R + where + F: FnOnce(&mut Self) -> R; +} + +pub(crate) type CpuBooleanEngine = BooleanEngine; +#[cfg(feature = "cuda")] +pub(crate) type CudaBooleanEngine = BooleanEngine; + +// All our thread local engines +// that our exposed types will use internally to implement their methods +thread_local! { + static CPU_ENGINE: RefCell> = RefCell::new(BooleanEngine::<_>::new()); + #[cfg(feature = "cuda")] + static CUDA_ENGINE: RefCell> = RefCell::new(BooleanEngine::<_>::new()); +} + +impl WithThreadLocalEngine for CpuBooleanEngine { + fn with_thread_local_mut(func: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + CPU_ENGINE.with(|engine_cell| func(&mut engine_cell.borrow_mut())) + } +} + +#[cfg(feature = "cuda")] +impl WithThreadLocalEngine for CudaBooleanEngine { + fn with_thread_local_mut(func: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + CUDA_ENGINE.with(|engine_cell| func(&mut engine_cell.borrow_mut())) + } +} + +pub(crate) struct BooleanEngine { + pub(crate) engine: DefaultEngine, + bootstrapper: B, +} + +impl BooleanEngine { + pub fn create_server_key(&mut self, cks: &ClientKey) -> CpuBootstrapKey { + let server_key = self.bootstrapper.new_server_key(cks).unwrap(); + + server_key + } +} + +#[cfg(feature = "cuda")] +impl BooleanEngine { + pub fn create_server_key(&mut self, cpu_key: &CpuBootstrapKey) -> CudaBootstrapKey { + let server_key = self.bootstrapper.new_serverk_key(cpu_key).unwrap(); + + server_key + } +} + +// We have q = 32 so log2q = 5 +const LOG2_Q_32: usize = 5; + +impl BooleanEngine { + pub fn create_client_key(&mut self, parameters: BooleanParameters) -> ClientKey { + // generate the lwe secret key + let lwe_secret_key: LweSecretKey32 = self + .engine + .generate_new_lwe_secret_key(parameters.lwe_dimension) + .unwrap(); + + // generate the rlwe secret key + let glwe_secret_key: GlweSecretKey32 = self + .engine + .generate_new_glwe_secret_key(parameters.glwe_dimension, parameters.polynomial_size) + .unwrap(); + + ClientKey { + lwe_secret_key, + glwe_secret_key, + parameters, + } + } + + pub fn create_public_key(&mut self, client_key: &ClientKey) -> PublicKey { + let client_parameters = client_key.parameters; + + // Formula is (n + 1) * log2(q) + 128 + let zero_encryption_count = LwePublicKeyZeroEncryptionCount( + client_parameters.lwe_dimension.to_lwe_size().0 * LOG2_Q_32 + 128, + ); + + PublicKey { + lwe_public_key: self + .engine + .generate_new_lwe_public_key( + &client_key.lwe_secret_key, + Variance(client_key.parameters.lwe_modular_std_dev.get_variance()), + zero_encryption_count, + ) + .unwrap(), + parameters: client_key.parameters.to_owned(), + } + } + + pub fn trivial_encrypt(&mut self, message: bool) -> Ciphertext { + Ciphertext::Trivial(message) + } + + pub fn encrypt(&mut self, message: bool, cks: &ClientKey) -> Ciphertext { + // encode the boolean message + let plain: Plaintext32 = if message { + self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap() + } else { + self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap() + }; + + // convert into a variance + let var = Variance(cks.parameters.lwe_modular_std_dev.get_variance()); + + // encryption + let ct = self + .engine + .encrypt_lwe_ciphertext(&cks.lwe_secret_key, &plain, var) + .unwrap(); + + Ciphertext::Encrypted(ct) + } + + pub fn encrypt_with_public_key(&mut self, message: bool, pks: &PublicKey) -> Ciphertext { + // encode the boolean message + let plain: Plaintext32 = if message { + self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap() + } else { + self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap() + }; + + let mut underlying_ciphertext = self + .engine + .create_lwe_ciphertext_from(vec![0u32; pks.parameters.lwe_dimension.to_lwe_size().0]) + .unwrap(); + + // encryption + self.engine + .discard_encrypt_lwe_ciphertext_with_public_key( + &pks.lwe_public_key, + &mut underlying_ciphertext, + &plain, + ) + .unwrap(); + + Ciphertext::Encrypted(underlying_ciphertext) + } + + pub fn decrypt(&mut self, ct: &Ciphertext, cks: &ClientKey) -> bool { + match ct { + Ciphertext::Trivial(b) => *b, + Ciphertext::Encrypted(ciphertext) => { + // decryption + let decrypted = self + .engine + .decrypt_lwe_ciphertext(&cks.lwe_secret_key, ciphertext) + .unwrap(); + + // cast as a u32 + let mut decrypted_u32: u32 = 0; + self.engine + .discard_retrieve_plaintext(&mut decrypted_u32, &decrypted) + .unwrap(); + + // return + decrypted_u32 < (1 << 31) + } + } + } + + pub fn not(&mut self, ct: &Ciphertext) -> Ciphertext { + match ct { + Ciphertext::Trivial(message) => Ciphertext::Trivial(!*message), + Ciphertext::Encrypted(ct_ct) => { + // Compute the linear combination for NOT: -ct + let mut ct_res = ct_ct.clone(); + self.engine.fuse_opp_lwe_ciphertext(&mut ct_res).unwrap(); // compute the negation + + // Output the result: + Ciphertext::Encrypted(ct_res) + } + } + } + + pub fn not_assign(&mut self, ct: &mut Ciphertext) { + match ct { + Ciphertext::Trivial(message) => *message = !*message, + Ciphertext::Encrypted(ct_ct) => { + self.engine.fuse_opp_lwe_ciphertext(ct_ct).unwrap(); // compute the negation + } + } + } +} + +impl BooleanEngine +where + B: Bootstrapper, +{ + pub fn new() -> Self { + let root_seeder = new_seeder(); + + Self::new_from_seeder(root_seeder) + } + + pub fn new_from_seeder(mut root_seeder: Box) -> Self { + let mut deterministic_seeder = + DeterministicSeeder::::new(root_seeder.seed()); + + let default_engine_seeder = Box::new(DeterministicSeeder::::new( + deterministic_seeder.seed(), + )); + + let engine = + DefaultEngine::new(default_engine_seeder).expect("Failed to create a DefaultEngine"); + Self { + engine, + bootstrapper: Default::default(), + } + } + + /// convert into an actual LWE ciphertext even when trivial + fn convert_into_lwe_ciphertext_32( + &mut self, + ct: &Ciphertext, + server_key: &B::ServerKey, + ) -> LweCiphertext32 { + match ct { + Ciphertext::Encrypted(ct_ct) => ct_ct.clone(), + Ciphertext::Trivial(message) => { + // encode the boolean message + let plain: Plaintext32 = if *message { + self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap() + } else { + self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap() + }; + self.engine + .trivially_encrypt_lwe_ciphertext(server_key.lwe_size(), &plain) + .unwrap() + } + } + } + + pub fn mux( + &mut self, + ct_condition: &Ciphertext, + ct_then: &Ciphertext, + ct_else: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + // In theory MUX gate = (ct_condition AND ct_then) + (!ct_condition AND ct_else) + + match ct_condition { + // in the case of the condition is trivially encrypted + Ciphertext::Trivial(message_condition) => { + if *message_condition { + ct_then.clone() + } else { + ct_else.clone() + } + } + Ciphertext::Encrypted(ct_condition_ct) => { + // condition is actually encrypted + + // take a shortcut if ct_then is trivially encrypted + if let Ciphertext::Trivial(message_then) = ct_then { + return if *message_then { + self.or(ct_condition, ct_else, server_key) + } else { + let ct_not_condition = self.not(ct_condition); + self.and(&ct_not_condition, ct_else, server_key) + }; + } + + // take a shortcut if ct_else is trivially encrypted + if let Ciphertext::Trivial(message_else) = ct_else { + return if *message_else { + let ct_not_condition = self.not(ct_condition); + self.or(ct_then, &ct_not_condition, server_key) + } else { + self.and(ct_condition, ct_then, server_key) + }; + } + + // convert inputs into LweCiphertext32 + let ct_then_ct = self.convert_into_lwe_ciphertext_32(ct_then, server_key); + let ct_else_ct = self.convert_into_lwe_ciphertext_32(ct_else, server_key); + + let mut buffer_lwe_before_pbs_o = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let buffer_lwe_before_pbs = &mut buffer_lwe_before_pbs_o; + let bootstrapper = &mut self.bootstrapper; + + // Compute the linear combination for first AND: ct_condition + ct_then + + // (0,...,0,-1/8) + self.engine + .discard_add_lwe_ciphertext(buffer_lwe_before_pbs, ct_condition_ct, &ct_then_ct) + .unwrap(); // ct_condition + ct_then + let cst = self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(buffer_lwe_before_pbs, &cst) + .unwrap(); // + // - 1/8 + + // Compute the linear combination for second AND: - ct_condition + ct_else + + // (0,...,0,-1/8) + let mut ct_temp_2 = ct_condition_ct.clone(); // ct_condition + self.engine.fuse_opp_lwe_ciphertext(&mut ct_temp_2).unwrap(); // compute the negation + self.engine + .fuse_add_lwe_ciphertext(&mut ct_temp_2, &ct_else_ct) + .unwrap(); // + ct_else + let cst = self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut ct_temp_2, &cst) + .unwrap(); // + // - 1/8 + + // Compute the first programmable bootstrapping with fixed test polynomial: + let mut ct_pbs_1 = bootstrapper + .bootstrap(buffer_lwe_before_pbs, server_key) + .unwrap(); + + let ct_pbs_2 = bootstrapper.bootstrap(&ct_temp_2, server_key).unwrap(); + + // Compute the linear combination to add the two results: + // buffer_lwe_pbs + ct_pbs_2 + (0,...,0, +1/8) + self.engine + .fuse_add_lwe_ciphertext(&mut ct_pbs_1, &ct_pbs_2) + .unwrap(); // + buffer_lwe_pbs + let cst = self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut ct_pbs_1, &cst) + .unwrap(); // + 1/8 + + let ct_ks = bootstrapper.keyswitch(&ct_pbs_1, server_key).unwrap(); + + // Output the result: + Ciphertext::Encrypted(ct_ks) + } + } + } +} + +impl BinaryGatesEngine<&Ciphertext, &Ciphertext, B::ServerKey> for BooleanEngine +where + B: Bootstrapper, +{ + fn and( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + match (ct_left, ct_right) { + (Ciphertext::Trivial(message_left), Ciphertext::Trivial(message_right)) => { + Ciphertext::Trivial(*message_left && *message_right) + } + (Ciphertext::Encrypted(_), Ciphertext::Trivial(message_right)) => { + self.and(ct_left, *message_right, server_key) + } + (Ciphertext::Trivial(message_left), Ciphertext::Encrypted(_)) => { + self.and(*message_left, ct_right, server_key) + } + (Ciphertext::Encrypted(ct_left_ct), Ciphertext::Encrypted(ct_right_ct)) => { + let mut buffer_lwe_before_pbs = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let bootstrapper = &mut self.bootstrapper; + + // compute the linear combination for AND: ct_left + ct_right + (0,...,0,-1/8) + self.engine + .discard_add_lwe_ciphertext(&mut buffer_lwe_before_pbs, ct_left_ct, ct_right_ct) + .unwrap(); // ct_left + ct_right + let cst = self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut buffer_lwe_before_pbs, &cst) + .unwrap(); // + // - 1/8 + + // compute the bootstrap and the key switch + bootstrapper + .bootstrap_keyswitch(buffer_lwe_before_pbs, server_key) + .unwrap() + } + } + } + + fn nand( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + match (ct_left, ct_right) { + (Ciphertext::Trivial(message_left), Ciphertext::Trivial(message_right)) => { + Ciphertext::Trivial(!(*message_left && *message_right)) + } + (Ciphertext::Encrypted(_), Ciphertext::Trivial(message_right)) => { + self.nand(ct_left, *message_right, server_key) + } + (Ciphertext::Trivial(message_left), Ciphertext::Encrypted(_)) => { + self.nand(*message_left, ct_right, server_key) + } + (Ciphertext::Encrypted(ct_left_ct), Ciphertext::Encrypted(ct_right_ct)) => { + let mut buffer_lwe_before_pbs = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let bootstrapper = &mut self.bootstrapper; + + // Compute the linear combination for NAND: - ct_left - ct_right + (0,...,0,1/8) + self.engine + .discard_add_lwe_ciphertext(&mut buffer_lwe_before_pbs, ct_left_ct, ct_right_ct) + .unwrap(); // ct_left + ct_right + self.engine + .fuse_opp_lwe_ciphertext(&mut buffer_lwe_before_pbs) + .unwrap(); // compute the negation + let cst = self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut buffer_lwe_before_pbs, &cst) + .unwrap(); // + 1/8 + + // compute the bootstrap and the key switch + bootstrapper + .bootstrap_keyswitch(buffer_lwe_before_pbs, server_key) + .unwrap() + } + } + } + + fn nor( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + match (ct_left, ct_right) { + (Ciphertext::Trivial(message_left), Ciphertext::Trivial(message_right)) => { + Ciphertext::Trivial(!(*message_left || *message_right)) + } + (Ciphertext::Encrypted(_), Ciphertext::Trivial(message_right)) => { + self.nor(ct_left, *message_right, server_key) + } + (Ciphertext::Trivial(message_left), Ciphertext::Encrypted(_)) => { + self.nor(*message_left, ct_right, server_key) + } + (Ciphertext::Encrypted(ct_left_ct), Ciphertext::Encrypted(ct_right_ct)) => { + let mut buffer_lwe_before_pbs = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let bootstrapper = &mut self.bootstrapper; + + // Compute the linear combination for NOR: - ct_left - ct_right + (0,...,0,-1/8) + self.engine + .discard_add_lwe_ciphertext(&mut buffer_lwe_before_pbs, ct_left_ct, ct_right_ct) + .unwrap(); // ct_left + ct_right + self.engine + .fuse_opp_lwe_ciphertext(&mut buffer_lwe_before_pbs) + .unwrap(); // compute the negation + let cst = self.engine.create_plaintext_from(&PLAINTEXT_FALSE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut buffer_lwe_before_pbs, &cst) + .unwrap(); // + // - 1/8 + + // compute the bootstrap and the key switch + bootstrapper + .bootstrap_keyswitch(buffer_lwe_before_pbs, server_key) + .unwrap() + } + } + } + + fn or( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + match (ct_left, ct_right) { + (Ciphertext::Trivial(message_left), Ciphertext::Trivial(message_right)) => { + Ciphertext::Trivial(*message_left || *message_right) + } + (Ciphertext::Encrypted(_), Ciphertext::Trivial(message_right)) => { + self.or(ct_left, *message_right, server_key) + } + (Ciphertext::Trivial(message_left), Ciphertext::Encrypted(_)) => { + self.or(*message_left, ct_right, server_key) + } + (Ciphertext::Encrypted(ct_left_ct), Ciphertext::Encrypted(ct_right_ct)) => { + let mut buffer_lwe_before_pbs = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let bootstrapper = &mut self.bootstrapper; + + // Compute the linear combination for OR: ct_left + ct_right + (0,...,0,+1/8) + self.engine + .discard_add_lwe_ciphertext(&mut buffer_lwe_before_pbs, ct_left_ct, ct_right_ct) + .unwrap(); // ct_left + ct_right + let cst = self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut buffer_lwe_before_pbs, &cst) + .unwrap(); // + 1/8 + + // compute the bootstrap and the key switch + bootstrapper + .bootstrap_keyswitch(buffer_lwe_before_pbs, server_key) + .unwrap() + } + } + } + + fn xor( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + match (ct_left, ct_right) { + (Ciphertext::Trivial(message_left), Ciphertext::Trivial(message_right)) => { + Ciphertext::Trivial(*message_left ^ *message_right) + } + (Ciphertext::Encrypted(_), Ciphertext::Trivial(message_right)) => { + self.xor(ct_left, *message_right, server_key) + } + (Ciphertext::Trivial(message_left), Ciphertext::Encrypted(_)) => { + self.xor(*message_left, ct_right, server_key) + } + (Ciphertext::Encrypted(ct_left_ct), Ciphertext::Encrypted(ct_right_ct)) => { + let mut buffer_lwe_before_pbs = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let bootstrapper = &mut self.bootstrapper; + + // Compute the linear combination for XOR: 2*(ct_left + ct_right) + (0,...,0,1/4) + self.engine + .discard_add_lwe_ciphertext(&mut buffer_lwe_before_pbs, ct_left_ct, ct_right_ct) + .unwrap(); // ct_left + ct_right + let cst_add = self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut buffer_lwe_before_pbs, &cst_add) + .unwrap(); // + 1/8 + let cst_mul = self.engine.create_cleartext_from(&2u32).unwrap(); + self.engine + .fuse_mul_lwe_ciphertext_cleartext(&mut buffer_lwe_before_pbs, &cst_mul) + .unwrap(); //* 2 + + // compute the bootstrap and the key switch + bootstrapper + .bootstrap_keyswitch(buffer_lwe_before_pbs, server_key) + .unwrap() + } + } + } + + fn xnor( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + match (ct_left, ct_right) { + (Ciphertext::Trivial(message_left), Ciphertext::Trivial(message_right)) => { + Ciphertext::Trivial(!(*message_left ^ *message_right)) + } + (Ciphertext::Encrypted(_), Ciphertext::Trivial(message_right)) => { + self.xnor(ct_left, *message_right, server_key) + } + (Ciphertext::Trivial(message_left), Ciphertext::Encrypted(_)) => { + self.xnor(*message_left, ct_right, server_key) + } + (Ciphertext::Encrypted(ct_left_ct), Ciphertext::Encrypted(ct_right_ct)) => { + let mut buffer_lwe_before_pbs = self + .engine + .create_lwe_ciphertext_from(vec![0u32; server_key.lwe_size().0]) + .unwrap(); + let bootstrapper = &mut self.bootstrapper; + + // Compute the linear combination for XNOR: 2*(-ct_left - ct_right + (0,...,0,-1/8)) + self.engine + .discard_add_lwe_ciphertext(&mut buffer_lwe_before_pbs, ct_left_ct, ct_right_ct) + .unwrap(); // ct_left + ct_right + let cst_add = self.engine.create_plaintext_from(&PLAINTEXT_TRUE).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut buffer_lwe_before_pbs, &cst_add) + .unwrap(); // + 1/8 + self.engine + .fuse_opp_lwe_ciphertext(&mut buffer_lwe_before_pbs) + .unwrap(); // compute the negation + let cst_mul = self.engine.create_cleartext_from(&2u32).unwrap(); + self.engine + .fuse_mul_lwe_ciphertext_cleartext(&mut buffer_lwe_before_pbs, &cst_mul) + .unwrap(); //* 2 + + // compute the bootstrap and the key switch + bootstrapper + .bootstrap_keyswitch(buffer_lwe_before_pbs, server_key) + .unwrap() + } + } + } +} + +impl BinaryGatesAssignEngine<&mut Ciphertext, &Ciphertext, B::ServerKey> for BooleanEngine +where + B: Bootstrapper, +{ + fn and_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.and(&ct_left_clone, ct_right, server_key); + } + + fn nand_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.nand(&ct_left_clone, ct_right, server_key); + } + + fn nor_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.nor(&ct_left_clone, ct_right, server_key); + } + + fn or_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.or(&ct_left_clone, ct_right, server_key); + } + + fn xor_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.xor(&ct_left_clone, ct_right, server_key); + } + + fn xnor_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.xnor(&ct_left_clone, ct_right, server_key); + } +} + +impl BinaryGatesAssignEngine<&mut Ciphertext, bool, B::ServerKey> for BooleanEngine +where + B: Bootstrapper, +{ + fn and_assign(&mut self, ct_left: &mut Ciphertext, ct_right: bool, server_key: &B::ServerKey) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.and(&ct_left_clone, ct_right, server_key); + } + + fn nand_assign(&mut self, ct_left: &mut Ciphertext, ct_right: bool, server_key: &B::ServerKey) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.nand(&ct_left_clone, ct_right, server_key); + } + + fn nor_assign(&mut self, ct_left: &mut Ciphertext, ct_right: bool, server_key: &B::ServerKey) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.nor(&ct_left_clone, ct_right, server_key); + } + + fn or_assign(&mut self, ct_left: &mut Ciphertext, ct_right: bool, server_key: &B::ServerKey) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.or(&ct_left_clone, ct_right, server_key); + } + + fn xor_assign(&mut self, ct_left: &mut Ciphertext, ct_right: bool, server_key: &B::ServerKey) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.xor(&ct_left_clone, ct_right, server_key); + } + + fn xnor_assign(&mut self, ct_left: &mut Ciphertext, ct_right: bool, server_key: &B::ServerKey) { + let ct_left_clone = ct_left.clone(); + *ct_left = self.xnor(&ct_left_clone, ct_right, server_key); + } +} + +impl BinaryGatesAssignEngine for BooleanEngine +where + B: Bootstrapper, +{ + fn and_assign(&mut self, ct_left: bool, ct_right: &mut Ciphertext, server_key: &B::ServerKey) { + let ct_right_clone = ct_right.clone(); + *ct_right = self.and(ct_left, &ct_right_clone, server_key); + } + + fn nand_assign(&mut self, ct_left: bool, ct_right: &mut Ciphertext, server_key: &B::ServerKey) { + let ct_right_clone = ct_right.clone(); + *ct_right = self.nand(ct_left, &ct_right_clone, server_key); + } + + fn nor_assign(&mut self, ct_left: bool, ct_right: &mut Ciphertext, server_key: &B::ServerKey) { + let ct_right_clone = ct_right.clone(); + *ct_right = self.nor(ct_left, &ct_right_clone, server_key); + } + + fn or_assign(&mut self, ct_left: bool, ct_right: &mut Ciphertext, server_key: &B::ServerKey) { + let ct_right_clone = ct_right.clone(); + *ct_right = self.or(ct_left, &ct_right_clone, server_key); + } + + fn xor_assign(&mut self, ct_left: bool, ct_right: &mut Ciphertext, server_key: &B::ServerKey) { + let ct_right_clone = ct_right.clone(); + *ct_right = self.xor(ct_left, &ct_right_clone, server_key); + } + + fn xnor_assign(&mut self, ct_left: bool, ct_right: &mut Ciphertext, server_key: &B::ServerKey) { + let ct_right_clone = ct_right.clone(); + *ct_right = self.xnor(ct_left, &ct_right_clone, server_key); + } +} + +impl BinaryGatesEngine<&Ciphertext, bool, B::ServerKey> for BooleanEngine +where + B: Bootstrapper, +{ + fn and( + &mut self, + ct_left: &Ciphertext, + ct_right: bool, + _server_key: &B::ServerKey, + ) -> Ciphertext { + if ct_right { + // ct AND true = ct + ct_left.clone() + } else { + // ct AND false = false + self.trivial_encrypt(false) + } + } + + fn nand( + &mut self, + ct_left: &Ciphertext, + ct_right: bool, + _server_key: &B::ServerKey, + ) -> Ciphertext { + if ct_right { + // NOT (ct AND true) = NOT(ct) + self.not(ct_left) + } else { + // NOT (ct AND false) = NOT(false) = true + self.trivial_encrypt(true) + } + } + + fn nor( + &mut self, + ct_left: &Ciphertext, + ct_right: bool, + _server_key: &B::ServerKey, + ) -> Ciphertext { + if ct_right { + // NOT (ct OR true) = NOT(true) = false + self.trivial_encrypt(false) + } else { + // NOT (ct OR false) = NOT(ct) + self.not(ct_left) + } + } + + fn or( + &mut self, + ct_left: &Ciphertext, + ct_right: bool, + _server_key: &B::ServerKey, + ) -> Ciphertext { + if ct_right { + // ct OR true = true + self.trivial_encrypt(true) + } else { + // ct OR false = ct + ct_left.clone() + } + } + + fn xor( + &mut self, + ct_left: &Ciphertext, + ct_right: bool, + _server_key: &B::ServerKey, + ) -> Ciphertext { + if ct_right { + // ct XOR true = NOT(ct) + self.not(ct_left) + } else { + // ct XOR false = ct + ct_left.clone() + } + } + + fn xnor( + &mut self, + ct_left: &Ciphertext, + ct_right: bool, + _server_key: &B::ServerKey, + ) -> Ciphertext { + if ct_right { + // NOT(ct XOR true) = NOT(NOT(ct)) = ct + ct_left.clone() + } else { + // NOT(ct XOR false) = NOT(ct) + self.not(ct_left) + } + } +} + +impl BinaryGatesEngine for BooleanEngine +where + B: Bootstrapper, +{ + fn and( + &mut self, + ct_left: bool, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + self.and(ct_right, ct_left, server_key) + } + + fn nand( + &mut self, + ct_left: bool, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + self.nand(ct_right, ct_left, server_key) + } + + fn nor( + &mut self, + ct_left: bool, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + self.nor(ct_right, ct_left, server_key) + } + + fn or( + &mut self, + ct_left: bool, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + self.or(ct_right, ct_left, server_key) + } + + fn xor( + &mut self, + ct_left: bool, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + self.xor(ct_right, ct_left, server_key) + } + + fn xnor( + &mut self, + ct_left: bool, + ct_right: &Ciphertext, + server_key: &B::ServerKey, + ) -> Ciphertext { + self.xnor(ct_right, ct_left, server_key) + } +} diff --git a/tfhe/src/boolean/mod.rs b/tfhe/src/boolean/mod.rs new file mode 100644 index 000000000..7d714569f --- /dev/null +++ b/tfhe/src/boolean/mod.rs @@ -0,0 +1,136 @@ +#![deny(rustdoc::broken_intra_doc_links)] +//! Welcome the the tfhe.rs `boolean` module documentation! +//! +//! # Description +//! This library makes it possible to execute boolean gates over encrypted bits. +//! It allows to execute a boolean circuit on an untrusted server because both circuit inputs and +//! outputs are kept private. +//! Data are indeed encrypted on the client side, before being sent to the server. +//! On the server side every computation is performed on ciphertexts. +//! The server however has to know the boolean circuit to be evaluated. +//! At the end of the computation, the server returns the encryption of the result to the user. +//! +//! +//! +//! # Quick Example +//! +//! The following piece of code shows how to generate keys and run a small Boolean circuit +//! homomorphically. +//! +//! ```rust +//! # #[cfg(not(feature = "cuda"))] +//! # fn main() { +//! +//! use tfhe::boolean::gen_keys; +//! use tfhe::boolean::prelude::*; +//! +//! // We generate a set of client/server keys, using the default parameters: +//! let (mut client_key, mut server_key) = gen_keys(); +//! +//! // We use the client secret key to encrypt two messages: +//! let ct_1 = client_key.encrypt(true); +//! let ct_2 = client_key.encrypt(false); +//! +//! // We use the server public key to execute a boolean circuit: +//! // if ((NOT ct_2) NAND (ct_1 AND ct_2)) then (NOT ct_2) else (ct_1 AND ct_2) +//! let ct_3 = server_key.not(&ct_2); +//! let ct_4 = server_key.and(&ct_1, &ct_2); +//! let ct_5 = server_key.nand(&ct_3, &ct_4); +//! let ct_6 = server_key.mux(&ct_5, &ct_3, &ct_4); +//! +//! // We use the client key to decrypt the output of the circuit: +//! let output_1 = client_key.decrypt(&ct_6); +//! assert_eq!(output_1, true); +//! +//! // It is possible to compute gates with one input unencrypted +//! let ct_7 = server_key.and(&ct_6, true); +//! let output_2 = client_key.decrypt(&ct_7); +//! assert_eq!(output_2, true); +//! +//! // It is possible to trivially encrypt on the server side +//! // i.e. to not encrypt but still generate a compatible Ciphertext +//! let ct_8 = server_key.trivial_encrypt(false); +//! let ct_9 = server_key.mux(&ct_7, &ct_3, &ct_8); +//! let output_3 = client_key.decrypt(&ct_9); +//! assert_eq!(output_3, true); +//! # } +//! +//! # #[cfg(feature = "cuda")] +//! # fn main() {} +//! ``` + +use crate::boolean::client_key::ClientKey; +use crate::boolean::parameters::DEFAULT_PARAMETERS; +use crate::boolean::public_key::PublicKey; +use crate::boolean::server_key::ServerKey; +#[cfg(test)] +use rand::Rng; + +pub mod ciphertext; +pub mod client_key; +pub mod engine; +pub mod parameters; +pub mod prelude; +pub mod public_key; +pub mod server_key; + +/// The scaling factor used for the plaintext +pub(crate) const PLAINTEXT_LOG_SCALING_FACTOR: usize = 3; + +/// The plaintext associated with true: 1/8 +pub(crate) const PLAINTEXT_TRUE: u32 = 1 << (32 - PLAINTEXT_LOG_SCALING_FACTOR); + +/// The plaintext associated with false: -1/8 +pub(crate) const PLAINTEXT_FALSE: u32 = 7 << (32 - PLAINTEXT_LOG_SCALING_FACTOR); + +/// tool to generate random booleans +#[cfg(test)] +pub(crate) fn random_boolean() -> bool { + // create a random generator + let mut rng = rand::thread_rng(); + + // generate a random bit + let n: u32 = (rng.gen::()) % 2; + + // convert it to boolean and return + n != 0 +} + +/// tool to generate random integers +#[cfg(test)] +pub(crate) fn random_integer() -> u32 { + // create a random generator + let mut rng = rand::thread_rng(); + + // generate a random u32 + rng.gen::() +} + +/// Generate a couple of client and server keys with the default cryptographic parameters: +/// `DEFAULT_PARAMETERS`. +/// The client is the one generating both keys. +/// * the client key is used to encrypt and decrypt and has to be kept secret; +/// * the server key is used to perform homomorphic operations on the server side and it is +/// meant to be published (the client sends it to the server). +/// +/// ```rust +/// # #[cfg(not(feature = "cuda"))] +/// # fn main() { +/// use tfhe::boolean::gen_keys; +/// use tfhe::boolean::prelude::*; +/// // generate the client key and the server key: +/// let (cks, sks) = gen_keys(); +/// # } +/// # #[cfg(feature = "cuda")] +/// # fn main() {} +/// ``` +pub fn gen_keys() -> (ClientKey, ServerKey) { + // generate the client key + let cks = ClientKey::new(&DEFAULT_PARAMETERS); + + // generate the server key + let sks = ServerKey::new(&cks); + + // return + (cks, sks) +} diff --git a/tfhe/src/boolean/parameters/mod.rs b/tfhe/src/boolean/parameters/mod.rs new file mode 100644 index 000000000..04f1438c5 --- /dev/null +++ b/tfhe/src/boolean/parameters/mod.rs @@ -0,0 +1,124 @@ +//! The cryptographic parameter set. +//! +//! This module provides the structure containing the cryptographic parameters required for the +//! homomorphic evaluation of Boolean circuit as well as a list of secure cryptographic parameter +//! sets. +//! +//! Two parameter sets are provided: +//! * `tfhe::boolean::parameters::DEFAULT_PARAMETERS` +//! * `tfhe::boolean::parameters::TFHE_LIB_PARAMETERS` +//! +//! They ensure the correctness of the Boolean circuit evaluation result (up to a certain +//! probability) along with 128-bits of security. +//! +//! The two parameter sets offer a trade-off in terms of execution time versus error probability. +//! The `DEFAULT_PARAMETERS` set offers better performances on homomorphic circuit evaluation +//! with an higher probability error in comparison with the `TFHE_LIB_PARAMETERS`. +//! Note that if you desire, you can also create your own set of parameters. +//! This is an unsafe operation as failing to properly fix the parameters will potentially result +//! with an incorrect and/or insecure computation. + +pub use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + StandardDev, +}; +use serde::{Deserialize, Serialize}; + +/// A set of cryptographic parameters for homomorphic Boolean circuit evaluation. +#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct BooleanParameters { + pub lwe_dimension: LweDimension, + pub glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub lwe_modular_std_dev: StandardDev, + pub glwe_modular_std_dev: StandardDev, + pub pbs_base_log: DecompositionBaseLog, + pub pbs_level: DecompositionLevelCount, + pub ks_base_log: DecompositionBaseLog, + pub ks_level: DecompositionLevelCount, +} + +impl BooleanParameters { + /// Constructs a new set of parameters for boolean circuit evaluation. + /// + /// # Safety + /// + /// This function is unsafe, as failing to fix the parameters properly would yield incorrect + /// and insecure computation. Unless you are a cryptographer who really knows the impact of each + /// of those parameters, you __must__ stick with the provided parameters [`DEFAULT_PARAMETERS`] + /// and [`TFHE_LIB_PARAMETERS`], which both offer correct results with 128 bits of security. + #[allow(clippy::too_many_arguments)] + pub unsafe fn new( + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_modular_std_dev: StandardDev, + glwe_modular_std_dev: StandardDev, + pbs_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + ks_level: DecompositionLevelCount, + ) -> BooleanParameters { + BooleanParameters { + lwe_dimension, + glwe_dimension, + polynomial_size, + lwe_modular_std_dev, + glwe_modular_std_dev, + pbs_base_log, + pbs_level, + ks_level, + ks_base_log, + } + } +} + +/// Default parameter set. +/// +/// This parameter set ensures 128-bits of security, and a probability of error is upper-bounded by +/// $2^{-40}$. The secret keys generated with this parameter set are uniform binary. +/// This parameter set allows to evaluate faster Boolean circuits than the `TFHE_LIB_PARAMETERS` +/// one. +pub const DEFAULT_PARAMETERS: BooleanParameters = BooleanParameters { + lwe_dimension: LweDimension(777), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.000003725679281679651), + glwe_modular_std_dev: StandardDev(0.0000000000034525330484572114), + pbs_base_log: DecompositionBaseLog(18), + pbs_level: DecompositionLevelCount(1), + ks_base_log: DecompositionBaseLog(4), + ks_level: DecompositionLevelCount(3), +}; + +/// The secret keys generated with this parameter set are uniform binary. +/// This parameter set ensures a probability of error upper-bounded by $2^{-165}$ as the ones +/// proposed into [TFHE library](https://tfhe.github.io/tfhe/) for for 128-bits of security. +/// They are updated to the last security standards, so they differ from the original +/// publication. +pub const TFHE_LIB_PARAMETERS: BooleanParameters = BooleanParameters { + lwe_dimension: LweDimension(830), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000001412290588219445), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_base_log: DecompositionBaseLog(5), + ks_level: DecompositionLevelCount(3), +}; + +/// This parameter set ensures a probability of error upper-bounded by $2^{-40}$ and 128-bit +/// security. +/// They are GPU-compliant. +pub const GPU_DEFAULT_PARAMETERS: BooleanParameters = BooleanParameters { + lwe_dimension: LweDimension(686), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000019703241702943194), + glwe_modular_std_dev: StandardDev(0.00000004053919869756513), + pbs_base_log: DecompositionBaseLog(6), + pbs_level: DecompositionLevelCount(3), + ks_base_log: DecompositionBaseLog(2), + ks_level: DecompositionLevelCount(6), +}; diff --git a/tfhe/src/boolean/prelude.rs b/tfhe/src/boolean/prelude.rs new file mode 100644 index 000000000..8a778b5cf --- /dev/null +++ b/tfhe/src/boolean/prelude.rs @@ -0,0 +1,7 @@ +#![doc(hidden)] +pub use super::ciphertext::Ciphertext; +pub use super::client_key::ClientKey; +pub use super::gen_keys; +pub use super::parameters::*; +pub use super::public_key::PublicKey; +pub use super::server_key::{BinaryBooleanGates, ServerKey}; diff --git a/tfhe/src/boolean/public_key/mod.rs b/tfhe/src/boolean/public_key/mod.rs new file mode 100644 index 000000000..9e36446cf --- /dev/null +++ b/tfhe/src/boolean/public_key/mod.rs @@ -0,0 +1,116 @@ +use crate::boolean::ciphertext::Ciphertext; +use crate::boolean::client_key::ClientKey; +use crate::boolean::engine::{CpuBooleanEngine, WithThreadLocalEngine}; +use crate::boolean::parameters::BooleanParameters; +use crate::core_crypto::prelude::*; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +/// A structure containing a public key. +#[derive(Clone)] +pub struct PublicKey { + pub(crate) lwe_public_key: LwePublicKey32, + pub(crate) parameters: BooleanParameters, +} + +impl PublicKey { + /// Encrypts a Boolean message using the client key. + /// + /// # Example + /// + /// ```rust + /// # #[cfg(not(feature = "cuda"))] + /// # fn main() { + /// use tfhe::boolean::prelude::*; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(); + /// + /// let pks = PublicKey::new(&cks); + /// + /// // Encryption of one message: + /// let ct1 = pks.encrypt(true); + /// let ct2 = pks.encrypt(false); + /// let ct_res = sks.and(&ct1, &ct2); + /// + /// // Decryption: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(false, dec); + /// # } + /// # #[cfg(feature = "cuda")] + /// # fn main() {} + /// ``` + pub fn encrypt(&self, message: bool) -> Ciphertext { + CpuBooleanEngine::with_thread_local_mut(|engine| { + engine.encrypt_with_public_key(message, self) + }) + } + + /// Allocates and generates a client key. + /// + /// # Panic + /// + /// This will panic when the "cuda" feature is enabled and the parameters + /// uses a GlweDimension > 1 (as it is not yet supported by the cuda backend). + /// + /// # Example + /// + /// ```rust + /// # #[cfg(not(feature = "cuda"))] + /// # fn main() { + /// use tfhe::boolean::prelude::*; + /// + /// // Generate the client key and the server key: + /// let (cks, mut sks) = gen_keys(); + /// + /// let pks = PublicKey::new(&cks); + /// # } + /// # #[cfg(feature = "cuda")] + /// # fn main() {} + /// ``` + pub fn new(client_key: &ClientKey) -> PublicKey { + CpuBooleanEngine::with_thread_local_mut(|engine| engine.create_public_key(client_key)) + } +} + +#[derive(Serialize, Deserialize)] +struct SerializablePublicKey { + lwe_public_key: Vec, + parameters: BooleanParameters, +} + +impl Serialize for PublicKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let lwe_public_key = ser_eng + .serialize(&self.lwe_public_key) + .map_err(serde::ser::Error::custom)?; + + SerializablePublicKey { + lwe_public_key, + parameters: self.parameters, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for PublicKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = + SerializablePublicKey::deserialize(deserializer).map_err(serde::de::Error::custom)?; + let mut de_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + Ok(Self { + lwe_public_key: de_eng + .deserialize(thing.lwe_public_key.as_slice()) + .map_err(serde::de::Error::custom)?, + parameters: thing.parameters, + }) + } +} diff --git a/tfhe/src/boolean/server_key/mod.rs b/tfhe/src/boolean/server_key/mod.rs new file mode 100644 index 000000000..f80ac3741 --- /dev/null +++ b/tfhe/src/boolean/server_key/mod.rs @@ -0,0 +1,264 @@ +//! The public key for homomorphic computation. +//! +//! This module implements the generation of the server's public key, together with all the +//! available homomorphic Boolean gates ($\mathrm{AND}$, $\mathrm{MUX}$, $\mathrm{NAND}$, +//! $\mathrm{NOR}$, +//! $\mathrm{NOT}$, $\mathrm{OR}$, $\mathrm{XNOR}$, $\mathrm{XOR}$). + +#[cfg(test)] +mod tests; + +use serde::{Deserialize, Serialize}; + +use crate::boolean::ciphertext::Ciphertext; +use crate::boolean::client_key::ClientKey; +use crate::boolean::engine::bootstrapping::CpuBootstrapKey; +#[cfg(feature = "cuda")] +use crate::boolean::engine::{bootstrapping::CudaBootstrapKey, CudaBooleanEngine}; +use crate::boolean::engine::{ + BinaryGatesAssignEngine, BinaryGatesEngine, CpuBooleanEngine, WithThreadLocalEngine, +}; +#[cfg(feature = "cuda")] +use std::sync::Arc; + +pub trait BinaryBooleanGates { + fn and(&self, ct_left: L, ct_right: R) -> Ciphertext; + fn nand(&self, ct_left: L, ct_right: R) -> Ciphertext; + fn nor(&self, ct_left: L, ct_right: R) -> Ciphertext; + fn or(&self, ct_left: L, ct_right: R) -> Ciphertext; + fn xor(&self, ct_left: L, ct_right: R) -> Ciphertext; + fn xnor(&self, ct_left: L, ct_right: R) -> Ciphertext; +} + +pub trait BinaryBooleanGatesAssign { + fn and_assign(&self, ct_left: L, ct_right: R); + fn nand_assign(&self, ct_left: L, ct_right: R); + fn nor_assign(&self, ct_left: L, ct_right: R); + fn or_assign(&self, ct_left: L, ct_right: R); + fn xor_assign(&self, ct_left: L, ct_right: R); + fn xnor_assign(&self, ct_left: L, ct_right: R); +} + +trait RefFromServerKey { + fn get_ref(server_key: &ServerKey) -> &Self; +} + +trait DefaultImplementation { + type Engine: WithThreadLocalEngine; + type BootsrapKey: RefFromServerKey; +} + +#[derive(Clone)] +pub struct ServerKey { + cpu_key: CpuBootstrapKey, + #[cfg(feature = "cuda")] + cuda_key: Arc, +} + +#[cfg(not(feature = "cuda"))] +mod implementation { + use super::*; + + impl RefFromServerKey for CpuBootstrapKey { + fn get_ref(server_key: &ServerKey) -> &Self { + &server_key.cpu_key + } + } + + impl DefaultImplementation for ServerKey { + type Engine = CpuBooleanEngine; + type BootsrapKey = CpuBootstrapKey; + } +} + +#[cfg(feature = "cuda")] +mod implementation { + use super::*; + + impl RefFromServerKey for CudaBootstrapKey { + fn get_ref(server_key: &ServerKey) -> &Self { + &server_key.cuda_key + } + } + + impl DefaultImplementation for ServerKey { + type Engine = CudaBooleanEngine; + type BootsrapKey = CudaBootstrapKey; + } +} + +impl BinaryBooleanGates for ServerKey +where + ::Engine: + BinaryGatesEngine::BootsrapKey>, +{ + fn and(&self, ct_left: Lhs, ct_right: Rhs) -> Ciphertext { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.and(ct_left, ct_right, bootstrap_key) + }) + } + + fn nand(&self, ct_left: Lhs, ct_right: Rhs) -> Ciphertext { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.nand(ct_left, ct_right, bootstrap_key) + }) + } + + fn nor(&self, ct_left: Lhs, ct_right: Rhs) -> Ciphertext { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.nor(ct_left, ct_right, bootstrap_key) + }) + } + + fn or(&self, ct_left: Lhs, ct_right: Rhs) -> Ciphertext { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.or(ct_left, ct_right, bootstrap_key) + }) + } + + fn xor(&self, ct_left: Lhs, ct_right: Rhs) -> Ciphertext { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.xor(ct_left, ct_right, bootstrap_key) + }) + } + + fn xnor(&self, ct_left: Lhs, ct_right: Rhs) -> Ciphertext { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.xnor(ct_left, ct_right, bootstrap_key) + }) + } +} + +impl BinaryBooleanGatesAssign for ServerKey +where + ::Engine: + BinaryGatesAssignEngine::BootsrapKey>, +{ + fn and_assign(&self, ct_left: Lhs, ct_right: Rhs) { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.and_assign(ct_left, ct_right, bootstrap_key) + }) + } + + fn nand_assign(&self, ct_left: Lhs, ct_right: Rhs) { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.nand_assign(ct_left, ct_right, bootstrap_key) + }) + } + + fn nor_assign(&self, ct_left: Lhs, ct_right: Rhs) { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.nor_assign(ct_left, ct_right, bootstrap_key) + }) + } + + fn or_assign(&self, ct_left: Lhs, ct_right: Rhs) { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.or_assign(ct_left, ct_right, bootstrap_key) + }) + } + + fn xor_assign(&self, ct_left: Lhs, ct_right: Rhs) { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.xor_assign(ct_left, ct_right, bootstrap_key) + }) + } + + fn xnor_assign(&self, ct_left: Lhs, ct_right: Rhs) { + let bootstrap_key = ::BootsrapKey::get_ref(self); + ::Engine::with_thread_local_mut(|engine| { + engine.xnor_assign(ct_left, ct_right, bootstrap_key) + }) + } +} + +impl ServerKey { + pub fn new(cks: &ClientKey) -> Self { + let cpu_key = + CpuBooleanEngine::with_thread_local_mut(|engine| engine.create_server_key(cks)); + + Self::from(cpu_key) + } + + pub fn trivial_encrypt(&self, message: bool) -> Ciphertext { + Ciphertext::Trivial(message) + } + + pub fn not(&self, ct: &Ciphertext) -> Ciphertext { + CpuBooleanEngine::with_thread_local_mut(|engine| engine.not(ct)) + } + + pub fn not_assign(&self, ct: &mut Ciphertext) { + CpuBooleanEngine::with_thread_local_mut(|engine| engine.not_assign(ct)) + } + + pub fn mux( + &self, + ct_condition: &Ciphertext, + ct_then: &Ciphertext, + ct_else: &Ciphertext, + ) -> Ciphertext { + #[cfg(feature = "cuda")] + { + CudaBooleanEngine::with_thread_local_mut(|engine| { + engine.mux(ct_condition, ct_then, ct_else, &self.cuda_key) + }) + } + #[cfg(not(feature = "cuda"))] + { + CpuBooleanEngine::with_thread_local_mut(|engine| { + engine.mux(ct_condition, ct_then, ct_else, &self.cpu_key) + }) + } + } +} + +impl From for ServerKey { + fn from(cpu_key: CpuBootstrapKey) -> Self { + #[cfg(feature = "cuda")] + { + let cuda_key = CudaBooleanEngine::with_thread_local_mut(|engine| { + engine.create_server_key(&cpu_key) + }); + + let cuda_key = Arc::new(cuda_key); + + Self { cpu_key, cuda_key } + } + #[cfg(not(feature = "cuda"))] + { + Self { cpu_key } + } + } +} + +impl Serialize for ServerKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.cpu_key.serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ServerKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let cpu_key = CpuBootstrapKey::deserialize(deserializer)?; + + Ok(Self::from(cpu_key)) + } +} diff --git a/tfhe/src/boolean/server_key/tests.rs b/tfhe/src/boolean/server_key/tests.rs new file mode 100644 index 000000000..d1cdc318d --- /dev/null +++ b/tfhe/src/boolean/server_key/tests.rs @@ -0,0 +1,913 @@ +use crate::boolean::ciphertext::Ciphertext; +use crate::boolean::client_key::ClientKey; +use crate::boolean::parameters::BooleanParameters; +use crate::boolean::server_key::{BinaryBooleanGates, BinaryBooleanGatesAssign, ServerKey}; +use crate::boolean::{random_boolean, random_integer}; + +/// Number of assert in randomized tests +const NB_TEST: usize = 128; + +/// Number of ciphertext in the deep circuit test +const NB_CT: usize = 8; + +/// Number of gates computed in the deep circuit test +const NB_GATE: usize = 1 << 11; + +#[cfg(not(feature = "cuda"))] +mod default_parameters_tests { + use super::*; + use crate::boolean::parameters::DEFAULT_PARAMETERS; + + #[test] + fn test_encrypt_decrypt_lwe_secret_key_default_parameters() { + test_encrypt_decrypt_lwe_secret_key(DEFAULT_PARAMETERS); + } + #[test] + fn test_and_gate_default_parameters() { + test_and_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_nand_gate_default_parameters() { + test_nand_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_or_gate_default_parameters() { + test_or_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_nor_gate_default_parameters() { + test_nor_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_xor_gate_default_parameters() { + test_xor_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_xnor_gate_default_parameters() { + test_xnor_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_not_gate_default_parameters() { + test_not_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_mux_gate_default_parameters() { + test_mux_gate(DEFAULT_PARAMETERS); + } + #[test] + fn test_deep_circuit_default_parameters() { + test_deep_circuit(DEFAULT_PARAMETERS); + } +} + +#[cfg(not(feature = "cuda"))] +mod tfhe_lib_parameters_tests { + use super::*; + use crate::boolean::parameters::TFHE_LIB_PARAMETERS; + + #[test] + fn test_encrypt_decrypt_lwe_secret_key_tfhe_lib_parameters() { + test_encrypt_decrypt_lwe_secret_key(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_and_gate_tfhe_lib_parameters() { + test_and_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_nand_gate_tfhe_lib_parameters() { + test_nand_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_or_gate_tfhe_lib_parameters() { + test_or_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_nor_gate_tfhe_lib_parameters() { + test_nor_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_xor_gate_tfhe_lib_parameters() { + test_xor_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_xnor_gate_tfhe_lib_parameters() { + test_xnor_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_not_gate_tfhe_lib_parameters() { + test_not_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_mux_gate_tfhe_lib_parameters() { + test_mux_gate(TFHE_LIB_PARAMETERS); + } + #[test] + fn test_deep_circuit_tfhe_lib_parameters() { + test_deep_circuit(TFHE_LIB_PARAMETERS); + } +} + +mod gpu_default_parameters_tests { + use super::*; + use crate::boolean::parameters::GPU_DEFAULT_PARAMETERS; + + #[test] + fn test_encrypt_decrypt_lwe_secret_key_default_parameters() { + test_encrypt_decrypt_lwe_secret_key(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_and_gate_default_parameters() { + test_and_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_nand_gate_default_parameters() { + test_nand_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_or_gate_default_parameters() { + test_or_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_nor_gate_default_parameters() { + test_nor_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_xor_gate_default_parameters() { + test_xor_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_xnor_gate_default_parameters() { + test_xnor_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_not_gate_default_parameters() { + test_not_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_mux_gate_default_parameters() { + test_mux_gate(GPU_DEFAULT_PARAMETERS); + } + #[test] + fn test_deep_circuit_default_parameters() { + test_deep_circuit(GPU_DEFAULT_PARAMETERS); + } +} + +/// test encryption and decryption with the LWE secret key +fn test_encrypt_decrypt_lwe_secret_key(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // encryption of false + let ct_false = cks.encrypt(false); + + // encryption of true + let ct_true = cks.encrypt(true); + + // decryption of false + let dec_false = cks.decrypt(&ct_false); + + // decryption of true + let dec_true = cks.decrypt(&ct_true); + + // assert + assert!(!dec_false); + assert!(dec_true); + + // encryption of false + let ct_false = sks.trivial_encrypt(false); + + // encryption of true + let ct_true = sks.trivial_encrypt(true); + + // decryption of false + let dec_false = cks.decrypt(&ct_false); + + // decryption of true + let dec_true = cks.decrypt(&ct_true); + + // assert + assert!(!dec_false); + assert!(dec_true); + } +} + +/// This function randomly either computes a regular encryption of the message or a trivial +/// encryption of the message +fn random_enum_encryption(cks: &ClientKey, sks: &ServerKey, message: bool) -> Ciphertext { + if random_boolean() { + cks.encrypt(message) + } else { + sks.trivial_encrypt(message) + } +} + +fn test_and_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of two random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let expected_result = b1 && b2; + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // AND gate -> "left: {:?}, right: {:?}",ct1, ct2 + let ct_res = sks.and(&ct1, &ct2); + + // decryption + let dec_and = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_and, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // AND gate -> left: Ciphertext, right: bool + let ct_res = sks.and(&ct1, b2); + + // decryption + let dec_and = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_and, "left: {:?}, right: {:?}", ct1, b2); + + // AND gate -> left: bool, right: Ciphertext + let ct_res = sks.and(b1, &ct2); + + // decryption + let dec_and = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_and, "left: {:?}, right: {:?}", b1, ct2); + + // AND gate -> "left: {:?}, right: {:?}",ct1, ct2 + let mut ct_res = ct1.clone(); + sks.and_assign(&mut ct_res, &ct2); + + // decryption + let dec_and = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_and, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // AND gate -> left: Ciphertext, right: bool + let mut ct_res = ct1.clone(); + sks.and_assign(&mut ct_res, b2); + + // decryption + let dec_and = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_and, "left: {:?}, right: {:?}", ct1, b2); + + // AND gate -> left: bool, right: Ciphertext + let mut ct_res = ct2.clone(); + sks.and_assign(b1, &mut ct_res); + + // decryption + let dec_and = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_and, "left: {:?}, right: {:?}", b1, ct2); + } +} + +fn test_mux_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of three random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let b3 = random_boolean(); + let expected_result = if b1 { b2 } else { b3 }; + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // encryption of b3 + let ct3 = random_enum_encryption(&cks, &sks, b3); + + // MUX gate + let ct_res = sks.mux(&ct1, &ct2, &ct3); + + // decryption + let dec_mux = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_mux, + "cond: {:?}, then: {:?}, else: {:?}", + ct1, ct2, ct3 + ); + } +} + +fn test_nand_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of two random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let expected_result = !(b1 && b2); + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // NAND gate -> left: Ciphertext, right: Ciphertext + let ct_res = sks.nand(&ct1, &ct2); + + // decryption + let dec_nand = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nand, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // NAND gate -> left: Ciphertext, right: bool + let ct_res = sks.nand(&ct1, b2); + + // decryption + let dec_nand = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nand, + "left: {:?}, right: {:?}", + ct1, b2 + ); + + // NAND gate -> left: bool, right: Ciphertext + let ct_res = sks.nand(b1, &ct2); + + // decryption + let dec_nand = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nand, + "left: {:?}, right: {:?}", + b1, ct2 + ); + + // NAND gate -> "left: {:?}, right: {:?}",ct1, ct2 + let mut ct_res = ct1.clone(); + sks.nand_assign(&mut ct_res, &ct2); + + // decryption + let dec_nand = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nand, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // NAND gate -> left: Ciphertext, right: bool + let mut ct_res = ct1.clone(); + sks.nand_assign(&mut ct_res, b2); + + // decryption + let dec_nand = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nand, + "left: {:?}, right: {:?}", + ct1, b2 + ); + + // NAND gate -> left: bool, right: Ciphertext + let mut ct_res = ct2.clone(); + sks.nand_assign(b1, &mut ct_res); + + // decryption + let dec_nand = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nand, + "left: {:?}, right: {:?}", + b1, ct2 + ); + } +} + +fn test_nor_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of two random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let expected_result = !(b1 || b2); + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // NOR gate -> left: Ciphertext, right: Ciphertext + let ct_res = sks.nor(&ct1, &ct2); + + // decryption + let dec_nor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nor, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // NOR gate -> left: Ciphertext, right: bool + let ct_res = sks.nor(&ct1, b2); + + // decryption + let dec_nor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_nor, "left: {:?}, right: {:?}", ct1, b2); + + // NOR gate -> left: bool, right: Ciphertext + let ct_res = sks.nor(b1, &ct2); + + // decryption + let dec_nor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_nor, "left: {:?}, right: {:?}", b1, ct2); + + // NOR gate -> "left: {:?}, right: {:?}",ct1, ct2 + let mut ct_res = ct1.clone(); + sks.nor_assign(&mut ct_res, &ct2); + + // decryption + let dec_nor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_nor, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // NOR gate -> left: Ciphertext, right: bool + let mut ct_res = ct1.clone(); + sks.nor_assign(&mut ct_res, b2); + + // decryption + let dec_nor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_nor, "left: {:?}, right: {:?}", ct1, b2); + + // NOR gate -> left: bool, right: Ciphertext + let mut ct_res = ct2.clone(); + sks.nor_assign(b1, &mut ct_res); + + // decryption + let dec_nor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_nor, "left: {:?}, right: {:?}", b1, ct2); + } +} + +fn test_not_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of one random booleans + let b1 = random_boolean(); + let expected_result = !b1; + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // NOT gate + let ct_res = sks.not(&ct1); + + // decryption + let dec_not = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_not); + + // NOT gate + let mut ct_res = ct1.clone(); + sks.not_assign(&mut ct_res); + + // decryption + let dec_not = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_not); + } +} + +fn test_or_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of two random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let expected_result = b1 || b2; + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // OR gate -> left: Ciphertext, right: Ciphertext + let ct_res = sks.or(&ct1, &ct2); + + // decryption + let dec_or = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_or, "left: {:?}, right: {:?}", ct1, ct2); + + // OR gate -> left: Ciphertext, right: bool + let ct_res = sks.or(&ct1, b2); + + // decryption + let dec_or = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_or, "left: {:?}, right: {:?}", ct1, b2); + + // OR gate -> left: bool, right: Ciphertext + let ct_res = sks.or(b1, &ct2); + + // decryption + let dec_or = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_or, "left: {:?}, right: {:?}", b1, ct2); + + // OR gate -> "left: {:?}, right: {:?}",ct1, ct2 + let mut ct_res = ct1.clone(); + sks.or_assign(&mut ct_res, &ct2); + + // decryption + let dec_or = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_or, "left: {:?}, right: {:?}", ct1, ct2); + + // OR gate -> left: Ciphertext, right: bool + let mut ct_res = ct1.clone(); + sks.or_assign(&mut ct_res, b2); + + // decryption + let dec_or = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_or, "left: {:?}, right: {:?}", ct1, b2); + + // OR gate -> left: bool, right: Ciphertext + let mut ct_res = ct2.clone(); + sks.or_assign(b1, &mut ct_res); + + // decryption + let dec_or = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_or, "left: {:?}, right: {:?}", b1, ct2); + } +} + +fn test_xnor_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of two random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let expected_result = b1 == b2; + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // XNOR gate -> left: Ciphertext, right: Ciphertext + let ct_res = sks.xnor(&ct1, &ct2); + + // decryption + let dec_xnor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xnor, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // XNOR gate -> left: Ciphertext, right: bool + let ct_res = sks.xnor(&ct1, b2); + + // decryption + let dec_xnor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xnor, + "left: {:?}, right: {:?}", + ct1, b2 + ); + + // XNOR gate -> left: bool, right: Ciphertext + let ct_res = sks.xnor(b1, &ct2); + + // decryption + let dec_xnor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xnor, + "left: {:?}, right: {:?}", + b1, ct2 + ); + + // XNOR gate -> "left: {:?}, right: {:?}",ct1, ct2 + let mut ct_res = ct1.clone(); + sks.xnor_assign(&mut ct_res, &ct2); + + // decryption + let dec_xnor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xnor, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // XNOR gate -> left: Ciphertext, right: bool + let mut ct_res = ct1.clone(); + sks.xnor_assign(&mut ct_res, b2); + + // decryption + let dec_xnor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xnor, + "left: {:?}, right: {:?}", + ct1, b2 + ); + + // XNOR gate -> left: bool, right: Ciphertext + let mut ct_res = ct2.clone(); + sks.xnor_assign(b1, &mut ct_res); + + // decryption + let dec_xnor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xnor, + "left: {:?}, right: {:?}", + b1, ct2 + ); + } +} + +fn test_xor_gate(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + for _ in 0..NB_TEST { + // generation of two random booleans + let b1 = random_boolean(); + let b2 = random_boolean(); + let expected_result = b1 ^ b2; + + // encryption of b1 + let ct1 = random_enum_encryption(&cks, &sks, b1); + + // encryption of b2 + let ct2 = random_enum_encryption(&cks, &sks, b2); + + // XOR gate -> left: Ciphertext, right: Ciphertext + let ct_res = sks.xor(&ct1, &ct2); + + // decryption + let dec_xor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xor, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // XOR gate -> left: Ciphertext, right: bool + let ct_res = sks.xor(&ct1, b2); + + // decryption + let dec_xor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_xor, "left: {:?}, right: {:?}", ct1, b2); + + // XOR gate -> left: bool, right: Ciphertext + let ct_res = sks.xor(b1, &ct2); + + // decryption + let dec_xor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_xor, "left: {:?}, right: {:?}", b1, ct2); + + // XOR gate -> "left: {:?}, right: {:?}",ct1, ct2 + let mut ct_res = ct1.clone(); + sks.xor_assign(&mut ct_res, &ct2); + + // decryption + let dec_xor = cks.decrypt(&ct_res); + + // assert + assert_eq!( + expected_result, dec_xor, + "left: {:?}, right: {:?}", + ct1, ct2 + ); + + // XOR gate -> left: Ciphertext, right: bool + let mut ct_res = ct1.clone(); + sks.xor_assign(&mut ct_res, b2); + + // decryption + let dec_xor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_xor, "left: {:?}, right: {:?}", ct1, b2); + + // XOR gate -> left: bool, right: Ciphertext + let mut ct_res = ct2.clone(); + sks.xor_assign(b1, &mut ct_res); + + // decryption + let dec_xor = cks.decrypt(&ct_res); + + // assert + assert_eq!(expected_result, dec_xor, "left: {:?}, right: {:?}", b1, ct2); + } +} + +/// generate a random index for the table in the long run tests +fn random_index() -> usize { + (random_integer() % (NB_CT as u32)) as usize +} + +/// randomly select a gate, randomly select inputs and the output, +/// compute the selected gate with the selected inputs +/// and write in the selected output +fn random_gate_all(ct_tab: &mut [Ciphertext], bool_tab: &mut [bool], sks: &ServerKey) { + // select a random gate in the array [NOT,CMUX,AND,NAND,NOR,OR,XOR,XNOR] + let gate_id = random_integer() % 8; + + let index_1: usize = random_index(); + let index_2: usize = random_index(); + + if gate_id == 0 { + // NOT gate + bool_tab[index_2] = !bool_tab[index_1]; + ct_tab[index_2] = sks.not(&ct_tab[index_1]); + } else if gate_id == 1 { + // MUX gate + let index_3: usize = random_index(); + let index_4: usize = random_index(); + bool_tab[index_4] = if bool_tab[index_1] { + bool_tab[index_2] + } else { + bool_tab[index_3] + }; + ct_tab[index_4] = sks.mux(&ct_tab[index_1], &ct_tab[index_2], &ct_tab[index_3]); + } else { + // 2-input gate + let index_3: usize = random_index(); + + if gate_id == 2 { + // AND gate + bool_tab[index_3] = bool_tab[index_1] && bool_tab[index_2]; + ct_tab[index_3] = sks.and(&ct_tab[index_1], &ct_tab[index_2]); + } else if gate_id == 3 { + // NAND gate + bool_tab[index_3] = !(bool_tab[index_1] && bool_tab[index_2]); + ct_tab[index_3] = sks.nand(&ct_tab[index_1], &ct_tab[index_2]); + } else if gate_id == 4 { + // NOR gate + bool_tab[index_3] = !(bool_tab[index_1] || bool_tab[index_2]); + ct_tab[index_3] = sks.nor(&ct_tab[index_1], &ct_tab[index_2]); + } else if gate_id == 5 { + // OR gate + bool_tab[index_3] = bool_tab[index_1] || bool_tab[index_2]; + ct_tab[index_3] = sks.or(&ct_tab[index_1], &ct_tab[index_2]); + } else if gate_id == 6 { + // XOR gate + bool_tab[index_3] = bool_tab[index_1] ^ bool_tab[index_2]; + ct_tab[index_3] = sks.xor(&ct_tab[index_1], &ct_tab[index_2]); + } else { + // XNOR gate + bool_tab[index_3] = !(bool_tab[index_1] ^ bool_tab[index_2]); + ct_tab[index_3] = sks.xnor(&ct_tab[index_1], &ct_tab[index_2]); + } + } +} + +fn test_deep_circuit(parameters: BooleanParameters) { + // generate the client key set + let cks = ClientKey::new(¶meters); + + // generate the server key set + let sks = ServerKey::new(&cks); + + // create an array of ciphertexts + let mut ct_tab: Vec = vec![cks.encrypt(true); NB_CT]; + + // create an array of booleans + let mut bool_tab: Vec = vec![true; NB_CT]; + + // randomly fill both arrays + for (ct, boolean) in ct_tab.iter_mut().zip(bool_tab.iter_mut()) { + *boolean = random_boolean(); + *ct = cks.encrypt(*boolean); + } + + // compute NB_GATE gates + for _ in 0..NB_GATE { + random_gate_all(&mut ct_tab, &mut bool_tab, &sks); + } + + // decrypt and assert equality + for (ct, boolean) in ct_tab.iter().zip(bool_tab.iter()) { + let dec = cks.decrypt(ct); + assert_eq!(*boolean, dec); + } +} diff --git a/tfhe/src/c_api/boolean/ciphertext.rs b/tfhe/src/c_api/boolean/ciphertext.rs new file mode 100644 index 000000000..a946a6623 --- /dev/null +++ b/tfhe/src/c_api/boolean/ciphertext.rs @@ -0,0 +1,44 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use crate::boolean; + +pub struct BooleanCiphertext(pub(in crate::c_api) boolean::ciphertext::Ciphertext); + +#[no_mangle] +pub unsafe extern "C" fn boolean_serialize_ciphertext( + ciphertext: *const BooleanCiphertext, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let ciphertext = get_ref_checked(ciphertext).unwrap(); + + let buffer: Buffer = bincode::serialize(&ciphertext.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_deserialize_ciphertext( + buffer_view: BufferView, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let ciphertext: boolean::ciphertext::Ciphertext = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_ciphertext = Box::new(BooleanCiphertext(ciphertext)); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} diff --git a/tfhe/src/c_api/boolean/client_key.rs b/tfhe/src/c_api/boolean/client_key.rs new file mode 100644 index 000000000..3d1c5a84d --- /dev/null +++ b/tfhe/src/c_api/boolean/client_key.rs @@ -0,0 +1,106 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use bincode; +use std::os::raw::c_int; + +use crate::boolean; + +use super::BooleanCiphertext; +pub struct BooleanClientKey(pub(in crate::c_api) boolean::client_key::ClientKey); + +#[no_mangle] +pub unsafe extern "C" fn boolean_gen_client_key( + boolean_parameters: *const super::parameters::BooleanParameters, + result_client_key: *mut *mut BooleanClientKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_client_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_client_key = std::ptr::null_mut(); + + let boolean_parameters = get_ref_checked(boolean_parameters).unwrap(); + + let client_key = boolean::client_key::ClientKey::new(&boolean_parameters.0); + + let heap_allocated_client_key = Box::new(BooleanClientKey(client_key)); + + *result_client_key = Box::into_raw(heap_allocated_client_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_client_key_encrypt( + client_key: *const BooleanClientKey, + value_to_encrypt: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let heap_allocated_ciphertext = + Box::new(BooleanCiphertext(client_key.0.encrypt(value_to_encrypt))); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_client_key_decrypt( + client_key: *const BooleanClientKey, + ciphertext_to_decrypt: *const BooleanCiphertext, + result: *mut bool, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let client_key = get_ref_checked(client_key).unwrap(); + let ciphertext_to_decrypt = get_ref_checked(ciphertext_to_decrypt).unwrap(); + + *result = client_key.0.decrypt(&ciphertext_to_decrypt.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_serialize_client_key( + client_key: *const BooleanClientKey, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let buffer: Buffer = bincode::serialize(&client_key.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_deserialize_client_key( + buffer_view: BufferView, + result: *mut *mut BooleanClientKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let client_key: boolean::client_key::ClientKey = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_client_key = Box::new(BooleanClientKey(client_key)); + + *result = Box::into_raw(heap_allocated_client_key); + }) +} diff --git a/tfhe/src/c_api/boolean/destroy.rs b/tfhe/src/c_api/boolean/destroy.rs new file mode 100644 index 000000000..8c2fb2279 --- /dev/null +++ b/tfhe/src/c_api/boolean/destroy.rs @@ -0,0 +1,54 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::parameters::BooleanParameters; +use super::{BooleanCiphertext, BooleanClientKey, BooleanPublicKey, BooleanServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn destroy_boolean_client_key(client_key: *mut BooleanClientKey) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(client_key).unwrap(); + + drop(Box::from_raw(client_key)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_boolean_server_key(server_key: *mut BooleanServerKey) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(server_key).unwrap(); + + drop(Box::from_raw(server_key)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_boolean_public_key(public_key: *mut BooleanPublicKey) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(public_key).unwrap(); + + drop(Box::from_raw(public_key)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_boolean_parameters( + boolean_parameters: *mut BooleanParameters, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(boolean_parameters).unwrap(); + + drop(Box::from_raw(boolean_parameters)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_boolean_ciphertext( + boolean_ciphertext: *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(boolean_ciphertext).unwrap(); + + drop(Box::from_raw(boolean_ciphertext)); + }) +} diff --git a/tfhe/src/c_api/boolean/mod.rs b/tfhe/src/c_api/boolean/mod.rs new file mode 100644 index 000000000..0507a30af --- /dev/null +++ b/tfhe/src/c_api/boolean/mod.rs @@ -0,0 +1,119 @@ +pub mod ciphertext; +pub mod client_key; +pub mod destroy; +pub mod parameters; +pub mod public_key; +pub mod server_key; + +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use crate::boolean; + +pub use ciphertext::BooleanCiphertext; +pub use client_key::BooleanClientKey; +pub use public_key::BooleanPublicKey; +pub use server_key::BooleanServerKey; + +#[no_mangle] +pub unsafe extern "C" fn boolean_gen_keys_with_default_parameters( + result_client_key: *mut *mut BooleanClientKey, + result_server_key: *mut *mut BooleanServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_client_key).unwrap(); + check_ptr_is_non_null_and_aligned(result_server_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_client_key = std::ptr::null_mut(); + *result_server_key = std::ptr::null_mut(); + + let (client_key, server_key) = boolean::gen_keys(); + let heap_allocated_client_key = Box::new(BooleanClientKey(client_key)); + let heap_allocated_server_key = Box::new(BooleanServerKey(server_key)); + + *result_client_key = Box::into_raw(heap_allocated_client_key); + *result_server_key = Box::into_raw(heap_allocated_server_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_gen_keys_with_parameters( + boolean_parameters: *const parameters::BooleanParameters, + result_client_key: *mut *mut BooleanClientKey, + result_server_key: *mut *mut BooleanServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_client_key).unwrap(); + check_ptr_is_non_null_and_aligned(result_server_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_client_key = std::ptr::null_mut(); + *result_server_key = std::ptr::null_mut(); + + let boolean_parameters = get_ref_checked(boolean_parameters).unwrap(); + + let client_key = boolean::client_key::ClientKey::new(&boolean_parameters.0); + let server_key = boolean::server_key::ServerKey::new(&client_key); + + let heap_allocated_client_key = Box::new(BooleanClientKey(client_key)); + let heap_allocated_server_key = Box::new(BooleanServerKey(server_key)); + + *result_client_key = Box::into_raw(heap_allocated_client_key); + *result_server_key = Box::into_raw(heap_allocated_server_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_gen_keys_with_predefined_parameters_set( + boolean_parameters_set: c_int, + result_client_key: *mut *mut BooleanClientKey, + result_server_key: *mut *mut BooleanServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_client_key).unwrap(); + check_ptr_is_non_null_and_aligned(result_server_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_client_key = std::ptr::null_mut(); + *result_server_key = std::ptr::null_mut(); + + let boolean_parameters_set_as_enum = + parameters::BooleanParametersSet::try_from(boolean_parameters_set).unwrap(); + + let boolean_parameters = + parameters::BooleanParameters::from(boolean_parameters_set_as_enum); + + let client_key = boolean::client_key::ClientKey::new(&boolean_parameters.0); + let server_key = boolean::server_key::ServerKey::new(&client_key); + + let heap_allocated_client_key = Box::new(BooleanClientKey(client_key)); + let heap_allocated_server_key = Box::new(BooleanServerKey(server_key)); + + *result_client_key = Box::into_raw(heap_allocated_client_key); + *result_server_key = Box::into_raw(heap_allocated_server_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_trivial_encrypt( + message: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + use boolean::engine::WithThreadLocalEngine; + + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let heap_allocated_result = Box::new(BooleanCiphertext( + boolean::engine::CpuBooleanEngine::with_thread_local_mut(|engine| { + engine.trivial_encrypt(message) + }), + )); + + *result = Box::into_raw(heap_allocated_result); + }) +} diff --git a/tfhe/src/c_api/boolean/parameters.rs b/tfhe/src/c_api/boolean/parameters.rs new file mode 100644 index 000000000..620832287 --- /dev/null +++ b/tfhe/src/c_api/boolean/parameters.rs @@ -0,0 +1,104 @@ +use crate::c_api::utils::*; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + StandardDev, +}; +use std::os::raw::c_int; + +use crate::boolean; + +pub struct BooleanParameters(pub(in crate::c_api) boolean::parameters::BooleanParameters); + +#[no_mangle] +pub unsafe extern "C" fn boolean_get_parameters( + boolean_parameters_set: c_int, + result: *mut *mut BooleanParameters, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let boolean_parameters_set_as_enum = + BooleanParametersSet::try_from(boolean_parameters_set).unwrap(); + + let boolean_parameters = Box::new(BooleanParameters::from(boolean_parameters_set_as_enum)); + + *result = Box::into_raw(boolean_parameters); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_create_parameters( + lwe_dimension: usize, + glwe_dimension: usize, + polynomial_size: usize, + lwe_modular_std_dev: f64, + glwe_modular_std_dev: f64, + pbs_base_log: usize, + pbs_level: usize, + ks_base_log: usize, + ks_level: usize, + result_parameters: *mut *mut BooleanParameters, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_parameters).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_parameters = std::ptr::null_mut(); + + let heap_allocated_parameters = + Box::new(BooleanParameters(boolean::parameters::BooleanParameters { + lwe_dimension: LweDimension(lwe_dimension), + glwe_dimension: GlweDimension(glwe_dimension), + polynomial_size: PolynomialSize(polynomial_size), + lwe_modular_std_dev: StandardDev(lwe_modular_std_dev), + glwe_modular_std_dev: StandardDev(glwe_modular_std_dev), + pbs_base_log: DecompositionBaseLog(pbs_base_log), + pbs_level: DecompositionLevelCount(pbs_level), + ks_base_log: DecompositionBaseLog(ks_base_log), + ks_level: DecompositionLevelCount(ks_level), + })); + + *result_parameters = Box::into_raw(heap_allocated_parameters); + }) +} + +pub(in crate::c_api) enum BooleanParametersSet { + DefaultParameters, + TfheLibParameters, +} + +pub const BOOLEAN_PARAMETERS_SET_DEFAULT_PARAMETERS: c_int = 0; +pub const BOOLEAN_PARAMETERS_SET_THFE_LIB_PARAMETERS: c_int = 1; + +impl TryFrom for BooleanParametersSet { + type Error = String; + + fn try_from(value: c_int) -> Result { + match value { + BOOLEAN_PARAMETERS_SET_DEFAULT_PARAMETERS => { + Ok(BooleanParametersSet::DefaultParameters) + } + BOOLEAN_PARAMETERS_SET_THFE_LIB_PARAMETERS => { + Ok(BooleanParametersSet::TfheLibParameters) + } + _ => Err(format!( + "Invalid value '{value}' for BooleansParametersSet, use \ + BOOLEAN_PARAMETERS_SET constants" + )), + } + } +} + +impl From for BooleanParameters { + fn from(boolean_parameters_set: BooleanParametersSet) -> Self { + match boolean_parameters_set { + BooleanParametersSet::DefaultParameters => { + BooleanParameters(boolean::parameters::DEFAULT_PARAMETERS) + } + BooleanParametersSet::TfheLibParameters => { + BooleanParameters(boolean::parameters::TFHE_LIB_PARAMETERS) + } + } + } +} diff --git a/tfhe/src/c_api/boolean/public_key.rs b/tfhe/src/c_api/boolean/public_key.rs new file mode 100644 index 000000000..51d71524e --- /dev/null +++ b/tfhe/src/c_api/boolean/public_key.rs @@ -0,0 +1,91 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use bincode; +use std::os::raw::c_int; + +use crate::boolean; + +use super::{BooleanCiphertext, BooleanClientKey}; + +pub struct BooleanPublicKey(pub(in crate::c_api) boolean::public_key::PublicKey); + +#[no_mangle] +pub unsafe extern "C" fn boolean_gen_public_key( + client_key: *const BooleanClientKey, + result: *mut *mut BooleanPublicKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let heap_allocated_public_key = Box::new(BooleanPublicKey( + boolean::public_key::PublicKey::new(&client_key.0), + )); + + *result = Box::into_raw(heap_allocated_public_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_public_key_encrypt( + public_key: *const BooleanPublicKey, + value_to_encrypt: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let public_key = get_ref_checked(public_key).unwrap(); + + let heap_allocated_ciphertext = + Box::new(BooleanCiphertext(public_key.0.encrypt(value_to_encrypt))); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_serialize_public_key( + public_key: *const BooleanPublicKey, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let public_key = get_ref_checked(public_key).unwrap(); + + let buffer: Buffer = bincode::serialize(&public_key.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_deserialize_public_key( + buffer_view: BufferView, + result: *mut *mut BooleanPublicKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let public_key: boolean::public_key::PublicKey = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_public_key = Box::new(BooleanPublicKey(public_key)); + + *result = Box::into_raw(heap_allocated_public_key); + }) +} diff --git a/tfhe/src/c_api/boolean/server_key.rs b/tfhe/src/c_api/boolean/server_key.rs new file mode 100644 index 000000000..bb6551ea6 --- /dev/null +++ b/tfhe/src/c_api/boolean/server_key.rs @@ -0,0 +1,604 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use crate::boolean; +use crate::boolean::server_key::{BinaryBooleanGates, BinaryBooleanGatesAssign}; + +use super::BooleanCiphertext; + +pub struct BooleanServerKey(pub(in crate::c_api) boolean::server_key::ServerKey); + +#[no_mangle] +pub unsafe extern "C" fn boolean_gen_server_key( + client_key: *const super::BooleanClientKey, + result_server_key: *mut *mut BooleanServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_server_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_server_key = std::ptr::null_mut(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let server_key = boolean::server_key::ServerKey::new(&client_key.0); + + let heap_allocated_server_key = Box::new(BooleanServerKey(server_key)); + + *result_server_key = Box::into_raw(heap_allocated_server_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_and( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + ct_right: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.and(&ct_left.0, &ct_right.0))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nand( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + ct_right: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + let heap_allocated_result = Box::new(BooleanCiphertext( + server_key.0.nand(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nor( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + ct_right: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.nor(&ct_left.0, &ct_right.0))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_or( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + ct_right: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.or(&ct_left.0, &ct_right.0))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xor( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + ct_right: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.xor(&ct_left.0, &ct_right.0))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xnor( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + ct_right: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + let heap_allocated_result = Box::new(BooleanCiphertext( + server_key.0.xnor(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_not( + server_key: *const BooleanServerKey, + ct_input: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_input = get_ref_checked(ct_input).unwrap(); + + let heap_allocated_result = Box::new(BooleanCiphertext(server_key.0.not(&ct_input.0))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_and_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + ct_right: *const BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + server_key.0.and_assign(&mut ct_left.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nand_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + ct_right: *const BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + server_key.0.nand_assign(&mut ct_left.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nor_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + ct_right: *const BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + server_key.0.nor_assign(&mut ct_left.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_or_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + ct_right: *const BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + server_key.0.or_assign(&mut ct_left.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xor_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + ct_right: *const BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + server_key.0.xor_assign(&mut ct_left.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xnor_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + ct_right: *const BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_ref_checked(ct_right).unwrap(); + + server_key.0.xnor_assign(&mut ct_left.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_not_assign( + server_key: *const BooleanServerKey, + ct_input: *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_input = get_mut_checked(ct_input).unwrap(); + + server_key.0.not_assign(&mut ct_input.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_mux( + server_key: *const BooleanServerKey, + ct_condition: *const BooleanCiphertext, + ct_then: *const BooleanCiphertext, + ct_else: *const BooleanCiphertext, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_condition = get_ref_checked(ct_condition).unwrap(); + let ct_then = get_ref_checked(ct_then).unwrap(); + let ct_else = get_ref_checked(ct_else).unwrap(); + + let heap_allocated_result = Box::new(BooleanCiphertext(server_key.0.mux( + &ct_condition.0, + &ct_then.0, + &ct_else.0, + ))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_and_scalar( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + scalar: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.and(&ct_left.0, scalar))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nand_scalar( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + scalar: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.nand(&ct_left.0, scalar))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nor_scalar( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + scalar: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.nor(&ct_left.0, scalar))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_or_scalar( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + scalar: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.or(&ct_left.0, scalar))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xor_scalar( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + scalar: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.xor(&ct_left.0, scalar))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xnor_scalar( + server_key: *const BooleanServerKey, + ct_left: *const BooleanCiphertext, + scalar: bool, + result: *mut *mut BooleanCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + + let heap_allocated_result = + Box::new(BooleanCiphertext(server_key.0.xnor(&ct_left.0, scalar))); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_and_scalar_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + scalar: bool, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + server_key.0.and_assign(&mut ct_left.0, scalar); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nand_scalar_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + scalar: bool, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + server_key.0.nand_assign(&mut ct_left.0, scalar); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_nor_scalar_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + scalar: bool, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + server_key.0.nor_assign(&mut ct_left.0, scalar); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_or_scalar_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + scalar: bool, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + server_key.0.or_assign(&mut ct_left.0, scalar); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xor_scalar_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + scalar: bool, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + server_key.0.xor_assign(&mut ct_left.0, scalar); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_server_key_xnor_scalar_assign( + server_key: *const BooleanServerKey, + ct_left: *mut BooleanCiphertext, + scalar: bool, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + server_key.0.xnor_assign(&mut ct_left.0, scalar); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_serialize_server_key( + server_key: *const BooleanServerKey, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + + let buffer: Buffer = bincode::serialize(&server_key.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn boolean_deserialize_server_key( + buffer_view: BufferView, + result: *mut *mut BooleanServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key: boolean::server_key::ServerKey = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_server_key = Box::new(BooleanServerKey(server_key)); + + *result = Box::into_raw(heap_allocated_server_key); + }) +} diff --git a/tfhe/src/c_api/buffer.rs b/tfhe/src/c_api/buffer.rs new file mode 100644 index 000000000..c9c7df1b8 --- /dev/null +++ b/tfhe/src/c_api/buffer.rs @@ -0,0 +1,81 @@ +//! Module providing some common `C` FFI utilities for key serialization and deserialization. + +use crate::c_api::utils::*; +use std::os::raw::c_int; + +#[repr(C)] +pub struct Buffer { + pointer: *mut u8, + length: usize, +} + +#[repr(C)] +pub struct BufferView { + pointer: *const u8, + length: usize, +} + +impl From> for Buffer { + fn from(a: Vec) -> Self { + let a = a.leak(); + + Self { + pointer: a.as_mut_ptr(), + length: a.len(), + } + } +} + +impl From for &[u8] { + fn from(bf: BufferView) -> &'static [u8] { + unsafe { std::slice::from_raw_parts(bf.pointer, bf.length) } + } +} + +impl From<&[u8]> for BufferView { + fn from(a: &[u8]) -> Self { + Self { + pointer: a.as_ptr(), + length: a.len(), + } + } +} + +/// Deallocate the memory pointed to by a [`Buffer`]. +/// +/// The [`Buffer`] `pointer` is set to `NULL` and `length` is set to `0` to signal it was freed in +/// addition to the function's return code. +/// +/// This function is [checked](crate#safety-checked-and-unchecked-functions). +#[no_mangle] +pub unsafe extern "C" fn destroy_buffer(buffer: *mut Buffer) -> c_int { + catch_panic(|| { + let buffer = get_mut_checked(buffer).unwrap(); + + let pointer = get_mut_checked(buffer.pointer).unwrap(); + let length = buffer.length; + + // Reconstruct a vector that will be dropped so that the memory gets freed + Vec::from_raw_parts(pointer, length, length); + + buffer.length = 0; + buffer.pointer = std::ptr::null_mut(); + }) +} + +/// [Unchecked](crate#safety-checked-and-unchecked-functions) version of [`destroy_buffer`]. +#[no_mangle] +pub unsafe extern "C" fn destroy_buffer_unchecked(buffer: *mut Buffer) -> c_int { + catch_panic(|| { + let buffer = &mut (*buffer); + + let pointer = &mut (*buffer.pointer); + let length = buffer.length; + + // Reconstruct a vector that will be dropped so that the memory gets freed + Vec::from_raw_parts(pointer, length, length); + + buffer.length = 0; + buffer.pointer = std::ptr::null_mut(); + }) +} diff --git a/tfhe/src/c_api/mod.rs b/tfhe/src/c_api/mod.rs new file mode 100644 index 000000000..9cc665e71 --- /dev/null +++ b/tfhe/src/c_api/mod.rs @@ -0,0 +1,8 @@ +#![deny(rustdoc::broken_intra_doc_links)] +#![allow(clippy::missing_safety_doc)] +#[cfg(feature = "boolean-c-api")] +pub mod boolean; +pub mod buffer; +#[cfg(feature = "shortint-c-api")] +pub mod shortint; +pub(crate) mod utils; diff --git a/tfhe/src/c_api/shortint/ciphertext.rs b/tfhe/src/c_api/shortint/ciphertext.rs new file mode 100644 index 000000000..e2abf51e5 --- /dev/null +++ b/tfhe/src/c_api/shortint/ciphertext.rs @@ -0,0 +1,70 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use crate::shortint; + +pub struct ShortintCiphertext(pub(in crate::c_api) shortint::ciphertext::Ciphertext); + +#[no_mangle] +pub unsafe extern "C" fn shortint_ciphertext_set_degree( + ciphertext: *mut ShortintCiphertext, + degree: usize, +) -> c_int { + catch_panic(|| { + let ciphertext = get_mut_checked(ciphertext).unwrap(); + + ciphertext.0.degree.0 = degree; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_ciphertext_get_degree( + ciphertext: *const ShortintCiphertext, + result: *mut usize, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let ciphertext = get_ref_checked(ciphertext).unwrap(); + + *result = ciphertext.0.degree.0; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_serialize_ciphertext( + ciphertext: *const ShortintCiphertext, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let ciphertext = get_ref_checked(ciphertext).unwrap(); + + let buffer: Buffer = bincode::serialize(&ciphertext.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_deserialize_ciphertext( + buffer_view: BufferView, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let ciphertext: shortint::ciphertext::Ciphertext = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_ciphertext = Box::new(ShortintCiphertext(ciphertext)); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} diff --git a/tfhe/src/c_api/shortint/client_key.rs b/tfhe/src/c_api/shortint/client_key.rs new file mode 100644 index 000000000..4d32929b6 --- /dev/null +++ b/tfhe/src/c_api/shortint/client_key.rs @@ -0,0 +1,106 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use bincode; +use std::os::raw::c_int; + +use crate::shortint; + +use super::ShortintCiphertext; +pub struct ShortintClientKey(pub(in crate::c_api) shortint::client_key::ClientKey); + +#[no_mangle] +pub unsafe extern "C" fn shortint_gen_client_key( + shortint_parameters: *const super::parameters::ShortintParameters, + result_client_key: *mut *mut ShortintClientKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_client_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_client_key = std::ptr::null_mut(); + + let shortint_parameters = get_ref_checked(shortint_parameters).unwrap(); + + let client_key = shortint::client_key::ClientKey::new(shortint_parameters.0.to_owned()); + + let heap_allocated_client_key = Box::new(ShortintClientKey(client_key)); + + *result_client_key = Box::into_raw(heap_allocated_client_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_client_key_encrypt( + client_key: *const ShortintClientKey, + value_to_encrypt: u64, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let heap_allocated_ciphertext = + Box::new(ShortintCiphertext(client_key.0.encrypt(value_to_encrypt))); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_client_key_decrypt( + client_key: *const ShortintClientKey, + ciphertext_to_decrypt: *const ShortintCiphertext, + result: *mut u64, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let client_key = get_ref_checked(client_key).unwrap(); + let ciphertext_to_decrypt = get_ref_checked(ciphertext_to_decrypt).unwrap(); + + *result = client_key.0.decrypt(&ciphertext_to_decrypt.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_serialize_client_key( + client_key: *const ShortintClientKey, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let buffer: Buffer = bincode::serialize(&client_key.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_deserialize_client_key( + buffer_view: BufferView, + result: *mut *mut ShortintClientKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let client_key: shortint::client_key::ClientKey = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_client_key = Box::new(ShortintClientKey(client_key)); + + *result = Box::into_raw(heap_allocated_client_key); + }) +} diff --git a/tfhe/src/c_api/shortint/destroy.rs b/tfhe/src/c_api/shortint/destroy.rs new file mode 100644 index 000000000..079e59b80 --- /dev/null +++ b/tfhe/src/c_api/shortint/destroy.rs @@ -0,0 +1,79 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::parameters::ShortintParameters; +use super::{ + ShortintBivariatePBSAccumulator, ShortintCiphertext, ShortintClientKey, ShortintPBSAccumulator, + ShortintPublicKey, ShortintServerKey, +}; + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_client_key(client_key: *mut ShortintClientKey) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(client_key).unwrap(); + + drop(Box::from_raw(client_key)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_server_key(server_key: *mut ShortintServerKey) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(server_key).unwrap(); + + drop(Box::from_raw(server_key)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_public_key(public_key: *mut ShortintPublicKey) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(public_key).unwrap(); + + drop(Box::from_raw(public_key)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_parameters( + shortint_parameters: *mut ShortintParameters, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(shortint_parameters).unwrap(); + + drop(Box::from_raw(shortint_parameters)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_ciphertext( + shortint_ciphertext: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(shortint_ciphertext).unwrap(); + + drop(Box::from_raw(shortint_ciphertext)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_pbs_accumulator( + pbs_accumulator: *mut ShortintPBSAccumulator, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(pbs_accumulator).unwrap(); + + drop(Box::from_raw(pbs_accumulator)); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn destroy_shortint_bivariate_pbs_accumulator( + pbs_accumulator: *mut ShortintBivariatePBSAccumulator, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(pbs_accumulator).unwrap(); + + drop(Box::from_raw(pbs_accumulator)); + }) +} diff --git a/tfhe/src/c_api/shortint/mod.rs b/tfhe/src/c_api/shortint/mod.rs new file mode 100644 index 000000000..60a96afc1 --- /dev/null +++ b/tfhe/src/c_api/shortint/mod.rs @@ -0,0 +1,45 @@ +pub mod ciphertext; +pub mod client_key; +pub mod destroy; +pub mod parameters; +pub mod public_key; +pub mod server_key; + +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use crate::shortint; + +pub use ciphertext::ShortintCiphertext; +pub use client_key::ShortintClientKey; +pub use public_key::ShortintPublicKey; +pub use server_key::pbs::{ShortintBivariatePBSAccumulator, ShortintPBSAccumulator}; +pub use server_key::ShortintServerKey; + +#[no_mangle] +pub unsafe extern "C" fn shortint_gen_keys_with_parameters( + shortint_parameters: *const parameters::ShortintParameters, + result_client_key: *mut *mut ShortintClientKey, + result_server_key: *mut *mut ShortintServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_client_key).unwrap(); + check_ptr_is_non_null_and_aligned(result_server_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_client_key = std::ptr::null_mut(); + *result_server_key = std::ptr::null_mut(); + + let shortint_parameters = get_ref_checked(shortint_parameters).unwrap(); + + let client_key = shortint::client_key::ClientKey::new(shortint_parameters.0.to_owned()); + let server_key = shortint::server_key::ServerKey::new(&client_key); + + let heap_allocated_client_key = Box::new(ShortintClientKey(client_key)); + let heap_allocated_server_key = Box::new(ShortintServerKey(server_key)); + + *result_client_key = Box::into_raw(heap_allocated_client_key); + *result_server_key = Box::into_raw(heap_allocated_server_key); + }) +} diff --git a/tfhe/src/c_api/shortint/parameters.rs b/tfhe/src/c_api/shortint/parameters.rs new file mode 100644 index 000000000..94d6a9de6 --- /dev/null +++ b/tfhe/src/c_api/shortint/parameters.rs @@ -0,0 +1,119 @@ +use crate::c_api::utils::*; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + StandardDev, +}; +use std::os::raw::c_int; + +use crate::shortint; + +pub struct ShortintParameters(pub(in crate::c_api) shortint::parameters::Parameters); + +#[no_mangle] +pub unsafe extern "C" fn shortint_get_parameters( + message_bits: u32, + carry_bits: u32, + result: *mut *mut ShortintParameters, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + let params: Option<_> = match (message_bits, carry_bits) { + (1, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_0), + (1, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_1), + (2, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_0), + (1, 2) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_2), + (2, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_1), + (3, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_0), + (1, 3) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_3), + (2, 2) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2), + (3, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_1), + (4, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_0), + (1, 4) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_4), + (2, 3) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_3), + (3, 2) => Some(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_2), + (4, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_1), + (5, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_0), + (1, 5) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_5), + (2, 4) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_4), + (3, 3) => Some(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_3), + (4, 2) => Some(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_2), + (5, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_1), + (6, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_6_CARRY_0), + (1, 6) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_6), + (2, 5) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_5), + (3, 4) => Some(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_4), + (4, 3) => Some(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_3), + (5, 2) => Some(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_2), + (6, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_6_CARRY_1), + (7, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_7_CARRY_0), + (1, 7) => Some(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_7), + (2, 6) => Some(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_6), + (3, 5) => Some(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_5), + (4, 4) => Some(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_4), + (5, 3) => Some(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_3), + (6, 2) => Some(crate::shortint::parameters::PARAM_MESSAGE_6_CARRY_2), + (7, 1) => Some(crate::shortint::parameters::PARAM_MESSAGE_7_CARRY_1), + (8, 0) => Some(crate::shortint::parameters::PARAM_MESSAGE_8_CARRY_0), + _ => None, + }; + + match params { + Some(params) => { + let params = Box::new(ShortintParameters(params)); + *result = Box::into_raw(params); + } + None => *result = std::ptr::null_mut(), + } + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_create_parameters( + lwe_dimension: usize, + glwe_dimension: usize, + polynomial_size: usize, + lwe_modular_std_dev: f64, + glwe_modular_std_dev: f64, + pbs_base_log: usize, + pbs_level: usize, + ks_base_log: usize, + ks_level: usize, + pfks_level: usize, + pfks_base_log: usize, + pfks_modular_std_dev: f64, + cbs_level: usize, + cbs_base_log: usize, + message_modulus: usize, + carry_modulus: usize, + result: *mut *mut ShortintParameters, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let heap_allocated_parameters = + Box::new(ShortintParameters(shortint::parameters::Parameters { + lwe_dimension: LweDimension(lwe_dimension), + glwe_dimension: GlweDimension(glwe_dimension), + polynomial_size: PolynomialSize(polynomial_size), + lwe_modular_std_dev: StandardDev(lwe_modular_std_dev), + glwe_modular_std_dev: StandardDev(glwe_modular_std_dev), + pbs_base_log: DecompositionBaseLog(pbs_base_log), + pbs_level: DecompositionLevelCount(pbs_level), + ks_base_log: DecompositionBaseLog(ks_base_log), + ks_level: DecompositionLevelCount(ks_level), + pfks_level: DecompositionLevelCount(pfks_level), + pfks_base_log: DecompositionBaseLog(pfks_base_log), + pfks_modular_std_dev: StandardDev(pfks_modular_std_dev), + cbs_level: DecompositionLevelCount(cbs_level), + cbs_base_log: DecompositionBaseLog(cbs_base_log), + message_modulus: crate::shortint::parameters::MessageModulus(message_modulus), + carry_modulus: crate::shortint::parameters::CarryModulus(carry_modulus), + })); + + *result = Box::into_raw(heap_allocated_parameters); + }) +} diff --git a/tfhe/src/c_api/shortint/public_key.rs b/tfhe/src/c_api/shortint/public_key.rs new file mode 100644 index 000000000..3c4928d3c --- /dev/null +++ b/tfhe/src/c_api/shortint/public_key.rs @@ -0,0 +1,94 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use bincode; +use std::os::raw::c_int; + +use crate::shortint; + +use super::{ShortintCiphertext, ShortintClientKey, ShortintServerKey}; + +pub struct ShortintPublicKey(pub(in crate::c_api) shortint::public_key::PublicKey); + +#[no_mangle] +pub unsafe extern "C" fn shortint_gen_public_key( + client_key: *const ShortintClientKey, + result: *mut *mut ShortintPublicKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let heap_allocated_public_key = Box::new(ShortintPublicKey( + shortint::public_key::PublicKey::new(&client_key.0), + )); + + *result = Box::into_raw(heap_allocated_public_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_public_key_encrypt( + public_key: *const ShortintPublicKey, + server_key: *const ShortintServerKey, + value_to_encrypt: u64, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let public_key = get_ref_checked(public_key).unwrap(); + let server_key = get_ref_checked(server_key).unwrap(); + + let heap_allocated_ciphertext = Box::new(ShortintCiphertext( + public_key.0.encrypt(&server_key.0, value_to_encrypt), + )); + + *result = Box::into_raw(heap_allocated_ciphertext); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_serialize_public_key( + public_key: *const ShortintPublicKey, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let public_key = get_ref_checked(public_key).unwrap(); + + let buffer: Buffer = bincode::serialize(&public_key.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_deserialize_public_key( + buffer_view: BufferView, + result: *mut *mut ShortintPublicKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let public_key: shortint::public_key::PublicKey = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_public_key = Box::new(ShortintPublicKey(public_key)); + + *result = Box::into_raw(heap_allocated_public_key); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/add.rs b/tfhe/src/c_api/shortint/server_key/add.rs new file mode 100644 index 000000000..956d53980 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/add.rs @@ -0,0 +1,82 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_add( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_add(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_add( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_add(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_add_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_add_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_add_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_add_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/bitwise_op.rs b/tfhe/src/c_api/shortint/server_key/bitwise_op.rs new file mode 100644 index 000000000..30f963433 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/bitwise_op.rs @@ -0,0 +1,238 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_bitand( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_bitand(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_bitand( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_bitand(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_bitand_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_bitand_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_bitand_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_bitand_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_bitxor( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_bitxor(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_bitxor( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_bitxor(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_bitxor_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_bitxor_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_bitxor_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_bitxor_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_bitor( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_bitor(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_bitor( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_bitor(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_bitor_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_bitor_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_bitor_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_bitor_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/comp_op.rs b/tfhe/src/c_api/shortint/server_key/comp_op.rs new file mode 100644 index 000000000..f46a09fb4 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/comp_op.rs @@ -0,0 +1,406 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_greater( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_greater(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_greater( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_greater(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_greater_or_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key + .0 + .smart_greater_or_equal(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_greater_or_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key + .0 + .unchecked_greater_or_equal(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_less( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_less(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_less( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_less(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_less_or_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key + .0 + .smart_less_or_equal(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_less_or_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key + .0 + .unchecked_less_or_equal(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_equal(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_equal(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_not_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key + .0 + .smart_not_equal(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_not_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_not_equal(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_greater( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_greater(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_greater_or_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key + .0 + .smart_scalar_greater_or_equal(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_less( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_less(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_less_or_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_less_or_equal(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_equal(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_not_equal( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_not_equal(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/div_mod.rs b/tfhe/src/c_api/shortint/server_key/div_mod.rs new file mode 100644 index 000000000..938cf0165 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/div_mod.rs @@ -0,0 +1,156 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_div( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_div(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_div( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_div(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_div_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_div_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_div_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_div_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_div( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_div(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_div_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .unchecked_scalar_div_assign(&mut ct_left_and_result.0, right); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_mod( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_mod(&ct_left.0, right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_mod_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .unchecked_scalar_mod_assign(&mut ct_left_and_result.0, right); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/mod.rs b/tfhe/src/c_api/shortint/server_key/mod.rs new file mode 100644 index 000000000..a8e0510f0 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/mod.rs @@ -0,0 +1,81 @@ +use crate::c_api::buffer::*; +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use crate::shortint; + +use super::ShortintCiphertext; + +pub mod add; +pub mod bitwise_op; +pub mod comp_op; +pub mod div_mod; +pub mod mul; +pub mod neg; +pub mod pbs; +pub mod scalar_add; +pub mod scalar_mul; +pub mod scalar_sub; +pub mod shift; +pub mod sub; + +pub struct ShortintServerKey(pub(in crate::c_api) shortint::server_key::ServerKey); + +#[no_mangle] +pub unsafe extern "C" fn shortint_gen_server_key( + client_key: *const super::ShortintClientKey, + result_server_key: *mut *mut ShortintServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result_server_key).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result_server_key = std::ptr::null_mut(); + + let client_key = get_ref_checked(client_key).unwrap(); + + let server_key = shortint::server_key::ServerKey::new(&client_key.0); + + let heap_allocated_server_key = Box::new(ShortintServerKey(server_key)); + + *result_server_key = Box::into_raw(heap_allocated_server_key); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_serialize_server_key( + server_key: *const ShortintServerKey, + result: *mut Buffer, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + + let buffer: Buffer = bincode::serialize(&server_key.0).unwrap().into(); + + *result = buffer; + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_deserialize_server_key( + buffer_view: BufferView, + result: *mut *mut ShortintServerKey, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key: shortint::server_key::ServerKey = + bincode::deserialize(buffer_view.into()).unwrap(); + + let heap_allocated_server_key = Box::new(ShortintServerKey(server_key)); + + *result = Box::into_raw(heap_allocated_server_key); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/mul.rs b/tfhe/src/c_api/shortint/server_key/mul.rs new file mode 100644 index 000000000..b6e5b216d --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/mul.rs @@ -0,0 +1,82 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_mul( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_mul_lsb(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_mul( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_mul_lsb(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_mul_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_mul_lsb_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_mul_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_mul_lsb_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/neg.rs b/tfhe/src/c_api/shortint/server_key/neg.rs new file mode 100644 index 000000000..c397b125c --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/neg.rs @@ -0,0 +1,68 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_neg( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = + Box::new(ShortintCiphertext(server_key.0.smart_neg(&mut ct_left.0))); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_neg( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = + Box::new(ShortintCiphertext(server_key.0.unchecked_neg(&ct_left.0))); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_neg_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key.0.smart_neg_assign(&mut ct_left_and_result.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_neg_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key.0.unchecked_neg_assign(&mut ct_left_and_result.0); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/pbs.rs b/tfhe/src/c_api/shortint/server_key/pbs.rs new file mode 100644 index 000000000..afe8aea2f --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/pbs.rs @@ -0,0 +1,176 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +// This is the accepted way to declare a pointer to a C function/callback in cbindgen +pub type AccumulatorCallback = Option u64>; +pub type BivariateAccumulatorCallback = Option u64>; + +pub struct ShortintPBSAccumulator( + pub(in crate::c_api) crate::core_crypto::prelude::GlweCiphertext64, +); +pub struct ShortintBivariatePBSAccumulator( + pub(in crate::c_api) crate::core_crypto::prelude::GlweCiphertext64, +); + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_generate_pbs_accumulator( + server_key: *const ShortintServerKey, + accumulator_callback: AccumulatorCallback, + result: *mut *mut ShortintPBSAccumulator, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let accumulator_callback = accumulator_callback.unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + + // Closure is required as extern "C" fn does not implement the Fn trait + #[allow(clippy::redundant_closure)] + let heap_allocated_accumulator = Box::new(ShortintPBSAccumulator( + server_key + .0 + .generate_accumulator(|x: u64| accumulator_callback(x)), + )); + + *result = Box::into_raw(heap_allocated_accumulator); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_programmable_bootstrap( + server_key: *const ShortintServerKey, + accumulator: *const ShortintPBSAccumulator, + ct_in: *const ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let accumulator = get_ref_checked(accumulator).unwrap(); + let ct_in = get_ref_checked(ct_in).unwrap(); + + let heap_allocated_result = Box::new(ShortintCiphertext( + server_key + .0 + .keyswitch_programmable_bootstrap(&ct_in.0, &accumulator.0), + )); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_programmable_bootstrap_assign( + server_key: *const ShortintServerKey, + accumulator: *const ShortintPBSAccumulator, + ct_in_and_result: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let accumulator = get_ref_checked(accumulator).unwrap(); + let ct_in_and_result = get_mut_checked(ct_in_and_result).unwrap(); + + server_key + .0 + .keyswitch_programmable_bootstrap_assign(&mut ct_in_and_result.0, &accumulator.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_generate_bivariate_pbs_accumulator( + server_key: *const ShortintServerKey, + accumulator_callback: BivariateAccumulatorCallback, + result: *mut *mut ShortintBivariatePBSAccumulator, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let accumulator_callback = accumulator_callback.unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + + // Closure is required as extern "C" fn does not implement the Fn trait + #[allow(clippy::redundant_closure)] + let heap_allocated_accumulator = Box::new(ShortintBivariatePBSAccumulator( + server_key + .0 + .generate_accumulator_bivariate(|x: u64, y: u64| accumulator_callback(x, y)), + )); + + *result = Box::into_raw(heap_allocated_accumulator); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_bivariate_programmable_bootstrap( + server_key: *const ShortintServerKey, + accumulator: *const ShortintBivariatePBSAccumulator, + ct_left: *const ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + // First fill the result with a null ptr so that if we fail and the return code is not + // checked, then any access to the result pointer will segfault (mimics malloc on failure) + *result = std::ptr::null_mut(); + + let server_key = get_ref_checked(server_key).unwrap(); + let accumulator = get_ref_checked(accumulator).unwrap(); + let ct_left = get_ref_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_result = Box::new(ShortintCiphertext( + crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_bivariate_pbs(&server_key.0, &ct_left.0, &mut ct_right.0, &accumulator.0) + .unwrap() + }), + )); + + *result = Box::into_raw(heap_allocated_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_bivariate_programmable_bootstrap_assign( + server_key: *const ShortintServerKey, + accumulator: *const ShortintBivariatePBSAccumulator, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let accumulator = get_ref_checked(accumulator).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_bivariate_pbs_assign( + &server_key.0, + &mut ct_left_and_result.0, + &mut ct_right.0, + &accumulator.0, + ) + .unwrap() + }); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/scalar_add.rs b/tfhe/src/c_api/shortint/server_key/scalar_add.rs new file mode 100644 index 000000000..edc7e3f68 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/scalar_add.rs @@ -0,0 +1,78 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_add( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + scalar_right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_add(&mut ct_left.0, scalar_right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_add( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + scalar_right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_add(&ct_left.0, scalar_right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_add_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + scalar_right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .smart_scalar_add_assign(&mut ct_left_and_result.0, scalar_right); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_add_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + scalar_right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .unchecked_scalar_add_assign(&mut ct_left_and_result.0, scalar_right); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/scalar_mul.rs b/tfhe/src/c_api/shortint/server_key/scalar_mul.rs new file mode 100644 index 000000000..9b596ab37 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/scalar_mul.rs @@ -0,0 +1,78 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_mul( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + scalar_right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_mul(&mut ct_left.0, scalar_right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_mul( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + scalar_right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_mul(&ct_left.0, scalar_right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_mul_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + scalar_right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .smart_scalar_mul_assign(&mut ct_left_and_result.0, scalar_right); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_mul_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + scalar_right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .unchecked_scalar_mul_assign(&mut ct_left_and_result.0, scalar_right); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/scalar_sub.rs b/tfhe/src/c_api/shortint/server_key/scalar_sub.rs new file mode 100644 index 000000000..b7d5c4b14 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/scalar_sub.rs @@ -0,0 +1,78 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_sub( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + scalar_right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_sub(&mut ct_left.0, scalar_right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_sub( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + scalar_right: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_sub(&ct_left.0, scalar_right), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_sub_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + scalar_right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .smart_scalar_sub_assign(&mut ct_left_and_result.0, scalar_right); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_sub_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + scalar_right: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + + server_key + .0 + .unchecked_scalar_sub_assign(&mut ct_left_and_result.0, scalar_right); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/shift.rs b/tfhe/src/c_api/shortint/server_key/shift.rs new file mode 100644 index 000000000..8137108c3 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/shift.rs @@ -0,0 +1,134 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_left_shift( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct = get_mut_checked(ct).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_scalar_left_shift(&mut ct.0, shift), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_left_shift( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct = get_mut_checked(ct).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_left_shift(&ct.0, shift), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_right_shift( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + shortint_server_key_unchecked_scalar_right_shift(server_key, ct, shift, result) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_right_shift( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct = get_mut_checked(ct).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_scalar_right_shift(&ct.0, shift), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_left_shift_assign( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct = get_mut_checked(ct).unwrap(); + + server_key + .0 + .smart_scalar_left_shift_assign(&mut ct.0, shift); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_left_shift_assign( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct = get_mut_checked(ct).unwrap(); + + server_key + .0 + .unchecked_scalar_left_shift_assign(&mut ct.0, shift); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_scalar_right_shift_assign( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, +) -> c_int { + shortint_server_key_unchecked_scalar_right_shift_assign(server_key, ct, shift) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_scalar_right_shift_assign( + server_key: *const ShortintServerKey, + ct: *mut ShortintCiphertext, + shift: u8, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct = get_mut_checked(ct).unwrap(); + + server_key + .0 + .unchecked_scalar_right_shift_assign(&mut ct.0, shift); + }) +} diff --git a/tfhe/src/c_api/shortint/server_key/sub.rs b/tfhe/src/c_api/shortint/server_key/sub.rs new file mode 100644 index 000000000..bf952d1b3 --- /dev/null +++ b/tfhe/src/c_api/shortint/server_key/sub.rs @@ -0,0 +1,82 @@ +use crate::c_api::utils::*; +use std::os::raw::c_int; + +use super::{ShortintCiphertext, ShortintServerKey}; + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_sub( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.smart_sub(&mut ct_left.0, &mut ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_sub( + server_key: *const ShortintServerKey, + ct_left: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, + result: *mut *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + check_ptr_is_non_null_and_aligned(result).unwrap(); + + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left = get_mut_checked(ct_left).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + let heap_allocated_ct_result = Box::new(ShortintCiphertext( + server_key.0.unchecked_sub(&ct_left.0, &ct_right.0), + )); + + *result = Box::into_raw(heap_allocated_ct_result); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_smart_sub_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .smart_sub_assign(&mut ct_left_and_result.0, &mut ct_right.0); + }) +} + +#[no_mangle] +pub unsafe extern "C" fn shortint_server_key_unchecked_sub_assign( + server_key: *const ShortintServerKey, + ct_left_and_result: *mut ShortintCiphertext, + ct_right: *mut ShortintCiphertext, +) -> c_int { + catch_panic(|| { + let server_key = get_ref_checked(server_key).unwrap(); + let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap(); + let ct_right = get_mut_checked(ct_right).unwrap(); + + server_key + .0 + .unchecked_sub_assign(&mut ct_left_and_result.0, &ct_right.0); + }) +} diff --git a/tfhe/src/c_api/utils.rs b/tfhe/src/c_api/utils.rs new file mode 100644 index 000000000..f4752ff91 --- /dev/null +++ b/tfhe/src/c_api/utils.rs @@ -0,0 +1,47 @@ +use std::os::raw::c_int; + +pub fn catch_panic(closure: F) -> c_int +where + F: FnOnce(), +{ + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(closure)) { + Ok(_) => 0, + _ => 1, + } +} + +pub fn check_ptr_is_non_null_and_aligned(ptr: *const T) -> Result<(), String> { + if ptr.is_null() { + return Err(format!("pointer is null, got: {ptr:p}")); + } + let expected_alignment = std::mem::align_of::(); + if ptr as usize % expected_alignment != 0 { + return Err(format!( + "pointer is misaligned, expected {} bytes alignement, got pointer: {:p}. \ + You May have mixed some pointers in your function call. If that's not the case \ + check tfhe.h for alignment constants for plain data types allocation.", + expected_alignment, ptr + )); + } + Ok(()) +} + +pub fn get_mut_checked<'a, T>(ptr: *mut T) -> Result<&'a mut T, String> { + match check_ptr_is_non_null_and_aligned(ptr) { + Ok(()) => unsafe { + ptr.as_mut() + .ok_or_else(|| "Error while converting to mut reference".into()) + }, + Err(e) => Err(e), + } +} + +pub fn get_ref_checked<'a, T>(ptr: *const T) -> Result<&'a T, String> { + match check_ptr_is_non_null_and_aligned(ptr) { + Ok(()) => unsafe { + ptr.as_ref() + .ok_or_else(|| "Error while converting to reference".into()) + }, + Err(e) => Err(e), + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/glwe_ciphertext_conversion.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/glwe_ciphertext_conversion.rs new file mode 100644 index 000000000..dc12c424a --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/glwe_ciphertext_conversion.rs @@ -0,0 +1,340 @@ +use crate::core_crypto::backends::cuda::engines::{CudaEngine, CudaError}; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaGlweCiphertext32, CudaGlweCiphertext64, +}; +use crate::core_crypto::backends::cuda::private::crypto::glwe::ciphertext::CudaGlweCiphertext; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext; +use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; +use crate::core_crypto::prelude::{GlweCiphertext32, GlweCiphertext64, GlweCiphertextView64}; +use crate::core_crypto::specification::engines::{ + GlweCiphertextConversionEngine, GlweCiphertextConversionError, +}; +use crate::core_crypto::specification::entities::GlweCiphertextEntity; + +impl From for GlweCiphertextConversionError { + fn from(err: CudaError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// Convert a GLWE ciphertext with 32 bits of precision from CPU to GPU 0. +/// Only this conversion is necessary to run the bootstrap on the GPU. +impl GlweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u32 << 20; polynomial_size.0]; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_plaintext_vector: PlaintextVector32 = + /// default_engine.create_plaintext_vector_from(&input)?; + /// let mut h_ciphertext: GlweCiphertext32 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &h_plaintext_vector)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaGlweCiphertext32 = cuda_engine.convert_glwe_ciphertext(&h_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(d_ciphertext.polynomial_size(), polynomial_size); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_glwe_ciphertext( + &mut self, + input: &GlweCiphertext32, + ) -> Result> { + let stream = &self.streams[0]; + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + Ok(unsafe { self.convert_glwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_glwe_ciphertext_unchecked( + &mut self, + input: &GlweCiphertext32, + ) -> CudaGlweCiphertext32 { + // Copy the entire input vector over all GPUs + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0; + let stream = &self.streams[0]; + let mut vec = stream.malloc::(data_per_gpu as u32); + let input_slice = input.0.as_tensor().as_slice(); + stream.copy_to_gpu::(&mut vec, input_slice); + CudaGlweCiphertext32(CudaGlweCiphertext:: { + d_vec: vec, + glwe_dimension: input.glwe_dimension(), + polynomial_size: input.polynomial_size(), + }) + } +} + +/// # Description +/// Convert a GLWE ciphertext vector with 32 bits of precision from GPU 0 to CPU. +/// This conversion is not necessary to run the bootstrap on the GPU. +/// It is implemented for testing purposes only. +impl GlweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u32 << 20; polynomial_size.0]; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_plaintext_vector: PlaintextVector32 = + /// default_engine.create_plaintext_vector_from(&input)?; + /// let mut h_ciphertext: GlweCiphertext32 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &h_plaintext_vector)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaGlweCiphertext32 = cuda_engine.convert_glwe_ciphertext(&h_ciphertext)?; + /// let h_output_ciphertext: GlweCiphertext32 = + /// cuda_engine.convert_glwe_ciphertext(&d_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(d_ciphertext.polynomial_size(), polynomial_size); + /// assert_eq!(h_ciphertext, h_output_ciphertext); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_glwe_ciphertext( + &mut self, + input: &CudaGlweCiphertext32, + ) -> Result> { + Ok(unsafe { self.convert_glwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_glwe_ciphertext_unchecked( + &mut self, + input: &CudaGlweCiphertext32, + ) -> GlweCiphertext32 { + // Copy the data from GPU 0 back to the CPU + let mut output = + vec![0u32; input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0]; + let stream = &self.streams[0]; + stream.copy_to_cpu::(&mut output, &input.0.d_vec); + GlweCiphertext32(GlweCiphertext::from_container( + output, + input.polynomial_size(), + )) + } +} + +/// # Description +/// Convert a GLWE ciphertext with 64 bits of precision from CPU to GPU 0. +/// Only this conversion is necessary to run the bootstrap on the GPU. +impl GlweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u64 << 20; polynomial_size.0]; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_plaintext_vector: PlaintextVector64 = + /// default_engine.create_plaintext_vector_from(&input)?; + /// let mut h_ciphertext: GlweCiphertext64 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &h_plaintext_vector)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaGlweCiphertext64 = cuda_engine.convert_glwe_ciphertext(&h_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(d_ciphertext.polynomial_size(), polynomial_size); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_glwe_ciphertext( + &mut self, + input: &GlweCiphertext64, + ) -> Result> { + let stream = &self.streams[0]; + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + Ok(unsafe { self.convert_glwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_glwe_ciphertext_unchecked( + &mut self, + input: &GlweCiphertext64, + ) -> CudaGlweCiphertext64 { + // Copy the entire input vector over all GPUs + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0; + let stream = &self.streams[0]; + let mut vec = stream.malloc::(data_per_gpu as u32); + let input_slice = input.0.as_tensor().as_slice(); + stream.copy_to_gpu::(&mut vec, input_slice); + CudaGlweCiphertext64(CudaGlweCiphertext:: { + d_vec: vec, + glwe_dimension: input.glwe_dimension(), + polynomial_size: input.polynomial_size(), + }) + } +} + +/// # Description +/// Convert a GLWE ciphertext vector with 64 bits of precision from GPU 0 to CPU. +/// This conversion is not necessary to run the bootstrap on the GPU. +/// It is implemented for testing purposes only. +impl GlweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u64 << 20; polynomial_size.0]; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_plaintext_vector: PlaintextVector64 = + /// default_engine.create_plaintext_vector_from(&input)?; + /// let mut h_ciphertext: GlweCiphertext64 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &h_plaintext_vector)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaGlweCiphertext64 = cuda_engine.convert_glwe_ciphertext(&h_ciphertext)?; + /// let h_output_ciphertext: GlweCiphertext64 = + /// cuda_engine.convert_glwe_ciphertext(&d_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(d_ciphertext.polynomial_size(), polynomial_size); + /// assert_eq!(h_ciphertext, h_output_ciphertext); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_glwe_ciphertext( + &mut self, + input: &CudaGlweCiphertext64, + ) -> Result> { + Ok(unsafe { self.convert_glwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_glwe_ciphertext_unchecked( + &mut self, + input: &CudaGlweCiphertext64, + ) -> GlweCiphertext64 { + // Copy the data from GPU 0 back to the CPU + let mut output = + vec![0u64; input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0]; + let stream = &self.streams[0]; + stream.copy_to_cpu::(&mut output, &input.0.d_vec); + GlweCiphertext64(GlweCiphertext::from_container( + output, + input.polynomial_size(), + )) + } +} + +/// # Description +/// Convert a view of a GLWE ciphertext with 64 bits of precision from CPU to GPU 0. +impl GlweCiphertextConversionEngine, CudaGlweCiphertext64> for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u64 << 20; polynomial_size.0]; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_plaintext_vector: PlaintextVector64 = + /// default_engine.create_plaintext_vector_from(&input)?; + /// let mut h_ciphertext: GlweCiphertext64 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &h_plaintext_vector)?; + /// let h_raw_ciphertext: Vec = + /// default_engine.consume_retrieve_glwe_ciphertext(h_ciphertext)?; + /// let mut h_view_ciphertext: GlweCiphertextView64 = + /// default_engine.create_glwe_ciphertext_from(h_raw_ciphertext.as_slice(), polynomial_size)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaGlweCiphertext64 = + /// cuda_engine.convert_glwe_ciphertext(&h_view_ciphertext)?; + /// let h_output_ciphertext: GlweCiphertext64 = + /// cuda_engine.convert_glwe_ciphertext(&d_ciphertext)?; + /// + /// // Extracts the internal container + /// let h_raw_output_ciphertext: Vec = + /// default_engine.consume_retrieve_glwe_ciphertext(h_output_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(d_ciphertext.polynomial_size(), polynomial_size); + /// assert_eq!(h_raw_ciphertext, h_raw_output_ciphertext); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_glwe_ciphertext( + &mut self, + input: &GlweCiphertextView64, + ) -> Result> { + let stream = &self.streams[0]; + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + Ok(unsafe { self.convert_glwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_glwe_ciphertext_unchecked( + &mut self, + input: &GlweCiphertextView64, + ) -> CudaGlweCiphertext64 { + // Copy the entire input vector over all GPUs + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 * input.polynomial_size().0; + let stream = &self.streams[0]; + let mut vec = stream.malloc::(data_per_gpu as u32); + let input_slice = input.0.as_tensor().as_slice(); + stream.copy_to_gpu::(&mut vec, input_slice); + CudaGlweCiphertext64(CudaGlweCiphertext:: { + d_vec: vec, + glwe_dimension: input.glwe_dimension(), + polynomial_size: input.polynomial_size(), + }) + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_bootstrap_key_conversion.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_bootstrap_key_conversion.rs new file mode 100644 index 000000000..255fbedf0 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_bootstrap_key_conversion.rs @@ -0,0 +1,196 @@ +use crate::core_crypto::backends::cuda::engines::CudaError; +use crate::core_crypto::backends::cuda::implementation::engines::{ + check_base_log, check_glwe_dim, CudaEngine, +}; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaFourierLweBootstrapKey32, CudaFourierLweBootstrapKey64, +}; +use crate::core_crypto::backends::cuda::private::crypto::bootstrap::{ + convert_lwe_bootstrap_key_from_cpu_to_gpu, CudaBootstrapKey, +}; +use crate::core_crypto::prelude::{LweBootstrapKey32, LweBootstrapKey64}; +use crate::core_crypto::specification::engines::{ + LweBootstrapKeyConversionEngine, LweBootstrapKeyConversionError, +}; +use crate::core_crypto::specification::entities::LweBootstrapKeyEntity; +use std::marker::PhantomData; + +impl From for LweBootstrapKeyConversionError { + fn from(err: CudaError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// Convert an LWE bootstrap key corresponding to 32 bits of precision from the CPU to the GPU. +/// The bootstrap key is copied entirely to all the GPUs and converted from the standard to the +/// Fourier domain. + +impl LweBootstrapKeyConversionEngine + for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::backends::cuda::private::device::GpuIndex; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(1), PolynomialSize(512)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey32 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_fourier_bsk: CudaFourierLweBootstrapKey32 = + /// cuda_engine.convert_lwe_bootstrap_key(&bsk)?; + /// + /// assert_eq!(d_fourier_bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(d_fourier_bsk.polynomial_size(), poly_size); + /// assert_eq!(d_fourier_bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(d_fourier_bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(d_fourier_bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_bootstrap_key( + &mut self, + input: &LweBootstrapKey32, + ) -> Result> { + let poly_size = input.0.polynomial_size(); + check_poly_size!(poly_size); + let glwe_dim = input.0.glwe_size().to_glwe_dimension(); + check_glwe_dim!(glwe_dim); + let base_log = input.0.base_log(); + check_base_log!(base_log); + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 + * input.glwe_dimension().to_glwe_size().0 + * input.input_lwe_dimension().0 + * input.decomposition_level_count().0 + * input.polynomial_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + for stream in self.streams.iter() { + stream.check_device_memory(size)?; + } + Ok(unsafe { self.convert_lwe_bootstrap_key_unchecked(input) }) + } + + unsafe fn convert_lwe_bootstrap_key_unchecked( + &mut self, + input: &LweBootstrapKey32, + ) -> CudaFourierLweBootstrapKey32 { + let vecs = convert_lwe_bootstrap_key_from_cpu_to_gpu::( + self.get_cuda_streams(), + &input.0, + self.get_number_of_gpus(), + ); + CudaFourierLweBootstrapKey32(CudaBootstrapKey:: { + d_vecs: vecs, + polynomial_size: input.polynomial_size(), + input_lwe_dimension: input.input_lwe_dimension(), + glwe_dimension: input.glwe_dimension(), + decomp_level: input.decomposition_level_count(), + decomp_base_log: input.decomposition_base_log(), + _phantom: PhantomData::default(), + }) + } +} + +/// # Description +/// Convert an LWE bootstrap key corresponding to 64 bits of precision from the CPU to the GPU. +/// The bootstrap key is copied entirely to all the GPUs and converted from the standard to the +/// Fourier domain. + +impl LweBootstrapKeyConversionEngine + for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::backends::cuda::private::device::GpuIndex; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(1), PolynomialSize(512)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey64 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_fourier_bsk: CudaFourierLweBootstrapKey64 = + /// cuda_engine.convert_lwe_bootstrap_key(&bsk)?; + /// + /// assert_eq!(d_fourier_bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(d_fourier_bsk.polynomial_size(), poly_size); + /// assert_eq!(d_fourier_bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(d_fourier_bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(d_fourier_bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_bootstrap_key( + &mut self, + input: &LweBootstrapKey64, + ) -> Result> { + let poly_size = input.0.polynomial_size(); + check_poly_size!(poly_size); + let glwe_dim = input.0.glwe_size().to_glwe_dimension(); + check_glwe_dim!(glwe_dim); + let data_per_gpu = input.glwe_dimension().to_glwe_size().0 + * input.glwe_dimension().to_glwe_size().0 + * input.input_lwe_dimension().0 + * input.decomposition_level_count().0 + * input.polynomial_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + for stream in self.streams.iter() { + stream.check_device_memory(size)?; + } + Ok(unsafe { self.convert_lwe_bootstrap_key_unchecked(input) }) + } + + unsafe fn convert_lwe_bootstrap_key_unchecked( + &mut self, + input: &LweBootstrapKey64, + ) -> CudaFourierLweBootstrapKey64 { + let vecs = convert_lwe_bootstrap_key_from_cpu_to_gpu::( + self.get_cuda_streams(), + &input.0, + self.get_number_of_gpus(), + ); + CudaFourierLweBootstrapKey64(CudaBootstrapKey:: { + d_vecs: vecs, + polynomial_size: input.polynomial_size(), + input_lwe_dimension: input.input_lwe_dimension(), + glwe_dimension: input.glwe_dimension(), + decomp_level: input.decomposition_level_count(), + decomp_base_log: input.decomposition_base_log(), + _phantom: PhantomData::default(), + }) + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_conversion.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_conversion.rs new file mode 100644 index 000000000..2358761d5 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_conversion.rs @@ -0,0 +1,301 @@ +use crate::core_crypto::backends::cuda::implementation::engines::{CudaEngine, CudaError}; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaLweCiphertext32, CudaLweCiphertext64, +}; +use crate::core_crypto::backends::cuda::private::crypto::lwe::ciphertext::CudaLweCiphertext; +use crate::core_crypto::commons::crypto::lwe::LweCiphertext; +use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; +use crate::core_crypto::prelude::{LweCiphertext32, LweCiphertext64, LweCiphertextView64}; +use crate::core_crypto::specification::engines::{ + LweCiphertextConversionEngine, LweCiphertextConversionError, +}; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +impl From for LweCiphertextConversionError { + fn from(err: CudaError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// Convert an LWE ciphertext with 32 bits of precision from CPU to GPU 0. +impl LweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_ciphertext( + &mut self, + input: &LweCiphertext32, + ) -> Result> { + let stream = self.streams.first().unwrap(); + let data_per_gpu = input.lwe_dimension().to_lwe_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + Ok(unsafe { self.convert_lwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_lwe_ciphertext_unchecked( + &mut self, + input: &LweCiphertext32, + ) -> CudaLweCiphertext32 { + let alloc_size = input.lwe_dimension().to_lwe_size().0 as u32; + let input_slice = input.0.as_tensor().as_slice(); + let stream = self.streams.first().unwrap(); + let mut vec = stream.malloc::(alloc_size); + stream.copy_to_gpu::(&mut vec, input_slice); + CudaLweCiphertext32(CudaLweCiphertext:: { + d_vec: vec, + lwe_dimension: input.lwe_dimension(), + }) + } +} + +/// # Description +/// Convert an LWE ciphertext with 32 bits of precision from GPU 0 to CPU. +impl LweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// let h_ciphertext_output: LweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&d_ciphertext)?; + /// assert_eq!(h_ciphertext_output.lwe_dimension(), lwe_dimension); + /// assert_eq!(h_ciphertext, h_ciphertext_output); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_ciphertext( + &mut self, + input: &CudaLweCiphertext32, + ) -> Result> { + Ok(unsafe { self.convert_lwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_lwe_ciphertext_unchecked( + &mut self, + input: &CudaLweCiphertext32, + ) -> LweCiphertext32 { + let mut output = vec![0_u32; input.lwe_dimension().to_lwe_size().0]; + let stream = self.streams.first().unwrap(); + stream.copy_to_cpu::(&mut output, &input.0.d_vec); + LweCiphertext32(LweCiphertext::from_container(output)) + } +} + +/// # Description +/// Convert an LWE ciphertext with 64 bits of precision from CPU to GPU 0. +impl LweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u64 << 20; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext64 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext64 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext64 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_ciphertext( + &mut self, + input: &LweCiphertext64, + ) -> Result> { + let stream = self.streams.first().unwrap(); + let data_per_gpu = input.lwe_dimension().to_lwe_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + Ok(unsafe { self.convert_lwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_lwe_ciphertext_unchecked( + &mut self, + input: &LweCiphertext64, + ) -> CudaLweCiphertext64 { + let alloc_size = input.lwe_dimension().to_lwe_size().0 as u32; + let input_slice = input.0.as_tensor().as_slice(); + let stream = self.streams.first().unwrap(); + let mut vec = stream.malloc::(alloc_size); + stream.copy_to_gpu::(&mut vec, input_slice); + CudaLweCiphertext64(CudaLweCiphertext:: { + d_vec: vec, + lwe_dimension: input.lwe_dimension(), + }) + } +} + +/// # Description +/// Convert an LWE ciphertext with 64 bits of precision from GPU 0 to CPU. +impl LweCiphertextConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext64 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext64 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext64 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// let h_ciphertext_output: LweCiphertext64 = cuda_engine.convert_lwe_ciphertext(&d_ciphertext)?; + /// assert_eq!(h_ciphertext_output.lwe_dimension(), lwe_dimension); + /// assert_eq!(h_ciphertext, h_ciphertext_output); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_ciphertext( + &mut self, + input: &CudaLweCiphertext64, + ) -> Result> { + Ok(unsafe { self.convert_lwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_lwe_ciphertext_unchecked( + &mut self, + input: &CudaLweCiphertext64, + ) -> LweCiphertext64 { + let mut output = vec![0_u64; input.lwe_dimension().to_lwe_size().0]; + let stream = self.streams.first().unwrap(); + stream.copy_to_cpu::(&mut output, &input.0.d_vec); + LweCiphertext64(LweCiphertext::from_container(output)) + } +} + +/// # Description +/// Convert a view of an LWE ciphertext with 64 bits of precision from CPU to GPU 0. +impl LweCiphertextConversionEngine, CudaLweCiphertext64> for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u64 << 20; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext64 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext64 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// // Creates a LweCiphertextView64 object from LweCiphertext64 + /// let h_raw_ciphertext: Vec = + /// default_engine.consume_retrieve_lwe_ciphertext(h_ciphertext)?; + /// let mut h_view_ciphertext: LweCiphertextView64 = + /// default_engine.create_lwe_ciphertext_from(h_raw_ciphertext.as_slice())?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext64 = + /// cuda_engine.convert_lwe_ciphertext(&h_view_ciphertext)?; + /// + /// assert_eq!(d_ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_ciphertext( + &mut self, + input: &LweCiphertextView64, + ) -> Result> { + let stream = &self.streams[0]; + let data_per_gpu = input.lwe_dimension().to_lwe_size().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + Ok(unsafe { self.convert_lwe_ciphertext_unchecked(input) }) + } + + unsafe fn convert_lwe_ciphertext_unchecked( + &mut self, + input: &LweCiphertextView64, + ) -> CudaLweCiphertext64 { + let alloc_size = input.lwe_dimension().to_lwe_size().0 as u32; + let input_slice = input.0.as_tensor().as_slice(); + let stream = &self.streams[0]; + let mut d_vec = stream.malloc::(alloc_size); + stream.copy_to_gpu::(&mut d_vec, input_slice); + CudaLweCiphertext64(CudaLweCiphertext:: { + d_vec, + lwe_dimension: input.lwe_dimension(), + }) + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_bootstrap.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_bootstrap.rs new file mode 100644 index 000000000..baba5a9d0 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_bootstrap.rs @@ -0,0 +1,298 @@ +use crate::core_crypto::backends::cuda::engines::CudaError; +use crate::core_crypto::backends::cuda::implementation::engines::{ + check_base_log, check_glwe_dim, CudaEngine, +}; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaFourierLweBootstrapKey32, CudaFourierLweBootstrapKey64, CudaGlweCiphertext32, + CudaGlweCiphertext64, CudaLweCiphertext32, CudaLweCiphertext64, +}; +use crate::core_crypto::backends::cuda::private::device::NumberOfSamples; +use crate::core_crypto::prelude::LweCiphertextIndex; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingBootstrapEngine, LweCiphertextDiscardingBootstrapError, +}; +use crate::core_crypto::specification::entities::LweBootstrapKeyEntity; + +impl From for LweCiphertextDiscardingBootstrapError { + fn from(err: CudaError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// A discard bootstrap on an input ciphertext with 32 bits of precision. +/// The input bootstrap key is in the Fourier domain. +impl + LweCiphertextDiscardingBootstrapEngine< + CudaFourierLweBootstrapKey32, + CudaGlweCiphertext32, + CudaLweCiphertext32, + CudaLweCiphertext32, + > for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, PolynomialSize, + /// }; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let (lwe_dim, lwe_dim_output, glwe_dim, poly_size) = ( + /// LweDimension(130), + /// LweDimension(512), + /// GlweDimension(1), + /// PolynomialSize(512), + /// ); + /// let log_degree = f64::log2(poly_size.0 as f64) as i32; + /// let val: u32 = ((poly_size.0 as f64 - (10. * f64::sqrt((lwe_dim.0 as f64) / 16.0))) + /// * 2_f64.powi(32 - log_degree - 1)) as u32; + /// let input = val; + /// let noise = Variance(2_f64.powf(-29.)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(7)); + /// // An identity function is applied during the bootstrap + /// let mut lut = vec![0u32; poly_size.0]; + /// for i in 0..poly_size.0 { + /// let l = (i as f64 * 2_f64.powi(32 - log_degree - 1)) as u32; + /// lut[i] = l; + /// } + /// + /// // 1. default engine + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// // create a vector of LWE ciphertexts + /// let h_input_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let h_input_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_input_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&h_input_key, &h_input_plaintext, noise)?; + /// // create a GLWE ciphertext containing an encryption of the LUT + /// let h_lut_plaintext_vector = default_engine.create_plaintext_vector_from(&lut)?; + /// let h_lut_key: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let mut h_lut: GlweCiphertext32 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dim.to_glwe_size(), &h_lut_plaintext_vector)?; + /// // create a BSK + /// let h_bootstrap_key: LweBootstrapKey32 = default_engine.generate_new_lwe_bootstrap_key( + /// &h_input_key, + /// &h_lut_key, + /// dec_bl, + /// dec_lc, + /// noise, + /// )?; + /// // initialize an output LWE ciphertext vector + /// let h_dummy_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim_output)?; + /// + /// // 2. cuda engine + /// let mut cuda_engine = CudaEngine::new(())?; + /// // convert input to GPU 0 + /// let d_input_ciphertext: CudaLweCiphertext32 = + /// cuda_engine.convert_lwe_ciphertext(&h_input_ciphertext)?; + /// // convert accumulator to GPU + /// let d_input_lut: CudaGlweCiphertext32 = cuda_engine.convert_glwe_ciphertext(&h_lut)?; + /// // convert BSK to GPU (and from Standard to Fourier representations) + /// let d_fourier_bsk: CudaFourierLweBootstrapKey32 = + /// cuda_engine.convert_lwe_bootstrap_key(&h_bootstrap_key)?; + /// // launch bootstrap on GPU + /// let h_zero_output_ciphertext: LweCiphertext32 = + /// default_engine.zero_encrypt_lwe_ciphertext(&h_dummy_key, noise)?; + /// let mut d_output_ciphertext: CudaLweCiphertext32 = + /// cuda_engine.convert_lwe_ciphertext(&h_zero_output_ciphertext)?; + /// cuda_engine.discard_bootstrap_lwe_ciphertext( + /// &mut d_output_ciphertext, + /// &d_input_ciphertext, + /// &d_input_lut, + /// &d_fourier_bsk, + /// )?; + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut CudaLweCiphertext32, + input: &CudaLweCiphertext32, + acc: &CudaGlweCiphertext32, + bsk: &CudaFourierLweBootstrapKey32, + ) -> Result<(), LweCiphertextDiscardingBootstrapError> { + LweCiphertextDiscardingBootstrapError::perform_generic_checks(output, input, acc, bsk)?; + let poly_size = bsk.polynomial_size(); + check_poly_size!(poly_size); + let glwe_dim = bsk.glwe_dimension(); + check_glwe_dim!(glwe_dim); + let base_log = bsk.decomposition_base_log(); + check_base_log!(base_log); + unsafe { self.discard_bootstrap_lwe_ciphertext_unchecked(output, input, acc, bsk) }; + Ok(()) + } + + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut CudaLweCiphertext32, + input: &CudaLweCiphertext32, + acc: &CudaGlweCiphertext32, + bsk: &CudaFourierLweBootstrapKey32, + ) { + let stream = self.streams.first().unwrap(); + let mut test_vector_indexes = stream.malloc::(1); + stream.copy_to_gpu(&mut test_vector_indexes, &[0]); + + stream.discard_bootstrap_low_latency_lwe_ciphertext_vector::( + &mut output.0.d_vec, + &acc.0.d_vec, + &test_vector_indexes, + &input.0.d_vec, + bsk.0.d_vecs.first().unwrap(), + input.0.lwe_dimension, + bsk.glwe_dimension(), + bsk.polynomial_size(), + bsk.decomposition_base_log(), + bsk.decomposition_level_count(), + NumberOfSamples(1), + LweCiphertextIndex(0), + self.get_cuda_shared_memory(), + ); + } +} + +/// # Description +/// A discard bootstrap on an input ciphertext with 64 bits of precision. +/// The input bootstrap key is in the Fourier domain. +impl + LweCiphertextDiscardingBootstrapEngine< + CudaFourierLweBootstrapKey64, + CudaGlweCiphertext64, + CudaLweCiphertext64, + CudaLweCiphertext64, + > for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, PolynomialSize, + /// }; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let (lwe_dim, lwe_dim_output, glwe_dim, poly_size) = ( + /// LweDimension(130), + /// LweDimension(512), + /// GlweDimension(1), + /// PolynomialSize(512), + /// ); + /// let log_degree = f64::log2(poly_size.0 as f64) as i64; + /// let val: u64 = ((poly_size.0 as f64 - (10. * f64::sqrt((lwe_dim.0 as f64) / 16.0))) + /// * 2_f64.powi((64 - log_degree - 1) as i32)) as u64; + /// let input = val; + /// let noise = Variance(2_f64.powf(-29.)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(7)); + /// // An identity function is applied during the bootstrap + /// let mut lut = vec![0u64; poly_size.0]; + /// for i in 0..poly_size.0 { + /// let l = (i as f64 * 2_f64.powi((64 - log_degree - 1) as i32)) as u64; + /// lut[i] = l; + /// } + /// + /// // 1. default engine + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// // create a vector of LWE ciphertexts + /// let h_input_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let h_input_plaintext: Plaintext64 = default_engine.create_plaintext_from(&input)?; + /// let mut h_input_ciphertext: LweCiphertext64 = + /// default_engine.encrypt_lwe_ciphertext(&h_input_key, &h_input_plaintext, noise)?; + /// // create a GLWE ciphertext containing an encryption of the LUT + /// let h_lut_plaintext_vector = default_engine.create_plaintext_vector_from(&lut)?; + /// let h_lut_key: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let mut h_lut: GlweCiphertext64 = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dim.to_glwe_size(), &h_lut_plaintext_vector)?; + /// // create a BSK + /// let h_bootstrap_key: LweBootstrapKey64 = default_engine.generate_new_lwe_bootstrap_key( + /// &h_input_key, + /// &h_lut_key, + /// dec_bl, + /// dec_lc, + /// noise, + /// )?; + /// // initialize an output LWE ciphertext vector + /// let h_dummy_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim_output)?; + /// + /// // 2. cuda engine + /// let mut cuda_engine = CudaEngine::new(())?; + /// // convert input to GPU 0 + /// let d_input_ciphertext: CudaLweCiphertext64 = + /// cuda_engine.convert_lwe_ciphertext(&h_input_ciphertext)?; + /// // convert accumulator to GPU + /// let d_input_lut: CudaGlweCiphertext64 = cuda_engine.convert_glwe_ciphertext(&h_lut)?; + /// // convert BSK to GPU (and from Standard to Fourier representations) + /// let d_fourier_bsk: CudaFourierLweBootstrapKey64 = + /// cuda_engine.convert_lwe_bootstrap_key(&h_bootstrap_key)?; + /// // launch bootstrap on GPU + /// let h_zero_output_ciphertext: LweCiphertext64 = + /// default_engine.zero_encrypt_lwe_ciphertext(&h_dummy_key, noise)?; + /// let mut d_output_ciphertext: CudaLweCiphertext64 = + /// cuda_engine.convert_lwe_ciphertext(&h_zero_output_ciphertext)?; + /// cuda_engine.discard_bootstrap_lwe_ciphertext( + /// &mut d_output_ciphertext, + /// &d_input_ciphertext, + /// &d_input_lut, + /// &d_fourier_bsk, + /// )?; + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut CudaLweCiphertext64, + input: &CudaLweCiphertext64, + acc: &CudaGlweCiphertext64, + bsk: &CudaFourierLweBootstrapKey64, + ) -> Result<(), LweCiphertextDiscardingBootstrapError> { + LweCiphertextDiscardingBootstrapError::perform_generic_checks(output, input, acc, bsk)?; + let poly_size = bsk.polynomial_size(); + check_poly_size!(poly_size); + let glwe_dim = bsk.glwe_dimension(); + check_glwe_dim!(glwe_dim); + let base_log = bsk.decomposition_base_log(); + check_base_log!(base_log); + unsafe { self.discard_bootstrap_lwe_ciphertext_unchecked(output, input, acc, bsk) }; + Ok(()) + } + + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut CudaLweCiphertext64, + input: &CudaLweCiphertext64, + acc: &CudaGlweCiphertext64, + bsk: &CudaFourierLweBootstrapKey64, + ) { + let stream = self.streams.first().unwrap(); + let mut test_vector_indexes = stream.malloc::(1); + stream.copy_to_gpu(&mut test_vector_indexes, &[0]); + + stream.discard_bootstrap_low_latency_lwe_ciphertext_vector::( + &mut output.0.d_vec, + &acc.0.d_vec, + &test_vector_indexes, + &input.0.d_vec, + bsk.0.d_vecs.first().unwrap(), + input.0.lwe_dimension, + bsk.glwe_dimension(), + bsk.polynomial_size(), + bsk.decomposition_base_log(), + bsk.decomposition_level_count(), + NumberOfSamples(1), + LweCiphertextIndex(0), + self.get_cuda_shared_memory(), + ); + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_conversion.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_conversion.rs new file mode 100644 index 000000000..c0844b1cd --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_conversion.rs @@ -0,0 +1,278 @@ +use crate::core_crypto::backends::cuda::implementation::engines::{CudaEngine, CudaError}; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaLweCiphertext32, CudaLweCiphertext64, +}; +use crate::core_crypto::commons::math::tensor::{AsMutSlice, AsRefSlice}; +use crate::core_crypto::prelude::{ + LweCiphertext32, LweCiphertextMutView32, LweCiphertextMutView64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingConversionEngine, LweCiphertextDiscardingConversionError, +}; + +/// # Description +/// +/// Convert an LWE ciphertext with 32 bits of precision from GPU 0 to a view on the CPU. +impl LweCiphertextDiscardingConversionEngine> + for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// use std::borrow::BorrowMut; + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 25 bits) + /// let input = 3_u32 << 25; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// // Prepares the output container + /// let mut h_raw_output_ciphertext = vec![0_u32; lwe_dimension.0 + 1]; + /// let mut h_output_view_ciphertext: LweCiphertextMutView32 = + /// default_engine.create_lwe_ciphertext_from(h_raw_output_ciphertext.as_mut_slice())?; + /// + /// cuda_engine.discard_convert_lwe_ciphertext(&mut h_output_view_ciphertext, &d_ciphertext)?; + /// + /// assert_eq!(h_output_view_ciphertext.lwe_dimension(), lwe_dimension); + /// // Extracts the internal container + /// let h_raw_input_ciphertext: Vec = + /// default_engine.consume_retrieve_lwe_ciphertext(h_ciphertext)?; + /// let h_raw_output_ciphertext: &[u32] = + /// default_engine.consume_retrieve_lwe_ciphertext(h_output_view_ciphertext)?; + /// assert_eq!(h_raw_input_ciphertext.as_slice(), h_raw_output_ciphertext); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_convert_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView32, + input: &CudaLweCiphertext32, + ) -> Result<(), LweCiphertextDiscardingConversionError> { + unsafe { self.discard_convert_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn discard_convert_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView32, + input: &CudaLweCiphertext32, + ) { + let stream = &self.streams[0]; + stream.copy_to_cpu::(output.0.tensor.as_mut_slice(), &input.0.d_vec); + } +} + +/// # Description +/// +/// Convert an LWE ciphertext with 32 bits of precision from GPU 0 to a ciphertext on the CPU. +impl LweCiphertextDiscardingConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// use std::borrow::BorrowMut; + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 25 bits) + /// let input = 3_u32 << 25; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// // Prepares the output container + /// let h_raw_output_ciphertext = vec![0_u32; lwe_dimension.0 + 1]; + /// let mut h_output_ciphertext: LweCiphertext32 = + /// default_engine.create_lwe_ciphertext_from(h_raw_output_ciphertext)?; + /// + /// cuda_engine.discard_convert_lwe_ciphertext(&mut h_output_ciphertext, &d_ciphertext)?; + /// + /// assert_eq!(h_output_ciphertext.lwe_dimension(), lwe_dimension); + /// // Extracts the internal container + /// let h_raw_input_ciphertext: Vec = + /// default_engine.consume_retrieve_lwe_ciphertext(h_ciphertext)?; + /// let h_raw_output_ciphertext: Vec = + /// default_engine.consume_retrieve_lwe_ciphertext(h_output_ciphertext)?; + /// assert_eq!(h_raw_input_ciphertext, h_raw_output_ciphertext); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_convert_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext32, + input: &CudaLweCiphertext32, + ) -> Result<(), LweCiphertextDiscardingConversionError> { + unsafe { self.discard_convert_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn discard_convert_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &CudaLweCiphertext32, + ) { + let stream = &self.streams[0]; + stream.copy_to_cpu::(output.0.tensor.as_mut_slice(), &input.0.d_vec); + } +} + +/// # Description +/// +/// Convert an LWE ciphertext with 32 bits of precision from CPU to a ciphertext on the GPU 0. +impl LweCiphertextDiscardingConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// use std::borrow::BorrowMut; + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 25 bits) + /// let input = 3_u32 << 25; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let mut d_ciphertext: CudaLweCiphertext32 = + /// cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// let h_ciphertext_out: LweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&d_ciphertext)?; + /// + /// assert_eq!(h_ciphertext, h_ciphertext_out); + /// + /// // Prepare input for discarding convert + /// let input_2 = 5_u32 << 25; + /// let h_plaintext_2: Plaintext32 = default_engine.create_plaintext_from(&input_2)?; + /// let mut h_ciphertext_2: LweCiphertext32 = default_engine + /// .trivially_encrypt_lwe_ciphertext(lwe_dimension.to_lwe_size(), &h_plaintext_2)?; + /// + /// cuda_engine.discard_convert_lwe_ciphertext(&mut d_ciphertext, &h_ciphertext_2)?; + /// + /// let h_ciphertext_out_2: LweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&d_ciphertext)?; + /// + /// assert_eq!(h_ciphertext_2, h_ciphertext_out_2); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_convert_lwe_ciphertext( + &mut self, + output: &mut CudaLweCiphertext32, + input: &LweCiphertext32, + ) -> Result<(), LweCiphertextDiscardingConversionError> { + unsafe { self.discard_convert_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn discard_convert_lwe_ciphertext_unchecked( + &mut self, + output: &mut CudaLweCiphertext32, + input: &LweCiphertext32, + ) { + let stream = &self.streams[0]; + stream.copy_to_gpu::(&mut output.0.d_vec, input.0.tensor.as_slice()); + } +} + +/// # Description +/// +/// Convert an LWE ciphertext with 64 bits of precision from GPU 0 to a view on the CPU. +impl LweCiphertextDiscardingConversionEngine> + for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// use std::borrow::BorrowMut; + /// let lwe_dimension = LweDimension(6); + /// // Here a hard-set encoding is applied (shift by 25 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let h_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let h_plaintext: Plaintext64 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext64 = + /// default_engine.encrypt_lwe_ciphertext(&h_key, &h_plaintext, noise)?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext64 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// + /// // Prepares the output container + /// let mut h_raw_output_ciphertext = vec![0_u64; lwe_dimension.0 + 1]; + /// let mut h_view_output_ciphertext: LweCiphertextMutView64 = + /// default_engine.create_lwe_ciphertext_from(h_raw_output_ciphertext.as_mut_slice())?; + /// + /// cuda_engine + /// .discard_convert_lwe_ciphertext(h_view_output_ciphertext.borrow_mut(), &d_ciphertext)?; + /// + /// assert_eq!(h_view_output_ciphertext.lwe_dimension(), lwe_dimension); + /// // Extracts the internal container + /// let h_raw_input_ciphertext: Vec = + /// default_engine.consume_retrieve_lwe_ciphertext(h_ciphertext)?; + /// let h_raw_output_ciphertext: &[u64] = + /// default_engine.consume_retrieve_lwe_ciphertext(h_view_output_ciphertext)?; + /// assert_eq!(h_raw_input_ciphertext, h_raw_output_ciphertext.to_vec()); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_convert_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView64, + input: &CudaLweCiphertext64, + ) -> Result<(), LweCiphertextDiscardingConversionError> { + unsafe { self.discard_convert_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn discard_convert_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView64, + input: &CudaLweCiphertext64, + ) { + let stream = &self.streams[0]; + stream.copy_to_cpu::(output.0.tensor.as_mut_slice(), &input.0.d_vec); + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_keyswitch.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_keyswitch.rs new file mode 100644 index 000000000..1e72bb511 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_ciphertext_discarding_keyswitch.rs @@ -0,0 +1,230 @@ +use crate::core_crypto::backends::cuda::engines::CudaError; +use crate::core_crypto::backends::cuda::implementation::engines::CudaEngine; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaLweCiphertext32, CudaLweCiphertext64, CudaLweKeyswitchKey32, CudaLweKeyswitchKey64, +}; +use crate::core_crypto::backends::cuda::private::device::NumberOfSamples; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingKeyswitchEngine, LweCiphertextDiscardingKeyswitchError, +}; +use crate::core_crypto::specification::entities::LweKeyswitchKeyEntity; + +impl From for LweCiphertextDiscardingKeyswitchError { + fn from(err: CudaError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// A discard keyswitch on a vector of input ciphertext vectors with 32 bits of precision. +impl + LweCiphertextDiscardingKeyswitchEngine< + CudaLweKeyswitchKey32, + CudaLweCiphertext32, + CudaLweCiphertext32, + > for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// // Generate two secret keys + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// // Generate keyswitch keys to switch between first_key and second_key + /// let h_ksk = default_engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// // Encrypt something + /// let h_plaintext: Plaintext32 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext32 = + /// default_engine.encrypt_lwe_ciphertext(&input_key, &h_plaintext, noise)?; + /// + /// // Copy to the GPU + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext32 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// let d_ksk: CudaLweKeyswitchKey32 = cuda_engine.convert_lwe_keyswitch_key(&h_ksk)?; + /// + /// // launch keyswitch on GPU + /// let h_dummy_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let h_zero_ciphertext: LweCiphertext32 = + /// default_engine.zero_encrypt_lwe_ciphertext(&h_dummy_key, noise)?; + /// + /// let mut d_keyswitched_ciphertext: CudaLweCiphertext32 = + /// cuda_engine.convert_lwe_ciphertext(&h_zero_ciphertext)?; + /// cuda_engine.discard_keyswitch_lwe_ciphertext( + /// &mut d_keyswitched_ciphertext, + /// &d_ciphertext, + /// &d_ksk, + /// )?; + /// + /// assert_eq!( + /// d_keyswitched_ciphertext.lwe_dimension(), + /// output_lwe_dimension + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut CudaLweCiphertext32, + input: &CudaLweCiphertext32, + ksk: &CudaLweKeyswitchKey32, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError> { + LweCiphertextDiscardingKeyswitchError::perform_generic_checks(output, input, ksk)?; + unsafe { self.discard_keyswitch_lwe_ciphertext_unchecked(output, input, ksk) }; + Ok(()) + } + + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut CudaLweCiphertext32, + input: &CudaLweCiphertext32, + ksk: &CudaLweKeyswitchKey32, + ) { + let stream = &self.streams[0]; + + stream.discard_keyswitch_lwe_ciphertext_vector::( + &mut output.0.d_vec, + &input.0.d_vec, + input.0.lwe_dimension, + output.0.lwe_dimension, + ksk.0.d_vecs.first().unwrap(), + ksk.decomposition_base_log(), + ksk.decomposition_level_count(), + NumberOfSamples(1), + ); + } +} + +/// # Description +/// A discard keyswitch on a vector of input ciphertext vectors with 64 bits of precision. +impl + LweCiphertextDiscardingKeyswitchEngine< + CudaLweKeyswitchKey64, + CudaLweCiphertext64, + CudaLweCiphertext64, + > for CudaEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// // Generate two secret keys + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// // Generate keyswitch keys to switch between first_key and second_key + /// let h_ksk = default_engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// // Encrypt something + /// let h_plaintext: Plaintext64 = default_engine.create_plaintext_from(&input)?; + /// let mut h_ciphertext: LweCiphertext64 = + /// default_engine.encrypt_lwe_ciphertext(&input_key, &h_plaintext, noise)?; + /// + /// // Copy to the GPU + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ciphertext: CudaLweCiphertext64 = cuda_engine.convert_lwe_ciphertext(&h_ciphertext)?; + /// let d_ksk: CudaLweKeyswitchKey64 = cuda_engine.convert_lwe_keyswitch_key(&h_ksk)?; + /// + /// // launch keyswitch on GPU + /// let h_dummy_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let h_zero_ciphertext: LweCiphertext64 = + /// default_engine.zero_encrypt_lwe_ciphertext(&h_dummy_key, noise)?; + /// + /// let mut d_keyswitched_ciphertext: CudaLweCiphertext64 = + /// cuda_engine.convert_lwe_ciphertext(&h_zero_ciphertext)?; + /// cuda_engine.discard_keyswitch_lwe_ciphertext( + /// &mut d_keyswitched_ciphertext, + /// &d_ciphertext, + /// &d_ksk, + /// )?; + /// + /// assert_eq!( + /// d_keyswitched_ciphertext.lwe_dimension(), + /// output_lwe_dimension + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut CudaLweCiphertext64, + input: &CudaLweCiphertext64, + ksk: &CudaLweKeyswitchKey64, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError> { + LweCiphertextDiscardingKeyswitchError::perform_generic_checks(output, input, ksk)?; + unsafe { self.discard_keyswitch_lwe_ciphertext_unchecked(output, input, ksk) }; + Ok(()) + } + + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut CudaLweCiphertext64, + input: &CudaLweCiphertext64, + ksk: &CudaLweKeyswitchKey64, + ) { + let stream = &self.streams[0]; + + stream.discard_keyswitch_lwe_ciphertext_vector::( + &mut output.0.d_vec, + &input.0.d_vec, + input.0.lwe_dimension, + output.0.lwe_dimension, + ksk.0.d_vecs.first().unwrap(), + ksk.decomposition_base_log(), + ksk.decomposition_level_count(), + NumberOfSamples(1), + ); + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_keyswitch_key_conversion.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_keyswitch_key_conversion.rs new file mode 100644 index 000000000..6bae61c75 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/lwe_keyswitch_key_conversion.rs @@ -0,0 +1,346 @@ +use crate::core_crypto::backends::cuda::engines::CudaError; +use crate::core_crypto::backends::cuda::implementation::engines::CudaEngine; +use crate::core_crypto::backends::cuda::implementation::entities::{ + CudaLweKeyswitchKey32, CudaLweKeyswitchKey64, +}; +use crate::core_crypto::backends::cuda::private::crypto::keyswitch::CudaLweKeyswitchKey; +use crate::core_crypto::commons::crypto::lwe::LweKeyswitchKey; +use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; +use crate::core_crypto::prelude::{LweKeyswitchKey32, LweKeyswitchKey64}; +use crate::core_crypto::specification::engines::{ + LweKeyswitchKeyConversionEngine, LweKeyswitchKeyConversionError, +}; +use crate::core_crypto::specification::entities::LweKeyswitchKeyEntity; + +impl From for LweKeyswitchKeyConversionError { + fn from(err: CudaError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// Convert an LWE keyswitch key corresponding to 32 bits of precision from the CPU to the GPU. +/// We only support the conversion from CPU to GPU: the conversion from GPU to CPU is not +/// necessary at this stage to support the keyswitch. The keyswitch key is copied entirely to all +/// the GPUs. +impl LweKeyswitchKeyConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::backends::cuda::private::device::GpuIndex; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let ksk = default_engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ksk: CudaLweKeyswitchKey32 = cuda_engine.convert_lwe_keyswitch_key(&ksk)?; + /// + /// assert_eq!(d_ksk.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!(d_ksk.output_lwe_dimension(), output_lwe_dimension); + /// assert_eq!(d_ksk.decomposition_level_count(), decomposition_level_count); + /// assert_eq!(d_ksk.decomposition_base_log(), decomposition_base_log); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_keyswitch_key( + &mut self, + input: &LweKeyswitchKey32, + ) -> Result> { + for gpu_index in 0..self.get_number_of_gpus().0 { + let stream = &self.streams[gpu_index]; + let data_per_gpu = input.decomposition_level_count().0 + * (input.output_lwe_dimension().0 + 1) + * input.input_lwe_dimension().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + } + Ok(unsafe { self.convert_lwe_keyswitch_key_unchecked(input) }) + } + + unsafe fn convert_lwe_keyswitch_key_unchecked( + &mut self, + input: &LweKeyswitchKey32, + ) -> CudaLweKeyswitchKey32 { + // Copy the entire input vector over all GPUs + let mut d_vecs = Vec::with_capacity(self.get_number_of_gpus().0); + + let data_per_gpu = input.decomposition_level_count().0 + * input.output_lwe_dimension().to_lwe_size().0 + * input.input_lwe_dimension().0; + for stream in self.streams.iter() { + let mut d_vec = stream.malloc::(data_per_gpu as u32); + stream.copy_to_gpu(&mut d_vec, input.0.as_tensor().as_slice()); + d_vecs.push(d_vec); + } + CudaLweKeyswitchKey32(CudaLweKeyswitchKey:: { + d_vecs, + input_lwe_dimension: input.input_lwe_dimension(), + output_lwe_dimension: input.output_lwe_dimension(), + decomp_level: input.decomposition_level_count(), + decomp_base_log: input.decomposition_base_log(), + }) + } +} + +/// # Description +/// Convert an LWE keyswitch key corresponding to 32 bits of precision from the GPU to the CPU. +/// We assume consistency between all the available GPUs and simply copy what is in the one with +/// index 0. +impl LweKeyswitchKeyConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::backends::cuda::private::device::GpuIndex; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let h_ksk = default_engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ksk: CudaLweKeyswitchKey32 = cuda_engine.convert_lwe_keyswitch_key(&h_ksk)?; + /// let h_output_ksk: LweKeyswitchKey32 = cuda_engine.convert_lwe_keyswitch_key(&d_ksk)?; + /// + /// assert_eq!(d_ksk.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!(d_ksk.output_lwe_dimension(), output_lwe_dimension); + /// assert_eq!(d_ksk.decomposition_level_count(), decomposition_level_count); + /// assert_eq!(d_ksk.decomposition_base_log(), decomposition_base_log); + /// assert_eq!(h_output_ksk, h_ksk); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_keyswitch_key( + &mut self, + input: &CudaLweKeyswitchKey32, + ) -> Result> { + Ok(unsafe { self.convert_lwe_keyswitch_key_unchecked(input) }) + } + + unsafe fn convert_lwe_keyswitch_key_unchecked( + &mut self, + input: &CudaLweKeyswitchKey32, + ) -> LweKeyswitchKey32 { + let data_per_gpu = input.decomposition_level_count().0 + * input.output_lwe_dimension().to_lwe_size().0 + * input.input_lwe_dimension().0; + + // Copy the data from GPU 0 back to the CPU + let mut output = vec![0u32; data_per_gpu]; + let stream = self.streams.first().unwrap(); + stream.copy_to_cpu::(&mut output, input.0.d_vecs.first().unwrap()); + + LweKeyswitchKey32(LweKeyswitchKey::from_container( + output, + input.decomposition_base_log(), + input.decomposition_level_count(), + input.output_lwe_dimension(), + )) + } +} + +/// # Description +/// Convert an LWE keyswitch key corresponding to 64 bits of precision from the CPU to the GPU. +/// We only support the conversion from CPU to GPU: the conversion from GPU to CPU is not +/// necessary at this stage to support the keyswitch. The keyswitch key is copied entirely to all +/// the GPUs. +impl LweKeyswitchKeyConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::backends::cuda::private::device::GpuIndex; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let ksk = default_engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ksk: CudaLweKeyswitchKey64 = cuda_engine.convert_lwe_keyswitch_key(&ksk)?; + /// + /// assert_eq!(d_ksk.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!(d_ksk.output_lwe_dimension(), output_lwe_dimension); + /// assert_eq!(d_ksk.decomposition_level_count(), decomposition_level_count); + /// assert_eq!(d_ksk.decomposition_base_log(), decomposition_base_log); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_keyswitch_key( + &mut self, + input: &LweKeyswitchKey64, + ) -> Result> { + for stream in self.streams.iter() { + let data_per_gpu = input.decomposition_level_count().0 + * input.output_lwe_dimension().to_lwe_size().0 + * input.input_lwe_dimension().0; + let size = data_per_gpu as u64 * std::mem::size_of::() as u64; + stream.check_device_memory(size)?; + } + Ok(unsafe { self.convert_lwe_keyswitch_key_unchecked(input) }) + } + + unsafe fn convert_lwe_keyswitch_key_unchecked( + &mut self, + input: &LweKeyswitchKey64, + ) -> CudaLweKeyswitchKey64 { + // Copy the entire input vector over all GPUs + let mut d_vecs = Vec::with_capacity(self.get_number_of_gpus().0); + + let data_per_gpu = input.decomposition_level_count().0 + * input.output_lwe_dimension().to_lwe_size().0 + * input.input_lwe_dimension().0; + for stream in self.streams.iter() { + let mut d_vec = stream.malloc::(data_per_gpu as u32); + stream.copy_to_gpu(&mut d_vec, input.0.as_tensor().as_slice()); + d_vecs.push(d_vec); + } + CudaLweKeyswitchKey64(CudaLweKeyswitchKey:: { + d_vecs, + input_lwe_dimension: input.input_lwe_dimension(), + output_lwe_dimension: input.output_lwe_dimension(), + decomp_level: input.decomposition_level_count(), + decomp_base_log: input.decomposition_base_log(), + }) + } +} + +impl LweKeyswitchKeyConversionEngine for CudaEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::backends::cuda::private::device::GpuIndex; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let h_ksk = default_engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut cuda_engine = CudaEngine::new(())?; + /// let d_ksk: CudaLweKeyswitchKey64 = cuda_engine.convert_lwe_keyswitch_key(&h_ksk)?; + /// let h_output_ksk: LweKeyswitchKey64 = cuda_engine.convert_lwe_keyswitch_key(&d_ksk)?; + /// + /// assert_eq!(d_ksk.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!(d_ksk.output_lwe_dimension(), output_lwe_dimension); + /// assert_eq!(d_ksk.decomposition_level_count(), decomposition_level_count); + /// assert_eq!(d_ksk.decomposition_base_log(), decomposition_base_log); + /// assert_eq!(h_output_ksk, h_ksk); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_keyswitch_key( + &mut self, + input: &CudaLweKeyswitchKey64, + ) -> Result> { + Ok(unsafe { self.convert_lwe_keyswitch_key_unchecked(input) }) + } + + unsafe fn convert_lwe_keyswitch_key_unchecked( + &mut self, + input: &CudaLweKeyswitchKey64, + ) -> LweKeyswitchKey64 { + let data_per_gpu = input.decomposition_level_count().0 + * input.output_lwe_dimension().to_lwe_size().0 + * input.input_lwe_dimension().0; + + // Copy the data from GPU 0 back to the CPU + let mut output = vec![0u64; data_per_gpu]; + let stream = self.streams.first().unwrap(); + stream.copy_to_cpu::(&mut output, input.0.d_vecs.first().unwrap()); + + LweKeyswitchKey64(LweKeyswitchKey::from_container( + output, + input.decomposition_base_log(), + input.decomposition_level_count(), + input.output_lwe_dimension(), + )) + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/mod.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/mod.rs new file mode 100644 index 000000000..8424b06d3 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/cuda_engine/mod.rs @@ -0,0 +1,77 @@ +use crate::core_crypto::backends::cuda::private::device::{CudaStream, GpuIndex, NumberOfGpus}; +use crate::core_crypto::prelude::sealed::AbstractEngineSeal; +use crate::core_crypto::prelude::{AbstractEngine, CudaError, SharedMemoryAmount}; +use concrete_cuda::cuda_bind::cuda_get_number_of_gpus; +/// The main engine exposed by the cuda backend. +/// +/// This engine handles single-GPU and multi-GPU computations for the user. It always associates +/// one Cuda stream to each available Nvidia GPU, and splits the input ciphertexts evenly over +/// the GPUs (the last GPU may be a bit more loaded if the number of GPUs does not divide the +/// number of input ciphertexts). This engine does not give control over the streams, nor the GPU +/// load balancing. In this way, we can overlap computations done on different GPUs, but not +/// computations done on a given GPU, which are executed in a sequence. +// A finer access to streams could allow for more overlapping of computations +// on a given device. We'll probably want to support it in the future, in an AdvancedCudaEngine +// for example. +#[derive(Debug, Clone)] +pub struct CudaEngine { + streams: Vec, + max_shared_memory: usize, +} + +impl AbstractEngineSeal for CudaEngine {} + +impl AbstractEngine for CudaEngine { + type EngineError = CudaError; + + type Parameters = (); + + fn new(_parameters: Self::Parameters) -> Result { + let number_of_gpus = unsafe { cuda_get_number_of_gpus() as usize }; + if number_of_gpus == 0 { + Err(CudaError::DeviceNotFound) + } else { + let mut streams: Vec = Vec::new(); + for gpu_index in 0..number_of_gpus { + streams.push(CudaStream::new(GpuIndex(gpu_index))?); + } + let max_shared_memory = streams[0].get_max_shared_memory()?; + + Ok(CudaEngine { + streams, + max_shared_memory: max_shared_memory as usize, + }) + } + } +} + +impl CudaEngine { + /// Get the number of available GPUs from the engine + pub fn get_number_of_gpus(&self) -> NumberOfGpus { + NumberOfGpus(self.streams.len()) + } + /// Get the Cuda streams from the engine + pub fn get_cuda_streams(&self) -> &Vec { + &self.streams + } + /// Get the size of the shared memory (on device 0) + pub fn get_cuda_shared_memory(&self) -> SharedMemoryAmount { + SharedMemoryAmount(self.max_shared_memory) + } +} + +macro_rules! check_poly_size { + ($poly_size: ident) => { + if $poly_size.0 != 512 && $poly_size.0 != 1024 && $poly_size.0 != 2048 { + return Err(CudaError::PolynomialSizeNotSupported.into()); + } + }; +} + +mod glwe_ciphertext_conversion; +mod lwe_bootstrap_key_conversion; +mod lwe_ciphertext_conversion; +mod lwe_ciphertext_discarding_bootstrap; +mod lwe_ciphertext_discarding_conversion; +mod lwe_ciphertext_discarding_keyswitch; +mod lwe_keyswitch_key_conversion; diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/engines/mod.rs b/tfhe/src/core_crypto/backends/cuda/implementation/engines/mod.rs new file mode 100644 index 000000000..00af9e29a --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/engines/mod.rs @@ -0,0 +1,87 @@ +//! A module containing the [engines](crate::core_crypto::specification::engines) exposed by +//! the cuda backend. + +use crate::core_crypto::backends::cuda::private::device::GpuIndex; + +use std::error::Error; +use std::fmt::{Display, Formatter}; + +mod cuda_engine; +pub use cuda_engine::*; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct SharedMemoryAmount(pub usize); + +#[derive(Debug)] +pub enum CudaError { + DeviceNotFound, + SharedMemoryNotFound(GpuIndex), + NotEnoughDeviceMemory(GpuIndex), + InvalidDeviceIndex(GpuIndex), + UnspecifiedDeviceError(GpuIndex), + PolynomialSizeNotSupported, + GlweDimensionNotSupported, + BaseLogNotSupported, +} +impl Display for CudaError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CudaError::DeviceNotFound => { + write!(f, "No GPU detected on the machine.") + } + CudaError::SharedMemoryNotFound(gpu_index) => { + write!(f, "No shared memory detected on the GPU #{}.", gpu_index.0) + } + CudaError::NotEnoughDeviceMemory(gpu_index) => { + write!( + f, + "The GPU #{} does not have enough global memory to hold all the data.", + gpu_index.0 + ) + } + CudaError::InvalidDeviceIndex(gpu_index) => { + write!( + f, + "The specified GPU index, {}, does not exist.", + gpu_index.0 + ) + } + CudaError::PolynomialSizeNotSupported => { + write!( + f, + "The polynomial size should be a power of 2. Values strictly lower than \ + 512, and strictly greater than 8192, are not supported." + ) + } + CudaError::GlweDimensionNotSupported => { + write!(f, "The only supported GLWE dimension is 1.") + } + CudaError::BaseLogNotSupported => { + write!(f, "The base log has to be lower or equal to 16.") + } + CudaError::UnspecifiedDeviceError(gpu_index) => { + write!(f, "Unspecified device error on GPU #{}.", gpu_index.0) + } + } + } +} + +impl Error for CudaError {} + +macro_rules! check_glwe_dim { + ($glwe_dimension: ident) => { + if $glwe_dimension.0 != 1 { + return Err(CudaError::GlweDimensionNotSupported.into()); + } + }; +} + +macro_rules! check_base_log { + ($base_log: ident) => { + if $base_log.0 > 16 { + return Err(CudaError::BaseLogNotSupported.into()); + } + }; +} + +pub(crate) use {check_base_log, check_glwe_dim}; diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/entities/glwe_ciphertext.rs b/tfhe/src/core_crypto/backends/cuda/implementation/entities/glwe_ciphertext.rs new file mode 100644 index 000000000..1856cc15f --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/entities/glwe_ciphertext.rs @@ -0,0 +1,45 @@ +use std::fmt::Debug; + +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; + +use crate::core_crypto::backends::cuda::private::crypto::glwe::ciphertext::CudaGlweCiphertext; +use crate::core_crypto::specification::entities::markers::GlweCiphertextKind; +use crate::core_crypto::specification::entities::{AbstractEntity, GlweCiphertextEntity}; + +/// A structure representing a vector of GLWE ciphertexts with 32 bits of precision on the GPU. +/// It is used as input to the Cuda bootstrap for the array of lookup tables. +#[derive(Debug)] +pub struct CudaGlweCiphertext32(pub(crate) CudaGlweCiphertext); + +impl AbstractEntity for CudaGlweCiphertext32 { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for CudaGlweCiphertext32 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_dimension + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size + } +} + +/// A structure representing a vector of GLWE ciphertexts with 64 bits of precision on the GPU. +/// It is used as input to the Cuda bootstrap for the array of lookup tables. +#[derive(Debug)] +pub struct CudaGlweCiphertext64(pub(crate) CudaGlweCiphertext); + +impl AbstractEntity for CudaGlweCiphertext64 { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for CudaGlweCiphertext64 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_dimension + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_bootstrap_key.rs b/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_bootstrap_key.rs new file mode 100644 index 000000000..c84f2a92d --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_bootstrap_key.rs @@ -0,0 +1,67 @@ +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, +}; + +use crate::core_crypto::backends::cuda::private::crypto::bootstrap::CudaBootstrapKey; +use crate::core_crypto::specification::entities::markers::LweBootstrapKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweBootstrapKeyEntity}; + +/// A structure representing a Fourier bootstrap key for 32 bits precision ciphertexts on the GPU. +#[derive(Debug)] +pub struct CudaFourierLweBootstrapKey32(pub(crate) CudaBootstrapKey); + +impl AbstractEntity for CudaFourierLweBootstrapKey32 { + type Kind = LweBootstrapKeyKind; +} + +impl LweBootstrapKeyEntity for CudaFourierLweBootstrapKey32 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_dimension + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.input_lwe_dimension + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomp_base_log + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomp_level + } +} + +/// A structure representing a Fourier bootstrap key for 64 bits precision ciphertexts on the GPU. +#[derive(Debug)] +pub struct CudaFourierLweBootstrapKey64(pub(crate) CudaBootstrapKey); + +impl AbstractEntity for CudaFourierLweBootstrapKey64 { + type Kind = LweBootstrapKeyKind; +} + +impl LweBootstrapKeyEntity for CudaFourierLweBootstrapKey64 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_dimension + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.input_lwe_dimension + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomp_base_log + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomp_level + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_ciphertext.rs b/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_ciphertext.rs new file mode 100644 index 000000000..f6a7b4c7a --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_ciphertext.rs @@ -0,0 +1,35 @@ +use std::fmt::Debug; + +use crate::core_crypto::prelude::LweDimension; + +use crate::core_crypto::backends::cuda::private::crypto::lwe::ciphertext::CudaLweCiphertext; +use crate::core_crypto::specification::entities::markers::LweCiphertextKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweCiphertextEntity}; + +/// A structure representing a vector of LWE ciphertexts with 32 bits of precision on the GPU. +#[derive(Debug)] +pub struct CudaLweCiphertext32(pub(crate) CudaLweCiphertext); + +impl AbstractEntity for CudaLweCiphertext32 { + type Kind = LweCiphertextKind; +} + +impl LweCiphertextEntity for CudaLweCiphertext32 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_dimension + } +} + +/// A structure representing a vector of LWE ciphertexts with 64 bits of precision on the GPU. +#[derive(Debug)] +pub struct CudaLweCiphertext64(pub(crate) CudaLweCiphertext); + +impl AbstractEntity for CudaLweCiphertext64 { + type Kind = LweCiphertextKind; +} + +impl LweCiphertextEntity for CudaLweCiphertext64 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_dimension + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_keyswitch_key.rs b/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_keyswitch_key.rs new file mode 100644 index 000000000..6ac98fceb --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/entities/lwe_keyswitch_key.rs @@ -0,0 +1,57 @@ +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + +use crate::core_crypto::backends::cuda::private::crypto::keyswitch::CudaLweKeyswitchKey; +use crate::core_crypto::specification::entities::markers::LweKeyswitchKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweKeyswitchKeyEntity}; + +/// A structure representing a keyswitch key for 32 bits precision ciphertexts on the GPU. +#[derive(Debug)] +pub struct CudaLweKeyswitchKey32(pub(crate) CudaLweKeyswitchKey); + +impl AbstractEntity for CudaLweKeyswitchKey32 { + type Kind = LweKeyswitchKeyKind; +} + +impl LweKeyswitchKeyEntity for CudaLweKeyswitchKey32 { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.input_lwe_dimension + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.output_lwe_dimension + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomp_level + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomp_base_log + } +} + +/// A structure representing a keyswitch key for 64 bits precision ciphertexts on the GPU. +#[derive(Debug)] +pub struct CudaLweKeyswitchKey64(pub(crate) CudaLweKeyswitchKey); + +impl AbstractEntity for CudaLweKeyswitchKey64 { + type Kind = LweKeyswitchKeyKind; +} + +impl LweKeyswitchKeyEntity for CudaLweKeyswitchKey64 { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.input_lwe_dimension + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.output_lwe_dimension + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomp_level + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomp_base_log + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/entities/mod.rs b/tfhe/src/core_crypto/backends/cuda/implementation/entities/mod.rs new file mode 100644 index 000000000..ba83c016b --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/entities/mod.rs @@ -0,0 +1,12 @@ +//! A module containing all the [entities](crate::core_crypto::specification::entities) +//! exposed by the cuda backend. + +mod glwe_ciphertext; +mod lwe_bootstrap_key; +mod lwe_ciphertext; +mod lwe_keyswitch_key; + +pub use glwe_ciphertext::*; +pub use lwe_bootstrap_key::*; +pub use lwe_ciphertext::*; +pub use lwe_keyswitch_key::*; diff --git a/tfhe/src/core_crypto/backends/cuda/implementation/mod.rs b/tfhe/src/core_crypto/backends/cuda/implementation/mod.rs new file mode 100644 index 000000000..49169443f --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/implementation/mod.rs @@ -0,0 +1,2 @@ +pub mod engines; +pub mod entities; diff --git a/tfhe/src/core_crypto/backends/cuda/mod.rs b/tfhe/src/core_crypto/backends/cuda/mod.rs new file mode 100644 index 000000000..c4ec7b210 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/mod.rs @@ -0,0 +1,13 @@ +//! A module containing the cuda backend implementation. +//! +//! This module contains CUDA GPU implementations of some FHE cryptographic primitives. In +//! particular, it makes it possible to execute bootstraps on a vector of ciphertext vectors, with a +//! vector of LUT and a bootstrap key as other inputs. To do so, the backend also exposes functions +//! to transfer data to and from the GPU, via conversion functions. + +#[doc(hidden)] +pub mod private; + +pub(crate) mod implementation; + +pub use implementation::{engines, entities}; diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/bootstrap/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/bootstrap/mod.rs new file mode 100644 index 000000000..88093096d --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/bootstrap/mod.rs @@ -0,0 +1,67 @@ +//! Bootstrap key with Cuda. +use crate::core_crypto::backends::cuda::private::device::{CudaStream, NumberOfGpus}; +use crate::core_crypto::backends::cuda::private::vec::CudaVec; +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; +use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, +}; +use std::marker::PhantomData; + +#[derive(Debug)] +pub(crate) struct CudaBootstrapKey { + // Pointers to GPU data: one cuda vec per GPU + pub(crate) d_vecs: Vec>, + // Input LWE dimension + pub(crate) input_lwe_dimension: LweDimension, + // Size of polynomials in the key + pub(crate) polynomial_size: PolynomialSize, + // GLWE dimension + pub(crate) glwe_dimension: GlweDimension, + // Number of decomposition levels + pub(crate) decomp_level: DecompositionLevelCount, + // Value of the base log for the decomposition + pub(crate) decomp_base_log: DecompositionBaseLog, + // Field to hold type T + pub(crate) _phantom: PhantomData, +} + +unsafe impl Send for CudaBootstrapKey where T: Send + UnsignedInteger {} +unsafe impl Sync for CudaBootstrapKey where T: Sync + UnsignedInteger {} + +pub(crate) unsafe fn convert_lwe_bootstrap_key_from_cpu_to_gpu( + streams: &[CudaStream], + input: &StandardBootstrapKey, + number_of_gpus: NumberOfGpus, +) -> Vec> +where + Cont: AsRefSlice, +{ + // Copy the entire input vector over all GPUs + let mut vecs = Vec::with_capacity(number_of_gpus.0); + // TODO + // Check if it would be better to have GPU 0 compute the BSK and copy it back to the + // CPU, then copy the BSK to the other GPUs. The order of instructions varies depending on + // the Cuda warp scheduling, which we cannot assume is deterministic, so we'll end up with + // slightly different BSKs on the GPUs. It is unclear how significantly this affects the + // noise after the bootstrap. + let total_polynomials = + input.key_size().0 * input.glwe_size().0 * input.glwe_size().0 * input.level_count().0; + let alloc_size = total_polynomials * input.polynomial_size().0; + for stream in streams.iter() { + let mut d_vec = stream.malloc::(alloc_size as u32); + let input_slice = input.as_tensor().as_slice(); + stream.initialize_twiddles(input.polynomial_size()); + stream.convert_lwe_bootstrap_key::( + &mut d_vec, + input_slice, + input.key_size(), + input.glwe_size().to_glwe_dimension(), + input.level_count(), + input.polynomial_size(), + ); + vecs.push(d_vec); + } + vecs +} diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/glwe/ciphertext.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/glwe/ciphertext.rs new file mode 100644 index 000000000..971ae1361 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/glwe/ciphertext.rs @@ -0,0 +1,18 @@ +use crate::core_crypto::backends::cuda::private::vec::CudaVec; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; + +/// One GLWE ciphertext on GPU 0. +/// +/// There is no multi GPU support at this stage since the user cannot +/// specify on which GPU to convert the data. +// Fields with `d_` are data in the GPU +#[derive(Debug)] +pub(crate) struct CudaGlweCiphertext { + // Pointer to GPU data: one cuda vec on GPU 0 + pub(crate) d_vec: CudaVec, + // Glwe dimension + pub(crate) glwe_dimension: GlweDimension, + // Polynomial size + pub(crate) polynomial_size: PolynomialSize, +} diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/glwe/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/glwe/mod.rs new file mode 100644 index 000000000..43c710a72 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/glwe/mod.rs @@ -0,0 +1,3 @@ +//! GLWE ciphertexts and ciphertext vectors with Cuda. + +pub(crate) mod ciphertext; diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/keyswitch/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/keyswitch/mod.rs new file mode 100644 index 000000000..609e9518c --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/keyswitch/mod.rs @@ -0,0 +1,21 @@ +//! Keyswitch key with Cuda. +use crate::core_crypto::backends::cuda::private::vec::CudaVec; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + +#[derive(Debug)] +pub(crate) struct CudaLweKeyswitchKey { + // Pointers to GPU data: one cuda vec per GPU + pub(crate) d_vecs: Vec>, + // Input LWE dimension + pub(crate) input_lwe_dimension: LweDimension, + // Output LWE dimension + pub(crate) output_lwe_dimension: LweDimension, + // Number of decomposition levels + pub(crate) decomp_level: DecompositionLevelCount, + // Value of the base log for the decomposition + pub(crate) decomp_base_log: DecompositionBaseLog, +} + +unsafe impl Send for CudaLweKeyswitchKey where T: Send + UnsignedInteger {} +unsafe impl Sync for CudaLweKeyswitchKey where T: Sync + UnsignedInteger {} diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/lwe/ciphertext.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/lwe/ciphertext.rs new file mode 100644 index 000000000..58e039be9 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/lwe/ciphertext.rs @@ -0,0 +1,17 @@ +use crate::core_crypto::backends::cuda::private::vec::CudaVec; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::LweDimension; + +/// An LWE ciphertext on the GPU 0. +/// +/// There is no multi GPU support at this stage since the user cannot +/// specify on which GPU to convert the data. + +// Fields with `d_` are data in the GPU +#[derive(Debug)] +pub(crate) struct CudaLweCiphertext { + // Pointers to GPU data: one cuda vec on GPU 0 + pub(crate) d_vec: CudaVec, + // Lwe dimension + pub(crate) lwe_dimension: LweDimension, +} diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/lwe/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/lwe/mod.rs new file mode 100644 index 000000000..a426d083d --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/lwe/mod.rs @@ -0,0 +1,3 @@ +//! LWE ciphertexts and ciphertext vectors with Cuda. + +pub(crate) mod ciphertext; diff --git a/tfhe/src/core_crypto/backends/cuda/private/crypto/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/crypto/mod.rs new file mode 100644 index 000000000..f3f632af5 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/crypto/mod.rs @@ -0,0 +1,8 @@ +//! Low-overhead homomorphic primitives. +//! +//! This module implements low-overhead fully homomorphic operations. + +pub mod bootstrap; +pub mod glwe; +pub mod keyswitch; +pub mod lwe; diff --git a/tfhe/src/core_crypto/backends/cuda/private/device.rs b/tfhe/src/core_crypto/backends/cuda/private/device.rs new file mode 100644 index 000000000..18d56e5a6 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/device.rs @@ -0,0 +1,398 @@ +use crate::core_crypto::backends::cuda::engines::CudaError; +use crate::core_crypto::backends::cuda::private::pointers::StreamPointer; +use crate::core_crypto::backends::cuda::private::vec::CudaVec; +use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger}; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweCiphertextIndex, LweDimension, + PolynomialSize, SharedMemoryAmount, +}; +use concrete_cuda::cuda_bind::*; +use std::ffi::c_void; +use std::marker::PhantomData; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct GpuIndex(pub usize); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct NumberOfSamples(pub usize); + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct NumberOfGpus(pub usize); + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct CudaStream { + gpu_index: GpuIndex, + stream: StreamPointer, +} + +impl CudaStream { + /// Creates a new stream attached to GPU at gpu_index + pub(crate) fn new(gpu_index: GpuIndex) -> Result { + if gpu_index.0 >= unsafe { cuda_get_number_of_gpus() } as usize { + Err(CudaError::InvalidDeviceIndex(gpu_index)) + } else { + let stream = StreamPointer(unsafe { cuda_create_stream(gpu_index.0 as u32) }); + Ok(CudaStream { gpu_index, stream }) + } + } + + /// Gets the GPU index the stream is associated to + pub(crate) fn gpu_index(&self) -> GpuIndex { + self.gpu_index + } + + /// Gets the stream handle + pub(crate) fn stream_handle(&self) -> StreamPointer { + self.stream + } + + /// Check that the GPU has enough global memory + pub(crate) fn check_device_memory(&self, size: u64) -> Result<(), CudaError> { + let valid = unsafe { cuda_check_valid_malloc(size, self.gpu_index().0 as u32) }; + match valid { + 0 => Ok(()), + -1 => Err(CudaError::NotEnoughDeviceMemory(self.gpu_index())), + -2 => Err(CudaError::InvalidDeviceIndex(self.gpu_index())), + _ => Err(CudaError::UnspecifiedDeviceError(self.gpu_index())), + } + } + + /// Allocates `elements` on the GPU + pub(crate) fn malloc(&self, elements: u32) -> CudaVec + where + T: Numeric, + { + let size = elements as u64 * std::mem::size_of::() as u64; + let ptr = unsafe { cuda_malloc(size, self.gpu_index().0 as u32) }; + CudaVec { + ptr, + idx: self.gpu_index.0 as u32, + len: elements as usize, + _phantom: PhantomData::default(), + } + } + + /// Copies data from slice into GPU pointer + /// + /// # Safety + /// + /// - `dest` __must__ be a valid pointer + /// - [CudaStream::cuda_synchronize_device] __must__ have been called before + /// - [CudaStream::cuda_synchronize_device] __must__ be called after the copy + /// as soon as synchronization is required + pub(crate) unsafe fn copy_to_gpu_async(&self, dest: &mut CudaVec, src: &[T]) + where + T: Numeric, + { + let size = (src.len() * std::mem::size_of::()) as u64; + cuda_memcpy_async_to_gpu( + dest.as_mut_c_ptr(), + src.as_ptr() as *const c_void, + size, + self.stream_handle().0, + self.gpu_index().0 as u32, + ); + } + + /// Copies data from slice into GPU pointer + /// + /// # Safety + /// + /// - `dest` __must__ be a valid pointer + /// - [CudaStream::cuda_synchronize_device] __must__ have been called before + pub(crate) unsafe fn copy_to_gpu(&self, dest: &mut CudaVec, src: &[T]) + where + T: Numeric, + { + self.copy_to_gpu_async(dest, src); + self.synchronize_device(); + } + + /// Copies data from GPU pointer into slice + /// + /// # Safety + /// + /// - `dest` __must__ be a valid pointer + /// - [CudaStream::cuda_synchronize_device] __must__ have been called before + /// - [CudaStream::cuda_synchronize_device] __must__ be called as soon as synchronization is + /// required + pub(crate) unsafe fn copy_to_cpu_async(&self, dest: &mut [T], src: &CudaVec) + where + T: Numeric, + { + let size = (dest.len() * std::mem::size_of::()) as u64; + cuda_memcpy_async_to_cpu( + dest.as_mut_ptr() as *mut c_void, + src.as_c_ptr(), + size, + self.stream_handle().0, + self.gpu_index().0 as u32, + ); + } + + /// Copies data from GPU pointer into slice + /// + /// # Safety + /// + /// - `dest` __must__ be a valid pointer + /// - [CudaStream::cuda_synchronize_device] __must__ have been called before + pub(crate) unsafe fn copy_to_cpu(&self, dest: &mut [T], src: &CudaVec) + where + T: Numeric, + { + self.copy_to_cpu_async(dest, src); + self.synchronize_device(); + } + + /// Synchronizes the device + #[allow(dead_code)] + pub(crate) fn synchronize_device(&self) { + unsafe { cuda_synchronize_device(self.gpu_index().0 as u32) }; + } + + /// Get the maximum amount of shared memory + pub(crate) fn get_max_shared_memory(&self) -> Result { + let max_shared_memory = unsafe { cuda_get_max_shared_memory(self.gpu_index().0 as u32) }; + match max_shared_memory { + 0 => Err(CudaError::SharedMemoryNotFound(self.gpu_index())), + -2 => Err(CudaError::InvalidDeviceIndex(self.gpu_index())), + _ => Ok(max_shared_memory), + } + } + + /// Initialize twiddles + #[allow(dead_code)] + pub fn initialize_twiddles(&self, polynomial_size: PolynomialSize) { + unsafe { cuda_initialize_twiddles(polynomial_size.0 as u32, self.gpu_index.0 as u32) }; + } + + /// Convert bootstrap key + #[allow(dead_code)] + pub unsafe fn convert_lwe_bootstrap_key( + &self, + dest: &mut CudaVec, + src: &[T], + input_lwe_dim: LweDimension, + glwe_dim: GlweDimension, + l_gadget: DecompositionLevelCount, + polynomial_size: PolynomialSize, + ) { + if T::BITS == 32 { + cuda_convert_lwe_bootstrap_key_32( + dest.as_mut_c_ptr(), + src.as_ptr() as *mut c_void, + self.stream.0, + self.gpu_index.0 as u32, + input_lwe_dim.0 as u32, + glwe_dim.0 as u32, + l_gadget.0 as u32, + polynomial_size.0 as u32, + ) + } else if T::BITS == 64 { + cuda_convert_lwe_bootstrap_key_64( + dest.as_mut_c_ptr(), + src.as_ptr() as *mut c_void, + self.stream.0, + self.gpu_index.0 as u32, + input_lwe_dim.0 as u32, + glwe_dim.0 as u32, + l_gadget.0 as u32, + polynomial_size.0 as u32, + ) + } + } + + /// Discarding bootstrap on a vector of LWE ciphertexts + #[allow(dead_code, clippy::too_many_arguments)] + pub unsafe fn discard_bootstrap_amortized_lwe_ciphertext_vector( + &self, + lwe_array_out: &mut CudaVec, + test_vector: &CudaVec, + test_vector_indexes: &CudaVec, + lwe_array_in: &CudaVec, + bootstrapping_key: &CudaVec, + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + num_samples: NumberOfSamples, + lwe_idx: LweCiphertextIndex, + max_shared_memory: SharedMemoryAmount, + ) { + if T::BITS == 32 { + cuda_bootstrap_amortized_lwe_ciphertext_vector_32( + self.stream.0, + lwe_array_out.as_mut_c_ptr(), + test_vector.as_c_ptr(), + test_vector_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + base_log.0 as u32, + level.0 as u32, + num_samples.0 as u32, + num_samples.0 as u32, + lwe_idx.0 as u32, + max_shared_memory.0 as u32, + ) + } else if T::BITS == 64 { + cuda_bootstrap_amortized_lwe_ciphertext_vector_64( + self.stream.0, + lwe_array_out.as_mut_c_ptr(), + test_vector.as_c_ptr(), + test_vector_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + base_log.0 as u32, + level.0 as u32, + num_samples.0 as u32, + num_samples.0 as u32, + lwe_idx.0 as u32, + max_shared_memory.0 as u32, + ) + } + } + + /// Discarding bootstrap on a vector of LWE ciphertexts + #[allow(dead_code, clippy::too_many_arguments)] + pub unsafe fn discard_bootstrap_low_latency_lwe_ciphertext_vector( + &self, + lwe_array_out: &mut CudaVec, + test_vector: &CudaVec, + test_vector_indexes: &CudaVec, + lwe_array_in: &CudaVec, + bootstrapping_key: &CudaVec, + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + num_samples: NumberOfSamples, + lwe_idx: LweCiphertextIndex, + max_shared_memory: SharedMemoryAmount, + ) { + if T::BITS == 32 { + cuda_bootstrap_low_latency_lwe_ciphertext_vector_32( + self.stream.0, + lwe_array_out.as_mut_c_ptr(), + test_vector.as_c_ptr(), + test_vector_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + base_log.0 as u32, + level.0 as u32, + num_samples.0 as u32, + num_samples.0 as u32, + lwe_idx.0 as u32, + max_shared_memory.0 as u32, + ) + } else if T::BITS == 64 { + cuda_bootstrap_low_latency_lwe_ciphertext_vector_64( + self.stream.0, + lwe_array_out.as_mut_c_ptr(), + test_vector.as_c_ptr(), + test_vector_indexes.as_c_ptr(), + lwe_array_in.as_c_ptr(), + bootstrapping_key.as_c_ptr(), + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + base_log.0 as u32, + level.0 as u32, + num_samples.0 as u32, + num_samples.0 as u32, + lwe_idx.0 as u32, + max_shared_memory.0 as u32, + ) + } + } + + /// Discarding keyswitch on a vector of LWE ciphertexts + #[allow(dead_code, clippy::too_many_arguments)] + pub unsafe fn discard_keyswitch_lwe_ciphertext_vector( + &self, + lwe_array_out: &mut CudaVec, + lwe_array_in: &CudaVec, + input_lwe_dimension: LweDimension, + output_lwe_dimension: LweDimension, + keyswitch_key: &CudaVec, + base_log: DecompositionBaseLog, + l_gadget: DecompositionLevelCount, + num_samples: NumberOfSamples, + ) { + if T::BITS == 32 { + cuda_keyswitch_lwe_ciphertext_vector_32( + self.stream.0, + lwe_array_out.as_mut_c_ptr(), + lwe_array_in.as_c_ptr(), + keyswitch_key.as_c_ptr(), + input_lwe_dimension.0 as u32, + output_lwe_dimension.0 as u32, + base_log.0 as u32, + l_gadget.0 as u32, + num_samples.0 as u32, + ) + } else if T::BITS == 64 { + cuda_keyswitch_lwe_ciphertext_vector_64( + self.stream.0, + lwe_array_out.as_mut_c_ptr(), + lwe_array_in.as_c_ptr(), + keyswitch_key.as_c_ptr(), + input_lwe_dimension.0 as u32, + output_lwe_dimension.0 as u32, + base_log.0 as u32, + l_gadget.0 as u32, + num_samples.0 as u32, + ) + } + } +} + +impl Drop for CudaStream { + fn drop(&mut self) { + unsafe { + cuda_destroy_stream(self.stream_handle().0, self.gpu_index().0 as u32); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn print_gpu_info() { + println!("Number of GPUs: {}", unsafe { cuda_get_number_of_gpus() }); + let gpu_index = GpuIndex(0); + let stream = CudaStream::new(gpu_index).unwrap(); + println!( + "Max shared memory: {}", + stream.get_max_shared_memory().unwrap() + ) + } + #[test] + fn allocate_and_copy() { + let vec = vec![1_u64, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + let gpu_index = GpuIndex(0); + let stream = CudaStream::new(gpu_index).unwrap(); + stream.check_device_memory(vec.len() as u64).unwrap(); + let mut d_vec: CudaVec = stream.malloc::(vec.len() as u32); + unsafe { + stream.copy_to_gpu(&mut d_vec, &vec); + } + let mut empty = vec![0_u64; vec.len()]; + unsafe { + stream.copy_to_cpu(&mut empty, &d_vec); + } + assert_eq!(vec, empty); + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/private/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/mod.rs new file mode 100644 index 000000000..06506464c --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/mod.rs @@ -0,0 +1,5 @@ +pub mod crypto; +pub mod device; +pub mod pointers; +pub mod vec; +pub mod wopbs; diff --git a/tfhe/src/core_crypto/backends/cuda/private/pointers.rs b/tfhe/src/core_crypto/backends/cuda/private/pointers.rs new file mode 100644 index 000000000..ba03a4c88 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/pointers.rs @@ -0,0 +1,5 @@ +use std::ffi::c_void; + +#[repr(transparent)] +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub struct StreamPointer(pub *mut c_void); diff --git a/tfhe/src/core_crypto/backends/cuda/private/vec.rs b/tfhe/src/core_crypto/backends/cuda/private/vec.rs new file mode 100644 index 000000000..84ef12056 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/vec.rs @@ -0,0 +1,52 @@ +use crate::core_crypto::commons::numeric::Numeric; +use concrete_cuda::cuda_bind::cuda_drop; +use std::ffi::c_void; +use std::marker::PhantomData; + +/// A contiguous array type stored in the gpu memory. +/// +/// Note: +/// ----- +/// +/// Such a structure: +/// + can be created via the `CudaStream::malloc` function +/// + can not be copied or cloned but can be (mutably) borrowed +/// + frees the gpu memory on drop. +/// +/// Put differently, it owns a region of the gpu memory at a given time. For this reason, regarding +/// memory, it is pretty close to a `Vec`. That being said, it only present a very very limited api. +#[derive(Debug)] +pub struct CudaVec { + pub(super) ptr: *mut c_void, + pub(super) idx: u32, + pub(super) len: usize, + pub(super) _phantom: PhantomData, +} + +impl CudaVec { + /// Returns a raw pointer to the vector’s buffer. + pub fn as_c_ptr(&self) -> *const c_void { + self.ptr as *const c_void + } + + /// Returns an unsafe mutable pointer to the vector’s buffer. + pub fn as_mut_c_ptr(&mut self) -> *mut c_void { + self.ptr + } + + /// Returns the number of elements in the vector, also referred to as its ‘length’. + pub fn len(&self) -> usize { + self.len + } + + /// Returns `true` if the CudaVec contains no elements. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +impl Drop for CudaVec { + fn drop(&mut self) { + unsafe { cuda_drop(self.ptr, self.idx) }; + } +} diff --git a/tfhe/src/core_crypto/backends/cuda/private/wopbs/mod.rs b/tfhe/src/core_crypto/backends/cuda/private/wopbs/mod.rs new file mode 100644 index 000000000..d0435bce8 --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/wopbs/mod.rs @@ -0,0 +1,2 @@ +#[cfg(test)] +mod test; diff --git a/tfhe/src/core_crypto/backends/cuda/private/wopbs/test.rs b/tfhe/src/core_crypto/backends/cuda/private/wopbs/test.rs new file mode 100644 index 000000000..a153451ab --- /dev/null +++ b/tfhe/src/core_crypto/backends/cuda/private/wopbs/test.rs @@ -0,0 +1,407 @@ +use crate::core_crypto::backends::cuda::private::device::{CudaStream, GpuIndex}; +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext; +use crate::core_crypto::commons::crypto::lwe::{LweCiphertext, LweKeyswitchKey, LweList}; +use crate::core_crypto::commons::crypto::secret::generators::{ + EncryptionRandomGenerator, SecretRandomGenerator, +}; +use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; +use crate::core_crypto::commons::math::decomposition::SignedDecomposer; +use crate::core_crypto::commons::math::polynomial::PolynomialList; +use crate::core_crypto::commons::math::tensor::{AsMutTensor, AsRefSlice, AsRefTensor}; +use crate::core_crypto::commons::test_tools; +use crate::core_crypto::prelude::*; +use concrete_csprng::generators::SoftwareRandomGenerator; +use concrete_csprng::seeders::UnixSeeder; +use concrete_cuda::cuda_bind::{ + cuda_cmux_tree_64, cuda_convert_lwe_bootstrap_key_64, cuda_extract_bits_64, + cuda_initialize_twiddles, cuda_synchronize_device, +}; +use std::os::raw::c_void; + +#[test] +pub fn test_cuda_cmux_tree() { + let polynomial_size = PolynomialSize(512); + let glwe_dimension = GlweDimension(1); + let level = DecompositionLevelCount(3); + let base_log = DecompositionBaseLog(6); + let delta_log = 60; + + let std = LogStandardDev::from_log_standard_dev(-60.); + + println!( + "polynomial_size: {}, glwe_dimension: {}, level: {}, base_log: {}", + polynomial_size.0, glwe_dimension.0, level.0, base_log.0 + ); + + let r = 10; // Depth of the tree + let num_lut = 1 << r; + + // Size of a GGSW ciphertext + // N * (k+1) * (k+1) * ell + let ggsw_size = polynomial_size.0 + * glwe_dimension.to_glwe_size().0 + * glwe_dimension.to_glwe_size().0 + * level.0; + // Size of a GLWE ciphertext + // (k+1) * N + let glwe_size = glwe_dimension.to_glwe_size().0 * polynomial_size.0; + + println!("r: {}", r); + println!("glwe_size: {}, ggsw_size: {}", glwe_size, ggsw_size); + + // Engines + const UNSAFE_SECRET: u128 = 0; + let mut seeder = UnixSeeder::new(UNSAFE_SECRET); + + // Key + let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), &mut seeder); + let rlwe_sk: GlweSecretKey<_, Vec> = + GlweSecretKey::generate_binary(glwe_dimension, polynomial_size, &mut secret_generator); + + // Instantiate the LUTs + // We need 2^r GLWEs + let mut h_concatenated_luts = vec![]; + let mut h_luts = PolynomialList::allocate(0u64, PolynomialCount(num_lut), polynomial_size); + for (i, mut polynomial) in h_luts.polynomial_iter_mut().enumerate() { + polynomial + .as_mut_tensor() + .fill_with_element((i as u64 % (1 << (64 - delta_log))) << delta_log); + + let mut h_lut = polynomial.as_tensor().as_slice().to_vec(); + let mut h_zeroes = vec![0_u64; polynomial_size.0]; + // println!("lut {}) {}", i, h_lut[0]); + + // Mask is zero + h_concatenated_luts.append(&mut h_zeroes); + // Body is something else + h_concatenated_luts.append(&mut h_lut); + } + + // Now we have (2**r GLWE ciphertexts) + assert_eq!(h_concatenated_luts.len(), num_lut * glwe_size); + println!("\nWe have {} LUTs", num_lut); + + // Copy to Device + let gpu_index = GpuIndex(0); + let stream = CudaStream::new(gpu_index).unwrap(); + + let mut d_concatenated_luts = stream.malloc::(h_concatenated_luts.len() as u32); + unsafe { + stream.copy_to_gpu::(&mut d_concatenated_luts, h_concatenated_luts.as_slice()); + } + + // Instantiate the GGSW m^tree ciphertexts + // We need r GGSW ciphertexts + // Bit decomposition of the value from MSB to LSB + let mut value = 0b111101; + let witness = value; + //bit decomposition of the value + let mut vec_message = vec![Plaintext(0); r as usize]; + for i in 0..r { + vec_message[i as usize] = Plaintext(value & 1); + value >>= 1; + } + + // bit decomposition are stored in ggsw + let mut h_concatenated_ggsw = vec![]; + for vec_msg in vec_message.iter().take(r as usize) { + println!("vec_msg: {}", vec_msg.0); + + let mut ggsw = StandardGgswCiphertext::allocate( + 0 as u64, + polynomial_size, + glwe_dimension.to_glwe_size(), + level, + base_log, + ); + rlwe_sk.encrypt_constant_ggsw(&mut ggsw, vec_msg, std, &mut encryption_generator); + + let ggsw_slice = ggsw.as_tensor().as_slice(); + h_concatenated_ggsw.append(&mut ggsw_slice.to_vec()); + } + + assert_eq!(h_concatenated_ggsw.len(), (r as usize) * ggsw_size); + println!("We have {} ggsw", r); + + // Copy to Device + let mut d_concatenated_mtree = stream.malloc::(h_concatenated_ggsw.len() as u32); + unsafe { + stream.copy_to_gpu::(&mut d_concatenated_mtree, h_concatenated_ggsw.as_slice()); + } + + let mut d_result = stream.malloc::(glwe_size as u32); + unsafe { + cuda_cmux_tree_64( + stream.stream_handle().0, + d_result.as_mut_c_ptr(), + d_concatenated_mtree.as_c_ptr(), + d_concatenated_luts.as_c_ptr(), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + base_log.0 as u32, + level.0 as u32, + r as u32, + stream.get_max_shared_memory().unwrap() as u32, + ); + } + + let mut h_result = vec![49u64; glwe_size]; + unsafe { + stream.copy_to_cpu::(&mut h_result, &d_result); + } + assert_eq!(h_result.len(), glwe_size); + + let glwe_result = GlweCiphertext::from_container(h_result, polynomial_size); + + let mut decrypted_result = + PlaintextList::from_container(vec![0_u64; rlwe_sk.polynomial_size().0]); + rlwe_sk.decrypt_glwe(&mut decrypted_result, &glwe_result); + let lut_number = + ((*decrypted_result.tensor.first() as f64) / (1u64 << delta_log) as f64).round(); + + println!("\nresult: {:?}", decrypted_result.tensor.first()); + // println!("\nresult: {:?}", decrypted_result.tensor.as_container()); + println!("witness : {:?}", witness % (1 << (64 - delta_log))); + println!("lut_number: {}", lut_number); + // println!( + // "lut value : {:?}", + // h_luts[witness as usize] + // ); + println!("Done!"); + assert_eq!(lut_number as u64, witness % (1 << (64 - delta_log))) +} + +#[test] +pub fn test_cuda_extract_bits() { + // Define settings for an insecure toy example + let polynomial_size = PolynomialSize(1024); + let glwe_dimension = GlweDimension(1); + let lwe_dimension = LweDimension(585); + + let level_bsk = DecompositionLevelCount(2); + let base_log_bsk = DecompositionBaseLog(7); + + let level_ksk = DecompositionLevelCount(2); + let base_log_ksk = DecompositionBaseLog(11); + + let std = LogStandardDev::from_log_standard_dev(-60.); + + let number_of_bits_of_message_including_padding = 5_usize; + // Tests take about 2-3 seconds on a laptop with this number + let nos: u32 = 1; + let number_of_test_runs = 10; + + const UNSAFE_SECRET: u128 = 0; + let mut seeder = UnixSeeder::new(UNSAFE_SECRET); + + let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), &mut seeder); + + // allocation and generation of the key in coef domain: + let rlwe_sk: GlweSecretKey<_, Vec> = + GlweSecretKey::generate_binary(glwe_dimension, polynomial_size, &mut secret_generator); + let lwe_small_sk: LweSecretKey<_, Vec> = + LweSecretKey::generate_binary(lwe_dimension, &mut secret_generator); + + let mut coef_bsk = StandardBootstrapKey::allocate( + 0_u64, + glwe_dimension.to_glwe_size(), + polynomial_size, + level_bsk, + base_log_bsk, + lwe_dimension, + ); + coef_bsk.fill_with_new_key(&lwe_small_sk, &rlwe_sk, std, &mut encryption_generator); + + /* + // allocation for the bootstrapping key + let mut fourier_bsk: FourierBootstrapKey<_, u64> = FourierBootstrapKey::allocate( + Complex64::new(0., 0.), + rlwe_dimension.to_glwe_size(), + polynomial_size, + level_bsk, + base_log_bsk, + lwe_dimension, + ); + */ + + let mut h_coef_bsk: Vec = vec![]; + let mut h_ksk: Vec = vec![]; + h_coef_bsk.append(&mut coef_bsk.tensor.as_slice().to_vec()); + let gpu_index = GpuIndex(0); + let stream = CudaStream::new(gpu_index).unwrap(); + + let bsk_size = (glwe_dimension.0 + 1) + * (glwe_dimension.0 + 1) + * polynomial_size.0 + * level_bsk.0 + * lwe_dimension.0; + let ksksize = level_ksk.0 * polynomial_size.0 * (lwe_dimension.0 + 1); + + let mut h_lut_vector_indexes = vec![0 as u32; 1]; + + let mut d_lwe_array_out = stream.malloc::( + nos * (lwe_dimension.0 as u32 + 1) * (number_of_bits_of_message_including_padding) as u32, + ); + let mut d_lwe_array_in = stream.malloc::(nos * (polynomial_size.0 + 1) as u32); + let mut d_lwe_array_in_buffer = stream.malloc::(nos * (polynomial_size.0 + 1) as u32); + let mut d_lwe_array_in_shifted_buffer = + stream.malloc::(nos * (polynomial_size.0 + 1) as u32); + let mut d_lwe_array_out_ks_buffer = stream.malloc::(nos * (lwe_dimension.0 + 1) as u32); + let mut d_lwe_array_out_pbs_buffer = stream.malloc::(nos * (polynomial_size.0 + 1) as u32); + let mut d_lut_pbs = stream.malloc::((2 * polynomial_size.0) as u32); + let mut d_lut_vector_indexes = stream.malloc::(1); + let mut d_ksk = stream.malloc::(ksksize as u32); + let mut d_bsk_fourier = stream.malloc::(bsk_size as u32); + //decomp_size.0 * (output_size.0 + 1) * input_size.0 + unsafe { + cuda_initialize_twiddles(polynomial_size.0 as u32, gpu_index.0 as u32); + cuda_convert_lwe_bootstrap_key_64( + d_bsk_fourier.as_mut_c_ptr(), + h_coef_bsk.as_ptr() as *mut c_void, + stream.stream_handle().0, + gpu_index.0 as u32, + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + level_bsk.0 as u32, + polynomial_size.0 as u32, + ); + stream.copy_to_gpu::(&mut d_lut_vector_indexes, &mut h_lut_vector_indexes); + } + //let mut buffers = FourierBuffers::new(fourier_bsk.polynomial_size(), + // fourier_bsk.glwe_size()); fourier_bsk.fill_with_forward_fourier(&coef_bsk, &mut buffers); + + let lwe_big_sk = LweSecretKey::binary_from_container(rlwe_sk.as_tensor().as_slice()); + let mut ksk_lwe_big_to_small = LweKeyswitchKey::allocate( + 0_u64, + level_ksk, + base_log_ksk, + lwe_big_sk.key_size(), + lwe_small_sk.key_size(), + ); + ksk_lwe_big_to_small.fill_with_keyswitch_key( + &lwe_big_sk, + &lwe_small_sk, + std, + &mut encryption_generator, + ); + + let delta_log = DeltaLog(64 - number_of_bits_of_message_including_padding); + // Decomposer to manage the rounding after decrypting the extracted bit + + let decomposer = SignedDecomposer::new(DecompositionBaseLog(1), DecompositionLevelCount(1)); + + h_ksk.clone_from(&ksk_lwe_big_to_small.into_container()); + + //////////////////////////////////////////////////////////////////////////////////////////////// + + use std::time::Instant; + let mut now = Instant::now(); + let mut elapsed = now.elapsed(); + + for _ in 0..number_of_test_runs { + // Generate a random plaintext in [0; 2^{number_of_bits_of_message_including_padding}[ + let val = test_tools::random_uint_between( + 0..2u64.pow(number_of_bits_of_message_including_padding as u32), + ); + + // Encryption + let message = Plaintext(val << delta_log.0); + println!("{:?}", message); + let mut lwe_array_in = LweCiphertext::allocate(0u64, LweSize(polynomial_size.0 + 1)); + lwe_big_sk.encrypt_lwe(&mut lwe_array_in, &message, std, &mut encryption_generator); + + // Bit extraction + // Extract all the bits + let number_values_to_extract = ExtractedBitsCount(64 - delta_log.0); + + let mut _lwe_array_out_list = LweList::allocate( + 0u64, + lwe_dimension.to_lwe_size(), + CiphertextCount(number_values_to_extract.0), + ); + /* + extract_bits( + delta_log, + &mut lwe_array_out_list, + &lwe_array_in, + &ksk_lwe_big_to_small, + &fourier_bsk, + &mut buffers, + number_values_to_extract, + ); + */ + + unsafe { + stream.copy_to_gpu::(&mut d_ksk, &mut h_ksk); + //println!("rust_lwe_array_in: {:?}", lwe_array_in); + stream.copy_to_gpu::(&mut d_lwe_array_in, &mut lwe_array_in.tensor.as_slice()); + + now = Instant::now(); + cuda_extract_bits_64( + stream.stream_handle().0, + d_lwe_array_out.as_mut_c_ptr(), + d_lwe_array_in.as_c_ptr(), + d_lwe_array_in_buffer.as_mut_c_ptr(), + d_lwe_array_in_shifted_buffer.as_mut_c_ptr(), + d_lwe_array_out_ks_buffer.as_mut_c_ptr(), + d_lwe_array_out_pbs_buffer.as_mut_c_ptr(), + d_lut_pbs.as_mut_c_ptr(), + d_lut_vector_indexes.as_mut_c_ptr(), + d_ksk.as_c_ptr(), + d_bsk_fourier.as_c_ptr(), + number_values_to_extract.0 as u32, + delta_log.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + glwe_dimension.0 as u32, + base_log_bsk.0 as u32, + level_bsk.0 as u32, + base_log_ksk.0 as u32, + level_ksk.0 as u32, + nos, + ); + elapsed += now.elapsed(); + println!("elapsed: {:?}", elapsed); + + let mut h_result = vec![0u64; (lwe_dimension.0 + 1) * number_values_to_extract.0]; + stream.copy_to_cpu::(&mut h_result, &d_lwe_array_out); + + cuda_synchronize_device(gpu_index.0 as u32); + + let mut i = 0; + for result_h in h_result.chunks(lwe_dimension.0 + 1).rev() { + let result_ct = LweCiphertext::from_container(result_h); + let mut decrypted_message = Plaintext(0_u64); + lwe_small_sk.decrypt_lwe(&mut decrypted_message, &result_ct); + // Round after decryption using decomposer + let decrypted_rounded = decomposer.closest_representable(decrypted_message.0); + // Bring back the extracted bit found in the MSB in the LSB + let decrypted_extract_bit = decrypted_rounded >> 63; + println!("extracted bit : {:?}", decrypted_extract_bit); + println!("{:?}", decrypted_message); + + // TODO decomposition algorithm should be changed for keyswitch and amortized pbs. + + assert_eq!( + ((message.0 >> delta_log.0) >> i) & 1, + decrypted_extract_bit, + "Bit #{}, for plaintext {:#066b}", + delta_log.0 + i, + message.0 + ); + + i += 1; + } + } + } + println!("number of tests: {}", number_of_test_runs); + println!("total_time: {:?}", elapsed); + println!("average time {:?}", elapsed / number_of_test_runs); +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/activated_generator.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/activated_generator.rs new file mode 100644 index 000000000..2d60c38d2 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/activated_generator.rs @@ -0,0 +1,19 @@ +#[cfg(feature = "backend_default_generator_x86_64_aesni")] +use concrete_csprng::generators::AesniRandomGenerator; +#[cfg(feature = "backend_default_generator_aarch64_aes")] +use concrete_csprng::generators::NeonAesRandomGenerator; +#[cfg(all( + not(feature = "backend_default_generator_x86_64_aesni"), + not(feature = "backend_default_generator_aarch64_aes") +))] +use concrete_csprng::generators::SoftwareRandomGenerator; + +#[cfg(feature = "backend_default_generator_x86_64_aesni")] +pub type ActivatedRandomGenerator = AesniRandomGenerator; +#[cfg(feature = "backend_default_generator_aarch64_aes")] +pub type ActivatedRandomGenerator = NeonAesRandomGenerator; +#[cfg(all( + not(feature = "backend_default_generator_x86_64_aesni"), + not(feature = "backend_default_generator_aarch64_aes") +))] +pub type ActivatedRandomGenerator = SoftwareRandomGenerator; diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/cleartext_creation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/cleartext_creation.rs new file mode 100644 index 000000000..4badaeec4 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/cleartext_creation.rs @@ -0,0 +1,104 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{Cleartext32, Cleartext64}; +use crate::core_crypto::commons::crypto::encoding::Cleartext as ImplCleartext; +use crate::core_crypto::prelude::CleartextF64; +use crate::core_crypto::specification::engines::{CleartextCreationEngine, CleartextCreationError}; + +/// # Description: +/// Implementation of [`CleartextCreationEngine`] for [`DefaultEngine`] that operates on 32 bits +/// integers. +impl CleartextCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: u32 = 3; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext32 = engine.create_cleartext_from(&input)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_cleartext_from( + &mut self, + input: &u32, + ) -> Result> { + Ok(unsafe { self.create_cleartext_from_unchecked(input) }) + } + + unsafe fn create_cleartext_from_unchecked(&mut self, input: &u32) -> Cleartext32 { + Cleartext32(ImplCleartext(*input)) + } +} + +/// # Description: +/// Implementation of [`CleartextCreationEngine`] for [`DefaultEngine`] that operates on 64 bits +/// integers. +impl CleartextCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: u64 = 3; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext64 = engine.create_cleartext_from(&input)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_cleartext_from( + &mut self, + input: &u64, + ) -> Result> { + Ok(unsafe { self.create_cleartext_from_unchecked(input) }) + } + + unsafe fn create_cleartext_from_unchecked(&mut self, input: &u64) -> Cleartext64 { + Cleartext64(ImplCleartext(*input)) + } +} + +/// # Description: +/// Implementation of [`CleartextCreationEngine`] for [`DefaultEngine`] that operates on 64 bits +/// floating point numbers. +impl CleartextCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: f64 = 3.; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: CleartextF64 = engine.create_cleartext_from(&input)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_cleartext_from( + &mut self, + value: &f64, + ) -> Result> { + Ok(unsafe { self.create_cleartext_from_unchecked(value) }) + } + + unsafe fn create_cleartext_from_unchecked(&mut self, value: &f64) -> CleartextF64 { + CleartextF64(ImplCleartext(*value)) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_consuming_retrieval.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_consuming_retrieval.rs new file mode 100644 index 000000000..e51b4b93d --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_consuming_retrieval.rs @@ -0,0 +1,301 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweCiphertext32, GlweCiphertext64, GlweCiphertextMutView32, GlweCiphertextMutView64, + GlweCiphertextView32, GlweCiphertextView64, +}; +use crate::core_crypto::commons::math::tensor::IntoTensor; +use crate::core_crypto::specification::engines::{ + GlweCiphertextConsumingRetrievalEngine, GlweCiphertextConsumingRetrievalError, +}; + +/// # Description: +/// Implementation of [`GlweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying vec of a [`GlweCiphertext32`] consuming it in the process +impl GlweCiphertextConsumingRetrievalEngine> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u32; glwe_size.0 * polynomial_size.0]; + /// let original_vec_ptr = owned_container.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: GlweCiphertext32 = + /// engine.create_glwe_ciphertext_from(owned_container, polynomial_size)?; + /// let retrieved_container = engine.consume_retrieve_glwe_ciphertext(ciphertext)?; + /// assert_eq!(original_vec_ptr, retrieved_container.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: GlweCiphertext32, + ) -> Result, GlweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_glwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: GlweCiphertext32, + ) -> Vec { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying vec of a [`GlweCiphertext64`] consuming it in the process +impl GlweCiphertextConsumingRetrievalEngine> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u64; glwe_size.0 * polynomial_size.0]; + /// let original_vec_ptr = owned_container.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: GlweCiphertext64 = + /// engine.create_glwe_ciphertext_from(owned_container, polynomial_size)?; + /// let retrieved_container = engine.consume_retrieve_glwe_ciphertext(ciphertext)?; + /// assert_eq!(original_vec_ptr, retrieved_container.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: GlweCiphertext64, + ) -> Result, GlweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_glwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: GlweCiphertext64, + ) -> Vec { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying slice of a [`GlweCiphertextView32`] consuming it in the process +impl<'data> GlweCiphertextConsumingRetrievalEngine, &'data [u32]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u32; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextView32 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// let retrieved_slice = engine.consume_retrieve_glwe_ciphertext(ciphertext_view)?; + /// assert_eq!(slice, retrieved_slice); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: GlweCiphertextView32<'data>, + ) -> Result<&'data [u32], GlweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_glwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: GlweCiphertextView32<'data>, + ) -> &'data [u32] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying slice of a [`GlweCiphertextMutView32`] consuming it in the process +impl<'data> GlweCiphertextConsumingRetrievalEngine, &'data mut [u32]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u32; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// // Required as we can't borrow a mut slice more than once + /// let underlying_ptr = slice.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextMutView32 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// let retrieved_slice = engine.consume_retrieve_glwe_ciphertext(ciphertext_view)?; + /// assert_eq!(underlying_ptr, retrieved_slice.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: GlweCiphertextMutView32<'data>, + ) -> Result<&'data mut [u32], GlweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_glwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: GlweCiphertextMutView32<'data>, + ) -> &'data mut [u32] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying slice of a [`GlweCiphertextView64`] consuming it in the process +impl<'data> GlweCiphertextConsumingRetrievalEngine, &'data [u64]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u64; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextView64 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// let retrieved_slice = engine.consume_retrieve_glwe_ciphertext(ciphertext_view)?; + /// assert_eq!(slice, retrieved_slice); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: GlweCiphertextView64<'data>, + ) -> Result<&'data [u64], GlweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_glwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: GlweCiphertextView64<'data>, + ) -> &'data [u64] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying slice of a [`GlweCiphertextMutView64`] consuming it in the process +impl<'data> GlweCiphertextConsumingRetrievalEngine, &'data mut [u64]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u64; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// // Required as we can't borrow a mut slice more than once + /// let underlying_ptr = slice.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextMutView64 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// let retrieved_slice = engine.consume_retrieve_glwe_ciphertext(ciphertext_view)?; + /// assert_eq!(underlying_ptr, retrieved_slice.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: GlweCiphertextMutView64<'data>, + ) -> Result<&'data mut [u64], GlweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_glwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: GlweCiphertextMutView64<'data>, + ) -> &'data mut [u64] { + ciphertext.0.into_tensor().into_container() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_creation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_creation.rs new file mode 100644 index 000000000..1738af27b --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_creation.rs @@ -0,0 +1,340 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweCiphertext32, GlweCiphertext64, GlweCiphertextMutView32, GlweCiphertextMutView64, + GlweCiphertextView32, GlweCiphertextView64, +}; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext as ImplGlweCiphertext; +use crate::core_crypto::prelude::PolynomialSize; +use crate::core_crypto::specification::engines::{ + GlweCiphertextCreationEngine, GlweCiphertextCreationError, +}; + +/// # Description: +/// Implementation of [`GlweCiphertextCreationEngine`] for [`DefaultEngine`] which returns a +/// [`GlweCiphertext32`]. +impl GlweCiphertextCreationEngine, GlweCiphertext32> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let owned_container = vec![0_u32; glwe_size.0 * polynomial_size.0]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: GlweCiphertext32 = + /// engine.create_glwe_ciphertext_from(owned_container, polynomial_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_glwe_ciphertext_from( + &mut self, + container: Vec, + polynomial_size: PolynomialSize, + ) -> Result> { + GlweCiphertextCreationError::::perform_generic_checks( + container.len(), + polynomial_size, + )?; + Ok(unsafe { self.create_glwe_ciphertext_from_unchecked(container, polynomial_size) }) + } + + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: Vec, + polynomial_size: PolynomialSize, + ) -> GlweCiphertext32 { + GlweCiphertext32(ImplGlweCiphertext::from_container( + container, + polynomial_size, + )) + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextCreationEngine`] for [`DefaultEngine`] which returns a +/// [`GlweCiphertext64`]. +impl GlweCiphertextCreationEngine, GlweCiphertext64> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let owned_container = vec![0_u64; glwe_size.0 * polynomial_size.0]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: GlweCiphertext64 = + /// engine.create_glwe_ciphertext_from(owned_container, polynomial_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_glwe_ciphertext_from( + &mut self, + container: Vec, + polynomial_size: PolynomialSize, + ) -> Result> { + GlweCiphertextCreationError::::perform_generic_checks( + container.len(), + polynomial_size, + )?; + Ok(unsafe { self.create_glwe_ciphertext_from_unchecked(container, polynomial_size) }) + } + + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: Vec, + polynomial_size: PolynomialSize, + ) -> GlweCiphertext64 { + GlweCiphertext64(ImplGlweCiphertext::from_container( + container, + polynomial_size, + )) + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextCreationEngine`] for [`DefaultEngine`] which returns an +/// immutable [`GlweCiphertextView32`] that does not own its memory. +impl<'data> GlweCiphertextCreationEngine<&'data [u32], GlweCiphertextView32<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u32; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextView32 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_glwe_ciphertext_from( + &mut self, + container: &'data [u32], + polynomial_size: PolynomialSize, + ) -> Result, GlweCiphertextCreationError> { + GlweCiphertextCreationError::::perform_generic_checks( + container.len(), + polynomial_size, + )?; + Ok(unsafe { self.create_glwe_ciphertext_from_unchecked(container, polynomial_size) }) + } + + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: &'data [u32], + polynomial_size: PolynomialSize, + ) -> GlweCiphertextView32<'data> { + GlweCiphertextView32(ImplGlweCiphertext::from_container( + container, + polynomial_size, + )) + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextCreationEngine`] for [`DefaultEngine`] which returns a mutable +/// [`GlweCiphertextMutView32`] that does not own its memory. +impl<'data> GlweCiphertextCreationEngine<&'data mut [u32], GlweCiphertextMutView32<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u32; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextMutView32 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_glwe_ciphertext_from( + &mut self, + container: &'data mut [u32], + polynomial_size: PolynomialSize, + ) -> Result, GlweCiphertextCreationError> + { + GlweCiphertextCreationError::::perform_generic_checks( + container.len(), + polynomial_size, + )?; + Ok(unsafe { self.create_glwe_ciphertext_from_unchecked(container, polynomial_size) }) + } + + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: &'data mut [u32], + polynomial_size: PolynomialSize, + ) -> GlweCiphertextMutView32<'data> { + GlweCiphertextMutView32(ImplGlweCiphertext::from_container( + container, + polynomial_size, + )) + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextCreationEngine`] for [`DefaultEngine`] which returns an +/// immutable [`GlweCiphertextView64`] that does not own its memory. +impl<'data> GlweCiphertextCreationEngine<&'data [u64], GlweCiphertextView64<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = 600_usize; + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u64; (glwe_size + 1) * polynomial_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextView64 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_glwe_ciphertext_from( + &mut self, + container: &'data [u64], + polynomial_size: PolynomialSize, + ) -> Result, GlweCiphertextCreationError> { + GlweCiphertextCreationError::::perform_generic_checks( + container.len(), + polynomial_size, + )?; + Ok(unsafe { self.create_glwe_ciphertext_from_unchecked(container, polynomial_size) }) + } + + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: &'data [u64], + polynomial_size: PolynomialSize, + ) -> GlweCiphertextView64<'data> { + GlweCiphertextView64(ImplGlweCiphertext::from_container( + container, + polynomial_size, + )) + } +} + +/// # Description: +/// Implementation of [`GlweCiphertextCreationEngine`] for [`DefaultEngine`] which returns a mutable +/// [`GlweCiphertextMutView64`] that does not own its memory. +impl<'data> GlweCiphertextCreationEngine<&'data mut [u64], GlweCiphertextMutView64<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let glwe_size = GlweSize(600); + /// let polynomial_size = PolynomialSize(1024); + /// + /// // You have to make sure you size the container properly + /// let mut owned_container = vec![0_u64; glwe_size.0 * polynomial_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: GlweCiphertextMutView64 = + /// engine.create_glwe_ciphertext_from(slice, polynomial_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_glwe_ciphertext_from( + &mut self, + container: &'data mut [u64], + polynomial_size: PolynomialSize, + ) -> Result, GlweCiphertextCreationError> + { + GlweCiphertextCreationError::::perform_generic_checks( + container.len(), + polynomial_size, + )?; + Ok(unsafe { self.create_glwe_ciphertext_from_unchecked(container, polynomial_size) }) + } + + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: &'data mut [u64], + polynomial_size: PolynomialSize, + ) -> GlweCiphertextMutView64<'data> { + GlweCiphertextMutView64(ImplGlweCiphertext::from_container( + container, + polynomial_size, + )) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_trivial_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_trivial_encryption.rs new file mode 100644 index 000000000..f507f499f --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_ciphertext_trivial_encryption.rs @@ -0,0 +1,105 @@ +use crate::core_crypto::prelude::GlweSize; + +use crate::core_crypto::backends::default::entities::{ + GlweCiphertext32, GlweCiphertext64, PlaintextVector32, PlaintextVector64, +}; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext as ImplGlweCiphertext; +use crate::core_crypto::specification::engines::{ + GlweCiphertextTrivialEncryptionEngine, GlweCiphertextTrivialEncryptionError, +}; + +use crate::core_crypto::backends::default::engines::DefaultEngine; + +impl GlweCiphertextTrivialEncryptionEngine for DefaultEngine { + /// # Example: + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, Variance, *}; + /// + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// let input = vec![3_u32 << 20; polynomial_size.0]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector32 = engine.create_plaintext_vector_from(&input)?; + /// // DISCLAIMER: trivial encryption is NOT secure, and DOES NOT hide the message at all. + /// let ciphertext: GlweCiphertext32 = engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &plaintext_vector)?; + /// + /// assert_eq!(ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(ciphertext.polynomial_size(), polynomial_size); + /// + /// # Ok(()) + /// # } + /// ``` + fn trivially_encrypt_glwe_ciphertext( + &mut self, + glwe_size: GlweSize, + input: &PlaintextVector32, + ) -> Result> { + unsafe { Ok(self.trivially_encrypt_glwe_ciphertext_unchecked(glwe_size, input)) } + } + + unsafe fn trivially_encrypt_glwe_ciphertext_unchecked( + &mut self, + glwe_size: GlweSize, + input: &PlaintextVector32, + ) -> GlweCiphertext32 { + let ciphertext: ImplGlweCiphertext> = + ImplGlweCiphertext::new_trivial_encryption(glwe_size, &input.0); + GlweCiphertext32(ciphertext) + } +} + +impl GlweCiphertextTrivialEncryptionEngine for DefaultEngine { + /// # Example: + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, Variance, *}; + /// + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// let input = vec![3_u64 << 20; polynomial_size.0]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector64 = engine.create_plaintext_vector_from(&input)?; + /// // DISCLAIMER: trivial encryption is NOT secure, and DOES NOT hide the message at all. + /// let ciphertext: GlweCiphertext64 = engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dimension.to_glwe_size(), &plaintext_vector)?; + /// + /// assert_eq!(ciphertext.glwe_dimension(), glwe_dimension); + /// assert_eq!(ciphertext.polynomial_size(), polynomial_size); + /// + /// # Ok(()) + /// # } + /// ``` + fn trivially_encrypt_glwe_ciphertext( + &mut self, + glwe_size: GlweSize, + input: &PlaintextVector64, + ) -> Result> { + unsafe { Ok(self.trivially_encrypt_glwe_ciphertext_unchecked(glwe_size, input)) } + } + + unsafe fn trivially_encrypt_glwe_ciphertext_unchecked( + &mut self, + glwe_size: GlweSize, + input: &PlaintextVector64, + ) -> GlweCiphertext64 { + let ciphertext: ImplGlweCiphertext> = + ImplGlweCiphertext::new_trivial_encryption(glwe_size, &input.0); + GlweCiphertext64(ciphertext) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_secret_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_secret_key_generation.rs new file mode 100644 index 000000000..72977e147 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_secret_key_generation.rs @@ -0,0 +1,110 @@ +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweSecretKey32, GlweSecretKey64, +}; +use crate::core_crypto::commons::crypto::secret::GlweSecretKey as ImplGlweSecretKey; +use crate::core_crypto::specification::engines::{ + GlweSecretKeyGenerationEngine, GlweSecretKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`GlweSecretKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl GlweSecretKeyGenerationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let glwe_secret_key: GlweSecretKey32 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// # + /// assert_eq!(glwe_secret_key.glwe_dimension(), glwe_dimension); + /// assert_eq!(glwe_secret_key.polynomial_size(), polynomial_size); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_glwe_secret_key( + &mut self, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> Result> { + GlweSecretKeyGenerationError::perform_generic_checks(glwe_dimension, polynomial_size)?; + Ok(unsafe { self.generate_new_glwe_secret_key_unchecked(glwe_dimension, polynomial_size) }) + } + + unsafe fn generate_new_glwe_secret_key_unchecked( + &mut self, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> GlweSecretKey32 { + GlweSecretKey32(ImplGlweSecretKey::generate_binary( + glwe_dimension, + polynomial_size, + &mut self.secret_generator, + )) + } +} + +/// # Description: +/// Implementation of [`GlweSecretKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl GlweSecretKeyGenerationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let glwe_secret_key: GlweSecretKey64 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// # + /// assert_eq!(glwe_secret_key.glwe_dimension(), glwe_dimension); + /// assert_eq!(glwe_secret_key.polynomial_size(), polynomial_size); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_glwe_secret_key( + &mut self, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> Result> { + GlweSecretKeyGenerationError::perform_generic_checks(glwe_dimension, polynomial_size)?; + Ok(unsafe { self.generate_new_glwe_secret_key_unchecked(glwe_dimension, polynomial_size) }) + } + + unsafe fn generate_new_glwe_secret_key_unchecked( + &mut self, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> GlweSecretKey64 { + GlweSecretKey64(ImplGlweSecretKey::generate_binary( + glwe_dimension, + polynomial_size, + &mut self.secret_generator, + )) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_to_lwe_secret_key_transformation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_to_lwe_secret_key_transformation.rs new file mode 100644 index 000000000..b3e323568 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/glwe_to_lwe_secret_key_transformation.rs @@ -0,0 +1,91 @@ +use crate::core_crypto::backends::default::engines::DefaultEngine; +use crate::core_crypto::backends::default::entities::{ + GlweSecretKey32, GlweSecretKey64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::specification::engines::{ + GlweToLweSecretKeyTransformationEngine, GlweToLweSecretKeyTransformationError, +}; + +impl GlweToLweSecretKeyTransformationEngine for DefaultEngine { + /// # Example + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// use tfhe::core_crypto::prelude::{GlweDimension, LweDimension, PolynomialSize, *}; + /// + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// + /// let glwe_secret_key: GlweSecretKey32 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// assert_eq!(glwe_secret_key.glwe_dimension(), glwe_dimension); + /// assert_eq!(glwe_secret_key.polynomial_size(), polynomial_size); + /// + /// let lwe_secret_key = engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_secret_key)?; + /// assert_eq!(lwe_secret_key.lwe_dimension(), LweDimension(8)); + /// + /// # Ok(()) + /// # } + /// ``` + fn transform_glwe_secret_key_to_lwe_secret_key( + &mut self, + glwe_secret_key: GlweSecretKey32, + ) -> Result> { + Ok(unsafe { self.transform_glwe_secret_key_to_lwe_secret_key_unchecked(glwe_secret_key) }) + } + + unsafe fn transform_glwe_secret_key_to_lwe_secret_key_unchecked( + &mut self, + glwe_secret_key: GlweSecretKey32, + ) -> LweSecretKey32 { + LweSecretKey32(glwe_secret_key.0.into_lwe_secret_key()) + } +} + +impl GlweToLweSecretKeyTransformationEngine for DefaultEngine { + /// # Example + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// use tfhe::core_crypto::prelude::{GlweDimension, LweDimension, PolynomialSize, *}; + /// + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// + /// let glwe_secret_key: GlweSecretKey64 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// assert_eq!(glwe_secret_key.glwe_dimension(), glwe_dimension); + /// assert_eq!(glwe_secret_key.polynomial_size(), polynomial_size); + /// + /// let lwe_secret_key = engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_secret_key)?; + /// assert_eq!(lwe_secret_key.lwe_dimension(), LweDimension(8)); + /// + /// # Ok(()) + /// # } + /// ``` + fn transform_glwe_secret_key_to_lwe_secret_key( + &mut self, + glwe_secret_key: GlweSecretKey64, + ) -> Result> { + Ok(unsafe { self.transform_glwe_secret_key_to_lwe_secret_key_unchecked(glwe_secret_key) }) + } + + unsafe fn transform_glwe_secret_key_to_lwe_secret_key_unchecked( + &mut self, + glwe_secret_key: GlweSecretKey64, + ) -> LweSecretKey64 { + LweSecretKey64(glwe_secret_key.0.into_lwe_secret_key()) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_bootstrap_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_bootstrap_key_generation.rs new file mode 100644 index 000000000..2d2411837 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_bootstrap_key_generation.rs @@ -0,0 +1,192 @@ +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, Variance}; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweSecretKey32, GlweSecretKey64, LweBootstrapKey32, LweBootstrapKey64, LweSecretKey32, + LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey as ImplStandardBootstrapKey; +use crate::core_crypto::prelude::{GlweSecretKeyEntity, LweSecretKeyEntity}; +use crate::core_crypto::specification::engines::{ + LweBootstrapKeyGenerationEngine, LweBootstrapKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`LweBootstrapKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. It outputs a bootstrap key in the standard domain. +impl LweBootstrapKeyGenerationEngine + for DefaultEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey32 = + /// engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// # + /// assert_eq!(bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(bsk.polynomial_size(), poly_size); + /// assert_eq!(bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_bootstrap_key( + &mut self, + input_key: &LweSecretKey32, + output_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result> { + LweBootstrapKeyGenerationError::perform_generic_checks( + decomposition_base_log, + decomposition_level_count, + 32, + )?; + Ok(unsafe { + self.generate_new_lwe_bootstrap_key_unchecked( + input_key, + output_key, + decomposition_base_log, + decomposition_level_count, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey32, + output_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweBootstrapKey32 { + let mut key = ImplStandardBootstrapKey::allocate( + 0, + output_key.glwe_dimension().to_glwe_size(), + output_key.polynomial_size(), + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + ); + key.fill_with_new_key( + &input_key.0, + &output_key.0, + noise, + &mut self.encryption_generator, + ); + LweBootstrapKey32(key) + } +} + +/// # Description: +/// Implementation of [`LweBootstrapKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. It outputs a bootstrap key in the standard domain. +impl LweBootstrapKeyGenerationEngine + for DefaultEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey64 = + /// engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// # + /// assert_eq!(bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(bsk.polynomial_size(), poly_size); + /// assert_eq!(bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_bootstrap_key( + &mut self, + input_key: &LweSecretKey64, + output_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result> { + LweBootstrapKeyGenerationError::perform_generic_checks( + decomposition_base_log, + decomposition_level_count, + 64, + )?; + Ok(unsafe { + self.generate_new_lwe_bootstrap_key_unchecked( + input_key, + output_key, + decomposition_base_log, + decomposition_level_count, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey64, + output_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweBootstrapKey64 { + let mut key = ImplStandardBootstrapKey::allocate( + 0, + output_key.glwe_dimension().to_glwe_size(), + output_key.polynomial_size(), + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + ); + key.fill_with_new_key( + &input_key.0, + &output_key.0, + noise, + &mut self.encryption_generator, + ); + LweBootstrapKey64(key) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_cleartext_fusing_multiplication.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_cleartext_fusing_multiplication.rs new file mode 100644 index 000000000..12c71070d --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_cleartext_fusing_multiplication.rs @@ -0,0 +1,116 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + Cleartext32, Cleartext64, LweCiphertext32, LweCiphertext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextCleartextFusingMultiplicationEngine, + LweCiphertextCleartextFusingMultiplicationError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextCleartextFusingMultiplicationEngine`] for [`DefaultEngine`] +/// that operates on 32 bits integers. +impl LweCiphertextCleartextFusingMultiplicationEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let cleartext_input = 12_u32; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext32 = engine.create_cleartext_from(&cleartext_input)?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// engine.fuse_mul_lwe_ciphertext_cleartext(&mut ciphertext, &cleartext)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_mul_lwe_ciphertext_cleartext( + &mut self, + output: &mut LweCiphertext32, + input: &Cleartext32, + ) -> Result<(), LweCiphertextCleartextFusingMultiplicationError> { + unsafe { self.fuse_mul_lwe_ciphertext_cleartext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_mul_lwe_ciphertext_cleartext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &Cleartext32, + ) { + output.0.update_with_scalar_mul(input.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextCleartextFusingMultiplicationEngine`] for [`DefaultEngine`] +/// that operates on 64 bits integers. +impl LweCiphertextCleartextFusingMultiplicationEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let cleartext_input = 12_u64; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext64 = engine.create_cleartext_from(&cleartext_input)?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// engine.fuse_mul_lwe_ciphertext_cleartext(&mut ciphertext, &cleartext)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_mul_lwe_ciphertext_cleartext( + &mut self, + output: &mut LweCiphertext64, + input: &Cleartext64, + ) -> Result<(), LweCiphertextCleartextFusingMultiplicationError> { + unsafe { self.fuse_mul_lwe_ciphertext_cleartext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_mul_lwe_ciphertext_cleartext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input: &Cleartext64, + ) { + output.0.update_with_scalar_mul(input.0); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_consuming_retrieval.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_consuming_retrieval.rs new file mode 100644 index 000000000..771c3aea1 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_consuming_retrieval.rs @@ -0,0 +1,277 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweCiphertextMutView32, LweCiphertextMutView64, + LweCiphertextView32, LweCiphertextView64, +}; +use crate::core_crypto::commons::math::tensor::IntoTensor; +use crate::core_crypto::specification::engines::{ + LweCiphertextConsumingRetrievalEngine, LweCiphertextConsumingRetrievalError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying vec of a [`LweCiphertext32`] consuming it in the process +impl LweCiphertextConsumingRetrievalEngine> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u32; lwe_size.0]; + /// let original_vec_ptr = owned_container.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: LweCiphertext32 = engine.create_lwe_ciphertext_from(owned_container)?; + /// let retrieved_container = engine.consume_retrieve_lwe_ciphertext(ciphertext)?; + /// assert_eq!(original_vec_ptr, retrieved_container.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: LweCiphertext32, + ) -> Result, LweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: LweCiphertext32, + ) -> Vec { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying vec of a [`LweCiphertext64`] consuming it in the process +impl LweCiphertextConsumingRetrievalEngine> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u64; lwe_size.0]; + /// let original_vec_ptr = owned_container.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: LweCiphertext64 = engine.create_lwe_ciphertext_from(owned_container)?; + /// let retrieved_container = engine.consume_retrieve_lwe_ciphertext(ciphertext)?; + /// assert_eq!(original_vec_ptr, retrieved_container.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: LweCiphertext64, + ) -> Result, LweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: LweCiphertext64, + ) -> Vec { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying container of a [`LweCiphertextView32`] consuming it in the process +impl<'data> LweCiphertextConsumingRetrievalEngine, &'data [u32]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u32; lwe_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextView32 = engine.create_lwe_ciphertext_from(slice)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext(ciphertext_view)?; + /// assert_eq!(slice, retrieved_slice); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: LweCiphertextView32<'data>, + ) -> Result<&'data [u32], LweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: LweCiphertextView32<'data>, + ) -> &'data [u32] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying container of a [`LweCiphertextMutView32`] consuming it in the process +impl<'data> LweCiphertextConsumingRetrievalEngine, &'data mut [u32]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u32; lwe_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// // Required as we can't borrow a mut slice more than once + /// let underlying_ptr = slice.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextMutView32 = engine.create_lwe_ciphertext_from(slice)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext(ciphertext_view)?; + /// assert_eq!(underlying_ptr, retrieved_slice.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: LweCiphertextMutView32<'data>, + ) -> Result<&'data mut [u32], LweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: LweCiphertextMutView32<'data>, + ) -> &'data mut [u32] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying container of a [`LweCiphertextView64`] consuming it in the process +impl<'data> LweCiphertextConsumingRetrievalEngine, &'data [u64]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u64; lwe_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextView64 = engine.create_lwe_ciphertext_from(slice)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext(ciphertext_view)?; + /// assert_eq!(slice, retrieved_slice); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: LweCiphertextView64<'data>, + ) -> Result<&'data [u64], LweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: LweCiphertextView64<'data>, + ) -> &'data [u64] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextConsumingRetrievalEngine`] for [`DefaultEngine`] that returns +/// the underlying container of a [`LweCiphertextMutView64`] consuming it in the process +impl<'data> LweCiphertextConsumingRetrievalEngine, &'data mut [u64]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u64; lwe_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// // Required as we can't borrow a mut slice more than once + /// let underlying_ptr = slice.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextMutView64 = engine.create_lwe_ciphertext_from(slice)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext(ciphertext_view)?; + /// assert_eq!(underlying_ptr, retrieved_slice.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: LweCiphertextMutView64<'data>, + ) -> Result<&'data mut [u64], LweCiphertextConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: LweCiphertextMutView64<'data>, + ) -> &'data mut [u64] { + ciphertext.0.into_tensor().into_container() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_creation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_creation.rs new file mode 100644 index 000000000..ab8a28603 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_creation.rs @@ -0,0 +1,264 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweCiphertextMutView32, LweCiphertextMutView64, + LweCiphertextView32, LweCiphertextView64, +}; +use crate::core_crypto::commons::crypto::lwe::LweCiphertext as ImplLweCiphertext; +use crate::core_crypto::specification::engines::{ + LweCiphertextCreationEngine, LweCiphertextCreationError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextCreationEngine`] for [`DefaultEngine`] which returns an +/// [`LweCiphertext32`]. +impl LweCiphertextCreationEngine, LweCiphertext32> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let owned_container = vec![0_u32; lwe_size.0]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: LweCiphertext32 = engine.create_lwe_ciphertext_from(owned_container)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_from( + &mut self, + container: Vec, + ) -> Result> { + LweCiphertextCreationError::::perform_generic_checks(container.len())?; + Ok(unsafe { self.create_lwe_ciphertext_from_unchecked(container) }) + } + + unsafe fn create_lwe_ciphertext_from_unchecked( + &mut self, + container: Vec, + ) -> LweCiphertext32 { + LweCiphertext32(ImplLweCiphertext::from_container(container)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextCreationEngine`] for [`DefaultEngine`] which returns an +/// [`LweCiphertext64`]. +impl LweCiphertextCreationEngine, LweCiphertext64> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let owned_container = vec![0_u64; lwe_size.0]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext: LweCiphertext64 = engine.create_lwe_ciphertext_from(owned_container)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_from( + &mut self, + container: Vec, + ) -> Result> { + LweCiphertextCreationError::::perform_generic_checks(container.len())?; + Ok(unsafe { self.create_lwe_ciphertext_from_unchecked(container) }) + } + + unsafe fn create_lwe_ciphertext_from_unchecked( + &mut self, + container: Vec, + ) -> LweCiphertext64 { + LweCiphertext64(ImplLweCiphertext::from_container(container)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextCreationEngine`] for [`DefaultEngine`] which returns an +/// immutable [`LweCiphertextView32`] that does not own its memory. +impl<'data> LweCiphertextCreationEngine<&'data [u32], LweCiphertextView32<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u32; lwe_size.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextView32 = engine.create_lwe_ciphertext_from(slice)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_from( + &mut self, + container: &'data [u32], + ) -> Result, LweCiphertextCreationError> { + LweCiphertextCreationError::::perform_generic_checks(container.len())?; + Ok(unsafe { self.create_lwe_ciphertext_from_unchecked(container) }) + } + + unsafe fn create_lwe_ciphertext_from_unchecked( + &mut self, + container: &'data [u32], + ) -> LweCiphertextView32<'data> { + LweCiphertextView32(ImplLweCiphertext::from_container(container)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextCreationEngine`] for [`DefaultEngine`] which returns a mutable +/// [`LweCiphertextMutView32`] that does not own its memory. +impl<'data> LweCiphertextCreationEngine<&'data mut [u32], LweCiphertextMutView32<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u32; lwe_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextMutView32 = engine.create_lwe_ciphertext_from(slice)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_from( + &mut self, + container: &'data mut [u32], + ) -> Result, LweCiphertextCreationError> { + LweCiphertextCreationError::::perform_generic_checks(container.len())?; + Ok(unsafe { self.create_lwe_ciphertext_from_unchecked(container) }) + } + + unsafe fn create_lwe_ciphertext_from_unchecked( + &mut self, + container: &'data mut [u32], + ) -> LweCiphertextMutView32<'data> { + LweCiphertextMutView32(ImplLweCiphertext::from_container(container)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextCreationEngine`] for [`DefaultEngine`] which returns an +/// immutable [`LweCiphertextView64`] that does not own its memory. +impl<'data> LweCiphertextCreationEngine<&'data [u64], LweCiphertextView64<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let mut owned_container = vec![0_u64; 128]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextView64 = engine.create_lwe_ciphertext_from(slice)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_from( + &mut self, + container: &'data [u64], + ) -> Result, LweCiphertextCreationError> { + LweCiphertextCreationError::::perform_generic_checks(container.len())?; + Ok(unsafe { self.create_lwe_ciphertext_from_unchecked(container) }) + } + + unsafe fn create_lwe_ciphertext_from_unchecked( + &mut self, + container: &'data [u64], + ) -> LweCiphertextView64<'data> { + LweCiphertextView64(ImplLweCiphertext::from_container(container)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextCreationEngine`] for [`DefaultEngine`] which returns a mutable +/// [`LweCiphertextMutView64`] that does not own its memory. +impl<'data> LweCiphertextCreationEngine<&'data mut [u64], LweCiphertextMutView64<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(128); + /// let mut owned_container = vec![0_u64; lwe_size.0]; + /// + /// let slice = &mut owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_view: LweCiphertextMutView64 = engine.create_lwe_ciphertext_from(slice)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_from( + &mut self, + container: &'data mut [u64], + ) -> Result, LweCiphertextCreationError> { + LweCiphertextCreationError::::perform_generic_checks(container.len())?; + Ok(unsafe { self.create_lwe_ciphertext_from_unchecked(container) }) + } + + unsafe fn create_lwe_ciphertext_from_unchecked( + &mut self, + container: &'data mut [u64], + ) -> LweCiphertextMutView64<'data> { + LweCiphertextMutView64(ImplLweCiphertext::from_container(container)) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_decryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_decryption.rs new file mode 100644 index 000000000..9c3d7caf4 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_decryption.rs @@ -0,0 +1,227 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweCiphertextView32, LweCiphertextView64, LweSecretKey32, + LweSecretKey64, Plaintext32, Plaintext64, +}; +use crate::core_crypto::commons::crypto::encoding::Plaintext as ImplPlaintext; +use crate::core_crypto::specification::engines::{ + LweCiphertextDecryptionEngine, LweCiphertextDecryptionError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextDecryptionEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl LweCiphertextDecryptionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let decrypted_plaintext = engine.decrypt_lwe_ciphertext(&key, &ciphertext)?; + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn decrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey32, + input: &LweCiphertext32, + ) -> Result> { + Ok(unsafe { self.decrypt_lwe_ciphertext_unchecked(key, input) }) + } + + unsafe fn decrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey32, + input: &LweCiphertext32, + ) -> Plaintext32 { + let mut plaintext = ImplPlaintext(0u32); + key.0.decrypt_lwe(&mut plaintext, &input.0); + Plaintext32(plaintext) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDecryptionEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl LweCiphertextDecryptionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let decrypted_plaintext = engine.decrypt_lwe_ciphertext(&key, &ciphertext)?; + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn decrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey64, + input: &LweCiphertext64, + ) -> Result> { + Ok(unsafe { self.decrypt_lwe_ciphertext_unchecked(key, input) }) + } + + unsafe fn decrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey64, + input: &LweCiphertext64, + ) -> Plaintext64 { + let mut plaintext = ImplPlaintext(0u64); + key.0.decrypt_lwe(&mut plaintext, &input.0); + Plaintext64(plaintext) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDecryptionEngine`] for [`DefaultEngine`] that operates on +/// an [`LweCiphertextView32`] containing 32 bits integers. +impl LweCiphertextDecryptionEngine, Plaintext32> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let mut raw_ciphertext = vec![0_u32; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_view: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut raw_ciphertext[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext_view, &plaintext, noise)?; + /// + /// // Convert MutView to View + /// let raw_ciphertext = engine.consume_retrieve_lwe_ciphertext(ciphertext_view)?; + /// let ciphertext_view: LweCiphertextView32 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext[..])?; + /// + /// let decrypted_plaintext = engine.decrypt_lwe_ciphertext(&key, &ciphertext_view)?; + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn decrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey32, + input: &LweCiphertextView32<'_>, + ) -> Result> { + Ok(unsafe { self.decrypt_lwe_ciphertext_unchecked(key, input) }) + } + + unsafe fn decrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey32, + input: &LweCiphertextView32<'_>, + ) -> Plaintext32 { + let mut plaintext = ImplPlaintext(0u32); + key.0.decrypt_lwe(&mut plaintext, &input.0); + Plaintext32(plaintext) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDecryptionEngine`] for [`DefaultEngine`] that operates on +/// an [`LweCiphertextView64`] containing 64 bits integers. +impl LweCiphertextDecryptionEngine, Plaintext64> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u64 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let mut raw_ciphertext = vec![0_u64; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_view: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut raw_ciphertext[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext_view, &plaintext, noise)?; + /// + /// // Convert MutView to View + /// let raw_ciphertext = engine.consume_retrieve_lwe_ciphertext(ciphertext_view)?; + /// let ciphertext_view: LweCiphertextView64 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext[..])?; + /// + /// let decrypted_plaintext = engine.decrypt_lwe_ciphertext(&key, &ciphertext_view)?; + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn decrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey64, + input: &LweCiphertextView64<'_>, + ) -> Result> { + Ok(unsafe { self.decrypt_lwe_ciphertext_unchecked(key, input) }) + } + + unsafe fn decrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey64, + input: &LweCiphertextView64<'_>, + ) -> Plaintext64 { + let mut plaintext = ImplPlaintext(0u64); + key.0.decrypt_lwe(&mut plaintext, &input.0); + Plaintext64(plaintext) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_addition.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_addition.rs new file mode 100644 index 000000000..150bef05b --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_addition.rs @@ -0,0 +1,293 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweCiphertextMutView32, LweCiphertextMutView64, + LweCiphertextView32, LweCiphertextView64, +}; +use crate::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingAdditionEngine, LweCiphertextDiscardingAdditionError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingAdditionEngine`] for [`DefaultEngine`] that operates +/// on 32 bits integers. +impl LweCiphertextDiscardingAdditionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input_1 = 3_u32 << 20; + /// let input_2 = 7_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// let ciphertext_2 = engine.encrypt_lwe_ciphertext(&key, &plaintext_2, noise)?; + /// let mut ciphertext_3 = engine.zero_encrypt_lwe_ciphertext(&key, noise)?; + /// + /// engine.discard_add_lwe_ciphertext(&mut ciphertext_3, &ciphertext_1, &ciphertext_2)?; + /// # + /// assert_eq!(ciphertext_3.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_add_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext32, + input_1: &LweCiphertext32, + input_2: &LweCiphertext32, + ) -> Result<(), LweCiphertextDiscardingAdditionError> { + LweCiphertextDiscardingAdditionError::perform_generic_checks(output, input_1, input_2)?; + unsafe { self.discard_add_lwe_ciphertext_unchecked(output, input_1, input_2) }; + Ok(()) + } + + unsafe fn discard_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input_1: &LweCiphertext32, + input_2: &LweCiphertext32, + ) { + output + .0 + .as_mut_tensor() + .fill_with_copy(input_1.0.as_tensor()); + output.0.update_with_add(&input_2.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingAdditionEngine`] for [`DefaultEngine`] that operates +/// on 64 bits integers. +impl LweCiphertextDiscardingAdditionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input_1 = 3_u64 << 50; + /// let input_2 = 7_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// let ciphertext_2 = engine.encrypt_lwe_ciphertext(&key, &plaintext_2, noise)?; + /// let mut ciphertext_3 = engine.zero_encrypt_lwe_ciphertext(&key, noise)?; + /// + /// engine.discard_add_lwe_ciphertext(&mut ciphertext_3, &ciphertext_1, &ciphertext_2)?; + /// # + /// assert_eq!(ciphertext_3.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_add_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext64, + input_1: &LweCiphertext64, + input_2: &LweCiphertext64, + ) -> Result<(), LweCiphertextDiscardingAdditionError> { + LweCiphertextDiscardingAdditionError::perform_generic_checks(output, input_1, input_2)?; + unsafe { self.discard_add_lwe_ciphertext_unchecked(output, input_1, input_2) }; + Ok(()) + } + + unsafe fn discard_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input_1: &LweCiphertext64, + input_2: &LweCiphertext64, + ) { + output + .0 + .as_mut_tensor() + .fill_with_copy(input_1.0.as_tensor()); + output.0.update_with_add(&input_2.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingAdditionEngine`] for [`DefaultEngine`] that operates +/// on views containing 32 bits integers. +impl LweCiphertextDiscardingAdditionEngine, LweCiphertextMutView32<'_>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input_1 = 3_u32 << 20; + /// let input_2 = 7_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// + /// let mut ciphertext_1_container = vec![0_u32; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_1: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut ciphertext_1_container[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext_1, &plaintext_1, noise)?; + /// let mut ciphertext_2_container = vec![0_u32; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_2: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut ciphertext_2_container[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext_2, &plaintext_2, noise)?; + /// + /// // Convert MutView to View + /// let raw_ciphertext_1 = engine.consume_retrieve_lwe_ciphertext(ciphertext_1)?; + /// let ciphertext_1: LweCiphertextView32 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext_1[..])?; + /// let raw_ciphertext_2 = engine.consume_retrieve_lwe_ciphertext(ciphertext_2)?; + /// let ciphertext_2: LweCiphertextView32 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext_2[..])?; + /// + /// let mut ciphertext_3_container = vec![0_u32; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_3: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut ciphertext_3_container[..])?; + /// + /// engine.discard_add_lwe_ciphertext(&mut ciphertext_3, &ciphertext_1, &ciphertext_2)?; + /// # + /// assert_eq!(ciphertext_3.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_add_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView32, + input_1: &LweCiphertextView32, + input_2: &LweCiphertextView32, + ) -> Result<(), LweCiphertextDiscardingAdditionError> { + LweCiphertextDiscardingAdditionError::perform_generic_checks(output, input_1, input_2)?; + unsafe { self.discard_add_lwe_ciphertext_unchecked(output, input_1, input_2) }; + Ok(()) + } + + unsafe fn discard_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView32, + input_1: &LweCiphertextView32, + input_2: &LweCiphertextView32, + ) { + output + .0 + .as_mut_tensor() + .fill_with_copy(input_1.0.as_tensor()); + output.0.update_with_add(&input_2.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingAdditionEngine`] for [`DefaultEngine`] that operates +/// on on views containing 64 bits integers. +impl LweCiphertextDiscardingAdditionEngine, LweCiphertextMutView64<'_>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input_1 = 3_u64 << 50; + /// let input_2 = 7_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// + /// let mut ciphertext_1_container = vec![0_u64; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_1: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut ciphertext_1_container[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext_1, &plaintext_1, noise)?; + /// let mut ciphertext_2_container = vec![0_u64; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_2: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut ciphertext_2_container[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext_2, &plaintext_2, noise)?; + /// + /// // Convert MutView to View + /// let raw_ciphertext_1 = engine.consume_retrieve_lwe_ciphertext(ciphertext_1)?; + /// let ciphertext_1: LweCiphertextView64 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext_1[..])?; + /// let raw_ciphertext_2 = engine.consume_retrieve_lwe_ciphertext(ciphertext_2)?; + /// let ciphertext_2: LweCiphertextView64 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext_2[..])?; + /// + /// let mut ciphertext_3_container = vec![0_u64; key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_3: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut ciphertext_3_container[..])?; + /// + /// engine.discard_add_lwe_ciphertext(&mut ciphertext_3, &ciphertext_1, &ciphertext_2)?; + /// # + /// assert_eq!(ciphertext_3.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_add_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView64, + input_1: &LweCiphertextView64, + input_2: &LweCiphertextView64, + ) -> Result<(), LweCiphertextDiscardingAdditionError> { + LweCiphertextDiscardingAdditionError::perform_generic_checks(output, input_1, input_2)?; + unsafe { self.discard_add_lwe_ciphertext_unchecked(output, input_1, input_2) }; + Ok(()) + } + + unsafe fn discard_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView64, + input_1: &LweCiphertextView64, + input_2: &LweCiphertextView64, + ) { + output + .0 + .as_mut_tensor() + .fill_with_copy(input_1.0.as_tensor()); + output.0.update_with_add(&input_2.0); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_encryption.rs new file mode 100644 index 000000000..55551ba75 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_encryption.rs @@ -0,0 +1,264 @@ +use crate::core_crypto::prelude::Variance; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweCiphertextMutView32, LweCiphertextMutView64, + LweSecretKey32, LweSecretKey64, Plaintext32, Plaintext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingEncryptionEngine, LweCiphertextDiscardingEncryptionError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 32 bits integers. +impl LweCiphertextDiscardingEncryptionEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext, &plaintext, noise)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey32, + output: &mut LweCiphertext32, + input: &Plaintext32, + noise: Variance, + ) -> Result<(), LweCiphertextDiscardingEncryptionError> { + LweCiphertextDiscardingEncryptionError::perform_generic_checks(key, output)?; + unsafe { self.discard_encrypt_lwe_ciphertext_unchecked(key, output, input, noise) }; + Ok(()) + } + + unsafe fn discard_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey32, + output: &mut LweCiphertext32, + input: &Plaintext32, + noise: Variance, + ) { + key.0.encrypt_lwe( + &mut output.0, + &input.0, + noise, + &mut self.encryption_generator, + ); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 64 bits integers. +impl LweCiphertextDiscardingEncryptionEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut ciphertext, &plaintext, noise)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey64, + output: &mut LweCiphertext64, + input: &Plaintext64, + noise: Variance, + ) -> Result<(), LweCiphertextDiscardingEncryptionError> { + LweCiphertextDiscardingEncryptionError::perform_generic_checks(key, output)?; + unsafe { self.discard_encrypt_lwe_ciphertext_unchecked(key, output, input, noise) }; + Ok(()) + } + + unsafe fn discard_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey64, + output: &mut LweCiphertext64, + input: &Plaintext64, + noise: Variance, + ) { + key.0.encrypt_lwe( + &mut output.0, + &input.0, + noise, + &mut self.encryption_generator, + ); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 32 bits integers. +impl + LweCiphertextDiscardingEncryptionEngine> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let mut output_cipertext_container = vec![0_32; lwe_dimension.to_lwe_size().0]; + /// let mut output_ciphertext: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut output_cipertext_container[..])?; + /// + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut output_ciphertext, &plaintext, noise)?; + /// # + /// assert_eq!(output_ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey32, + output: &mut LweCiphertextMutView32, + input: &Plaintext32, + noise: Variance, + ) -> Result<(), LweCiphertextDiscardingEncryptionError> { + LweCiphertextDiscardingEncryptionError::perform_generic_checks(key, output)?; + unsafe { self.discard_encrypt_lwe_ciphertext_unchecked(key, output, input, noise) }; + Ok(()) + } + + unsafe fn discard_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey32, + output: &mut LweCiphertextMutView32, + input: &Plaintext32, + noise: Variance, + ) { + key.0.encrypt_lwe( + &mut output.0, + &input.0, + noise, + &mut self.encryption_generator, + ); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 64 bits integers. +impl + LweCiphertextDiscardingEncryptionEngine> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let mut output_cipertext_container = vec![0_64; lwe_dimension.to_lwe_size().0]; + /// let mut output_ciphertext: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut output_cipertext_container[..])?; + /// + /// engine.discard_encrypt_lwe_ciphertext(&key, &mut output_ciphertext, &plaintext, noise)?; + /// # + /// assert_eq!(output_ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey64, + output: &mut LweCiphertextMutView64, + input: &Plaintext64, + noise: Variance, + ) -> Result<(), LweCiphertextDiscardingEncryptionError> { + LweCiphertextDiscardingEncryptionError::perform_generic_checks(key, output)?; + unsafe { self.discard_encrypt_lwe_ciphertext_unchecked(key, output, input, noise) }; + Ok(()) + } + + unsafe fn discard_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey64, + output: &mut LweCiphertextMutView64, + input: &Plaintext64, + noise: Variance, + ) { + key.0.encrypt_lwe( + &mut output.0, + &input.0, + noise, + &mut self.encryption_generator, + ); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_keyswitch.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_keyswitch.rs new file mode 100644 index 000000000..c30bd28f5 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_keyswitch.rs @@ -0,0 +1,316 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweCiphertextMutView32, LweCiphertextMutView64, + LweCiphertextView32, LweCiphertextView64, LweKeyswitchKey32, LweKeyswitchKey64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingKeyswitchEngine, LweCiphertextDiscardingKeyswitchError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingKeyswitchEngine`] for [`DefaultEngine`] that operates +/// on 32 bits integers. +impl LweCiphertextDiscardingKeyswitchEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&input_key, &plaintext, noise)?; + /// let mut ciphertext_2 = engine.zero_encrypt_lwe_ciphertext(&output_key, noise)?; + /// + /// engine.discard_keyswitch_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1, &keyswitch_key)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), output_lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + ksk: &LweKeyswitchKey32, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError> { + LweCiphertextDiscardingKeyswitchError::perform_generic_checks(output, input, ksk)?; + unsafe { self.discard_keyswitch_lwe_ciphertext_unchecked(output, input, ksk) }; + Ok(()) + } + + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + ksk: &LweKeyswitchKey32, + ) { + ksk.0.keyswitch_ciphertext(&mut output.0, &input.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingKeyswitchEngine`] for [`DefaultEngine`] that operates +/// on 64 bits integers. +impl LweCiphertextDiscardingKeyswitchEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-50.)); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&input_key, &plaintext, noise)?; + /// let mut ciphertext_2 = engine.zero_encrypt_lwe_ciphertext(&output_key, noise)?; + /// + /// engine.discard_keyswitch_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1, &keyswitch_key)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), output_lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + ksk: &LweKeyswitchKey64, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError> { + LweCiphertextDiscardingKeyswitchError::perform_generic_checks(output, input, ksk)?; + unsafe { self.discard_keyswitch_lwe_ciphertext_unchecked(output, input, ksk) }; + Ok(()) + } + + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + ksk: &LweKeyswitchKey64, + ) { + ksk.0.keyswitch_ciphertext(&mut output.0, &input.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingKeyswitchEngine`] for [`DefaultEngine`] that operates +/// on views containing 32 bits integers. +impl + LweCiphertextDiscardingKeyswitchEngine< + LweKeyswitchKey32, + LweCiphertextView32<'_>, + LweCiphertextMutView32<'_>, + > for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let mut raw_ciphertext_1_container = vec![0_u32; input_key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_1: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut raw_ciphertext_1_container[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&input_key, &mut ciphertext_1, &plaintext, noise)?; + /// + /// // Convert MutView to View + /// let raw_ciphertext_1 = engine.consume_retrieve_lwe_ciphertext(ciphertext_1)?; + /// let ciphertext_1: LweCiphertextView32 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext_1[..])?; + /// + /// let mut raw_ciphertext_2_container = vec![0_u32; output_key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_2: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(&mut raw_ciphertext_2_container[..])?; + /// + /// engine.discard_keyswitch_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1, &keyswitch_key)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), output_lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView32<'_>, + input: &LweCiphertextView32<'_>, + ksk: &LweKeyswitchKey32, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError> { + LweCiphertextDiscardingKeyswitchError::perform_generic_checks(output, input, ksk)?; + unsafe { self.discard_keyswitch_lwe_ciphertext_unchecked(output, input, ksk) }; + Ok(()) + } + + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView32<'_>, + input: &LweCiphertextView32<'_>, + ksk: &LweKeyswitchKey32, + ) { + ksk.0.keyswitch_ciphertext(&mut output.0, &input.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingKeyswitchEngine`] for [`DefaultEngine`] that operates +/// on views containing 64 bits integers. +impl + LweCiphertextDiscardingKeyswitchEngine< + LweKeyswitchKey64, + LweCiphertextView64<'_>, + LweCiphertextMutView64<'_>, + > for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let mut raw_ciphertext_1_container = vec![0_u64; input_key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_1: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut raw_ciphertext_1_container[..])?; + /// engine.discard_encrypt_lwe_ciphertext(&input_key, &mut ciphertext_1, &plaintext, noise)?; + /// + /// // Convert MutView to View + /// let raw_ciphertext_1 = engine.consume_retrieve_lwe_ciphertext(ciphertext_1)?; + /// let ciphertext_1: LweCiphertextView64 = + /// engine.create_lwe_ciphertext_from(&raw_ciphertext_1[..])?; + /// + /// let mut raw_ciphertext_2_container = vec![0_u64; output_key.lwe_dimension().to_lwe_size().0]; + /// let mut ciphertext_2: LweCiphertextMutView64 = + /// engine.create_lwe_ciphertext_from(&mut raw_ciphertext_2_container[..])?; + /// + /// engine.discard_keyswitch_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1, &keyswitch_key)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), output_lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView64<'_>, + input: &LweCiphertextView64<'_>, + ksk: &LweKeyswitchKey64, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError> { + LweCiphertextDiscardingKeyswitchError::perform_generic_checks(output, input, ksk)?; + unsafe { self.discard_keyswitch_lwe_ciphertext_unchecked(output, input, ksk) }; + Ok(()) + } + + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView64<'_>, + input: &LweCiphertextView64<'_>, + ksk: &LweKeyswitchKey64, + ) { + ksk.0.keyswitch_ciphertext(&mut output.0, &input.0); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_public_key_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_public_key_encryption.rs new file mode 100644 index 000000000..7f74e388b --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_discarding_public_key_encryption.rs @@ -0,0 +1,175 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LwePublicKey32, LwePublicKey64, Plaintext32, Plaintext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingPublicKeyEncryptionEngine, + LweCiphertextDiscardingPublicKeyEncryptionError, +}; +use crate::core_crypto::specification::entities::LwePublicKeyEntity; + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingPublicKeyEncryptionEngine`] for [`DefaultEngine`] +/// that operates on 32 bits integers. +impl LweCiphertextDiscardingPublicKeyEncryptionEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(7); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let public_key: LwePublicKey32 = engine.generate_new_lwe_public_key( + /// &secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext_container = vec![0u32; lwe_dimension.to_lwe_size().0]; + /// + /// let mut ciphertext = engine.create_lwe_ciphertext_from(ciphertext_container)?; + /// + /// engine.discard_encrypt_lwe_ciphertext_with_public_key( + /// &public_key, + /// &mut ciphertext, + /// &plaintext, + /// )?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_encrypt_lwe_ciphertext_with_public_key( + &mut self, + key: &LwePublicKey32, + output: &mut LweCiphertext32, + input: &Plaintext32, + ) -> Result<(), LweCiphertextDiscardingPublicKeyEncryptionError> { + LweCiphertextDiscardingPublicKeyEncryptionError::perform_generic_checks(key, output)?; + unsafe { + self.discard_encrypt_lwe_ciphertext_with_public_key_unchecked(key, output, input) + }; + Ok(()) + } + + unsafe fn discard_encrypt_lwe_ciphertext_with_public_key_unchecked( + &mut self, + key: &LwePublicKey32, + output: &mut LweCiphertext32, + input: &Plaintext32, + ) { + // Fills output masks with zeros, store input in the body + output.0.fill_with_trivial_encryption(&input.0); + let ct_choice = self + .secret_generator + .random_binary_tensor::(key.lwe_zero_encryption_count().0); + + // Add the public encryption of zeros to get the encryption + for (&chosen, public_encryption_of_zero) in + ct_choice.as_container().iter().zip(key.0.ciphertext_iter()) + { + if chosen == 1 { + output.0.update_with_add(&public_encryption_of_zero); + } + } + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingPublicKeyEncryptionEngine`] for [`DefaultEngine`] +/// that operates on 64 bits integers. +impl LweCiphertextDiscardingPublicKeyEncryptionEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(7); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-50.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let public_key: LwePublicKey64 = engine.generate_new_lwe_public_key( + /// &secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext_container = vec![0u64; lwe_dimension.to_lwe_size().0]; + /// + /// let mut ciphertext = engine.create_lwe_ciphertext_from(ciphertext_container)?; + /// + /// engine.discard_encrypt_lwe_ciphertext_with_public_key( + /// &public_key, + /// &mut ciphertext, + /// &plaintext, + /// )?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_encrypt_lwe_ciphertext_with_public_key( + &mut self, + key: &LwePublicKey64, + output: &mut LweCiphertext64, + input: &Plaintext64, + ) -> Result<(), LweCiphertextDiscardingPublicKeyEncryptionError> { + LweCiphertextDiscardingPublicKeyEncryptionError::perform_generic_checks(key, output)?; + unsafe { + self.discard_encrypt_lwe_ciphertext_with_public_key_unchecked(key, output, input) + }; + Ok(()) + } + + unsafe fn discard_encrypt_lwe_ciphertext_with_public_key_unchecked( + &mut self, + key: &LwePublicKey64, + output: &mut LweCiphertext64, + input: &Plaintext64, + ) { + // Fills output masks with zeros, store input in the body + output.0.fill_with_trivial_encryption(&input.0); + let ct_choice = self + .secret_generator + .random_binary_tensor::(key.lwe_zero_encryption_count().0); + + // Add the public encryption of zeros to get the encryption + for (&chosen, public_encryption_of_zero) in + ct_choice.as_container().iter().zip(key.0.ciphertext_iter()) + { + if chosen == 1 { + output.0.update_with_add(&public_encryption_of_zero); + } + } + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_encryption.rs new file mode 100644 index 000000000..1460bee96 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_encryption.rs @@ -0,0 +1,125 @@ +use crate::core_crypto::prelude::Variance; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweSecretKey32, LweSecretKey64, Plaintext32, Plaintext64, +}; +use crate::core_crypto::commons::crypto::lwe::LweCiphertext as ImplLweCiphertext; +use crate::core_crypto::specification::engines::{ + LweCiphertextEncryptionEngine, LweCiphertextEncryptionError, +}; +use crate::core_crypto::specification::entities::LweSecretKeyEntity; + +/// # Description: +/// Implementation of [`LweCiphertextEncryptionEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl LweCiphertextEncryptionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey32, + input: &Plaintext32, + noise: Variance, + ) -> Result> { + Ok(unsafe { self.encrypt_lwe_ciphertext_unchecked(key, input, noise) }) + } + + unsafe fn encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey32, + input: &Plaintext32, + noise: Variance, + ) -> LweCiphertext32 { + let mut ciphertext = ImplLweCiphertext::allocate(0u32, key.lwe_dimension().to_lwe_size()); + key.0.encrypt_lwe( + &mut ciphertext, + &input.0, + noise, + &mut self.encryption_generator, + ); + LweCiphertext32(ciphertext) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextEncryptionEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl LweCiphertextEncryptionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey64, + input: &Plaintext64, + noise: Variance, + ) -> Result> { + Ok(unsafe { self.encrypt_lwe_ciphertext_unchecked(key, input, noise) }) + } + + unsafe fn encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey64, + input: &Plaintext64, + noise: Variance, + ) -> LweCiphertext64 { + let mut ciphertext = ImplLweCiphertext::allocate(0u64, key.lwe_dimension().to_lwe_size()); + key.0.encrypt_lwe( + &mut ciphertext, + &input.0, + noise, + &mut self.encryption_generator, + ); + LweCiphertext64(ciphertext) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_addition.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_addition.rs new file mode 100644 index 000000000..f5598fb5f --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_addition.rs @@ -0,0 +1,115 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextFusingAdditionEngine, LweCiphertextFusingAdditionError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextFusingAdditionEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl LweCiphertextFusingAdditionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input_1 = 3_u32 << 20; + /// let input_2 = 5_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// let mut ciphertext_2 = engine.encrypt_lwe_ciphertext(&key, &plaintext_2, noise)?; + /// + /// engine.fuse_add_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_add_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + ) -> Result<(), LweCiphertextFusingAdditionError> { + LweCiphertextFusingAdditionError::perform_generic_checks(output, input)?; + unsafe { self.fuse_add_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + ) { + output.0.update_with_add(&input.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextFusingAdditionEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl LweCiphertextFusingAdditionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input_1 = 3_u64 << 50; + /// let input_2 = 5_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// let mut ciphertext_2 = engine.encrypt_lwe_ciphertext(&key, &plaintext_2, noise)?; + /// + /// engine.fuse_add_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_add_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + ) -> Result<(), LweCiphertextFusingAdditionError> { + LweCiphertextFusingAdditionError::perform_generic_checks(output, input)?; + unsafe { self.fuse_add_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + ) { + output.0.update_with_add(&input.0); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_opposite.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_opposite.rs new file mode 100644 index 000000000..0e50875d0 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_opposite.rs @@ -0,0 +1,97 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextFusingOppositeEngine, LweCiphertextFusingOppositeError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextFusingOppositeEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl LweCiphertextFusingOppositeEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// engine.fuse_opp_lwe_ciphertext(&mut ciphertext)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_opp_lwe_ciphertext( + &mut self, + input: &mut LweCiphertext32, + ) -> Result<(), LweCiphertextFusingOppositeError> { + unsafe { self.fuse_opp_lwe_ciphertext_unchecked(input) }; + Ok(()) + } + + unsafe fn fuse_opp_lwe_ciphertext_unchecked(&mut self, input: &mut LweCiphertext32) { + input.0.update_with_neg(); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextFusingOppositeEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl LweCiphertextFusingOppositeEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// engine.fuse_opp_lwe_ciphertext(&mut ciphertext)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_opp_lwe_ciphertext( + &mut self, + input: &mut LweCiphertext64, + ) -> Result<(), LweCiphertextFusingOppositeError> { + unsafe { self.fuse_opp_lwe_ciphertext_unchecked(input) }; + Ok(()) + } + + unsafe fn fuse_opp_lwe_ciphertext_unchecked(&mut self, input: &mut LweCiphertext64) { + input.0.update_with_neg(); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_subtraction.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_subtraction.rs new file mode 100644 index 000000000..f8d0258a7 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_fusing_subtraction.rs @@ -0,0 +1,115 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextFusingSubtractionEngine, LweCiphertextFusingSubtractionError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextFusingSubtractionEngine`] for [`DefaultEngine`] that operates +/// on 32 bits integers. +impl LweCiphertextFusingSubtractionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input_1 = 3_u32 << 20; + /// let input_2 = 5_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// let mut ciphertext_2 = engine.encrypt_lwe_ciphertext(&key, &plaintext_2, noise)?; + /// + /// engine.fuse_sub_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_sub_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + ) -> Result<(), LweCiphertextFusingSubtractionError> { + LweCiphertextFusingSubtractionError::perform_generic_checks(output, input)?; + unsafe { self.fuse_sub_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_sub_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + ) { + output.0.update_with_sub(&input.0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextFusingSubtractionEngine`] for [`DefaultEngine`] that operates +/// on 64 bits integers. +impl LweCiphertextFusingSubtractionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input_1 = 3_u64 << 50; + /// let input_2 = 5_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let ciphertext_1 = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// let mut ciphertext_2 = engine.encrypt_lwe_ciphertext(&key, &plaintext_2, noise)?; + /// + /// engine.fuse_sub_lwe_ciphertext(&mut ciphertext_2, &ciphertext_1)?; + /// # + /// assert_eq!(ciphertext_2.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_sub_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + ) -> Result<(), LweCiphertextFusingSubtractionError> { + LweCiphertextFusingSubtractionError::perform_generic_checks(output, input)?; + unsafe { self.fuse_sub_lwe_ciphertext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_sub_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + ) { + output.0.update_with_sub(&input.0); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_plaintext_fusing_addition.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_plaintext_fusing_addition.rs new file mode 100644 index 000000000..261a71a14 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_plaintext_fusing_addition.rs @@ -0,0 +1,111 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, Plaintext32, Plaintext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextPlaintextFusingAdditionEngine, LweCiphertextPlaintextFusingAdditionError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextPlaintextFusingAdditionEngine`] for [`DefaultEngine`] that +/// operates on 32 bits integers. +impl LweCiphertextPlaintextFusingAdditionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input_1 = 3_u32 << 20; + /// let input_2 = 5_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// + /// engine.fuse_add_lwe_ciphertext_plaintext(&mut ciphertext, &plaintext_2)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_add_lwe_ciphertext_plaintext( + &mut self, + output: &mut LweCiphertext32, + input: &Plaintext32, + ) -> Result<(), LweCiphertextPlaintextFusingAdditionError> { + unsafe { self.fuse_add_lwe_ciphertext_plaintext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_add_lwe_ciphertext_plaintext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &Plaintext32, + ) { + output.0.get_mut_body().0 = output.0.get_body().0.wrapping_add(input.0 .0); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextPlaintextFusingAdditionEngine`] for [`DefaultEngine`] that +/// operates on 64 bits integers. +impl LweCiphertextPlaintextFusingAdditionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 40 bits) + /// let input_1 = 3_u64 << 40; + /// let input_2 = 5_u64 << 40; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext_1 = engine.create_plaintext_from(&input_1)?; + /// let plaintext_2 = engine.create_plaintext_from(&input_2)?; + /// let mut ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext_1, noise)?; + /// + /// engine.fuse_add_lwe_ciphertext_plaintext(&mut ciphertext, &plaintext_2)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn fuse_add_lwe_ciphertext_plaintext( + &mut self, + output: &mut LweCiphertext64, + input: &Plaintext64, + ) -> Result<(), LweCiphertextPlaintextFusingAdditionError> { + unsafe { self.fuse_add_lwe_ciphertext_plaintext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn fuse_add_lwe_ciphertext_plaintext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input: &Plaintext64, + ) { + output.0.get_mut_body().0 = output.0.get_body().0.wrapping_add(input.0 .0); + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_trivial_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_trivial_encryption.rs new file mode 100644 index 000000000..907963e45 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_trivial_encryption.rs @@ -0,0 +1,97 @@ +use crate::core_crypto::prelude::{ + DefaultEngine, LweCiphertext32, LweCiphertext64, LweSize, Plaintext32, Plaintext64, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextTrivialEncryptionEngine, LweCiphertextTrivialEncryptionError, +}; + +use crate::core_crypto::commons::crypto::lwe::LweCiphertext as ImplLweCiphertext; + +impl LweCiphertextTrivialEncryptionEngine for DefaultEngine { + /// # Example: + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// + /// use tfhe::core_crypto::prelude::{LweSize, Variance, *}; + /// + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_size = LweSize(10); + /// let input = 3_u32 << 20; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext32 = engine.create_plaintext_from(&input)?; + /// // DISCLAIMER: trivial encryption is NOT secure, and DOES NOT hide the message at all. + /// let ciphertext: LweCiphertext32 = + /// engine.trivially_encrypt_lwe_ciphertext(lwe_size, &plaintext)?; + /// + /// assert_eq!(ciphertext.lwe_dimension().to_lwe_size(), lwe_size); + /// + /// # Ok(()) + /// # } + /// ``` + fn trivially_encrypt_lwe_ciphertext( + &mut self, + lwe_size: LweSize, + input: &Plaintext32, + ) -> Result> { + unsafe { Ok(self.trivially_encrypt_lwe_ciphertext_unchecked(lwe_size, input)) } + } + + unsafe fn trivially_encrypt_lwe_ciphertext_unchecked( + &mut self, + lwe_size: LweSize, + input: &Plaintext32, + ) -> LweCiphertext32 { + let ciphertext = ImplLweCiphertext::new_trivial_encryption(lwe_size, &input.0); + LweCiphertext32(ciphertext) + } +} + +impl LweCiphertextTrivialEncryptionEngine for DefaultEngine { + /// # Example: + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweSize, Variance, *}; + /// + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_size = LweSize(10); + /// let input = 3_u64 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext64 = engine.create_plaintext_from(&input)?; + /// // DISCLAIMER: trivial encryption is NOT secure, and DOES NOT hide the message at all. + /// let ciphertext: LweCiphertext64 = + /// engine.trivially_encrypt_lwe_ciphertext(lwe_size, &plaintext)?; + /// + /// assert_eq!(ciphertext.lwe_dimension().to_lwe_size(), lwe_size); + /// + /// # Ok(()) + /// # } + /// ``` + fn trivially_encrypt_lwe_ciphertext( + &mut self, + lwe_size: LweSize, + input: &Plaintext64, + ) -> Result> { + unsafe { Ok(self.trivially_encrypt_lwe_ciphertext_unchecked(lwe_size, input)) } + } + + unsafe fn trivially_encrypt_lwe_ciphertext_unchecked( + &mut self, + lwe_size: LweSize, + input: &Plaintext64, + ) -> LweCiphertext64 { + let ciphertext = ImplLweCiphertext::new_trivial_encryption(lwe_size, &input.0); + LweCiphertext64(ciphertext) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_consuming_retrieval.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_consuming_retrieval.rs new file mode 100644 index 000000000..1693944f2 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_consuming_retrieval.rs @@ -0,0 +1,307 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertextVector64, LweCiphertextVectorMutView64, LweCiphertextVectorView64, +}; +use crate::core_crypto::commons::math::tensor::IntoTensor; +use crate::core_crypto::prelude::{ + LweCiphertextVector32, LweCiphertextVectorMutView32, LweCiphertextVectorView32, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorConsumingRetrievalEngine, LweCiphertextVectorConsumingRetrievalError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextVectorConsumingRetrievalEngine`] for [`DefaultEngine`] that +/// returns the underlying slice of a [`LweCiphertextVector32`] consuming it in the process +impl LweCiphertextVectorConsumingRetrievalEngine> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// let lwe_size = LweSize(128); + /// let lwe_count = LweCiphertextCount(8); + /// let mut owned_container = vec![0_u32; lwe_size.0 * lwe_count.0]; + /// let original_vec_ptr = owned_container.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector: LweCiphertextVector32 = + /// engine.create_lwe_ciphertext_vector_from(owned_container, lwe_size)?; + /// let retrieved_container = engine.consume_retrieve_lwe_ciphertext_vector(ciphertext_vector)?; + /// assert_eq!(original_vec_ptr, retrieved_container.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: LweCiphertextVector32, + ) -> Result, LweCiphertextVectorConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_vector_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: LweCiphertextVector32, + ) -> Vec { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorConsumingRetrievalEngine`] for [`DefaultEngine`] that +/// returns the underlying slice of a [`LweCiphertextVector64`] consuming it in the process +impl LweCiphertextVectorConsumingRetrievalEngine> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// let lwe_size = LweSize(128); + /// let lwe_count = LweCiphertextCount(8); + /// let mut owned_container = vec![0_u64; lwe_size.0 * lwe_count.0]; + /// let original_vec_ptr = owned_container.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector: LweCiphertextVector64 = + /// engine.create_lwe_ciphertext_vector_from(owned_container, lwe_size)?; + /// let retrieved_container = engine.consume_retrieve_lwe_ciphertext_vector(ciphertext_vector)?; + /// assert_eq!(original_vec_ptr, retrieved_container.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: LweCiphertextVector64, + ) -> Result, LweCiphertextVectorConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_vector_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: LweCiphertextVector64, + ) -> Vec { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorConsumingRetrievalEngine`] for [`DefaultEngine`] that +/// returns the underlying slice of a [`LweCiphertextVectorView32`] consuming it in the process +impl<'data> + LweCiphertextVectorConsumingRetrievalEngine, &'data [u32]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_ciphertext_count = LweCiphertextCount(8); + /// let mut owned_container = vec![0_u32; lwe_size.0 * lwe_ciphertext_count.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorView32 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext_vector(ciphertext_vector_view)?; + /// assert_eq!(slice, retrieved_slice); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: LweCiphertextVectorView32<'data>, + ) -> Result<&'data [u32], LweCiphertextVectorConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_vector_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: LweCiphertextVectorView32<'data>, + ) -> &'data [u32] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorConsumingRetrievalEngine`] for [`DefaultEngine`] that +/// returns the underlying slice of a [`LweCiphertextVectorView64`] consuming it in the process +impl<'data> + LweCiphertextVectorConsumingRetrievalEngine, &'data [u64]> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_ciphertext_count = LweCiphertextCount(8); + /// let mut owned_container = vec![0_u64; lwe_size.0 * lwe_ciphertext_count.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorView64 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext_vector(ciphertext_vector_view)?; + /// assert_eq!(slice, retrieved_slice); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: LweCiphertextVectorView64<'data>, + ) -> Result<&'data [u64], LweCiphertextVectorConsumingRetrievalError> { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_vector_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: LweCiphertextVectorView64<'data>, + ) -> &'data [u64] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorConsumingRetrievalEngine`] for [`DefaultEngine`] that +/// returns the underlying slice of a [`LweCiphertextVectorMutView32`] consuming it in the process +impl<'data> + LweCiphertextVectorConsumingRetrievalEngine< + LweCiphertextVectorMutView32<'data>, + &'data mut [u32], + > for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_ciphertext_count = LweCiphertextCount(8); + /// let mut owned_container = vec![0_u32; lwe_size.0 * lwe_ciphertext_count.0]; + /// + /// let slice = &mut owned_container[..]; + /// // Required as we can't borrow a mut slice more than once + /// let underlying_ptr = slice.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorMutView32 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext_vector(ciphertext_vector_view)?; + /// assert_eq!(underlying_ptr, retrieved_slice.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: LweCiphertextVectorMutView32<'data>, + ) -> Result<&'data mut [u32], LweCiphertextVectorConsumingRetrievalError> + { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_vector_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: LweCiphertextVectorMutView32<'data>, + ) -> &'data mut [u32] { + ciphertext.0.into_tensor().into_container() + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorConsumingRetrievalEngine`] for [`DefaultEngine`] that +/// returns the underlying slice of a [`LweCiphertextVectorMutView64`] consuming it in the process +impl<'data> + LweCiphertextVectorConsumingRetrievalEngine< + LweCiphertextVectorMutView64<'data>, + &'data mut [u64], + > for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_ciphertext_count = LweCiphertextCount(8); + /// let mut owned_container = vec![0_u64; lwe_size.0 * lwe_ciphertext_count.0]; + /// + /// let slice = &mut owned_container[..]; + /// // Required as we can't borrow a mut slice more than once + /// let underlying_ptr = slice.as_ptr(); + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorMutView64 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// let retrieved_slice = engine.consume_retrieve_lwe_ciphertext_vector(ciphertext_vector_view)?; + /// assert_eq!(underlying_ptr, retrieved_slice.as_ptr()); + /// # + /// # Ok(()) + /// # } + /// ``` + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: LweCiphertextVectorMutView64<'data>, + ) -> Result<&'data mut [u64], LweCiphertextVectorConsumingRetrievalError> + { + Ok(unsafe { self.consume_retrieve_lwe_ciphertext_vector_unchecked(ciphertext) }) + } + + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: LweCiphertextVectorMutView64<'data>, + ) -> &'data mut [u64] { + ciphertext.0.into_tensor().into_container() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_creation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_creation.rs new file mode 100644 index 000000000..c3dbbfc67 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_creation.rs @@ -0,0 +1,312 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertextVectorMutView64, LweCiphertextVectorView64, +}; +use crate::core_crypto::commons::crypto::lwe::LweList as ImplLweList; +use crate::core_crypto::prelude::{ + LweCiphertextVector32, LweCiphertextVector64, LweCiphertextVectorMutView32, + LweCiphertextVectorView32, LweSize, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorCreationEngine, LweCiphertextVectorCreationError, +}; + +/// # Description: +/// Implementation of [`LweCiphertextVectorCreationEngine`] for [`DefaultEngine`] which returns a +/// [`LweCiphertextVector32`]. +impl LweCiphertextVectorCreationEngine, LweCiphertextVector32> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_count = LweCiphertextCount(3); + /// let mut owned_container = vec![0_u32; lwe_size.0 * lwe_count.0]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector: LweCiphertextVector32 = + /// engine.create_lwe_ciphertext_vector_from(owned_container, lwe_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_vector_from( + &mut self, + container: Vec, + lwe_size: LweSize, + ) -> Result> { + LweCiphertextVectorCreationError::::perform_generic_checks( + container.len(), + )?; + Ok(unsafe { self.create_lwe_ciphertext_vector_from_unchecked(container, lwe_size) }) + } + + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: Vec, + lwe_size: LweSize, + ) -> LweCiphertextVector32 { + LweCiphertextVector32(ImplLweList::from_container(container, lwe_size)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorCreationEngine`] for [`DefaultEngine`] which returns a +/// [`LweCiphertextVector64`]. +impl LweCiphertextVectorCreationEngine, LweCiphertextVector64> for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_count = LweCiphertextCount(3); + /// let mut owned_container = vec![0_u64; lwe_size.0 * lwe_count.0]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector: LweCiphertextVector64 = + /// engine.create_lwe_ciphertext_vector_from(owned_container, lwe_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_vector_from( + &mut self, + container: Vec, + lwe_size: LweSize, + ) -> Result> { + LweCiphertextVectorCreationError::::perform_generic_checks( + container.len(), + )?; + Ok(unsafe { self.create_lwe_ciphertext_vector_from_unchecked(container, lwe_size) }) + } + + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: Vec, + lwe_size: LweSize, + ) -> LweCiphertextVector64 { + LweCiphertextVector64(ImplLweList::from_container(container, lwe_size)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorCreationEngine`] for [`DefaultEngine`] which returns an +/// immutable [`LweCiphertextVectorView32`] that does not own its memory. +impl<'data> LweCiphertextVectorCreationEngine<&'data [u32], LweCiphertextVectorView32<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_count = LweCiphertextCount(3); + /// let mut owned_container = vec![0_u32; lwe_size.0 * lwe_count.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorView32 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_vector_from( + &mut self, + container: &'data [u32], + lwe_size: LweSize, + ) -> Result, LweCiphertextVectorCreationError> + { + LweCiphertextVectorCreationError::::perform_generic_checks( + container.len(), + )?; + Ok(unsafe { self.create_lwe_ciphertext_vector_from_unchecked(container, lwe_size) }) + } + + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: &'data [u32], + lwe_size: LweSize, + ) -> LweCiphertextVectorView32<'data> { + LweCiphertextVectorView32(ImplLweList::from_container(container, lwe_size)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorCreationEngine`] for [`DefaultEngine`] which returns a +/// mutable [`LweCiphertextVectorMutView32`] that does not own its memory. +impl<'data> LweCiphertextVectorCreationEngine<&'data mut [u32], LweCiphertextVectorMutView32<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_count = LweCiphertextCount(3); + /// let mut owned_container = vec![0_u32; lwe_size.0 * lwe_count.0]; + /// + /// let slice = &mut owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorMutView32 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_vector_from( + &mut self, + container: &'data mut [u32], + lwe_size: LweSize, + ) -> Result< + LweCiphertextVectorMutView32<'data>, + LweCiphertextVectorCreationError, + > { + LweCiphertextVectorCreationError::::perform_generic_checks( + container.len(), + )?; + Ok(unsafe { self.create_lwe_ciphertext_vector_from_unchecked(container, lwe_size) }) + } + + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: &'data mut [u32], + lwe_size: LweSize, + ) -> LweCiphertextVectorMutView32<'data> { + LweCiphertextVectorMutView32(ImplLweList::from_container(container, lwe_size)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorCreationEngine`] for [`DefaultEngine`] which returns an +/// immutable [`LweCiphertextVectorView64`] that does not own its memory. +impl<'data> LweCiphertextVectorCreationEngine<&'data [u64], LweCiphertextVectorView64<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_count = LweCiphertextCount(3); + /// let mut owned_container = vec![0_u64; lwe_size.0 * lwe_count.0]; + /// + /// let slice = &owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorView64 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_vector_from( + &mut self, + container: &'data [u64], + lwe_size: LweSize, + ) -> Result, LweCiphertextVectorCreationError> + { + LweCiphertextVectorCreationError::::perform_generic_checks( + container.len(), + )?; + Ok(unsafe { self.create_lwe_ciphertext_vector_from_unchecked(container, lwe_size) }) + } + + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: &'data [u64], + lwe_size: LweSize, + ) -> LweCiphertextVectorView64<'data> { + LweCiphertextVectorView64(ImplLweList::from_container(container, lwe_size)) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorCreationEngine`] for [`DefaultEngine`] which returns a +/// mutable [`LweCiphertextVectorMutView64`] that does not own its memory. +impl<'data> LweCiphertextVectorCreationEngine<&'data mut [u64], LweCiphertextVectorMutView64<'data>> + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here we create a container outside of the engine + /// // Note that the size here is just for demonstration purposes and should not be chosen + /// // without proper security analysis for production + /// let lwe_size = LweSize(16); + /// let lwe_count = LweCiphertextCount(3); + /// let mut owned_container = vec![0_u64; lwe_size.0 * lwe_count.0]; + /// + /// let slice = &mut owned_container[..]; + /// + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let ciphertext_vector_view: LweCiphertextVectorMutView64 = + /// engine.create_lwe_ciphertext_vector_from(slice, lwe_size)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_lwe_ciphertext_vector_from( + &mut self, + container: &'data mut [u64], + lwe_size: LweSize, + ) -> Result< + LweCiphertextVectorMutView64<'data>, + LweCiphertextVectorCreationError, + > { + LweCiphertextVectorCreationError::::perform_generic_checks( + container.len(), + )?; + Ok(unsafe { self.create_lwe_ciphertext_vector_from_unchecked(container, lwe_size) }) + } + + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: &'data mut [u64], + lwe_size: LweSize, + ) -> LweCiphertextVectorMutView64<'data> { + LweCiphertextVectorMutView64(ImplLweList::from_container(container, lwe_size)) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_zero_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_zero_encryption.rs new file mode 100644 index 000000000..db98bbe4e --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_vector_zero_encryption.rs @@ -0,0 +1,144 @@ +use crate::core_crypto::prelude::{CiphertextCount, LweCiphertextCount, PlaintextCount, Variance}; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertextVector32, LweCiphertextVector64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::encoding::PlaintextList as ImplPlaintextList; +use crate::core_crypto::commons::crypto::lwe::LweList as ImplLweList; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorZeroEncryptionEngine, LweCiphertextVectorZeroEncryptionError, +}; +use crate::core_crypto::specification::entities::LweSecretKeyEntity; + +/// # Description: +/// Implementation of [`LweCiphertextVectorZeroEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 32 bits integers. +impl LweCiphertextVectorZeroEncryptionEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let ciphertext_count = LweCiphertextCount(3); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let ciphertext_vector = + /// engine.zero_encrypt_lwe_ciphertext_vector(&key, noise, ciphertext_count)?; + /// # + /// assert_eq!(ciphertext_vector.lwe_dimension(), lwe_dimension); + /// assert_eq!(ciphertext_vector.lwe_ciphertext_count(), ciphertext_count); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn zero_encrypt_lwe_ciphertext_vector( + &mut self, + key: &LweSecretKey32, + noise: Variance, + count: LweCiphertextCount, + ) -> Result> + { + LweCiphertextVectorZeroEncryptionError::perform_generic_checks(count)?; + Ok(unsafe { self.zero_encrypt_lwe_ciphertext_vector_unchecked(key, noise, count) }) + } + + unsafe fn zero_encrypt_lwe_ciphertext_vector_unchecked( + &mut self, + key: &LweSecretKey32, + noise: Variance, + count: LweCiphertextCount, + ) -> LweCiphertextVector32 { + let mut vector = ImplLweList::allocate( + 0u32, + key.lwe_dimension().to_lwe_size(), + CiphertextCount(count.0), + ); + let plaintexts = ImplPlaintextList::allocate(0u32, PlaintextCount(count.0)); + key.0.encrypt_lwe_list( + &mut vector, + &plaintexts, + noise, + &mut self.encryption_generator, + ); + LweCiphertextVector32(vector) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorZeroEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 64 bits integers. +impl LweCiphertextVectorZeroEncryptionEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let ciphertext_count = LweCiphertextCount(3); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let ciphertext_vector = + /// engine.zero_encrypt_lwe_ciphertext_vector(&key, noise, ciphertext_count)?; + /// # + /// assert_eq!(ciphertext_vector.lwe_dimension(), lwe_dimension); + /// assert_eq!(ciphertext_vector.lwe_ciphertext_count(), ciphertext_count); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn zero_encrypt_lwe_ciphertext_vector( + &mut self, + key: &LweSecretKey64, + noise: Variance, + count: LweCiphertextCount, + ) -> Result> + { + LweCiphertextVectorZeroEncryptionError::perform_generic_checks(count)?; + Ok(unsafe { self.zero_encrypt_lwe_ciphertext_vector_unchecked(key, noise, count) }) + } + + unsafe fn zero_encrypt_lwe_ciphertext_vector_unchecked( + &mut self, + key: &LweSecretKey64, + noise: Variance, + count: LweCiphertextCount, + ) -> LweCiphertextVector64 { + let mut vector = ImplLweList::allocate( + 0u64, + key.lwe_dimension().to_lwe_size(), + CiphertextCount(count.0), + ); + let plaintexts = ImplPlaintextList::allocate(0u64, PlaintextCount(count.0)); + key.0.encrypt_lwe_list( + &mut vector, + &plaintexts, + noise, + &mut self.encryption_generator, + ); + LweCiphertextVector64(vector) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_zero_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_zero_encryption.rs new file mode 100644 index 000000000..38f41b807 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_ciphertext_zero_encryption.rs @@ -0,0 +1,116 @@ +use crate::core_crypto::prelude::Variance; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertext32, LweCiphertext64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::encoding::Plaintext as ImplPlaintext; +use crate::core_crypto::commons::crypto::lwe::LweCiphertext as ImplLweCiphertext; +use crate::core_crypto::specification::engines::{ + LweCiphertextZeroEncryptionEngine, LweCiphertextZeroEncryptionError, +}; +use crate::core_crypto::specification::entities::LweSecretKeyEntity; + +/// # Description: +/// Implementation of [`LweCiphertextZeroEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 32 bits integers. +impl LweCiphertextZeroEncryptionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let ciphertext = engine.zero_encrypt_lwe_ciphertext(&key, noise)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn zero_encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey32, + noise: Variance, + ) -> Result> { + Ok(unsafe { self.zero_encrypt_lwe_ciphertext_unchecked(key, noise) }) + } + + unsafe fn zero_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey32, + noise: Variance, + ) -> LweCiphertext32 { + let mut ciphertext = ImplLweCiphertext::allocate(0u32, key.lwe_dimension().to_lwe_size()); + key.0.encrypt_lwe( + &mut ciphertext, + &ImplPlaintext(0u32), + noise, + &mut self.encryption_generator, + ); + LweCiphertext32(ciphertext) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextZeroEncryptionEngine`] for [`DefaultEngine`] that +/// operates on 64 bits integers. +impl LweCiphertextZeroEncryptionEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let ciphertext = engine.zero_encrypt_lwe_ciphertext(&key, noise)?; + /// # + /// assert_eq!(ciphertext.lwe_dimension(), lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn zero_encrypt_lwe_ciphertext( + &mut self, + key: &LweSecretKey64, + noise: Variance, + ) -> Result> { + Ok(unsafe { self.zero_encrypt_lwe_ciphertext_unchecked(key, noise) }) + } + + unsafe fn zero_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &LweSecretKey64, + noise: Variance, + ) -> LweCiphertext64 { + let mut ciphertext = ImplLweCiphertext::allocate(0u64, key.lwe_dimension().to_lwe_size()); + key.0.encrypt_lwe( + &mut ciphertext, + &ImplPlaintext(0u64), + noise, + &mut self.encryption_generator, + ); + LweCiphertext64(ciphertext) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs new file mode 100644 index 000000000..f54542497 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs @@ -0,0 +1,240 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweSecretKey32, GlweSecretKey64, LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList as ImplLwePrivateFunctionalPackingKeyswitchKeyList; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, Variance, +}; +use crate::core_crypto::specification::engines::{ + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, +}; +use crate::core_crypto::specification::entities::{GlweSecretKeyEntity, LweSecretKeyEntity}; + +/// # Description: +/// Implementation of [`LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine`] +/// for [`DefaultEngine`] that operates on 32 bits integers. +impl + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine< + LweSecretKey32, + GlweSecretKey32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + > for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, GlweDimension,FunctionalPackingKeyswitchKeyCount + /// }; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey32 = engine.generate_new_glwe_secret_key(output_glwe_dimension, + /// polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 = + /// engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// # + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_level_count(), + /// # decomposition_level_count + /// # ); + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_base_log(), + /// # decomposition_base_log + /// # ); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.input_lwe_dimension(), + /// input_lwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.output_glwe_dimension(), + /// output_glwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.key_count().0, + /// output_glwe_dimension.to_glwe_size().0); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &mut self, + input_lwe_key: &LweSecretKey32, + output_glwe_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result< + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, + > { + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError::perform_generic_checks( + decomposition_level_count, decomposition_base_log, 32)?; + Ok(unsafe { + self.generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked(input_lwe_key, output_glwe_key, decomposition_base_log, decomposition_level_count, noise) + }) + } + + unsafe fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked( + &mut self, + input_lwe_key: &LweSecretKey32, + output_glwe_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + let mut fpksk_list = ImplLwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u32, + decomposition_level_count, + decomposition_base_log, + input_lwe_key.lwe_dimension(), + output_glwe_key.glwe_dimension(), + output_glwe_key.polynomial_size(), + FunctionalPackingKeyswitchKeyCount(output_glwe_key.glwe_dimension().to_glwe_size().0), + ); + + fpksk_list.fill_with_fpksk_for_circuit_bootstrap( + &input_lwe_key.0, + &output_glwe_key.0, + noise, + &mut self.encryption_generator, + ); + + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32(fpksk_list) + } +} + +/// # Description: +/// Implementation of [`LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine`] +/// for [`DefaultEngine`] that operates on 64 bits integers. +impl + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine< + LweSecretKey64, + GlweSecretKey64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + > for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, GlweDimension,FunctionalPackingKeyswitchKeyCount + /// }; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey64 = engine.generate_new_glwe_secret_key(output_glwe_dimension, + /// polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 = + /// engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// # + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_level_count(), + /// # decomposition_level_count + /// # ); + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_base_log(), + /// # decomposition_base_log + /// # ); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.input_lwe_dimension(), + /// input_lwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.output_glwe_dimension(), + /// output_glwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.key_count().0, + /// output_glwe_dimension.to_glwe_size().0); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &mut self, + input_lwe_key: &LweSecretKey64, + output_glwe_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result< + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, + > { + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError::perform_generic_checks( + decomposition_level_count, decomposition_base_log, 64)?; + Ok(unsafe { + self.generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked(input_lwe_key, output_glwe_key, decomposition_base_log, decomposition_level_count, noise) + }) + } + + unsafe fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked( + &mut self, + input_lwe_key: &LweSecretKey64, + output_glwe_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + let mut fpksk_list = ImplLwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u64, + decomposition_level_count, + decomposition_base_log, + input_lwe_key.lwe_dimension(), + output_glwe_key.glwe_dimension(), + output_glwe_key.polynomial_size(), + FunctionalPackingKeyswitchKeyCount(output_glwe_key.glwe_dimension().to_glwe_size().0), + ); + + fpksk_list.fill_with_fpksk_for_circuit_bootstrap( + &input_lwe_key.0, + &output_glwe_key.0, + noise, + &mut self.encryption_generator, + ); + + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64(fpksk_list) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_keyswitch_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_keyswitch_key_generation.rs new file mode 100644 index 000000000..250a870d1 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_keyswitch_key_generation.rs @@ -0,0 +1,215 @@ +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, Variance}; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweKeyswitchKey32, LweKeyswitchKey64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::lwe::LweKeyswitchKey as ImplLweKeyswitchKey; +use crate::core_crypto::specification::engines::{ + LweKeyswitchKeyGenerationEngine, LweKeyswitchKeyGenerationError, +}; +use crate::core_crypto::specification::entities::LweSecretKeyEntity; + +/// # Description: +/// Implementation of [`LweKeyswitchKeyGenerationEngine`] for [`DefaultEngine`] that +/// operates on 32 bits integers. +impl LweKeyswitchKeyGenerationEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, + /// }; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// # + /// assert_eq!( + /// # keyswitch_key.decomposition_level_count(), + /// # decomposition_level_count + /// # ); + /// assert_eq!( + /// # keyswitch_key.decomposition_base_log(), + /// # decomposition_base_log + /// # ); + /// assert_eq!(keyswitch_key.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!(keyswitch_key.output_lwe_dimension(), output_lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_keyswitch_key( + &mut self, + input_key: &LweSecretKey32, + output_key: &LweSecretKey32, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + noise: Variance, + ) -> Result> { + LweKeyswitchKeyGenerationError::perform_generic_checks( + decomposition_level_count, + decomposition_base_log, + 32, + )?; + Ok(unsafe { + self.generate_new_lwe_keyswitch_key_unchecked( + input_key, + output_key, + decomposition_level_count, + decomposition_base_log, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_keyswitch_key_unchecked( + &mut self, + input_key: &LweSecretKey32, + output_key: &LweSecretKey32, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + noise: Variance, + ) -> LweKeyswitchKey32 { + let mut ksk = ImplLweKeyswitchKey::allocate( + 0, + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + output_key.lwe_dimension(), + ); + ksk.fill_with_keyswitch_key( + &input_key.0, + &output_key.0, + noise, + &mut self.encryption_generator, + ); + LweKeyswitchKey32(ksk) + } +} + +/// # Description: +/// Implementation of [`LweKeyswitchKeyGenerationEngine`] for [`DefaultEngine`] that +/// operates on 64 bits integers. +impl LweKeyswitchKeyGenerationEngine + for DefaultEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, + /// }; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// # + /// assert_eq!( + /// # keyswitch_key.decomposition_level_count(), + /// # decomposition_level_count + /// # ); + /// assert_eq!( + /// # keyswitch_key.decomposition_base_log(), + /// # decomposition_base_log + /// # ); + /// assert_eq!(keyswitch_key.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!(keyswitch_key.output_lwe_dimension(), output_lwe_dimension); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_keyswitch_key( + &mut self, + input_key: &LweSecretKey64, + output_key: &LweSecretKey64, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + noise: Variance, + ) -> Result> { + LweKeyswitchKeyGenerationError::perform_generic_checks( + decomposition_level_count, + decomposition_base_log, + 64, + )?; + Ok(unsafe { + self.generate_new_lwe_keyswitch_key_unchecked( + input_key, + output_key, + decomposition_level_count, + decomposition_base_log, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_keyswitch_key_unchecked( + &mut self, + input_key: &LweSecretKey64, + output_key: &LweSecretKey64, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + noise: Variance, + ) -> LweKeyswitchKey64 { + let mut ksk = ImplLweKeyswitchKey::allocate( + 0, + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + output_key.lwe_dimension(), + ); + ksk.fill_with_keyswitch_key( + &input_key.0, + &output_key.0, + noise, + &mut self.encryption_generator, + ); + LweKeyswitchKey64(ksk) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_public_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_public_key_generation.rs new file mode 100644 index 000000000..e4e280bf9 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_public_key_generation.rs @@ -0,0 +1,145 @@ +use crate::core_crypto::backends::default::engines::DefaultEngine; +use crate::core_crypto::backends::default::entities::{ + LwePublicKey32, LwePublicKey64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::prelude::{LweCiphertextCount, LwePublicKeyZeroEncryptionCount, Variance}; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorZeroEncryptionEngine, LwePublicKeyGenerationEngine, + LwePublicKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`LwePublicKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl LwePublicKeyGenerationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey32 = engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// assert_eq!(public_key.lwe_dimension(), lwe_dimension); + /// assert_eq!( + /// public_key.lwe_zero_encryption_count(), + /// lwe_public_key_zero_encryption_count + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_public_key( + &mut self, + lwe_secret_key: &LweSecretKey32, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> Result> { + LwePublicKeyGenerationError::perform_generic_checks(lwe_public_key_zero_encryption_count)?; + Ok(unsafe { + self.generate_new_lwe_public_key_unchecked( + lwe_secret_key, + noise, + lwe_public_key_zero_encryption_count, + ) + }) + } + + unsafe fn generate_new_lwe_public_key_unchecked( + &mut self, + lwe_secret_key: &LweSecretKey32, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> LwePublicKey32 { + let encrypted_zeros = self.zero_encrypt_lwe_ciphertext_vector_unchecked( + lwe_secret_key, + noise, + LweCiphertextCount(lwe_public_key_zero_encryption_count.0), + ); + LwePublicKey32(encrypted_zeros.0) + } +} + +/// # Description: +/// Implementation of [`LwePublicKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl LwePublicKeyGenerationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey64 = engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// assert_eq!(public_key.lwe_dimension(), lwe_dimension); + /// assert_eq!( + /// public_key.lwe_zero_encryption_count(), + /// lwe_public_key_zero_encryption_count + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_public_key( + &mut self, + lwe_secret_key: &LweSecretKey64, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> Result> { + LwePublicKeyGenerationError::perform_generic_checks(lwe_public_key_zero_encryption_count)?; + Ok(unsafe { + self.generate_new_lwe_public_key_unchecked( + lwe_secret_key, + noise, + lwe_public_key_zero_encryption_count, + ) + }) + } + + unsafe fn generate_new_lwe_public_key_unchecked( + &mut self, + lwe_secret_key: &LweSecretKey64, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> LwePublicKey64 { + let encrypted_zeros = self.zero_encrypt_lwe_ciphertext_vector_unchecked( + lwe_secret_key, + noise, + LweCiphertextCount(lwe_public_key_zero_encryption_count.0), + ); + LwePublicKey64(encrypted_zeros.0) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_secret_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_secret_key_generation.rs new file mode 100644 index 000000000..49cd6017e --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/lwe_secret_key_generation.rs @@ -0,0 +1,96 @@ +use crate::core_crypto::prelude::LweDimension; + +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::secret::LweSecretKey as ImplLweSecretKey; +use crate::core_crypto::specification::engines::{ + LweSecretKeyGenerationEngine, LweSecretKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`LweSecretKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl LweSecretKeyGenerationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// # + /// assert_eq!(lwe_secret_key.lwe_dimension(), lwe_dimension); + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_secret_key( + &mut self, + lwe_dimension: LweDimension, + ) -> Result> { + LweSecretKeyGenerationError::perform_generic_checks(lwe_dimension)?; + Ok(unsafe { self.generate_new_lwe_secret_key_unchecked(lwe_dimension) }) + } + + unsafe fn generate_new_lwe_secret_key_unchecked( + &mut self, + lwe_dimension: LweDimension, + ) -> LweSecretKey32 { + LweSecretKey32(ImplLweSecretKey::generate_binary( + lwe_dimension, + &mut self.secret_generator, + )) + } +} + +/// # Description: +/// Implementation of [`LweSecretKeyGenerationEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl LweSecretKeyGenerationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// # + /// assert_eq!(lwe_secret_key.lwe_dimension(), lwe_dimension); + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_secret_key( + &mut self, + lwe_dimension: LweDimension, + ) -> Result> { + LweSecretKeyGenerationError::perform_generic_checks(lwe_dimension)?; + Ok(unsafe { self.generate_new_lwe_secret_key_unchecked(lwe_dimension) }) + } + + unsafe fn generate_new_lwe_secret_key_unchecked( + &mut self, + lwe_dimension: LweDimension, + ) -> LweSecretKey64 { + LweSecretKey64(ImplLweSecretKey::generate_binary( + lwe_dimension, + &mut self.secret_generator, + )) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/mod.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/mod.rs new file mode 100644 index 000000000..46ed80e9f --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/mod.rs @@ -0,0 +1,101 @@ +use super::ActivatedRandomGenerator; +use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder as ImplDeterministicSeeder, + EncryptionRandomGenerator as ImplEncryptionRandomGenerator, + SecretRandomGenerator as ImplSecretRandomGenerator, +}; +use crate::core_crypto::specification::engines::sealed::AbstractEngineSeal; +use crate::core_crypto::specification::engines::AbstractEngine; +use concrete_csprng::seeders::Seeder; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +/// The error which can occur in the execution of FHE operations, due to the default implementation. +/// +/// # Note: +/// +/// There is currently no such case, as the default implementation is not expected to undergo some +/// major issues unrelated to FHE. +#[derive(Debug)] +pub enum DefaultError {} + +impl Display for DefaultError { + fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { + Ok(()) + } +} + +impl Error for DefaultError {} + +pub struct DefaultEngine { + /// A structure containing a single CSPRNG to generate secret key coefficients. + secret_generator: ImplSecretRandomGenerator, + /// A structure containing two CSPRNGs to generate material for encryption like public masks + /// and secret errors. + /// + /// The [`ImplEncryptionRandomGenerator`] contains two CSPRNGs, one publicly seeded used to + /// generate mask coefficients and one privately seeded used to generate errors during + /// encryption. + encryption_generator: ImplEncryptionRandomGenerator, + // /// A seeder that can be called to generate 128 bits seeds, useful to create new + // /// [`ImplEncryptionRandomGenerator`] to encrypt seeded types. + // seeder: ImplDeterministicSeeder, +} +impl AbstractEngineSeal for DefaultEngine {} + +impl AbstractEngine for DefaultEngine { + type EngineError = DefaultError; + + type Parameters = Box; + + fn new(mut parameters: Self::Parameters) -> Result { + let mut deterministic_seeder = + ImplDeterministicSeeder::::new(parameters.seed()); + + // Note that the operands are evaluated from left to right for Rust Struct expressions + // See: https://doc.rust-lang.org/stable/reference/expressions.html?highlight=left#evaluation-order-of-operands + // So parameters is moved in seeder after the calls to seed and the potential calls when it + // is passed as_mut in ImplEncryptionRandomGenerator::new + Ok(DefaultEngine { + secret_generator: ImplSecretRandomGenerator::new(deterministic_seeder.seed()), + encryption_generator: ImplEncryptionRandomGenerator::new( + deterministic_seeder.seed(), + &mut deterministic_seeder, + ), + // seeder: deterministic_seeder, + }) + } +} + +mod cleartext_creation; +mod glwe_ciphertext_consuming_retrieval; +mod glwe_ciphertext_creation; +mod glwe_ciphertext_trivial_encryption; +mod glwe_secret_key_generation; +mod glwe_to_lwe_secret_key_transformation; +mod lwe_bootstrap_key_generation; +mod lwe_ciphertext_cleartext_fusing_multiplication; +mod lwe_ciphertext_consuming_retrieval; +mod lwe_ciphertext_creation; +mod lwe_ciphertext_decryption; +mod lwe_ciphertext_discarding_addition; +mod lwe_ciphertext_discarding_encryption; +mod lwe_ciphertext_discarding_keyswitch; +mod lwe_ciphertext_discarding_public_key_encryption; +mod lwe_ciphertext_encryption; +mod lwe_ciphertext_fusing_addition; +mod lwe_ciphertext_fusing_opposite; +mod lwe_ciphertext_fusing_subtraction; +mod lwe_ciphertext_plaintext_fusing_addition; +mod lwe_ciphertext_trivial_encryption; +mod lwe_ciphertext_vector_consuming_retrieval; +mod lwe_ciphertext_vector_creation; +mod lwe_ciphertext_vector_zero_encryption; +mod lwe_ciphertext_zero_encryption; +mod lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation; +mod lwe_keyswitch_key_generation; +mod lwe_public_key_generation; +mod lwe_secret_key_generation; +mod plaintext_creation; +mod plaintext_discarding_retrieval; +mod plaintext_vector_creation; diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_creation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_creation.rs new file mode 100644 index 000000000..b6cb627c2 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_creation.rs @@ -0,0 +1,72 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{Plaintext32, Plaintext64}; +use crate::core_crypto::commons::crypto::encoding::Plaintext as ImplPlaintext; +use crate::core_crypto::specification::engines::{PlaintextCreationEngine, PlaintextCreationError}; + +/// # Description: +/// Implementation of [`PlaintextCreationEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl PlaintextCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext32 = engine.create_plaintext_from(&input)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_plaintext_from( + &mut self, + input: &u32, + ) -> Result> { + Ok(unsafe { self.create_plaintext_from_unchecked(input) }) + } + + unsafe fn create_plaintext_from_unchecked(&mut self, input: &u32) -> Plaintext32 { + Plaintext32(ImplPlaintext(*input)) + } +} + +/// # Description: +/// Implementation of [`PlaintextCreationEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl PlaintextCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext64 = engine.create_plaintext_from(&input)?; + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_plaintext_from( + &mut self, + input: &u64, + ) -> Result> { + Ok(unsafe { self.create_plaintext_from_unchecked(input) }) + } + + unsafe fn create_plaintext_from_unchecked(&mut self, input: &u64) -> Plaintext64 { + Plaintext64(ImplPlaintext(*input)) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_discarding_retrieval.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_discarding_retrieval.rs new file mode 100644 index 000000000..96230bdff --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_discarding_retrieval.rs @@ -0,0 +1,90 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{Plaintext32, Plaintext64}; +use crate::core_crypto::specification::engines::{ + PlaintextDiscardingRetrievalEngine, PlaintextDiscardingRetrievalError, +}; + +/// # Description: +/// Implementation of [`PlaintextDiscardingRetrievalEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl PlaintextDiscardingRetrievalEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let mut output = 0_u32; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext32 = engine.create_plaintext_from(&input)?; + /// engine.discard_retrieve_plaintext(&mut output, &plaintext)?; + /// + /// assert_eq!(output, 3_u32 << 20); + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_retrieve_plaintext( + &mut self, + output: &mut u32, + input: &Plaintext32, + ) -> Result<(), PlaintextDiscardingRetrievalError> { + unsafe { self.discard_retrieve_plaintext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn discard_retrieve_plaintext_unchecked( + &mut self, + output: &mut u32, + input: &Plaintext32, + ) { + *output = input.0 .0; + } +} + +impl PlaintextDiscardingRetrievalEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u64 << 20; + /// let mut output = 0_u64; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext64 = engine.create_plaintext_from(&input)?; + /// engine.discard_retrieve_plaintext(&mut output, &plaintext)?; + /// + /// assert_eq!(output, 3_u64 << 20); + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_retrieve_plaintext( + &mut self, + output: &mut u64, + input: &Plaintext64, + ) -> Result<(), PlaintextDiscardingRetrievalError> { + unsafe { self.discard_retrieve_plaintext_unchecked(output, input) }; + Ok(()) + } + + unsafe fn discard_retrieve_plaintext_unchecked( + &mut self, + output: &mut u64, + input: &Plaintext64, + ) { + *output = input.0 .0; + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_vector_creation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_vector_creation.rs new file mode 100644 index 000000000..1d414eccd --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_engine/plaintext_vector_creation.rs @@ -0,0 +1,92 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + PlaintextVector32, PlaintextVector64, +}; +use crate::core_crypto::commons::crypto::encoding::PlaintextList as ImplPlaintextList; +use crate::core_crypto::specification::engines::{ + PlaintextVectorCreationEngine, PlaintextVectorCreationError, +}; + +/// # Description: +/// Implementation of [`PlaintextVectorCreationEngine`] for [`DefaultEngine`] that operates on +/// 32 bits integers. +impl PlaintextVectorCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PlaintextCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u32 << 20; 3]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector32 = engine.create_plaintext_vector_from(&input)?; + /// # + /// assert_eq!(plaintext_vector.plaintext_count(), PlaintextCount(3)); + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_plaintext_vector_from( + &mut self, + input: &[u32], + ) -> Result> { + if input.is_empty() { + return Err(PlaintextVectorCreationError::EmptyInput); + } + Ok(unsafe { self.create_plaintext_vector_from_unchecked(input) }) + } + + unsafe fn create_plaintext_vector_from_unchecked( + &mut self, + input: &[u32], + ) -> PlaintextVector32 { + PlaintextVector32(ImplPlaintextList::from_container(input.to_vec())) + } +} + +/// # Description: +/// Implementation of [`PlaintextVectorCreationEngine`] for [`DefaultEngine`] that operates on +/// 64 bits integers. +impl PlaintextVectorCreationEngine for DefaultEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PlaintextCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = vec![3_u64 << 50; 3]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector64 = engine.create_plaintext_vector_from(&input)?; + /// # + /// assert_eq!(plaintext_vector.plaintext_count(), PlaintextCount(3)); + /// # + /// # Ok(()) + /// # } + /// ``` + fn create_plaintext_vector_from( + &mut self, + input: &[u64], + ) -> Result> { + if input.is_empty() { + return Err(PlaintextVectorCreationError::EmptyInput); + } + Ok(unsafe { self.create_plaintext_vector_from_unchecked(input) }) + } + + unsafe fn create_plaintext_vector_from_unchecked( + &mut self, + input: &[u64], + ) -> PlaintextVector64 { + PlaintextVector64(ImplPlaintextList::from_container(input.to_vec())) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_bootstrap_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_bootstrap_key_generation.rs new file mode 100644 index 000000000..7b5dd5d1d --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_bootstrap_key_generation.rs @@ -0,0 +1,198 @@ +use crate::core_crypto::backends::default::implementation::entities::{ + GlweSecretKey32, GlweSecretKey64, LweBootstrapKey32, LweBootstrapKey64, LweSecretKey32, + LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey as ImplStandardBootstrapKey; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, DefaultParallelEngine, GlweSecretKeyEntity, + LweSecretKeyEntity, Variance, +}; +use crate::core_crypto::specification::engines::{ + LweBootstrapKeyGenerationEngine, LweBootstrapKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`LweBootstrapKeyGenerationEngine`] for [`DefaultParallelEngine`] that +/// operates on 32 bits integers. It outputs a bootstrap key in the standard domain. +impl LweBootstrapKeyGenerationEngine + for DefaultParallelEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey32 = default_parallel_engine + /// .generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// # + /// assert_eq!(bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(bsk.polynomial_size(), poly_size); + /// assert_eq!(bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_bootstrap_key( + &mut self, + input_key: &LweSecretKey32, + output_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result> { + LweBootstrapKeyGenerationError::perform_generic_checks( + decomposition_base_log, + decomposition_level_count, + 32, + )?; + Ok(unsafe { + self.generate_new_lwe_bootstrap_key_unchecked( + input_key, + output_key, + decomposition_base_log, + decomposition_level_count, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey32, + output_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweBootstrapKey32 { + let mut key = ImplStandardBootstrapKey::allocate( + 0, + output_key.glwe_dimension().to_glwe_size(), + output_key.polynomial_size(), + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + ); + key.par_fill_with_new_key( + &input_key.0, + &output_key.0, + noise, + &mut self.encryption_generator, + ); + LweBootstrapKey32(key) + } +} + +/// # Description: +/// Implementation of [`LweBootstrapKeyGenerationEngine`] for [`DefaultParallelEngine`] that +/// operates on 64 bits integers. It outputs a bootstrap key in the standard domain. +impl LweBootstrapKeyGenerationEngine + for DefaultParallelEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey64 = default_parallel_engine + /// .generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// # + /// assert_eq!(bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(bsk.polynomial_size(), poly_size); + /// assert_eq!(bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_bootstrap_key( + &mut self, + input_key: &LweSecretKey64, + output_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result> { + LweBootstrapKeyGenerationError::perform_generic_checks( + decomposition_base_log, + decomposition_level_count, + 64, + )?; + Ok(unsafe { + self.generate_new_lwe_bootstrap_key_unchecked( + input_key, + output_key, + decomposition_base_log, + decomposition_level_count, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey64, + output_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweBootstrapKey64 { + let mut key = ImplStandardBootstrapKey::allocate( + 0, + output_key.glwe_dimension().to_glwe_size(), + output_key.polynomial_size(), + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + ); + key.par_fill_with_new_key( + &input_key.0, + &output_key.0, + noise, + &mut self.encryption_generator, + ); + LweBootstrapKey64(key) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_ciphertext_vector_zero_encryption.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_ciphertext_vector_zero_encryption.rs new file mode 100644 index 000000000..33524f536 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_ciphertext_vector_zero_encryption.rs @@ -0,0 +1,147 @@ +use crate::core_crypto::prelude::{ + CiphertextCount, DefaultParallelEngine, LweCiphertextCount, PlaintextCount, Variance, +}; + +use crate::core_crypto::backends::default::implementation::entities::{ + LweCiphertextVector32, LweCiphertextVector64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::encoding::PlaintextList as ImplPlaintextList; +use crate::core_crypto::commons::crypto::lwe::LweList as ImplLweList; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorZeroEncryptionEngine, LweCiphertextVectorZeroEncryptionError, +}; +use crate::core_crypto::specification::entities::LweSecretKeyEntity; + +/// # Description: +/// Implementation of [`LweCiphertextVectorZeroEncryptionEngine`] for [`DefaultParallelEngine`] that +/// operates on 32 bits integers. +impl LweCiphertextVectorZeroEncryptionEngine + for DefaultParallelEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let ciphertext_count = LweCiphertextCount(3); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut par_engine = DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let ciphertext_vector = + /// par_engine.zero_encrypt_lwe_ciphertext_vector(&key, noise, ciphertext_count)?; + /// # + /// assert_eq!(ciphertext_vector.lwe_dimension(), lwe_dimension); + /// assert_eq!(ciphertext_vector.lwe_ciphertext_count(), ciphertext_count); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn zero_encrypt_lwe_ciphertext_vector( + &mut self, + key: &LweSecretKey32, + noise: Variance, + count: LweCiphertextCount, + ) -> Result> + { + LweCiphertextVectorZeroEncryptionError::perform_generic_checks(count)?; + Ok(unsafe { self.zero_encrypt_lwe_ciphertext_vector_unchecked(key, noise, count) }) + } + + unsafe fn zero_encrypt_lwe_ciphertext_vector_unchecked( + &mut self, + key: &LweSecretKey32, + noise: Variance, + count: LweCiphertextCount, + ) -> LweCiphertextVector32 { + let mut vector = ImplLweList::allocate( + 0u32, + key.lwe_dimension().to_lwe_size(), + CiphertextCount(count.0), + ); + let plaintexts = ImplPlaintextList::allocate(0u32, PlaintextCount(count.0)); + key.0.par_encrypt_lwe_list( + &mut vector, + &plaintexts, + noise, + &mut self.encryption_generator, + ); + LweCiphertextVector32(vector) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextVectorZeroEncryptionEngine`] for [`DefaultParallelEngine`] that +/// operates on 64 bits integers. +impl LweCiphertextVectorZeroEncryptionEngine + for DefaultParallelEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweCiphertextCount, LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// let ciphertext_count = LweCiphertextCount(3); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut par_engine = DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let ciphertext_vector = + /// par_engine.zero_encrypt_lwe_ciphertext_vector(&key, noise, ciphertext_count)?; + /// # + /// assert_eq!(ciphertext_vector.lwe_dimension(), lwe_dimension); + /// assert_eq!(ciphertext_vector.lwe_ciphertext_count(), ciphertext_count); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn zero_encrypt_lwe_ciphertext_vector( + &mut self, + key: &LweSecretKey64, + noise: Variance, + count: LweCiphertextCount, + ) -> Result> + { + LweCiphertextVectorZeroEncryptionError::perform_generic_checks(count)?; + Ok(unsafe { self.zero_encrypt_lwe_ciphertext_vector_unchecked(key, noise, count) }) + } + + unsafe fn zero_encrypt_lwe_ciphertext_vector_unchecked( + &mut self, + key: &LweSecretKey64, + noise: Variance, + count: LweCiphertextCount, + ) -> LweCiphertextVector64 { + let mut vector = ImplLweList::allocate( + 0u64, + key.lwe_dimension().to_lwe_size(), + CiphertextCount(count.0), + ); + let plaintexts = ImplPlaintextList::allocate(0u64, PlaintextCount(count.0)); + key.0.par_encrypt_lwe_list( + &mut vector, + &plaintexts, + noise, + &mut self.encryption_generator, + ); + LweCiphertextVector64(vector) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs new file mode 100644 index 000000000..51c1d13b2 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs @@ -0,0 +1,244 @@ +use crate::core_crypto::backends::default::implementation::engines::DefaultParallelEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweSecretKey32, GlweSecretKey64, LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList as ImplLwePrivateFunctionalPackingKeyswitchKeyList; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, Variance, +}; +use crate::core_crypto::specification::engines::{ + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, +}; +use crate::core_crypto::specification::entities::{GlweSecretKeyEntity, LweSecretKeyEntity}; + +/// # Description: +/// Implementation of [`LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine`] +/// for [`DefaultParallelEngine`] that operates on 32 bits integers. +impl + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine< + LweSecretKey32, + GlweSecretKey32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + > for DefaultParallelEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, GlweDimension,FunctionalPackingKeyswitchKeyCount + /// }; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey32 = default_engine.generate_new_glwe_secret_key(output_glwe_dimension, + /// polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 = + /// default_engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// # + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_level_count(), + /// # decomposition_level_count + /// # ); + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_base_log(), + /// # decomposition_base_log + /// # ); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.input_lwe_dimension(), + /// input_lwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.output_glwe_dimension(), + /// output_glwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.key_count().0, + /// output_glwe_dimension.to_glwe_size().0); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &mut self, + input_lwe_key: &LweSecretKey32, + output_glwe_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result< + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, + > { + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError::perform_generic_checks( + decomposition_level_count, decomposition_base_log, 32)?; + Ok(unsafe { + self.generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked(input_lwe_key, output_glwe_key, decomposition_base_log, decomposition_level_count, noise) + }) + } + + unsafe fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked( + &mut self, + input_lwe_key: &LweSecretKey32, + output_glwe_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + let mut fpksk_list = ImplLwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u32, + decomposition_level_count, + decomposition_base_log, + input_lwe_key.lwe_dimension(), + output_glwe_key.glwe_dimension(), + output_glwe_key.polynomial_size(), + FunctionalPackingKeyswitchKeyCount(output_glwe_key.glwe_dimension().to_glwe_size().0), + ); + + fpksk_list.par_fill_with_fpksk_for_circuit_bootstrap( + &input_lwe_key.0, + &output_glwe_key.0, + noise, + &mut self.encryption_generator, + ); + + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32(fpksk_list) + } +} + +/// # Description: +/// Implementation of [`LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine`] +/// for [`DefaultParallelEngine`] that operates on 64 bits integers. +impl + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine< + LweSecretKey64, + GlweSecretKey64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + > for DefaultParallelEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, GlweDimension,FunctionalPackingKeyswitchKeyCount + /// }; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey64 = default_engine.generate_new_glwe_secret_key(output_glwe_dimension, + /// polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 = + /// default_engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// # + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_level_count(), + /// # decomposition_level_count + /// # ); + /// assert_eq!( + /// # cbs_private_functional_packing_keyswitch_key.decomposition_base_log(), + /// # decomposition_base_log + /// # ); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.input_lwe_dimension(), + /// input_lwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.output_glwe_dimension(), + /// output_glwe_dimension); + /// assert_eq!(cbs_private_functional_packing_keyswitch_key.key_count().0, + /// output_glwe_dimension.to_glwe_size().0); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &mut self, + input_lwe_key: &LweSecretKey64, + output_glwe_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result< + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, + > { + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError::perform_generic_checks( + decomposition_level_count, decomposition_base_log, 64)?; + Ok(unsafe { + self.generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked(input_lwe_key, output_glwe_key, decomposition_base_log, decomposition_level_count, noise) + }) + } + + unsafe fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked( + &mut self, + input_lwe_key: &LweSecretKey64, + output_glwe_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + let mut fpksk_list = ImplLwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u64, + decomposition_level_count, + decomposition_base_log, + input_lwe_key.lwe_dimension(), + output_glwe_key.glwe_dimension(), + output_glwe_key.polynomial_size(), + FunctionalPackingKeyswitchKeyCount(output_glwe_key.glwe_dimension().to_glwe_size().0), + ); + + fpksk_list.par_fill_with_fpksk_for_circuit_bootstrap( + &input_lwe_key.0, + &output_glwe_key.0, + noise, + &mut self.encryption_generator, + ); + + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64(fpksk_list) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_public_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_public_key_generation.rs new file mode 100644 index 000000000..62aeb151a --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_public_key_generation.rs @@ -0,0 +1,148 @@ +use crate::core_crypto::backends::default::entities::{ + LwePublicKey32, LwePublicKey64, LweSecretKey32, LweSecretKey64, +}; +use crate::core_crypto::prelude::{ + DefaultParallelEngine, LweCiphertextCount, LwePublicKeyZeroEncryptionCount, Variance, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorZeroEncryptionEngine, LwePublicKeyGenerationEngine, + LwePublicKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`LwePublicKeyGenerationEngine`] for [`DefaultParallelEngine`] that operates +/// on 32 bits integers. +impl LwePublicKeyGenerationEngine for DefaultParallelEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut par_engine = DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey32 = par_engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// assert_eq!(public_key.lwe_dimension(), lwe_dimension); + /// assert_eq!( + /// public_key.lwe_zero_encryption_count(), + /// lwe_public_key_zero_encryption_count + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_public_key( + &mut self, + lwe_secret_key: &LweSecretKey32, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> Result> { + LwePublicKeyGenerationError::perform_generic_checks(lwe_public_key_zero_encryption_count)?; + Ok(unsafe { + self.generate_new_lwe_public_key_unchecked( + lwe_secret_key, + noise, + lwe_public_key_zero_encryption_count, + ) + }) + } + + unsafe fn generate_new_lwe_public_key_unchecked( + &mut self, + lwe_secret_key: &LweSecretKey32, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> LwePublicKey32 { + let encrypted_zeros = self.zero_encrypt_lwe_ciphertext_vector_unchecked( + lwe_secret_key, + noise, + LweCiphertextCount(lwe_public_key_zero_encryption_count.0), + ); + LwePublicKey32(encrypted_zeros.0) + } +} + +/// # Description: +/// Implementation of [`LwePublicKeyGenerationEngine`] for [`DefaultParallelEngine`] that operates +/// on 64 bits integers. +impl LwePublicKeyGenerationEngine for DefaultParallelEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut par_engine = DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey64 = par_engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// assert_eq!(public_key.lwe_dimension(), lwe_dimension); + /// assert_eq!( + /// public_key.lwe_zero_encryption_count(), + /// lwe_public_key_zero_encryption_count + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_public_key( + &mut self, + lwe_secret_key: &LweSecretKey64, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> Result> { + LwePublicKeyGenerationError::perform_generic_checks(lwe_public_key_zero_encryption_count)?; + Ok(unsafe { + self.generate_new_lwe_public_key_unchecked( + lwe_secret_key, + noise, + lwe_public_key_zero_encryption_count, + ) + }) + } + + unsafe fn generate_new_lwe_public_key_unchecked( + &mut self, + lwe_secret_key: &LweSecretKey64, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> LwePublicKey64 { + let encrypted_zeros = self.zero_encrypt_lwe_ciphertext_vector_unchecked( + lwe_secret_key, + noise, + LweCiphertextCount(lwe_public_key_zero_encryption_count.0), + ); + LwePublicKey64(encrypted_zeros.0) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_seeded_bootstrap_key_generation.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_seeded_bootstrap_key_generation.rs new file mode 100644 index 000000000..71a328f80 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/lwe_seeded_bootstrap_key_generation.rs @@ -0,0 +1,207 @@ +use super::ActivatedRandomGenerator; +use crate::core_crypto::backends::default::implementation::engines::DefaultParallelEngine; +use crate::core_crypto::backends::default::implementation::entities::{ + GlweSecretKey32, GlweSecretKey64, LweSecretKey32, LweSecretKey64, LweSeededBootstrapKey32, + LweSeededBootstrapKey64, +}; +use crate::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey as ImplStandardSeededBootstrapKey; +use crate::core_crypto::commons::math::random::{CompressionSeed, Seeder}; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSecretKeyEntity, LweSecretKeyEntity, + Variance, +}; +use crate::core_crypto::specification::engines::{ + LweSeededBootstrapKeyGenerationEngine, LweSeededBootstrapKeyGenerationError, +}; + +/// # Description: +/// Implementation of [`LweSeededBootstrapKeyGenerationEngine`] for [`DefaultParallelEngine`] that +/// operates on 32 bits integers. It outputs a seeded bootstrap key in the standard domain. +impl LweSeededBootstrapKeyGenerationEngine + for DefaultParallelEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweSeededBootstrapKey32 = default_parallel_engine + /// .generate_new_lwe_seeded_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// # + /// assert_eq!(bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(bsk.polynomial_size(), poly_size); + /// assert_eq!(bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_seeded_bootstrap_key( + &mut self, + input_key: &LweSecretKey32, + output_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result> + { + LweSeededBootstrapKeyGenerationError::perform_generic_checks( + decomposition_base_log, + decomposition_level_count, + 32, + )?; + Ok(unsafe { + self.generate_new_lwe_seeded_bootstrap_key_unchecked( + input_key, + output_key, + decomposition_base_log, + decomposition_level_count, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_seeded_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey32, + output_key: &GlweSecretKey32, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweSeededBootstrapKey32 { + let mut key = ImplStandardSeededBootstrapKey::>::allocate( + output_key.glwe_dimension().to_glwe_size(), + output_key.polynomial_size(), + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + CompressionSeed { + seed: self.seeder.seed(), + }, + ); + key.par_fill_with_new_key::<_, _, _, _, _, ActivatedRandomGenerator>( + &input_key.0, + &output_key.0, + noise, + &mut self.seeder, + ); + LweSeededBootstrapKey32(key) + } +} + +/// # Description: +/// Implementation of [`LweSeededBootstrapKeyGenerationEngine`] for [`DefaultParallelEngine`] that +/// operates on 64 bits integers. It outputs a seeded bootstrap key in the standard domain. +impl LweSeededBootstrapKeyGenerationEngine + for DefaultParallelEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweSeededBootstrapKey64 = default_parallel_engine + /// .generate_new_lwe_seeded_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// # + /// assert_eq!(bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(bsk.polynomial_size(), poly_size); + /// assert_eq!(bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn generate_new_lwe_seeded_bootstrap_key( + &mut self, + input_key: &LweSecretKey64, + output_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result> + { + LweSeededBootstrapKeyGenerationError::perform_generic_checks( + decomposition_base_log, + decomposition_level_count, + 64, + )?; + Ok(unsafe { + self.generate_new_lwe_seeded_bootstrap_key_unchecked( + input_key, + output_key, + decomposition_base_log, + decomposition_level_count, + noise, + ) + }) + } + + unsafe fn generate_new_lwe_seeded_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey64, + output_key: &GlweSecretKey64, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> LweSeededBootstrapKey64 { + let mut key = ImplStandardSeededBootstrapKey::>::allocate( + output_key.glwe_dimension().to_glwe_size(), + output_key.polynomial_size(), + decomposition_level_count, + decomposition_base_log, + input_key.lwe_dimension(), + CompressionSeed { + seed: self.seeder.seed(), + }, + ); + key.par_fill_with_new_key::<_, _, _, _, _, ActivatedRandomGenerator>( + &input_key.0, + &output_key.0, + noise, + &mut self.seeder, + ); + LweSeededBootstrapKey64(key) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/mod.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/mod.rs new file mode 100644 index 000000000..629ff11b2 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_parallel_engine/mod.rs @@ -0,0 +1,68 @@ +use super::ActivatedRandomGenerator; +use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder as ImplDeterministicSeeder, + EncryptionRandomGenerator as ImplEncryptionRandomGenerator, +}; +use crate::core_crypto::specification::engines::sealed::AbstractEngineSeal; +use crate::core_crypto::specification::engines::AbstractEngine; +use concrete_csprng::seeders::Seeder; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +/// The error which can occur in the execution of FHE operations, due to the default +/// parallel implementation. +/// +/// # Note: +/// +/// There is currently no such case, as the default parallel implementation is not expected to +/// undergo major issues unrelated to FHE. +#[derive(Debug)] +pub enum DefaultParallelError {} + +impl Display for DefaultParallelError { + fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result { + match *self {} + } +} + +impl Error for DefaultParallelError {} + +pub struct DefaultParallelEngine { + /// A structure containing two CSPRNGs to generate material for encryption like public masks + /// and secret errors. + /// + /// The [`ImplEncryptionRandomGenerator`] contains two CSPRNGs, one publicly seeded used to + /// generate mask coefficients and one privately seeded used to generate errors during + /// encryption. + pub(crate) encryption_generator: ImplEncryptionRandomGenerator, + // /// A seeder that can be called to generate 128 bits seeds, useful to create new + // /// [`ImplEncryptionRandomGenerator`] to encrypt seeded types. + // seeder: ImplDeterministicSeeder, +} + +impl AbstractEngineSeal for DefaultParallelEngine {} + +impl AbstractEngine for DefaultParallelEngine { + type EngineError = DefaultParallelError; + + type Parameters = Box; + + fn new(mut parameters: Self::Parameters) -> Result { + let mut deterministic_seeder = + ImplDeterministicSeeder::::new(parameters.seed()); + + Ok(DefaultParallelEngine { + encryption_generator: ImplEncryptionRandomGenerator::new( + deterministic_seeder.seed(), + &mut deterministic_seeder, + ), + // seeder: deterministic_seeder, + }) + } +} + +mod lwe_bootstrap_key_generation; +mod lwe_ciphertext_vector_zero_encryption; +mod lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation; +mod lwe_public_key_generation; +// mod lwe_seeded_bootstrap_key_generation; diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/entity_deserialization.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/entity_deserialization.rs new file mode 100644 index 000000000..cb345e7b0 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/entity_deserialization.rs @@ -0,0 +1,1411 @@ +#![allow(clippy::missing_safety_doc)] +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey as ImplStandardBootstrapKey; +use crate::core_crypto::commons::crypto::encoding::{ + Cleartext as ImplCleartext, Plaintext as ImplPlaintext, PlaintextList as ImplPlaintextList, +}; +use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList as ImplLweCircuitBoostrapPrivateFunctionalPackingKeyswitchKeys; +use crate::core_crypto::commons::crypto::lwe::{ + LweCiphertext as ImplLweCiphertext, LweKeyswitchKey as ImplLweKeyswitchKey, + LweList as ImplLweList, +}; +use crate::core_crypto::commons::crypto::secret::{ + GlweSecretKey as ImplGlweSecretKey, LweSecretKey as ImplLweSecretKey, +}; +use crate::core_crypto::prelude::{ + BinaryKeyKind, Cleartext32, Cleartext32Version, Cleartext64, Cleartext64Version, CleartextF64, + CleartextF64Version, DefaultSerializationEngine, DefaultSerializationError, + EntityDeserializationEngine, EntityDeserializationError, GlweSecretKey32, + GlweSecretKey32Version, GlweSecretKey64, GlweSecretKey64Version, LweBootstrapKey32, + LweBootstrapKey32Version, LweBootstrapKey64, LweBootstrapKey64Version, LweCiphertext32, + LweCiphertext32Version, LweCiphertext64, LweCiphertext64Version, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version, LweKeyswitchKey32, + LweKeyswitchKey32Version, LweKeyswitchKey64, LweKeyswitchKey64Version, LwePublicKey32, + LwePublicKey32Version, LwePublicKey64, LwePublicKey64Version, LweSecretKey32, + LweSecretKey32Version, LweSecretKey64, LweSecretKey64Version, Plaintext32, Plaintext32Version, + Plaintext64, Plaintext64Version, PlaintextVector32, PlaintextVector32Version, + PlaintextVector64, PlaintextVector64Version, +}; +use serde::Deserialize; + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a cleartext entity. +impl EntityDeserializationEngine<&[u8], Cleartext32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: u32 = 3; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext32 = engine.create_cleartext_from(&input)?; + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&cleartext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cleartext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableCleartext32 { + version: Cleartext32Version, + inner: ImplCleartext, + } + let deserialized: DeserializableCleartext32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableCleartext32 { + version: Cleartext32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableCleartext32 { + version: Cleartext32Version::V0, + inner, + } => Ok(Cleartext32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> Cleartext32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a cleartext entity. +impl EntityDeserializationEngine<&[u8], Cleartext64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: u64 = 3; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext64 = engine.create_cleartext_from(&input)?; + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&cleartext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cleartext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableCleartext64 { + version: Cleartext64Version, + inner: ImplCleartext, + } + let deserialized: DeserializableCleartext64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableCleartext64 { + version: Cleartext64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableCleartext64 { + version: Cleartext64Version::V0, + inner, + } => Ok(Cleartext64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> Cleartext64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a floating point cleartext entity. +impl EntityDeserializationEngine<&[u8], CleartextF64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: f64 = 3.; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: CleartextF64 = engine.create_cleartext_from(&input)?; + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&cleartext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cleartext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableCleartextF64 { + version: CleartextF64Version, + inner: ImplCleartext, + } + let deserialized: DeserializableCleartextF64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableCleartextF64 { + version: CleartextF64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableCleartextF64 { + version: CleartextF64Version::V0, + inner, + } => Ok(CleartextF64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> CleartextF64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a GLWE secret key entity. +impl EntityDeserializationEngine<&[u8], GlweSecretKey32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let glwe_secret_key: GlweSecretKey32 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&glwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(glwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableGlweSecretKey32 { + version: GlweSecretKey32Version, + inner: ImplGlweSecretKey>, + } + let deserialized: DeserializableGlweSecretKey32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableGlweSecretKey32 { + version: GlweSecretKey32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableGlweSecretKey32 { + version: GlweSecretKey32Version::V0, + inner, + } => Ok(GlweSecretKey32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> GlweSecretKey32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a GLWE secret key entity. +impl EntityDeserializationEngine<&[u8], GlweSecretKey64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let glwe_secret_key: GlweSecretKey64 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&glwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(glwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableGlweSecretKey64 { + version: GlweSecretKey64Version, + inner: ImplGlweSecretKey>, + } + let deserialized: DeserializableGlweSecretKey64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableGlweSecretKey64 { + version: GlweSecretKey64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableGlweSecretKey64 { + version: GlweSecretKey64Version::V0, + inner, + } => Ok(GlweSecretKey64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> GlweSecretKey64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a LWE bootstrap key entity. +impl EntityDeserializationEngine<&[u8], LweBootstrapKey32> for DefaultSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey32 = + /// engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweBootstrapKey32 { + version: LweBootstrapKey32Version, + inner: ImplStandardBootstrapKey>, + } + let deserialized: DeserializableLweBootstrapKey32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweBootstrapKey32 { + version: LweBootstrapKey32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweBootstrapKey32 { + version: LweBootstrapKey32Version::V0, + inner, + } => Ok(LweBootstrapKey32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweBootstrapKey32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a LWE bootstrap key entity. +impl EntityDeserializationEngine<&[u8], LweBootstrapKey64> for DefaultSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey64 = + /// engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweBootstrapKey64 { + version: LweBootstrapKey64Version, + inner: ImplStandardBootstrapKey>, + } + let deserialized: DeserializableLweBootstrapKey64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweBootstrapKey64 { + version: LweBootstrapKey64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweBootstrapKey64 { + version: LweBootstrapKey64Version::V0, + inner, + } => Ok(LweBootstrapKey64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweBootstrapKey64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a LWE ciphertext entity. +impl EntityDeserializationEngine<&[u8], LweCiphertext32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&ciphertext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(ciphertext, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweCiphertext32 { + version: LweCiphertext32Version, + inner: ImplLweCiphertext>, + } + let deserialized: DeserializableLweCiphertext32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweCiphertext32 { + version: LweCiphertext32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweCiphertext32 { + version: LweCiphertext32Version::V0, + inner, + } => Ok(LweCiphertext32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweCiphertext32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a LWE ciphertext entity. +impl EntityDeserializationEngine<&[u8], LweCiphertext64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::LweDimension; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&ciphertext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(ciphertext, recovered); + /// + /// # + /// # Ok(()) + /// # } + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweCiphertext64 { + version: LweCiphertext64Version, + inner: ImplLweCiphertext>, + } + let deserialized: DeserializableLweCiphertext64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweCiphertext64 { + version: LweCiphertext64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweCiphertext64 { + version: LweCiphertext64Version::V0, + inner, + } => Ok(LweCiphertext64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweCiphertext64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes an LWE circuit bootstrap private functional +/// packing keyswitch vector. +impl EntityDeserializationEngine<&[u8], LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + /// GlweDimension, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey32 = + /// engine.generate_new_glwe_secret_key(output_glwe_dimension, polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 = + /// engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = + /// serialization_engine.serialize(&cbs_private_functional_packing_keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cbs_private_functional_packing_keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result< + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + EntityDeserializationError, + > { + #[derive(Deserialize)] + struct DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version, + inner: ImplLweCircuitBoostrapPrivateFunctionalPackingKeyswitchKeys>, + } + let deserialized: DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 = + bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + version: + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version::V0, + inner, + } => Ok(LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32( + inner, + )), + } + } + + unsafe fn deserialize_unchecked( + &mut self, + serialized: &[u8], + ) -> LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes an LWE circuit bootstrap private functional +/// packing keyswitch vector. +impl EntityDeserializationEngine<&[u8], LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + /// GlweDimension, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey64 = + /// engine.generate_new_glwe_secret_key(output_glwe_dimension, polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 = + /// engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = + /// serialization_engine.serialize(&cbs_private_functional_packing_keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cbs_private_functional_packing_keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result< + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + EntityDeserializationError, + > { + #[derive(Deserialize)] + struct DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version, + inner: ImplLweCircuitBoostrapPrivateFunctionalPackingKeyswitchKeys>, + } + let deserialized: DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 = + bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + version: + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version::V0, + inner, + } => Ok(LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64( + inner, + )), + } + } + + unsafe fn deserialize_unchecked( + &mut self, + serialized: &[u8], + ) -> LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a LWE keyswitch key entity. +impl EntityDeserializationEngine<&[u8], LweKeyswitchKey32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweKeyswitchKey32 { + version: LweKeyswitchKey32Version, + inner: ImplLweKeyswitchKey>, + } + let deserialized: DeserializableLweKeyswitchKey32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweKeyswitchKey32 { + version: LweKeyswitchKey32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweKeyswitchKey32 { + version: LweKeyswitchKey32Version::V0, + inner, + } => Ok(LweKeyswitchKey32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweKeyswitchKey32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a LWE keyswitch key entity. +impl EntityDeserializationEngine<&[u8], LweKeyswitchKey64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweKeyswitchKey64 { + version: LweKeyswitchKey64Version, + inner: ImplLweKeyswitchKey>, + } + let deserialized: DeserializableLweKeyswitchKey64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweKeyswitchKey64 { + version: LweKeyswitchKey64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweKeyswitchKey64 { + version: LweKeyswitchKey64Version::V0, + inner, + } => Ok(LweKeyswitchKey64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweKeyswitchKey64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a LWE secret key entity. +impl EntityDeserializationEngine<&[u8], LweSecretKey32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&lwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(lwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweSecretKey32 { + version: LweSecretKey32Version, + inner: ImplLweSecretKey>, + } + let deserialized: DeserializableLweSecretKey32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweSecretKey32 { + version: LweSecretKey32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweSecretKey32 { + version: LweSecretKey32Version::V0, + inner, + } => Ok(LweSecretKey32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweSecretKey32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a LWE secret key entity. +impl EntityDeserializationEngine<&[u8], LweSecretKey64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&lwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(lwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLweSecretKey64 { + version: LweSecretKey64Version, + inner: ImplLweSecretKey>, + } + let deserialized: DeserializableLweSecretKey64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLweSecretKey64 { + version: LweSecretKey64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLweSecretKey64 { + version: LweSecretKey64Version::V0, + inner, + } => Ok(LweSecretKey64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LweSecretKey64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes an LWE public key. +impl EntityDeserializationEngine<&[u8], LwePublicKey32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey32 = engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&public_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(public_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLwePublicKey32 { + version: LwePublicKey32Version, + inner: ImplLweList>, + } + let deserialized: DeserializableLwePublicKey32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLwePublicKey32 { + version: LwePublicKey32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLwePublicKey32 { + version: LwePublicKey32Version::V0, + inner, + } => Ok(LwePublicKey32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LwePublicKey32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes an LWE public key. +impl EntityDeserializationEngine<&[u8], LwePublicKey64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey64 = engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&public_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(public_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializableLwePublicKey64 { + version: LwePublicKey64Version, + inner: ImplLweList>, + } + let deserialized: DeserializableLwePublicKey64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializableLwePublicKey64 { + version: LwePublicKey64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializableLwePublicKey64 { + version: LwePublicKey64Version::V0, + inner, + } => Ok(LwePublicKey64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> LwePublicKey64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a plaintext entity. +impl EntityDeserializationEngine<&[u8], Plaintext32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext32 = engine.create_plaintext_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializablePlaintext32 { + version: Plaintext32Version, + inner: ImplPlaintext, + } + let deserialized: DeserializablePlaintext32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializablePlaintext32 { + version: Plaintext32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializablePlaintext32 { + version: Plaintext32Version::V0, + inner, + } => Ok(Plaintext32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> Plaintext32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a plaintext entity. +impl EntityDeserializationEngine<&[u8], Plaintext64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext64 = engine.create_plaintext_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializablePlaintext64 { + version: Plaintext64Version, + inner: ImplPlaintext, + } + let deserialized: DeserializablePlaintext64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializablePlaintext64 { + version: Plaintext64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializablePlaintext64 { + version: Plaintext64Version::V0, + inner, + } => Ok(Plaintext64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> Plaintext64 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 32 bits integers. It deserializes a plaintext vector entity. +impl EntityDeserializationEngine<&[u8], PlaintextVector32> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PlaintextCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u32 << 20; 3]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector32 = engine.create_plaintext_vector_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext_vector)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext_vector, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializablePlaintextVector32 { + version: PlaintextVector32Version, + inner: ImplPlaintextList>, + } + let deserialized: DeserializablePlaintextVector32 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializablePlaintextVector32 { + version: PlaintextVector32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializablePlaintextVector32 { + version: PlaintextVector32Version::V0, + inner, + } => Ok(PlaintextVector32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> PlaintextVector32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`DefaultSerializationEngine`] that +/// operates on 64 bits integers. It deserializes a plaintext vector entity. +impl EntityDeserializationEngine<&[u8], PlaintextVector64> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PlaintextCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = vec![3_u64 << 50; 3]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector64 = engine.create_plaintext_vector_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext_vector)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext_vector, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct DeserializablePlaintextVector64 { + version: PlaintextVector64Version, + inner: ImplPlaintextList>, + } + let deserialized: DeserializablePlaintextVector64 = bincode::deserialize(serialized) + .map_err(DefaultSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + DeserializablePlaintextVector64 { + version: PlaintextVector64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + DefaultSerializationError::UnsupportedVersion, + )), + DeserializablePlaintextVector64 { + version: PlaintextVector64Version::V0, + inner, + } => Ok(PlaintextVector64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> PlaintextVector64 { + self.deserialize(serialized).unwrap() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/entity_serialization.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/entity_serialization.rs new file mode 100644 index 000000000..803c63f80 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/entity_serialization.rs @@ -0,0 +1,1484 @@ +#![allow(clippy::missing_safety_doc)] + +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey as ImplStandardBootstrapKey; +use crate::core_crypto::commons::crypto::encoding::{ + Cleartext as ImplCleartext, Plaintext as ImplPlaintext, PlaintextList as ImplPlaintextList, +}; +use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList as ImplLweCircuitBoostrapPrivateFunctionalPackingKeyswitchKeys; +use crate::core_crypto::commons::crypto::lwe::{ + LweCiphertext as ImplLweCiphertext, LweKeyswitchKey as ImplLweKeyswitchKey, + LweList as ImplLweList, +}; +use crate::core_crypto::commons::crypto::secret::{ + GlweSecretKey as ImplGlweSecretKey, LweSecretKey as ImplLweSecretKey, +}; +use crate::core_crypto::prelude::{ + BinaryKeyKind, Cleartext32, Cleartext32Version, Cleartext64, Cleartext64Version, CleartextF64, + CleartextF64Version, DefaultSerializationEngine, DefaultSerializationError, + EntitySerializationEngine, EntitySerializationError, GlweSecretKey32, GlweSecretKey32Version, + GlweSecretKey64, GlweSecretKey64Version, LweBootstrapKey32, LweBootstrapKey32Version, + LweBootstrapKey64, LweBootstrapKey64Version, LweCiphertext32, LweCiphertext32Version, + LweCiphertext64, LweCiphertext64Version, LweCiphertextMutView32, LweCiphertextMutView64, + LweCiphertextView32, LweCiphertextView64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version, LweKeyswitchKey32, + LweKeyswitchKey32Version, LweKeyswitchKey64, LweKeyswitchKey64Version, LwePublicKey32, + LwePublicKey32Version, LwePublicKey64, LwePublicKey64Version, LweSecretKey32, + LweSecretKey32Version, LweSecretKey64, LweSecretKey64Version, Plaintext32, Plaintext32Version, + Plaintext64, Plaintext64Version, PlaintextVector32, PlaintextVector32Version, + PlaintextVector64, PlaintextVector64Version, +}; +use serde::Serialize; + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a cleartext entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: u32 = 3; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext32 = engine.create_cleartext_from(&input)?; + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&cleartext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cleartext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &Cleartext32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableCleartext32<'a> { + version: Cleartext32Version, + inner: &'a ImplCleartext, + } + let serializable = SerializableCleartext32 { + version: Cleartext32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &Cleartext32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a cleartext entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: u64 = 3; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: Cleartext64 = engine.create_cleartext_from(&input)?; + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&cleartext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cleartext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &Cleartext64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableCleartext64<'a> { + version: Cleartext64Version, + inner: &'a ImplCleartext, + } + let serializable = SerializableCleartext64 { + version: Cleartext64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &Cleartext64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a floating point cleartext. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// let input: f64 = 3.; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let cleartext: CleartextF64 = engine.create_cleartext_from(&input)?; + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&cleartext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cleartext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &CleartextF64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableCleartextF64<'a> { + version: CleartextF64Version, + inner: &'a ImplCleartext, + } + let serializable = SerializableCleartextF64 { + version: CleartextF64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &CleartextF64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a GLWE secret key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let glwe_secret_key: GlweSecretKey32 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&glwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(glwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &GlweSecretKey32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableGlweSecretKey32<'a> { + version: GlweSecretKey32Version, + inner: &'a ImplGlweSecretKey>, + } + let serializable = SerializableGlweSecretKey32 { + version: GlweSecretKey32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &GlweSecretKey32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a GLWE secret key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let glwe_dimension = GlweDimension(2); + /// let polynomial_size = PolynomialSize(4); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let glwe_secret_key: GlweSecretKey64 = + /// engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&glwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(glwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &GlweSecretKey64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableGlweSecretKey64<'a> { + version: GlweSecretKey64Version, + inner: &'a ImplGlweSecretKey>, + } + let serializable = SerializableGlweSecretKey64 { + version: GlweSecretKey64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &GlweSecretKey64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a LWE bootstrap key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey32 = + /// engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweBootstrapKey32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweBootstrapKey32<'a> { + version: LweBootstrapKey32Version, + inner: &'a ImplStandardBootstrapKey>, + } + let serializable = SerializableLweBootstrapKey32 { + version: LweBootstrapKey32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweBootstrapKey32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a LWE bootstrap key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_sk: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// + /// let bsk: LweBootstrapKey64 = + /// engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweBootstrapKey64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweBootstrapKey64<'a> { + version: LweBootstrapKey64Version, + inner: &'a ImplStandardBootstrapKey>, + } + let serializable = SerializableLweBootstrapKey64 { + version: LweBootstrapKey64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweBootstrapKey64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a LWE ciphertext entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&ciphertext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(ciphertext, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweCiphertext32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCiphertext32<'a> { + version: LweCiphertext32Version, + inner: &'a ImplLweCiphertext>, + } + let serializable = SerializableLweCiphertext32 { + version: LweCiphertext32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweCiphertext32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a LWE ciphertext entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::LweDimension; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&ciphertext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(ciphertext, recovered); + /// + /// # + /// # Ok(()) + /// # } + fn serialize( + &mut self, + entity: &LweCiphertext64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCiphertext64<'a> { + version: LweCiphertext64Version, + inner: &'a ImplLweCiphertext>, + } + let serializable = SerializableLweCiphertext64 { + version: LweCiphertext64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweCiphertext64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a LWE ciphertext view entity. +impl<'b> EntitySerializationEngine, Vec> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// let raw_buffer = engine.consume_retrieve_lwe_ciphertext(ciphertext)?; + /// let view: LweCiphertextView32 = engine.create_lwe_ciphertext_from(raw_buffer.as_slice())?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&view)?; + /// let recovered: LweCiphertext32 = serialization_engine.deserialize(serialized.as_slice())?; + /// let recovered_buffer = engine.consume_retrieve_lwe_ciphertext(recovered)?; + /// assert_eq!(raw_buffer, recovered_buffer); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweCiphertextView32<'b>, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCiphertextView32<'a, 'b> { + version: LweCiphertext32Version, + inner: &'a ImplLweCiphertext<&'b [u32]>, + } + let serializable = SerializableLweCiphertextView32 { + version: LweCiphertext32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweCiphertextView32<'b>) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a LWE ciphertext view entity. +impl<'b> EntitySerializationEngine, Vec> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::LweDimension; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let raw_buffer = engine.consume_retrieve_lwe_ciphertext(ciphertext)?; + /// let view: LweCiphertextView64 = engine.create_lwe_ciphertext_from(raw_buffer.as_slice())?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&view)?; + /// let recovered: LweCiphertext64 = serialization_engine.deserialize(serialized.as_slice())?; + /// let recovered_buffer = engine.consume_retrieve_lwe_ciphertext(recovered)?; + /// assert_eq!(raw_buffer, recovered_buffer); + /// + /// # + /// # Ok(()) + /// # } + fn serialize( + &mut self, + entity: &LweCiphertextView64<'b>, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCiphertextView64<'a, 'b> { + version: LweCiphertext64Version, + inner: &'a ImplLweCiphertext<&'b [u64]>, + } + let serializable = SerializableLweCiphertextView64 { + version: LweCiphertext64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweCiphertextView64<'b>) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a LWE ciphertext mut view entity. +impl<'b> EntitySerializationEngine, Vec> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, Variance, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// let mut raw_buffer = engine.consume_retrieve_lwe_ciphertext(ciphertext)?; + /// let view: LweCiphertextMutView32 = + /// engine.create_lwe_ciphertext_from(raw_buffer.as_mut_slice())?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&view)?; + /// let recovered: LweCiphertext32 = serialization_engine.deserialize(serialized.as_slice())?; + /// let recovered_buffer = engine.consume_retrieve_lwe_ciphertext(recovered)?; + /// assert_eq!(raw_buffer, recovered_buffer); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweCiphertextMutView32<'b>, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCiphertextMutView32<'a, 'b> { + version: LweCiphertext32Version, + inner: &'a ImplLweCiphertext<&'b mut [u32]>, + } + let serializable = SerializableLweCiphertextMutView32 { + version: LweCiphertext32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweCiphertextMutView32<'b>) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a LWE ciphertext mut view entity. +impl<'b> EntitySerializationEngine, Vec> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::Variance; + /// use tfhe::core_crypto::prelude::LweDimension; + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(2); + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let plaintext = engine.create_plaintext_from(&input)?; + /// + /// let ciphertext = engine.encrypt_lwe_ciphertext(&key, &plaintext, noise)?; + /// + /// let mut raw_buffer = engine.consume_retrieve_lwe_ciphertext(ciphertext)?; + /// let view: LweCiphertextMutView64 = engine.create_lwe_ciphertext_from(raw_buffer.as_mut_slice())?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&view)?; + /// let recovered: LweCiphertext64 = serialization_engine.deserialize(serialized.as_slice())?; + /// let recovered_buffer = engine.consume_retrieve_lwe_ciphertext(recovered)?; + /// assert_eq!(raw_buffer, recovered_buffer); + /// + /// # + /// # Ok(()) + /// # } + fn serialize( + &mut self, + entity: &LweCiphertextMutView64<'b>, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCiphertextMutView64<'a, 'b> { + version: LweCiphertext64Version, + inner: &'a ImplLweCiphertext<&'b mut [u64]>, + } + let serializable = SerializableLweCiphertextMutView64 { + version: LweCiphertext64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweCiphertextMutView64<'b>) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes an LWE circuit bootstrap private functional packing keyswitch +/// vector. +impl EntitySerializationEngine> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + /// GlweDimension, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey32 = + /// engine.generate_new_glwe_secret_key(output_glwe_dimension, polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 = + /// engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = + /// serialization_engine.serialize(&cbs_private_functional_packing_keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cbs_private_functional_packing_keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32<'a> { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version, + inner: &'a ImplLweCircuitBoostrapPrivateFunctionalPackingKeyswitchKeys>, + } + let serializable = SerializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked( + &mut self, + entity: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + ) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes an LWE circuit bootstrap private functional packing keyswitch +/// vector. +impl EntitySerializationEngine> + for DefaultSerializationEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + /// GlweDimension, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(10); + /// let output_glwe_dimension = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomposition_base_log = DecompositionBaseLog(3); + /// let decomposition_level_count = DecompositionLevelCount(5); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: GlweSecretKey64 = + /// engine.generate_new_glwe_secret_key(output_glwe_dimension, polynomial_size)?; + /// + /// let cbs_private_functional_packing_keyswitch_key: + /// LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 = + /// engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &input_key, + /// &output_key, + /// decomposition_base_log, + /// decomposition_level_count, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = + /// serialization_engine.serialize(&cbs_private_functional_packing_keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(cbs_private_functional_packing_keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64<'a> { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version, + inner: &'a ImplLweCircuitBoostrapPrivateFunctionalPackingKeyswitchKeys>, + } + let serializable = SerializableLweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + version: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked( + &mut self, + entity: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + ) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a LWE keyswitch key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweKeyswitchKey32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweKeyswitchKey32<'a> { + version: LweKeyswitchKey32Version, + inner: &'a ImplLweKeyswitchKey>, + } + let serializable = SerializableLweKeyswitchKey32 { + version: LweKeyswitchKey32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweKeyswitchKey32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a LWE keyswitch key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let input_lwe_dimension = LweDimension(6); + /// let output_lwe_dimension = LweDimension(3); + /// let decomposition_level_count = DecompositionLevelCount(2); + /// let decomposition_base_log = DecompositionBaseLog(8); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let input_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(input_lwe_dimension)?; + /// let output_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(output_lwe_dimension)?; + /// + /// let keyswitch_key = engine.generate_new_lwe_keyswitch_key( + /// &input_key, + /// &output_key, + /// decomposition_level_count, + /// decomposition_base_log, + /// noise, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&keyswitch_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(keyswitch_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweKeyswitchKey64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweKeyswitchKey64<'a> { + version: LweKeyswitchKey64Version, + inner: &'a ImplLweKeyswitchKey>, + } + let serializable = SerializableLweKeyswitchKey64 { + version: LweKeyswitchKey64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweKeyswitchKey64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a LWE secret key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&lwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(lwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweSecretKey32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweSecretKey32<'a> { + version: LweSecretKey32Version, + inner: &'a ImplLweSecretKey>, + } + let serializable = SerializableLweSecretKey32 { + version: LweSecretKey32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweSecretKey32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a LWE secret key entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&lwe_secret_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(lwe_secret_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LweSecretKey64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLweSecretKey64<'a> { + version: LweSecretKey64Version, + inner: &'a ImplLweSecretKey>, + } + let serializable = SerializableLweSecretKey64 { + version: LweSecretKey64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LweSecretKey64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes an LWE public key. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey32 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey32 = engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&public_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(public_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LwePublicKey32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLwePublicKey32<'a> { + version: LwePublicKey32Version, + inner: &'a ImplLweList>, + } + let serializable = SerializableLwePublicKey32 { + version: LwePublicKey32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LwePublicKey32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes an LWE public key. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let lwe_dimension = LweDimension(6); + /// let noise = Variance(2_f64.powf(-50.)); + /// let lwe_public_key_zero_encryption_count = LwePublicKeyZeroEncryptionCount(42); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let lwe_secret_key: LweSecretKey64 = engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// + /// let public_key: LwePublicKey64 = engine.generate_new_lwe_public_key( + /// &lwe_secret_key, + /// noise, + /// lwe_public_key_zero_encryption_count, + /// )?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&public_key)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(public_key, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &LwePublicKey64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializableLwePublicKey64<'a> { + version: LwePublicKey64Version, + inner: &'a ImplLweList>, + } + let serializable = SerializableLwePublicKey64 { + version: LwePublicKey64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &LwePublicKey64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a plaintext entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext32 = engine.create_plaintext_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &Plaintext32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializablePlaintext32<'a> { + version: Plaintext32Version, + inner: &'a ImplPlaintext, + } + let serializable = SerializablePlaintext32 { + version: Plaintext32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &Plaintext32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a plaintext entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext: Plaintext64 = engine.create_plaintext_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext, recovered); + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &Plaintext64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializablePlaintext64<'a> { + version: Plaintext64Version, + inner: &'a ImplPlaintext, + } + let serializable = SerializablePlaintext64 { + version: Plaintext64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &Plaintext64) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 32 bits integers. It serializes a plaintext vector entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PlaintextCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = vec![3_u32 << 20; 3]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector32 = engine.create_plaintext_vector_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext_vector)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext_vector, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &PlaintextVector32, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializablePlaintextVector32<'a> { + version: PlaintextVector32Version, + inner: &'a ImplPlaintextList>, + } + let serializable = SerializablePlaintextVector32 { + version: PlaintextVector32Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &PlaintextVector32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`DefaultSerializationEngine`] that operates +/// on 64 bits integers. It serializes a plaintext vector entity. +impl EntitySerializationEngine> for DefaultSerializationEngine { + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::{PlaintextCount, *}; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = vec![3_u64 << 50; 3]; + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let plaintext_vector: PlaintextVector64 = engine.create_plaintext_vector_from(&input)?; + /// + /// let mut serialization_engine = DefaultSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&plaintext_vector)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(plaintext_vector, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &PlaintextVector64, + ) -> Result, EntitySerializationError> { + #[derive(Serialize)] + struct SerializablePlaintextVector64<'a> { + version: PlaintextVector64Version, + inner: &'a ImplPlaintextList>, + } + let serializable = SerializablePlaintextVector64 { + version: PlaintextVector64Version::V0, + inner: &entity.0, + }; + bincode::serialize(&serializable) + .map_err(DefaultSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &PlaintextVector64) -> Vec { + self.serialize(entity).unwrap() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/mod.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/mod.rs new file mode 100644 index 000000000..f3fd36cce --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/default_serialization_engine/mod.rs @@ -0,0 +1,54 @@ +use crate::core_crypto::prelude::sealed::AbstractEngineSeal; +use crate::core_crypto::prelude::AbstractEngine; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +/// The error which can occur in the executions of the `DefaultSerializationEngine` operations. +#[derive(Debug)] +pub enum DefaultSerializationError { + Serialization(bincode::Error), + Deserialization(bincode::Error), + UnsupportedVersion, +} + +#[allow(unused_variables)] +#[allow(unreachable_patterns)] +impl Display for DefaultSerializationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DefaultSerializationError::Serialization(bincode_error) => { + write!(f, "Failed to serialize entity: {bincode_error}") + } + DefaultSerializationError::Deserialization(bincode_error) => { + write!(f, "Failed to deserialize entity: {bincode_error}") + } + DefaultSerializationError::UnsupportedVersion => { + write!( + f, + "The version used to serialize the entity is not supported." + ) + } + } + } +} + +impl Error for DefaultSerializationError {} + +pub struct DefaultSerializationEngine; + +impl AbstractEngineSeal for DefaultSerializationEngine {} + +impl AbstractEngine for DefaultSerializationEngine { + type EngineError = DefaultSerializationError; + type Parameters = (); + + fn new(_parameter: Self::Parameters) -> Result + where + Self: Sized, + { + Ok(DefaultSerializationEngine) + } +} + +mod entity_deserialization; +mod entity_serialization; diff --git a/tfhe/src/core_crypto/backends/default/implementation/engines/mod.rs b/tfhe/src/core_crypto/backends/default/implementation/engines/mod.rs new file mode 100644 index 000000000..262361026 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/engines/mod.rs @@ -0,0 +1,18 @@ +//! A module containing the [engines](crate::core_crypto::specification::engines) exposed by +//! the default backend. + +mod default_engine; +pub use default_engine::*; + +#[cfg(feature = "backend_default_parallel")] +mod default_parallel_engine; +#[cfg(feature = "backend_default_parallel")] +pub use default_parallel_engine::*; + +#[cfg(feature = "backend_default_serialization")] +mod default_serialization_engine; +#[cfg(feature = "backend_default_serialization")] +pub use default_serialization_engine::*; + +mod activated_generator; +pub use activated_generator::ActivatedRandomGenerator; diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/cleartext.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/cleartext.rs new file mode 100644 index 000000000..fe998e8d8 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/cleartext.rs @@ -0,0 +1,53 @@ +use crate::core_crypto::commons::crypto::encoding::Cleartext as ImplCleartext; +use crate::core_crypto::specification::entities::markers::CleartextKind; +use crate::core_crypto::specification::entities::{AbstractEntity, CleartextEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a cleartext with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Cleartext32(pub(crate) ImplCleartext); +impl AbstractEntity for Cleartext32 { + type Kind = CleartextKind; +} +impl CleartextEntity for Cleartext32 {} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum Cleartext32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a cleartext with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Cleartext64(pub(crate) ImplCleartext); +impl AbstractEntity for Cleartext64 { + type Kind = CleartextKind; +} +impl CleartextEntity for Cleartext64 {} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum Cleartext64Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a floating point cleartext with 64 bits of precision. +#[derive(Debug, Clone, PartialEq)] +pub struct CleartextF64(pub(crate) ImplCleartext); +impl AbstractEntity for CleartextF64 { + type Kind = CleartextKind; +} +impl CleartextEntity for CleartextF64 {} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum CleartextF64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/glwe_ciphertext.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/glwe_ciphertext.rs new file mode 100644 index 000000000..37b6f4d2e --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/glwe_ciphertext.rs @@ -0,0 +1,159 @@ +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext as ImplGlweCiphertext; +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; +use crate::core_crypto::specification::entities::markers::GlweCiphertextKind; +use crate::core_crypto::specification::entities::{AbstractEntity, GlweCiphertextEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a GLWE ciphertext with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweCiphertext32(pub(crate) ImplGlweCiphertext>); + +impl AbstractEntity for GlweCiphertext32 { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for GlweCiphertext32 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum GlweCiphertext32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a GLWE ciphertext with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweCiphertext64(pub(crate) ImplGlweCiphertext>); + +impl AbstractEntity for GlweCiphertext64 { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for GlweCiphertext64 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum GlweCiphertext64Version { + V0, + #[serde(other)] + Unsupported, +} + +// GlweCiphertextViews are just GlweCiphertext entities that do not own their memory, they use a +// slice as a container as opposed to Vec for the standard GlweCiphertext + +/// A structure representing a GLWE ciphertext view, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct GlweCiphertextView32<'a>(pub(crate) ImplGlweCiphertext<&'a [u32]>); +impl AbstractEntity for GlweCiphertextView32<'_> { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for GlweCiphertextView32<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +/// A structure representing a GLWE ciphertext view, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct GlweCiphertextMutView32<'a>(pub(crate) ImplGlweCiphertext<&'a mut [u32]>); +impl AbstractEntity for GlweCiphertextMutView32<'_> { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for GlweCiphertextMutView32<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +/// A structure representing a GLWE ciphertext view, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct GlweCiphertextView64<'a>(pub(crate) ImplGlweCiphertext<&'a [u64]>); + +impl AbstractEntity for GlweCiphertextView64<'_> { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for GlweCiphertextView64<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +/// A structure representing a GLWE ciphertext view, with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct GlweCiphertextMutView64<'a>(pub(crate) ImplGlweCiphertext<&'a mut [u64]>); + +impl AbstractEntity for GlweCiphertextMutView64<'_> { + type Kind = GlweCiphertextKind; +} + +impl GlweCiphertextEntity for GlweCiphertextMutView64<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/glwe_secret_key.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/glwe_secret_key.rs new file mode 100644 index 000000000..5b178289e --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/glwe_secret_key.rs @@ -0,0 +1,54 @@ +use crate::core_crypto::commons::crypto::secret::GlweSecretKey as ImpGlweSecretKey; +use crate::core_crypto::prelude::{BinaryKeyKind, GlweDimension, PolynomialSize}; +use crate::core_crypto::specification::entities::markers::GlweSecretKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, GlweSecretKeyEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a GLWE secret key with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweSecretKey32(pub(crate) ImpGlweSecretKey>); +impl AbstractEntity for GlweSecretKey32 { + type Kind = GlweSecretKeyKind; +} +impl GlweSecretKeyEntity for GlweSecretKey32 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.key_size() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum GlweSecretKey32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a GLWE secret key with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweSecretKey64(pub(crate) ImpGlweSecretKey>); +impl AbstractEntity for GlweSecretKey64 { + type Kind = GlweSecretKeyKind; +} +impl GlweSecretKeyEntity for GlweSecretKey64 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.key_size() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum GlweSecretKey64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_bootstrap_key.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_bootstrap_key.rs new file mode 100644 index 000000000..f9ec80834 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_bootstrap_key.rs @@ -0,0 +1,220 @@ +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey as ImplStandardBootstrapKey; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, +}; +use crate::core_crypto::specification::entities::markers::LweBootstrapKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweBootstrapKeyEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing an LWE bootstrap key with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweBootstrapKey32(pub(crate) ImplStandardBootstrapKey>); +impl AbstractEntity for LweBootstrapKey32 { + type Kind = LweBootstrapKeyKind; +} +impl LweBootstrapKeyEntity for LweBootstrapKey32 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.base_log() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.level_count() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweBootstrapKey32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE bootstrap key with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweBootstrapKey64(pub(crate) ImplStandardBootstrapKey>); +impl AbstractEntity for LweBootstrapKey64 { + type Kind = LweBootstrapKeyKind; +} +impl LweBootstrapKeyEntity for LweBootstrapKey64 { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.base_log() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.level_count() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweBootstrapKey64Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE bootstrap key with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweBootstrapKeyMutView32<'a>(pub(crate) ImplStandardBootstrapKey<&'a mut [u32]>); +impl AbstractEntity for LweBootstrapKeyMutView32<'_> { + type Kind = LweBootstrapKeyKind; +} +impl LweBootstrapKeyEntity for LweBootstrapKeyMutView32<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.base_log() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.level_count() + } +} + +/// A structure representing an LWE bootstrap key with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweBootstrapKeyMutView64<'a>(pub(crate) ImplStandardBootstrapKey<&'a mut [u64]>); +impl AbstractEntity for LweBootstrapKeyMutView64<'_> { + type Kind = LweBootstrapKeyKind; +} +impl LweBootstrapKeyEntity for LweBootstrapKeyMutView64<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.base_log() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.level_count() + } +} + +/// A structure representing an LWE bootstrap key with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweBootstrapKeyView32<'a>(pub(crate) ImplStandardBootstrapKey<&'a [u32]>); +impl AbstractEntity for LweBootstrapKeyView32<'_> { + type Kind = LweBootstrapKeyKind; +} +impl LweBootstrapKeyEntity for LweBootstrapKeyView32<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.base_log() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.level_count() + } +} + +/// A structure representing an LWE bootstrap key with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweBootstrapKeyView64<'a>(pub(crate) ImplStandardBootstrapKey<&'a [u64]>); +impl AbstractEntity for LweBootstrapKeyView64<'_> { + type Kind = LweBootstrapKeyKind; +} +impl LweBootstrapKeyEntity for LweBootstrapKeyView64<'_> { + fn glwe_dimension(&self) -> GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.base_log() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.level_count() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_ciphertext.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_ciphertext.rs new file mode 100644 index 000000000..5c5ed6f8d --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_ciphertext.rs @@ -0,0 +1,129 @@ +use crate::core_crypto::commons::crypto::lwe::LweCiphertext as ImplLweCiphertext; +use crate::core_crypto::prelude::LweDimension; +use crate::core_crypto::specification::entities::markers::LweCiphertextKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweCiphertextEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing an LWE ciphertext with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweCiphertext32(pub(crate) ImplLweCiphertext>); +impl AbstractEntity for LweCiphertext32 { + type Kind = LweCiphertextKind; +} +impl LweCiphertextEntity for LweCiphertext32 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweCiphertext32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE ciphertext with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweCiphertext64(pub(crate) ImplLweCiphertext>); +impl AbstractEntity for LweCiphertext64 { + type Kind = LweCiphertextKind; +} +impl LweCiphertextEntity for LweCiphertext64 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweCiphertext64Version { + V0, + #[serde(other)] + Unsupported, +} + +// LweCiphertextViews are just LweCiphertext entities that do not own their memory, they use a slice +// as a container as opposed to Vec for the standard LweCiphertext + +/// A structure representing an LWE ciphertext view, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextView32<'a>(pub(crate) ImplLweCiphertext<&'a [u32]>); + +impl AbstractEntity for LweCiphertextView32<'_> { + type Kind = LweCiphertextKind; +} +impl LweCiphertextEntity for LweCiphertextView32<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } +} + +/// A structure representing an LWE ciphertext view, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextMutView32<'a>(pub(crate) ImplLweCiphertext<&'a mut [u32]>); + +impl AbstractEntity for LweCiphertextMutView32<'_> { + type Kind = LweCiphertextKind; +} +impl LweCiphertextEntity for LweCiphertextMutView32<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } +} + +/// A structure representing an LWE ciphertext view, with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextView64<'a>(pub(crate) ImplLweCiphertext<&'a [u64]>); + +impl AbstractEntity for LweCiphertextView64<'_> { + type Kind = LweCiphertextKind; +} +impl LweCiphertextEntity for LweCiphertextView64<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } +} + +/// A structure representing an LWE ciphertext view, with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextMutView64<'a>(pub(crate) ImplLweCiphertext<&'a mut [u64]>); + +impl AbstractEntity for LweCiphertextMutView64<'_> { + type Kind = LweCiphertextKind; +} +impl LweCiphertextEntity for LweCiphertextMutView64<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_ciphertext_vector.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_ciphertext_vector.rs new file mode 100644 index 000000000..23b0312ae --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_ciphertext_vector.rs @@ -0,0 +1,161 @@ +use crate::core_crypto::commons::crypto::lwe::LweList as ImplLweList; +use crate::core_crypto::prelude::{LweCiphertextCount, LweDimension}; +use crate::core_crypto::specification::entities::markers::LweCiphertextVectorKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweCiphertextVectorEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a vector of LWE ciphertexts with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweCiphertextVector32(pub(crate) ImplLweList>); + +impl AbstractEntity for LweCiphertextVector32 { + type Kind = LweCiphertextVectorKind; +} + +impl LweCiphertextVectorEntity for LweCiphertextVector32 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_ciphertext_count(&self) -> LweCiphertextCount { + LweCiphertextCount(self.0.count().0) + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweCiphertextVector32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a vector of LWE ciphertexts with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweCiphertextVector64(pub(crate) ImplLweList>); + +impl AbstractEntity for LweCiphertextVector64 { + type Kind = LweCiphertextVectorKind; +} + +impl LweCiphertextVectorEntity for LweCiphertextVector64 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_ciphertext_count(&self) -> LweCiphertextCount { + LweCiphertextCount(self.0.count().0) + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweCiphertextVector64Version { + V0, + #[serde(other)] + Unsupported, +} + +// LweCiphertextVectorViews are just LweCiphertextVector entities that do not own their memory, +// they use a slice as a container as opposed to Vec for the standard LweCiphertextVector + +/// A structure representing a vector of LWE ciphertext views, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextVectorView32<'a>(pub(crate) ImplLweList<&'a [u32]>); + +impl AbstractEntity for LweCiphertextVectorView32<'_> { + type Kind = LweCiphertextVectorKind; +} + +impl LweCiphertextVectorEntity for LweCiphertextVectorView32<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_ciphertext_count(&self) -> LweCiphertextCount { + LweCiphertextCount(self.0.count().0) + } +} + +/// A structure representing a vector of LWE ciphertext views, with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextVectorMutView32<'a>(pub(crate) ImplLweList<&'a mut [u32]>); + +impl AbstractEntity for LweCiphertextVectorMutView32<'_> { + type Kind = LweCiphertextVectorKind; +} + +impl LweCiphertextVectorEntity for LweCiphertextVectorMutView32<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_ciphertext_count(&self) -> LweCiphertextCount { + LweCiphertextCount(self.0.count().0) + } +} + +/// A structure representing a vector of LWE ciphertext views, with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextVectorView64<'a>(pub(crate) ImplLweList<&'a [u64]>); + +impl AbstractEntity for LweCiphertextVectorView64<'_> { + type Kind = LweCiphertextVectorKind; +} + +impl LweCiphertextVectorEntity for LweCiphertextVectorView64<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_ciphertext_count(&self) -> LweCiphertextCount { + LweCiphertextCount(self.0.count().0) + } +} + +/// A structure representing a vector of LWE ciphertext views, with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweCiphertextVectorMutView64<'a>(pub(crate) ImplLweList<&'a mut [u64]>); + +impl AbstractEntity for LweCiphertextVectorMutView64<'_> { + type Kind = LweCiphertextVectorKind; +} + +impl LweCiphertextVectorEntity for LweCiphertextVectorMutView64<'_> { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_ciphertext_count(&self) -> LweCiphertextCount { + LweCiphertextCount(self.0.count().0) + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys.rs new file mode 100644 index 000000000..d5c0a3c79 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys.rs @@ -0,0 +1,99 @@ +use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList as ImplLwePrivateFunctionalPackingKeyswitchKeyList; +use crate::core_crypto::prelude::markers::LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysKind; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + GlweDimension, LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, LweDimension, +}; +use crate::core_crypto::specification::entities::AbstractEntity; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a vector of private functional packing keyswitch keys used for a +/// circuit bootsrap with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32( + pub(crate) ImplLwePrivateFunctionalPackingKeyswitchKeyList>, +); +impl AbstractEntity for LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 { + type Kind = LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysKind; +} +impl LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity + for LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 +{ + fn input_lwe_dimension(&self) -> LweDimension { + self.0.input_lwe_key_dimension() + } + + fn output_glwe_dimension(&self) -> GlweDimension { + self.0.output_glwe_key_dimension() + } + + fn output_polynomial_size(&self) -> crate::core_crypto::prelude::PolynomialSize { + self.0.output_polynomial_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_level_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } + + fn key_count(&self) -> FunctionalPackingKeyswitchKeyCount { + self.0.fpksk_count() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a vector of private functional packing keyswitch keys used for a +/// circuit bootsrap with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64( + pub ImplLwePrivateFunctionalPackingKeyswitchKeyList>, +); +impl AbstractEntity for LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 { + type Kind = LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysKind; +} +impl LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity + for LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 +{ + fn input_lwe_dimension(&self) -> LweDimension { + self.0.input_lwe_key_dimension() + } + + fn output_glwe_dimension(&self) -> GlweDimension { + self.0.output_glwe_key_dimension() + } + + fn output_polynomial_size(&self) -> crate::core_crypto::prelude::PolynomialSize { + self.0.output_polynomial_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_level_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } + + fn key_count(&self) -> FunctionalPackingKeyswitchKeyCount { + self.0.fpksk_count() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_keyswitch_key.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_keyswitch_key.rs new file mode 100644 index 000000000..8c1ab7da6 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_keyswitch_key.rs @@ -0,0 +1,194 @@ +use crate::core_crypto::commons::crypto::lwe::LweKeyswitchKey as ImplLweKeyswitchKey; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; +use crate::core_crypto::specification::entities::markers::LweKeyswitchKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweKeyswitchKeyEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing an LWE keyswitch key with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweKeyswitchKey32(pub(crate) ImplLweKeyswitchKey>); +impl AbstractEntity for LweKeyswitchKey32 { + type Kind = LweKeyswitchKeyKind; +} +impl LweKeyswitchKeyEntity for LweKeyswitchKey32 { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.before_key_size() + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.after_key_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_levels_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweKeyswitchKey32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE keyswitch key with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweKeyswitchKey64(pub(crate) ImplLweKeyswitchKey>); +impl AbstractEntity for LweKeyswitchKey64 { + type Kind = LweKeyswitchKeyKind; +} +impl LweKeyswitchKeyEntity for LweKeyswitchKey64 { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.before_key_size() + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.after_key_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_levels_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweKeyswitchKey64Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE keyswitch key with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweKeyswitchKeyMutView32<'a>(pub(crate) ImplLweKeyswitchKey<&'a mut [u32]>); +impl AbstractEntity for LweKeyswitchKeyMutView32<'_> { + type Kind = LweKeyswitchKeyKind; +} +impl LweKeyswitchKeyEntity for LweKeyswitchKeyMutView32<'_> { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.before_key_size() + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.after_key_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_levels_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } +} + +/// A structure representing an LWE keyswitch key with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but mutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Mutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweKeyswitchKeyMutView64<'a>(pub(crate) ImplLweKeyswitchKey<&'a mut [u64]>); +impl AbstractEntity for LweKeyswitchKeyMutView64<'_> { + type Kind = LweKeyswitchKeyKind; +} +impl LweKeyswitchKeyEntity for LweKeyswitchKeyMutView64<'_> { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.before_key_size() + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.after_key_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_levels_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } +} + +/// A structure representing an LWE keyswitch key with 32 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweKeyswitchKeyView32<'a>(pub(crate) ImplLweKeyswitchKey<&'a [u32]>); +impl AbstractEntity for LweKeyswitchKeyView32<'_> { + type Kind = LweKeyswitchKeyKind; +} +impl LweKeyswitchKeyEntity for LweKeyswitchKeyView32<'_> { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.before_key_size() + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.after_key_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_levels_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } +} + +/// A structure representing an LWE keyswitch key with 64 bits of precision. +/// +/// By _view_ here, we mean that the entity does not own the data, but immutably borrows it. +/// +/// Notes: +/// ------ +/// This view is not Clone as Clone for a slice is not defined. It is not Deserialize either, +/// as Deserialize of a slice is not defined. Immutable variant. +#[derive(Debug, PartialEq, Eq)] +pub struct LweKeyswitchKeyView64<'a>(pub(crate) ImplLweKeyswitchKey<&'a [u64]>); +impl AbstractEntity for LweKeyswitchKeyView64<'_> { + type Kind = LweKeyswitchKeyKind; +} +impl LweKeyswitchKeyEntity for LweKeyswitchKeyView64<'_> { + fn input_lwe_dimension(&self) -> LweDimension { + self.0.before_key_size() + } + + fn output_lwe_dimension(&self) -> LweDimension { + self.0.after_key_size() + } + + fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.0.decomposition_levels_count() + } + + fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.0.decomposition_base_log() + } +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_public_key.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_public_key.rs new file mode 100644 index 000000000..3b8b99b14 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_public_key.rs @@ -0,0 +1,55 @@ +use crate::core_crypto::commons::crypto::lwe::LweList as ImpLwePublicKey; +use crate::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount}; +use crate::core_crypto::specification::entities::markers::LwePublicKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LwePublicKeyEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing an LWE secret key with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LwePublicKey32(pub(crate) ImpLwePublicKey>); +impl AbstractEntity for LwePublicKey32 { + type Kind = LwePublicKeyKind; +} + +impl LwePublicKeyEntity for LwePublicKey32 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_zero_encryption_count(&self) -> LwePublicKeyZeroEncryptionCount { + LwePublicKeyZeroEncryptionCount(self.0.count().0) + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LwePublicKey32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE secret key with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LwePublicKey64(pub(crate) ImpLwePublicKey>); +impl AbstractEntity for LwePublicKey64 { + type Kind = LwePublicKeyKind; +} +impl LwePublicKeyEntity for LwePublicKey64 { + fn lwe_dimension(&self) -> LweDimension { + self.0.lwe_size().to_lwe_dimension() + } + + fn lwe_zero_encryption_count(&self) -> LwePublicKeyZeroEncryptionCount { + LwePublicKeyZeroEncryptionCount(self.0.count().0) + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LwePublicKey64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_secret_key.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_secret_key.rs new file mode 100644 index 000000000..06c43a8c7 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/lwe_secret_key.rs @@ -0,0 +1,46 @@ +use crate::core_crypto::commons::crypto::secret::LweSecretKey as ImpLweSecretKey; +use crate::core_crypto::prelude::{BinaryKeyKind, LweDimension}; +use crate::core_crypto::specification::entities::markers::LweSecretKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweSecretKeyEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing an LWE secret key with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweSecretKey32(pub(crate) ImpLweSecretKey>); +impl AbstractEntity for LweSecretKey32 { + type Kind = LweSecretKeyKind; +} +impl LweSecretKeyEntity for LweSecretKey32 { + fn lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweSecretKey32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing an LWE secret key with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweSecretKey64(pub(crate) ImpLweSecretKey>); +impl AbstractEntity for LweSecretKey64 { + type Kind = LweSecretKeyKind; +} +impl LweSecretKeyEntity for LweSecretKey64 { + fn lwe_dimension(&self) -> LweDimension { + self.0.key_size() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum LweSecretKey64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/mod.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/mod.rs new file mode 100644 index 000000000..8d6f4934f --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/mod.rs @@ -0,0 +1,28 @@ +//! A module containing all the [entities](crate::core_crypto::specification::entities) +//! exposed by the default backend. + +mod cleartext; +mod glwe_ciphertext; +mod glwe_secret_key; +mod lwe_bootstrap_key; +mod lwe_ciphertext; +mod lwe_ciphertext_vector; +mod lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys; +mod lwe_keyswitch_key; +mod lwe_public_key; +mod lwe_secret_key; +mod plaintext; +mod plaintext_vector; + +pub use cleartext::*; +pub use glwe_ciphertext::*; +pub use glwe_secret_key::*; +pub use lwe_bootstrap_key::*; +pub use lwe_ciphertext::*; +pub use lwe_ciphertext_vector::*; +pub use lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys::*; +pub use lwe_keyswitch_key::*; +pub use lwe_public_key::*; +pub use lwe_secret_key::*; +pub use plaintext::*; +pub use plaintext_vector::*; diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/plaintext.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/plaintext.rs new file mode 100644 index 000000000..2cf5ecee1 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/plaintext.rs @@ -0,0 +1,37 @@ +use crate::core_crypto::commons::crypto::encoding::Plaintext as ImplPlaintext; +use crate::core_crypto::specification::entities::markers::PlaintextKind; +use crate::core_crypto::specification::entities::{AbstractEntity, PlaintextEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a plaintext with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Plaintext32(pub(crate) ImplPlaintext); +impl AbstractEntity for Plaintext32 { + type Kind = PlaintextKind; +} +impl PlaintextEntity for Plaintext32 {} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum Plaintext32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a plaintext with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Plaintext64(pub(crate) ImplPlaintext); +impl AbstractEntity for Plaintext64 { + type Kind = PlaintextKind; +} +impl PlaintextEntity for Plaintext64 {} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum Plaintext64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/entities/plaintext_vector.rs b/tfhe/src/core_crypto/backends/default/implementation/entities/plaintext_vector.rs new file mode 100644 index 000000000..2cb1801ba --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/entities/plaintext_vector.rs @@ -0,0 +1,46 @@ +use crate::core_crypto::commons::crypto::encoding::PlaintextList as ImplPlaintextList; +use crate::core_crypto::prelude::PlaintextCount; +use crate::core_crypto::specification::entities::markers::PlaintextVectorKind; +use crate::core_crypto::specification::entities::{AbstractEntity, PlaintextVectorEntity}; +#[cfg(feature = "backend_default_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing a vector of plaintexts with 32 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PlaintextVector32(pub(crate) ImplPlaintextList>); +impl AbstractEntity for PlaintextVector32 { + type Kind = PlaintextVectorKind; +} +impl PlaintextVectorEntity for PlaintextVector32 { + fn plaintext_count(&self) -> PlaintextCount { + self.0.count() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum PlaintextVector32Version { + V0, + #[serde(other)] + Unsupported, +} + +/// A structure representing a vector of plaintexts with 64 bits of precision. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PlaintextVector64(pub(crate) ImplPlaintextList>); +impl AbstractEntity for PlaintextVector64 { + type Kind = PlaintextVectorKind; +} +impl PlaintextVectorEntity for PlaintextVector64 { + fn plaintext_count(&self) -> PlaintextCount { + self.0.count() + } +} + +#[cfg(feature = "backend_default_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum PlaintextVector64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/default/implementation/mod.rs b/tfhe/src/core_crypto/backends/default/implementation/mod.rs new file mode 100644 index 000000000..49169443f --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/implementation/mod.rs @@ -0,0 +1,2 @@ +pub mod engines; +pub mod entities; diff --git a/tfhe/src/core_crypto/backends/default/mod.rs b/tfhe/src/core_crypto/backends/default/mod.rs new file mode 100644 index 000000000..00b126bf7 --- /dev/null +++ b/tfhe/src/core_crypto/backends/default/mod.rs @@ -0,0 +1,5 @@ +//! A pure-rust backend. + +mod implementation; + +pub use implementation::{engines, entities}; diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_bootstrap_key_conversion.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_bootstrap_key_conversion.rs new file mode 100644 index 000000000..ed15b2a43 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_bootstrap_key_conversion.rs @@ -0,0 +1,215 @@ +use super::{FftEngine, FftError}; +use crate::core_crypto::backends::fft::private::crypto::bootstrap::FourierLweBootstrapKey; +use crate::core_crypto::backends::fft::private::crypto::ggsw::fill_with_forward_fourier_scratch; +use crate::core_crypto::backends::fft::private::math::fft::Fft; +use crate::core_crypto::prelude::{ + FftFourierLweBootstrapKey32, FftFourierLweBootstrapKey64, LweBootstrapKey32, LweBootstrapKey64, + LweBootstrapKeyConversionEngine, LweBootstrapKeyConversionError, LweBootstrapKeyEntity, +}; +use aligned_vec::avec; +use concrete_fft::c64; + +impl From for LweBootstrapKeyConversionError { + fn from(err: FftError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// +/// Implementation of [`LweBootstrapKeyConversionEngine`] for [`FftEngine`] that operates on +/// 32 bit integers. It converts a bootstrap key from the standard to the Fourier domain. +impl LweBootstrapKeyConversionEngine for FftEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey32 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// # + /// assert_eq!(fourier_bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(fourier_bsk.polynomial_size(), poly_size); + /// assert_eq!(fourier_bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(fourier_bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(fourier_bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_bootstrap_key( + &mut self, + input: &LweBootstrapKey32, + ) -> Result> + { + FftError::perform_fft_checks(input.polynomial_size())?; + Ok(unsafe { self.convert_lwe_bootstrap_key_unchecked(input) }) + } + + unsafe fn convert_lwe_bootstrap_key_unchecked( + &mut self, + input: &LweBootstrapKey32, + ) -> FftFourierLweBootstrapKey32 { + let glwe_size = input.0.glwe_size(); + + let boxed = avec![ + c64::default(); + input.0.polynomial_size().0 + * input.0.key_size().0 + * input.0.level_count().0 + * glwe_size.0 + * glwe_size.0 + / 2 + ] + .into_boxed_slice(); + let fft = Fft::new(input.0.polynomial_size()); + let fft = fft.as_view(); + self.resize( + fill_with_forward_fourier_scratch(fft) + .unwrap() + .unaligned_bytes_required(), + ); + let stack = self.stack(); + + let mut output = FourierLweBootstrapKey::new( + boxed, + input.0.key_size(), + input.0.polynomial_size(), + input.0.glwe_size(), + input.0.base_log(), + input.0.level_count(), + ); + output + .as_mut_view() + .fill_with_forward_fourier(input.0.as_view(), fft, stack); + FftFourierLweBootstrapKey32(output) + } +} + +/// # Description +/// +/// Implementation of [`LweBootstrapKeyConversionEngine`] for [`FftEngine`] that operates on +/// 64 bit integers. It converts a bootstrap key from the standard to the Fourier domain. +impl LweBootstrapKeyConversionEngine for FftEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey64 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// # + /// assert_eq!(fourier_bsk.glwe_dimension(), glwe_dim); + /// assert_eq!(fourier_bsk.polynomial_size(), poly_size); + /// assert_eq!(fourier_bsk.input_lwe_dimension(), lwe_dim); + /// assert_eq!(fourier_bsk.decomposition_base_log(), dec_bl); + /// assert_eq!(fourier_bsk.decomposition_level_count(), dec_lc); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn convert_lwe_bootstrap_key( + &mut self, + input: &LweBootstrapKey64, + ) -> Result> + { + FftError::perform_fft_checks(input.polynomial_size())?; + Ok(unsafe { self.convert_lwe_bootstrap_key_unchecked(input) }) + } + + unsafe fn convert_lwe_bootstrap_key_unchecked( + &mut self, + input: &LweBootstrapKey64, + ) -> FftFourierLweBootstrapKey64 { + let glwe_size = input.0.glwe_size(); + + let boxed = avec![ + c64::default(); + input.0.polynomial_size().0 + * input.0.key_size().0 + * input.0.level_count().0 + * glwe_size.0 + * glwe_size.0 + / 2 + ] + .into_boxed_slice(); + + let fft = Fft::new(input.0.polynomial_size()); + let fft = fft.as_view(); + self.resize( + fill_with_forward_fourier_scratch(fft) + .unwrap() + .unaligned_bytes_required(), + ); + let stack = self.stack(); + + let mut output = FourierLweBootstrapKey::new( + boxed, + input.0.key_size(), + input.0.polynomial_size(), + input.0.glwe_size(), + input.0.base_log(), + input.0.level_count(), + ); + output + .as_mut_view() + .fill_with_forward_fourier(input.0.as_view(), fft, stack); + FftFourierLweBootstrapKey64(output) + } +} + +impl LweBootstrapKeyConversionEngine for FftEngine +where + Key: LweBootstrapKeyEntity + Clone, +{ + fn convert_lwe_bootstrap_key( + &mut self, + input: &Key, + ) -> Result> { + Ok(unsafe { self.convert_lwe_bootstrap_key_unchecked(input) }) + } + + unsafe fn convert_lwe_bootstrap_key_unchecked(&mut self, input: &Key) -> Key { + (*input).clone() + } +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_discarding_bit_extraction.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_discarding_bit_extraction.rs new file mode 100644 index 000000000..6e1728b93 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_discarding_bit_extraction.rs @@ -0,0 +1,641 @@ +use crate::core_crypto::backends::fft::engines::{FftEngine, FftError}; +use crate::core_crypto::backends::fft::entities::{ + FftFourierLweBootstrapKey32, FftFourierLweBootstrapKey64, +}; +use crate::core_crypto::backends::fft::private::crypto::wop_pbs::{ + extract_bits, extract_bits_scratch, +}; +use crate::core_crypto::backends::fft::private::math::fft::Fft; +use crate::core_crypto::prelude::{ + CiphertextModulusLog, DeltaLog, ExtractedBitsCount, LweBootstrapKeyEntity, LweCiphertext32, + LweCiphertext64, LweCiphertextEntity, LweCiphertextVector32, LweCiphertextVector64, + LweCiphertextVectorMutView32, LweCiphertextVectorMutView64, LweCiphertextView32, + LweCiphertextView64, LweKeyswitchKey32, LweKeyswitchKey64, LweKeyswitchKeyEntity, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextDiscardingBitExtractEngine, LweCiphertextDiscardingBitExtractError, +}; + +impl From for LweCiphertextDiscardingBitExtractError { + fn from(err: FftError) -> Self { + Self::Engine(err) + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingBitExtractEngine`] for [`FftEngine`] that operates +/// on 32 bits integers. +impl + LweCiphertextDiscardingBitExtractEngine< + FftFourierLweBootstrapKey32, + LweKeyswitchKey32, + LweCiphertext32, + LweCiphertextVector32, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(1), PolynomialSize(512)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let extracted_bits_count = ExtractedBitsCount(1); + /// let delta_log = DeltaLog(5); + /// let noise = Variance(2_f64.powf(-50.)); + /// let large_lwe_dim = LweDimension(glwe_dim.0 * poly_size.0); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, and rely on /dev/random only for tests. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let input_lwe_sk: LweSecretKey32 = + /// default_engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_sk.clone())?; + /// let output_lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let bsk: LweBootstrapKey32 = default_engine.generate_new_lwe_bootstrap_key( + /// &output_lwe_sk, + /// &glwe_sk, + /// dec_bl, + /// dec_lc, + /// noise, + /// )?; + /// let ksk: LweKeyswitchKey32 = default_engine.generate_new_lwe_keyswitch_key( + /// &input_lwe_sk, + /// &output_lwe_sk, + /// dec_lc, + /// dec_bl, + /// noise, + /// )?; + /// let bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// let input = default_engine.encrypt_lwe_ciphertext(&input_lwe_sk, &plaintext, noise)?; + /// let mut output = default_engine.zero_encrypt_lwe_ciphertext_vector( + /// &output_lwe_sk, + /// noise, + /// LweCiphertextCount(extracted_bits_count.0), + /// )?; + /// + /// fft_engine.discard_extract_bits_lwe_ciphertext( + /// &mut output, + /// &input, + /// &bsk, + /// &ksk, + /// extracted_bits_count, + /// delta_log, + /// )?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim); + /// assert_eq!( + /// output.lwe_ciphertext_count(), + /// LweCiphertextCount(extracted_bits_count.0) + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_extract_bits_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextVector32, + input: &LweCiphertext32, + bsk: &FftFourierLweBootstrapKey32, + ksk: &LweKeyswitchKey32, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) -> Result<(), LweCiphertextDiscardingBitExtractError> { + FftError::perform_fft_checks(bsk.polynomial_size())?; + LweCiphertextDiscardingBitExtractError::perform_generic_checks( + output, + input, + bsk, + ksk, + extracted_bits_count, + CiphertextModulusLog(32), + delta_log, + )?; + unsafe { + self.discard_extract_bits_lwe_ciphertext_unchecked( + output, + input, + bsk, + ksk, + extracted_bits_count, + delta_log, + ) + }; + Ok(()) + } + + unsafe fn discard_extract_bits_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextVector32, + input: &LweCiphertext32, + bsk: &FftFourierLweBootstrapKey32, + ksk: &LweKeyswitchKey32, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) { + let fft = Fft::new(bsk.polynomial_size()); + let fft = fft.as_view(); + self.resize( + extract_bits_scratch::( + input.lwe_dimension(), + ksk.output_lwe_dimension(), + bsk.glwe_dimension().to_glwe_size(), + bsk.polynomial_size(), + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + extract_bits( + output.0.as_mut_view(), + input.0.as_view(), + ksk.0.as_view(), + bsk.0.as_view(), + delta_log, + extracted_bits_count, + fft, + self.stack(), + ); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingBitExtractEngine`] for [`FftEngine`] that operates +/// on 64 bits integers. +impl + LweCiphertextDiscardingBitExtractEngine< + FftFourierLweBootstrapKey64, + LweKeyswitchKey64, + LweCiphertext64, + LweCiphertextVector64, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u64 << 50; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(1), PolynomialSize(512)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let extracted_bits_count = ExtractedBitsCount(1); + /// let delta_log = DeltaLog(5); + /// let noise = Variance(2_f64.powf(-50.)); + /// let large_lwe_dim = LweDimension(glwe_dim.0 * poly_size.0); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, and rely on /dev/random only for tests. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let input_lwe_sk: LweSecretKey64 = + /// default_engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_sk.clone())?; + /// let output_lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let bsk: LweBootstrapKey64 = default_engine.generate_new_lwe_bootstrap_key( + /// &output_lwe_sk, + /// &glwe_sk, + /// dec_bl, + /// dec_lc, + /// noise, + /// )?; + /// let ksk: LweKeyswitchKey64 = default_engine.generate_new_lwe_keyswitch_key( + /// &input_lwe_sk, + /// &output_lwe_sk, + /// dec_lc, + /// dec_bl, + /// noise, + /// )?; + /// let bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// let input = default_engine.encrypt_lwe_ciphertext(&input_lwe_sk, &plaintext, noise)?; + /// let mut output = default_engine.zero_encrypt_lwe_ciphertext_vector( + /// &output_lwe_sk, + /// noise, + /// LweCiphertextCount(extracted_bits_count.0), + /// )?; + /// + /// fft_engine.discard_extract_bits_lwe_ciphertext( + /// &mut output, + /// &input, + /// &bsk, + /// &ksk, + /// extracted_bits_count, + /// delta_log, + /// )?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim); + /// assert_eq!( + /// output.lwe_ciphertext_count(), + /// LweCiphertextCount(extracted_bits_count.0) + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_extract_bits_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextVector64, + input: &LweCiphertext64, + bsk: &FftFourierLweBootstrapKey64, + ksk: &LweKeyswitchKey64, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) -> Result<(), LweCiphertextDiscardingBitExtractError> { + FftError::perform_fft_checks(bsk.polynomial_size())?; + LweCiphertextDiscardingBitExtractError::perform_generic_checks( + output, + input, + bsk, + ksk, + extracted_bits_count, + CiphertextModulusLog(64), + delta_log, + )?; + unsafe { + self.discard_extract_bits_lwe_ciphertext_unchecked( + output, + input, + bsk, + ksk, + extracted_bits_count, + delta_log, + ) + }; + Ok(()) + } + + unsafe fn discard_extract_bits_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextVector64, + input: &LweCiphertext64, + bsk: &FftFourierLweBootstrapKey64, + ksk: &LweKeyswitchKey64, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) { + let fft = Fft::new(bsk.polynomial_size()); + let fft = fft.as_view(); + self.resize( + extract_bits_scratch::( + input.lwe_dimension(), + ksk.output_lwe_dimension(), + bsk.glwe_dimension().to_glwe_size(), + bsk.polynomial_size(), + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + extract_bits( + output.0.as_mut_view(), + input.0.as_view(), + ksk.0.as_view(), + bsk.0.as_view(), + delta_log, + extracted_bits_count, + fft, + self.stack(), + ); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingBitExtractEngine`] for [`FftEngine`] that operates +/// on views containing 32 bits integers. +impl + LweCiphertextDiscardingBitExtractEngine< + FftFourierLweBootstrapKey32, + LweKeyswitchKey32, + LweCiphertextView32<'_>, + LweCiphertextVectorMutView32<'_>, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(1), PolynomialSize(512)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let extracted_bits_count = ExtractedBitsCount(1); + /// let delta_log = DeltaLog(5); + /// let noise = Variance(2_f64.powf(-50.)); + /// let large_lwe_dim = LweDimension(glwe_dim.0 * poly_size.0); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, and rely on /dev/random only for tests. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let input_lwe_sk: LweSecretKey32 = + /// default_engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_sk.clone())?; + /// let output_lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let bsk: LweBootstrapKey32 = default_engine.generate_new_lwe_bootstrap_key( + /// &output_lwe_sk, + /// &glwe_sk, + /// dec_bl, + /// dec_lc, + /// noise, + /// )?; + /// let ksk: LweKeyswitchKey32 = default_engine.generate_new_lwe_keyswitch_key( + /// &input_lwe_sk, + /// &output_lwe_sk, + /// dec_lc, + /// dec_bl, + /// noise, + /// )?; + /// let bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// + /// let mut input_ct_container = vec![0u32; input_lwe_sk.lwe_dimension().to_lwe_size().0]; + /// let mut input: LweCiphertextMutView32 = + /// default_engine.create_lwe_ciphertext_from(input_ct_container.as_mut_slice())?; + /// + /// let mut output_ct_vec_container = + /// vec![0u32; output_lwe_sk.lwe_dimension().to_lwe_size().0 * extracted_bits_count.0]; + /// let mut output: LweCiphertextVectorMutView32 = default_engine + /// .create_lwe_ciphertext_vector_from( + /// output_ct_vec_container.as_mut_slice(), + /// output_lwe_sk.lwe_dimension().to_lwe_size(), + /// )?; + /// + /// default_engine.discard_encrypt_lwe_ciphertext(&input_lwe_sk, &mut input, &plaintext, noise)?; + /// + /// let input_slice = default_engine.consume_retrieve_lwe_ciphertext(input)?; + /// let input: LweCiphertextView32 = default_engine.create_lwe_ciphertext_from(&input_slice[..])?; + /// + /// fft_engine.discard_extract_bits_lwe_ciphertext( + /// &mut output, + /// &input, + /// &bsk, + /// &ksk, + /// extracted_bits_count, + /// delta_log, + /// )?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim); + /// assert_eq!( + /// output.lwe_ciphertext_count(), + /// LweCiphertextCount(extracted_bits_count.0) + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_extract_bits_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextVectorMutView32<'_>, + input: &LweCiphertextView32<'_>, + bsk: &FftFourierLweBootstrapKey32, + ksk: &LweKeyswitchKey32, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) -> Result<(), LweCiphertextDiscardingBitExtractError> { + FftError::perform_fft_checks(bsk.polynomial_size())?; + LweCiphertextDiscardingBitExtractError::perform_generic_checks( + output, + input, + bsk, + ksk, + extracted_bits_count, + CiphertextModulusLog(32), + delta_log, + )?; + unsafe { + self.discard_extract_bits_lwe_ciphertext_unchecked( + output, + input, + bsk, + ksk, + extracted_bits_count, + delta_log, + ) + }; + Ok(()) + } + + unsafe fn discard_extract_bits_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextVectorMutView32<'_>, + input: &LweCiphertextView32<'_>, + bsk: &FftFourierLweBootstrapKey32, + ksk: &LweKeyswitchKey32, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) { + let fft = Fft::new(bsk.polynomial_size()); + let fft = fft.as_view(); + self.resize( + extract_bits_scratch::( + input.lwe_dimension(), + ksk.output_lwe_dimension(), + bsk.glwe_dimension().to_glwe_size(), + bsk.polynomial_size(), + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + extract_bits( + output.0.as_mut_view(), + input.0.as_view(), + ksk.0.as_view(), + bsk.0.as_view(), + delta_log, + extracted_bits_count, + fft, + self.stack(), + ); + } +} + +/// # Description: +/// Implementation of [`LweCiphertextDiscardingBitExtractEngine`] for [`FftEngine`] that operates +/// on views containing 64 bits integers. +impl + LweCiphertextDiscardingBitExtractEngine< + FftFourierLweBootstrapKey64, + LweKeyswitchKey64, + LweCiphertextView64<'_>, + LweCiphertextVectorMutView64<'_>, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u64 << 20; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(1), PolynomialSize(512)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let extracted_bits_count = ExtractedBitsCount(1); + /// let delta_log = DeltaLog(5); + /// let noise = Variance(2_f64.powf(-50.)); + /// let large_lwe_dim = LweDimension(glwe_dim.0 * poly_size.0); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, and rely on /dev/random only for tests. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let input_lwe_sk: LweSecretKey64 = + /// default_engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_sk.clone())?; + /// let output_lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let bsk: LweBootstrapKey64 = default_engine.generate_new_lwe_bootstrap_key( + /// &output_lwe_sk, + /// &glwe_sk, + /// dec_bl, + /// dec_lc, + /// noise, + /// )?; + /// let ksk: LweKeyswitchKey64 = default_engine.generate_new_lwe_keyswitch_key( + /// &input_lwe_sk, + /// &output_lwe_sk, + /// dec_lc, + /// dec_bl, + /// noise, + /// )?; + /// let bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// + /// let mut input_ct_container = vec![0u64; input_lwe_sk.lwe_dimension().to_lwe_size().0]; + /// let mut input: LweCiphertextMutView64 = + /// default_engine.create_lwe_ciphertext_from(input_ct_container.as_mut_slice())?; + /// + /// let mut output_ct_vec_container = + /// vec![0u64; output_lwe_sk.lwe_dimension().to_lwe_size().0 * extracted_bits_count.0]; + /// let mut output: LweCiphertextVectorMutView64 = default_engine + /// .create_lwe_ciphertext_vector_from( + /// output_ct_vec_container.as_mut_slice(), + /// output_lwe_sk.lwe_dimension().to_lwe_size(), + /// )?; + /// + /// default_engine.discard_encrypt_lwe_ciphertext(&input_lwe_sk, &mut input, &plaintext, noise)?; + /// + /// let input_slice = default_engine.consume_retrieve_lwe_ciphertext(input)?; + /// let input: LweCiphertextView64 = default_engine.create_lwe_ciphertext_from(&input_slice[..])?; + /// + /// fft_engine.discard_extract_bits_lwe_ciphertext( + /// &mut output, + /// &input, + /// &bsk, + /// &ksk, + /// extracted_bits_count, + /// delta_log, + /// )?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim); + /// assert_eq!( + /// output.lwe_ciphertext_count(), + /// LweCiphertextCount(extracted_bits_count.0) + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_extract_bits_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextVectorMutView64<'_>, + input: &LweCiphertextView64<'_>, + bsk: &FftFourierLweBootstrapKey64, + ksk: &LweKeyswitchKey64, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) -> Result<(), LweCiphertextDiscardingBitExtractError> { + FftError::perform_fft_checks(bsk.polynomial_size())?; + LweCiphertextDiscardingBitExtractError::perform_generic_checks( + output, + input, + bsk, + ksk, + extracted_bits_count, + CiphertextModulusLog(64), + delta_log, + )?; + unsafe { + self.discard_extract_bits_lwe_ciphertext_unchecked( + output, + input, + bsk, + ksk, + extracted_bits_count, + delta_log, + ) + }; + Ok(()) + } + + unsafe fn discard_extract_bits_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextVectorMutView64<'_>, + input: &LweCiphertextView64<'_>, + bsk: &FftFourierLweBootstrapKey64, + ksk: &LweKeyswitchKey64, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) { + let fft = Fft::new(bsk.polynomial_size()); + let fft = fft.as_view(); + self.resize( + extract_bits_scratch::( + input.lwe_dimension(), + ksk.output_lwe_dimension(), + bsk.glwe_dimension().to_glwe_size(), + bsk.polynomial_size(), + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + extract_bits( + output.0.as_mut_view(), + input.0.as_view(), + ksk.0.as_view(), + bsk.0.as_view(), + delta_log, + extracted_bits_count, + fft, + self.stack(), + ); + } +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_discarding_bootstrap.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_discarding_bootstrap.rs new file mode 100644 index 000000000..548fbe49d --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_discarding_bootstrap.rs @@ -0,0 +1,667 @@ +use super::{FftEngine, FftError}; +use crate::core_crypto::backends::fft::private::crypto::bootstrap::bootstrap_scratch; +use crate::core_crypto::backends::fft::private::math::fft::Fft; +use crate::core_crypto::commons::math::tensor::{AsMutSlice, AsRefSlice}; +use crate::core_crypto::prelude::{ + FftFourierLweBootstrapKey32, FftFourierLweBootstrapKey64, GlweCiphertext32, GlweCiphertext64, + GlweCiphertextEntity, GlweCiphertextView32, GlweCiphertextView64, LweCiphertext32, + LweCiphertext64, LweCiphertextDiscardingBootstrapEngine, LweCiphertextDiscardingBootstrapError, + LweCiphertextMutView32, LweCiphertextMutView64, LweCiphertextView32, LweCiphertextView64, +}; + +impl From for LweCiphertextDiscardingBootstrapError { + fn from(err: FftError) -> Self { + Self::Engine(err) + } +} + +/// # Description +/// +/// Implementation of [`LweCiphertextDiscardingBootstrapEngine`] for [`FftEngine`] that operates +/// on 32 bit integers. +impl + LweCiphertextDiscardingBootstrapEngine< + FftFourierLweBootstrapKey32, + GlweCiphertext32, + LweCiphertext32, + LweCiphertext32, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// let input = 3_u32 << 20; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, lwe_dim_output, glwe_dim, poly_size) = ( + /// LweDimension(4), + /// LweDimension(1024), + /// GlweDimension(1), + /// PolynomialSize(1024), + /// ); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// // A constant function is applied during the bootstrap + /// let lut = vec![8_u32 << 20; poly_size.0]; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey32 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// let bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let lwe_sk_output: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(lwe_dim_output)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// let plaintext_vector = default_engine.create_plaintext_vector_from(&lut)?; + /// let acc = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dim.to_glwe_size(), &plaintext_vector)?; + /// let input = default_engine.encrypt_lwe_ciphertext(&lwe_sk, &plaintext, noise)?; + /// let mut output = default_engine.zero_encrypt_lwe_ciphertext(&lwe_sk_output, noise)?; + /// + /// fft_engine.discard_bootstrap_lwe_ciphertext(&mut output, &input, &acc, &bsk)?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim_output); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + acc: &GlweCiphertext32, + bsk: &FftFourierLweBootstrapKey32, + ) -> Result<(), LweCiphertextDiscardingBootstrapError> { + FftError::perform_fft_checks(acc.polynomial_size())?; + LweCiphertextDiscardingBootstrapError::perform_generic_checks(output, input, acc, bsk)?; + unsafe { self.discard_bootstrap_lwe_ciphertext_unchecked(output, input, acc, bsk) }; + Ok(()) + } + + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext32, + input: &LweCiphertext32, + acc: &GlweCiphertext32, + bsk: &FftFourierLweBootstrapKey32, + ) { + let fft = Fft::new(acc.0.polynomial_size()); + let fft = fft.as_view(); + self.resize( + bootstrap_scratch::(acc.0.size(), acc.0.polynomial_size(), fft) + .unwrap() + .unaligned_bytes_required(), + ); + bsk.0.as_view().bootstrap( + output.0.tensor.as_mut_slice(), + input.0.tensor.as_slice(), + acc.0.as_view(), + fft, + self.stack(), + ); + } +} + +/// # Description +/// +/// Implementation of [`LweCiphertextDiscardingBootstrapEngine`] for [`FftEngine`] that operates +/// on 64 bit integers. +impl + LweCiphertextDiscardingBootstrapEngine< + FftFourierLweBootstrapKey64, + GlweCiphertext64, + LweCiphertext64, + LweCiphertext64, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 50 bits) + /// let input = 3_u64 << 50; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, lwe_dim_output, glwe_dim, poly_size) = ( + /// LweDimension(4), + /// LweDimension(1024), + /// GlweDimension(1), + /// PolynomialSize(1024), + /// ); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// // A constant function is applied during the bootstrap + /// let lut = vec![8_u64 << 50; poly_size.0]; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey64 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// let bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let lwe_sk_output: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(lwe_dim_output)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// let plaintext_vector = default_engine.create_plaintext_vector_from(&lut)?; + /// let acc = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dim.to_glwe_size(), &plaintext_vector)?; + /// let input = default_engine.encrypt_lwe_ciphertext(&lwe_sk, &plaintext, noise)?; + /// let mut output = default_engine.zero_encrypt_lwe_ciphertext(&lwe_sk_output, noise)?; + /// + /// fft_engine.discard_bootstrap_lwe_ciphertext(&mut output, &input, &acc, &bsk)?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim_output); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + acc: &GlweCiphertext64, + bsk: &FftFourierLweBootstrapKey64, + ) -> Result<(), LweCiphertextDiscardingBootstrapError> { + FftError::perform_fft_checks(acc.polynomial_size())?; + LweCiphertextDiscardingBootstrapError::perform_generic_checks(output, input, acc, bsk)?; + unsafe { self.discard_bootstrap_lwe_ciphertext_unchecked(output, input, acc, bsk) }; + Ok(()) + } + + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertext64, + input: &LweCiphertext64, + acc: &GlweCiphertext64, + bsk: &FftFourierLweBootstrapKey64, + ) { + let fft = Fft::new(acc.0.polynomial_size()); + let fft = fft.as_view(); + self.resize( + bootstrap_scratch::(acc.0.size(), acc.0.polynomial_size(), fft) + .unwrap() + .unaligned_bytes_required(), + ); + bsk.0.as_view().bootstrap( + output.0.tensor.as_mut_slice(), + input.0.tensor.as_slice(), + acc.0.as_view(), + fft, + self.stack(), + ); + } +} + +/// # Description +/// +/// Implementation of [`LweCiphertextDiscardingBootstrapEngine`] for [`FftEngine`] that operates +/// on 32 bit integers. +impl + LweCiphertextDiscardingBootstrapEngine< + FftFourierLweBootstrapKey32, + GlweCiphertextView32<'_>, + LweCiphertextView32<'_>, + LweCiphertextMutView32<'_>, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// use tfhe::core_crypto::backends::fft::engines::FftEngine; + /// use tfhe::core_crypto::backends::fft::entities::FftFourierLweBootstrapKey32; + /// let input = 3_u32 << 20; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, lwe_dim_output, glwe_dim, poly_size) = ( + /// LweDimension(4), + /// LweDimension(1024), + /// GlweDimension(1), + /// PolynomialSize(1024), + /// ); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// // A constant function is applied during the bootstrap + /// let lut = vec![8_u32 << 20; poly_size.0]; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey32 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// let bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let lwe_sk_output: LweSecretKey32 = + /// default_engine.generate_new_lwe_secret_key(lwe_dim_output)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// let plaintext_vector = default_engine.create_plaintext_vector_from(&lut)?; + /// let acc = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dim.to_glwe_size(), &plaintext_vector)?; + /// + /// // Get the GlweCiphertext as a View + /// let raw_glwe = default_engine.consume_retrieve_glwe_ciphertext(acc)?; + /// let acc: GlweCiphertextView32 = + /// default_engine.create_glwe_ciphertext_from(&raw_glwe[..], poly_size)?; + /// + /// let mut raw_input_container = vec![0_u32; lwe_sk.lwe_dimension().to_lwe_size().0]; + /// let input: LweCiphertextMutView32 = + /// default_engine.create_lwe_ciphertext_from(&mut raw_input_container[..])?; + /// let input = default_engine.encrypt_lwe_ciphertext(&lwe_sk, &plaintext, noise)?; + /// + /// // Convert MutView to View + /// let raw_input = default_engine.consume_retrieve_lwe_ciphertext(input)?; + /// let input = default_engine.create_lwe_ciphertext_from(&raw_input[..])?; + /// + /// let mut raw_output_container = vec![0_u32; lwe_sk_output.lwe_dimension().to_lwe_size().0]; + /// let mut output = default_engine.create_lwe_ciphertext_from(&mut raw_output_container[..])?; + /// + /// fft_engine.discard_bootstrap_lwe_ciphertext(&mut output, &input, &acc, &bsk)?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim_output); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView32, + input: &LweCiphertextView32, + acc: &GlweCiphertextView32, + bsk: &FftFourierLweBootstrapKey32, + ) -> Result<(), LweCiphertextDiscardingBootstrapError> { + FftError::perform_fft_checks(acc.polynomial_size())?; + LweCiphertextDiscardingBootstrapError::perform_generic_checks(output, input, acc, bsk)?; + unsafe { self.discard_bootstrap_lwe_ciphertext_unchecked(output, input, acc, bsk) }; + Ok(()) + } + + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView32, + input: &LweCiphertextView32, + acc: &GlweCiphertextView32, + bsk: &FftFourierLweBootstrapKey32, + ) { + let fft = Fft::new(acc.0.polynomial_size()); + let fft = fft.as_view(); + self.resize( + bootstrap_scratch::(acc.0.size(), acc.0.polynomial_size(), fft) + .unwrap() + .unaligned_bytes_required(), + ); + bsk.0.as_view().bootstrap( + output.0.tensor.as_mut_slice(), + input.0.tensor.as_slice(), + acc.0.as_view(), + fft, + self.stack(), + ); + } +} + +/// # Description +/// +/// Implementation of [`LweCiphertextDiscardingBootstrapEngine`] for [`FftEngine`] that operates +/// on 64 bit integers. +impl + LweCiphertextDiscardingBootstrapEngine< + FftFourierLweBootstrapKey64, + GlweCiphertextView64<'_>, + LweCiphertextView64<'_>, + LweCiphertextMutView64<'_>, + > for FftEngine +{ + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // Here a hard-set encoding is applied (shift by 20 bits) + /// use tfhe::core_crypto::backends::fft::engines::FftEngine; + /// use tfhe::core_crypto::backends::fft::entities::FftFourierLweBootstrapKey32; + /// let input = 3_u64 << 20; + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, lwe_dim_output, glwe_dim, poly_size) = ( + /// LweDimension(4), + /// LweDimension(1024), + /// GlweDimension(1), + /// PolynomialSize(1024), + /// ); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// // A constant function is applied during the bootstrap + /// let lut = vec![8_u64 << 20; poly_size.0]; + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey64 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// let bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// let lwe_sk_output: LweSecretKey64 = + /// default_engine.generate_new_lwe_secret_key(lwe_dim_output)?; + /// let plaintext = default_engine.create_plaintext_from(&input)?; + /// let plaintext_vector = default_engine.create_plaintext_vector_from(&lut)?; + /// let acc = default_engine + /// .trivially_encrypt_glwe_ciphertext(glwe_dim.to_glwe_size(), &plaintext_vector)?; + /// + /// // Get the GlweCiphertext as a View + /// let raw_glwe = default_engine.consume_retrieve_glwe_ciphertext(acc)?; + /// let acc: GlweCiphertextView64 = + /// default_engine.create_glwe_ciphertext_from(&raw_glwe[..], poly_size)?; + /// + /// let mut raw_input_container = vec![0_u64; lwe_sk.lwe_dimension().to_lwe_size().0]; + /// let input: LweCiphertextMutView64 = + /// default_engine.create_lwe_ciphertext_from(&mut raw_input_container[..])?; + /// let input = default_engine.encrypt_lwe_ciphertext(&lwe_sk, &plaintext, noise)?; + /// + /// // Convert MutView to View + /// let raw_input = default_engine.consume_retrieve_lwe_ciphertext(input)?; + /// let input = default_engine.create_lwe_ciphertext_from(&raw_input[..])?; + /// + /// let mut raw_output_container = vec![0_u64; lwe_sk_output.lwe_dimension().to_lwe_size().0]; + /// let mut output = default_engine.create_lwe_ciphertext_from(&mut raw_output_container[..])?; + /// + /// fft_engine.discard_bootstrap_lwe_ciphertext(&mut output, &input, &acc, &bsk)?; + /// # + /// assert_eq!(output.lwe_dimension(), lwe_dim_output); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut LweCiphertextMutView64, + input: &LweCiphertextView64, + acc: &GlweCiphertextView64, + bsk: &FftFourierLweBootstrapKey64, + ) -> Result<(), LweCiphertextDiscardingBootstrapError> { + FftError::perform_fft_checks(acc.polynomial_size())?; + LweCiphertextDiscardingBootstrapError::perform_generic_checks(output, input, acc, bsk)?; + unsafe { self.discard_bootstrap_lwe_ciphertext_unchecked(output, input, acc, bsk) }; + Ok(()) + } + + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut LweCiphertextMutView64, + input: &LweCiphertextView64, + acc: &GlweCiphertextView64, + bsk: &FftFourierLweBootstrapKey64, + ) { + let fft = Fft::new(acc.0.polynomial_size()); + let fft = fft.as_view(); + self.resize( + bootstrap_scratch::(acc.0.size(), acc.0.polynomial_size(), fft) + .unwrap() + .unaligned_bytes_required(), + ); + bsk.0.as_view().bootstrap( + output.0.tensor.as_mut_slice(), + input.0.tensor.as_slice(), + acc.0.as_view(), + fft, + self.stack(), + ); + } +} + +#[cfg(test)] +mod unit_test_pbs { + use crate::core_crypto::commons::test_tools::new_random_generator; + use crate::core_crypto::prelude::*; + use std::error::Error; + + fn generate_accumulator_with_engine( + engine: &mut DefaultEngine, + bootstrapping_key: &FftFourierLweBootstrapKey64, + message_modulus: usize, + carry_modulus: usize, + f: F, + ) -> Result> + where + F: Fn(u64) -> u64, + { + // Modulus of the msg contained in the msg bits and operations buffer + let modulus_sup = message_modulus * carry_modulus; + + // N/(p/2) = size of each block + let box_size = bootstrapping_key.polynomial_size().0 / modulus_sup; + + // Value of the shift we multiply our messages by + let delta = (1_u64 << 63) / (modulus_sup) as u64; + + // Create the accumulator + let mut accumulator_u64 = vec![0_u64; bootstrapping_key.polynomial_size().0]; + + // This accumulator extracts the carry bits + for i in 0..modulus_sup { + let index = i as usize * box_size; + accumulator_u64[index..index + box_size] + .iter_mut() + .for_each(|a| *a = f(i as u64) * delta); + } + + let half_box_size = box_size / 2; + + // Negate the first half_box_size coefficients + for a_i in accumulator_u64[0..half_box_size].iter_mut() { + *a_i = (*a_i).wrapping_neg(); + } + + // Rotate the accumulator + accumulator_u64.rotate_left(half_box_size); + + // Everywhere + let accumulator_plaintext = engine.create_plaintext_vector_from(&accumulator_u64)?; + + let accumulator = engine.trivially_encrypt_glwe_ciphertext( + bootstrapping_key.glwe_dimension().to_glwe_size(), + &accumulator_plaintext, + )?; + + Ok(accumulator) + } + + #[test] + fn test_pbs() -> Result<(), Box> { + // Shortint 2_2 params + let lwe_dimension = LweDimension(742); + let glwe_dimension = GlweDimension(1); + let polynomial_size = PolynomialSize(2048); + let lwe_modular_std_dev = StandardDev(0.000007069849454709433); + let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533); + let pbs_base_log = DecompositionBaseLog(23); + let pbs_level = DecompositionLevelCount(1); + let message_modulus: usize = 4; + let carry_modulus: usize = 4; + + let payload_modulus = (message_modulus * carry_modulus) as u64; + + // Value of the shift we multiply our messages by + let delta = (1_u64 << 63) / payload_modulus; + + // Unix seeder must be given a secret input. + // Here we just give it 0, which is totally unsafe. + const UNSAFE_SECRET: u128 = 0; + + let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + let mut fft_engine = FftEngine::new(())?; + + let mut default_parallel_engine = + DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + + let repetitions = 10; + let samples = 100; + + let mut error_sample_vec = Vec::::with_capacity(repetitions * samples); + + let mut generator = new_random_generator(); + + for _ in 0..repetitions { + // Generate client-side keys + + // generate the lwe secret key + let small_lwe_secret_key: LweSecretKey64 = + default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + + // generate the rlwe secret key + let glwe_secret_key: GlweSecretKey64 = + default_engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + + let large_lwe_secret_key = default_engine + .transform_glwe_secret_key_to_lwe_secret_key(glwe_secret_key.clone())?; + + // Convert into a variance for rlwe context + let var_rlwe = Variance(glwe_modular_std_dev.get_variance()); + + let bootstrap_key: LweBootstrapKey64 = default_parallel_engine + .generate_new_lwe_bootstrap_key( + &small_lwe_secret_key, + &glwe_secret_key, + pbs_base_log, + pbs_level, + var_rlwe, + )?; + + // Creation of the bootstrapping key in the Fourier domain + + let fourier_bsk: FftFourierLweBootstrapKey64 = + fft_engine.convert_lwe_bootstrap_key(&bootstrap_key)?; + + let accumulator = generate_accumulator_with_engine( + &mut default_engine, + &fourier_bsk, + message_modulus, + carry_modulus, + |x| x, + )?; + + // convert into a variance + let var_lwe = Variance(lwe_modular_std_dev.get_variance()); + + for _ in 0..samples { + let input_plaintext: u64 = + (generator.random_uniform::() % payload_modulus) << delta; + + let plaintext = default_engine.create_plaintext_from(&input_plaintext)?; + let input = default_engine.encrypt_lwe_ciphertext( + &small_lwe_secret_key, + &plaintext, + var_lwe, + )?; + + let mut output = + default_engine.zero_encrypt_lwe_ciphertext(&large_lwe_secret_key, var_lwe)?; + + fft_engine.discard_bootstrap_lwe_ciphertext( + &mut output, + &input, + &accumulator, + &fourier_bsk, + )?; + + // decryption + let decrypted = + default_engine.decrypt_lwe_ciphertext(&large_lwe_secret_key, &output)?; + + if decrypted == plaintext { + panic!("Equal {decrypted:?}, {plaintext:?}"); + } + + let mut decrypted_u64: u64 = 0; + default_engine.discard_retrieve_plaintext(&mut decrypted_u64, &decrypted)?; + + // let err = if decrypted_u64 >= input_plaintext { + // decrypted_u64 - input_plaintext + // } else { + // input_plaintext - decrypted_u64 + // }; + + let err = { + let d0 = decrypted_u64.wrapping_sub(input_plaintext); + let d1 = input_plaintext.wrapping_sub(decrypted_u64); + std::cmp::min(d0, d1) + }; + + // let err = torus_modular_distance(input_plaintext, decrypted_u64); + + error_sample_vec.push(err); + + //The bit before the message + let rounding_bit = delta >> 1; + + //compute the rounding bit + let rounding = (decrypted_u64 & rounding_bit) << 1; + + let decoded = (decrypted_u64.wrapping_add(rounding)) / delta; + + assert_eq!(decoded, input_plaintext / delta); + } + } + + error_sample_vec.sort(); + + let bit_errors: Vec<_> = error_sample_vec + .iter() + .map(|&x| if x != 0 { 63 - x.leading_zeros() } else { 0 }) + .collect(); + + let mean_bit_errors: u32 = bit_errors.iter().sum::() / bit_errors.len() as u32; + let mean_bit_errors_f64: f64 = + bit_errors.iter().map(|&x| x as f64).sum::() as f64 / bit_errors.len() as f64; + + for (idx, (&val, &bit_error)) in error_sample_vec.iter().zip(bit_errors.iter()).enumerate() + { + println!("#{idx}: Error {val}, bit_error {bit_error}"); + } + + println!("Mean bit error: {mean_bit_errors}"); + println!("Mean bit error f64: {mean_bit_errors_f64}"); + + Ok(()) + } +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing.rs new file mode 100644 index 000000000..253df4f7c --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing.rs @@ -0,0 +1,531 @@ +use crate::core_crypto::backends::default::entities::{ + LweCiphertextVectorMutView32, LweCiphertextVectorMutView64, LweCiphertextVectorView32, + LweCiphertextVectorView64, LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, PlaintextVector32, + PlaintextVector64, +}; +use crate::core_crypto::backends::fft::engines::{FftEngine, FftError}; +use crate::core_crypto::backends::fft::entities::{ + FftFourierLweBootstrapKey32, FftFourierLweBootstrapKey64, +}; +use crate::core_crypto::backends::fft::private::crypto::wop_pbs::{ + circuit_bootstrap_boolean_vertical_packing, circuit_bootstrap_boolean_vertical_packing_scratch, +}; +use crate::core_crypto::backends::fft::private::math::fft::Fft; +use crate::core_crypto::commons::math::polynomial::PolynomialList; +use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; +use crate::core_crypto::prelude::{ + CiphertextCount, LweCiphertextVectorEntity, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, PlaintextVectorEntity, + PolynomialCount, +}; +use crate::core_crypto::specification::engines::{ + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingEngine, + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError, +}; +use crate::core_crypto::specification::entities::LweBootstrapKeyEntity; +use crate::core_crypto::specification::parameters::{ + DecompositionBaseLog, DecompositionLevelCount, +}; + +impl From + for LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError +{ + fn from(err: FftError) -> Self { + Self::Engine(err) + } +} + +impl + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingEngine< + LweCiphertextVectorView32<'_>, + LweCiphertextVectorMutView32<'_>, + FftFourierLweBootstrapKey32, + PlaintextVector32, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + > for FftEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let polynomial_size = PolynomialSize(1024); + /// let glwe_dimension = GlweDimension(1); + /// let lwe_dimension = LweDimension(481); + /// + /// let var_small = Variance::from_variance(2f64.powf(-70.0)); + /// let var_big = Variance::from_variance(2f64.powf(-60.0)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// let lwe_small_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let lwe_big_sk: LweSecretKey32 = + /// default_engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_sk.clone())?; + /// + /// let bsk_level_count = DecompositionLevelCount(7); + /// let bsk_base_log = DecompositionBaseLog(4); + /// + /// let std_bsk: LweBootstrapKey32 = default_parallel_engine.generate_new_lwe_bootstrap_key( + /// &lwe_small_sk, + /// &glwe_sk, + /// bsk_base_log, + /// bsk_level_count, + /// var_small, + /// )?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey32 = + /// fft_engine.convert_lwe_bootstrap_key(&std_bsk)?; + /// + /// let ksk_level_count = DecompositionLevelCount(9); + /// let ksk_base_log = DecompositionBaseLog(1); + /// + /// let ksk_big_to_small: LweKeyswitchKey32 = default_engine.generate_new_lwe_keyswitch_key( + /// &lwe_big_sk, + /// &lwe_small_sk, + /// ksk_level_count, + /// ksk_base_log, + /// var_big, + /// )?; + /// + /// let pfpksk_level_count = DecompositionLevelCount(7); + /// let pfpksk_base_log = DecompositionBaseLog(4); + /// + /// let cbs_pfpksk: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32 = default_engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &lwe_big_sk, + /// &glwe_sk, + /// pfpksk_base_log, + /// pfpksk_level_count, + /// var_small, + /// )?; + /// + /// // We will have a message with 10 bits of information + /// let message_bits = 10; + /// let bits_to_extract = ExtractedBitsCount(message_bits); + /// + /// // The value we encrypt is 42, we will extract the bits of this value and apply the + /// // circuit bootstrapping followed by the vertical packing on the extracted bits. + /// let cleartext = 42; + /// let delta_log_msg = DeltaLog(32 - message_bits); + /// + /// let encoded_message = default_engine.create_plaintext_from(&(cleartext << delta_log_msg.0))?; + /// let lwe_in = default_engine.encrypt_lwe_ciphertext(&lwe_big_sk, &encoded_message, var_big)?; + /// + /// // Bit extraction output, use the zero_encrypt engine to allocate a ciphertext vector + /// let mut bit_extraction_output = default_engine.zero_encrypt_lwe_ciphertext_vector( + /// &lwe_small_sk, + /// var_small, + /// LweCiphertextCount(bits_to_extract.0), + /// )?; + /// + /// fft_engine.discard_extract_bits_lwe_ciphertext( + /// &mut bit_extraction_output, + /// &lwe_in, + /// &fourier_bsk, + /// &ksk_big_to_small, + /// bits_to_extract, + /// delta_log_msg, + /// )?; + /// + /// // Though the delta log here is the same as the message delta log, in the general case they + /// // are different, so we create two DeltaLog parameters + /// let delta_log_lut = DeltaLog(32 - message_bits); + /// + /// // Create a look-up table we want to apply during vertical packing, here just the identity + /// // with the proper encoding. + /// // Note that this particular table will not trigger the cmux tree from the vertical packing, + /// // adapt the LUT generation to your usage. + /// // Here we apply a single look-up table as we output a single ciphertext. + /// let number_of_luts_and_output_vp_ciphertexts = 1; + /// let lut_size = 1 << bits_to_extract.0; + /// let mut lut: Vec = Vec::with_capacity(lut_size); + /// + /// for i in 0..lut_size { + /// lut.push((i as u32 % (1 << message_bits)) << delta_log_lut.0); + /// } + /// + /// let lut_as_plaintext_vector = default_engine.create_plaintext_vector_from(lut.as_slice())?; + /// + /// // We run on views, so we need a container for the output + /// let mut output_cbs_vp_ct_container = vec![ + /// 0u32; + /// lwe_big_sk.lwe_dimension().to_lwe_size().0 + /// * number_of_luts_and_output_vp_ciphertexts + /// ]; + /// + /// let mut output_cbs_vp_ct_mut_view: LweCiphertextVectorMutView32 = default_engine + /// .create_lwe_ciphertext_vector_from( + /// output_cbs_vp_ct_container.as_mut_slice(), + /// lwe_big_sk.lwe_dimension().to_lwe_size(), + /// )?; + /// // And we need to get a view on the bits extracted earlier that serve as inputs to the + /// // circuit bootstrap + vertical packing + /// let extracted_bits_lwe_size = bit_extraction_output.lwe_dimension().to_lwe_size(); + /// let extracted_bits_container = + /// default_engine.consume_retrieve_lwe_ciphertext_vector(bit_extraction_output)?; + /// let cbs_vp_input_vector_view: LweCiphertextVectorView32 = default_engine + /// .create_lwe_ciphertext_vector_from( + /// extracted_bits_container.as_slice(), + /// extracted_bits_lwe_size, + /// )?; + /// + /// let cbs_level_count = DecompositionLevelCount(4); + /// let cbs_base_log = DecompositionBaseLog(6); + /// + /// fft_engine.discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector( + /// &mut output_cbs_vp_ct_mut_view, + /// &cbs_vp_input_vector_view, + /// &fourier_bsk, + /// &lut_as_plaintext_vector, + /// cbs_level_count, + /// cbs_base_log, + /// &cbs_pfpksk, + /// )?; + /// + /// assert_eq!(output_cbs_vp_ct_mut_view.lwe_ciphertext_count().0, 1); + /// assert_eq!( + /// output_cbs_vp_ct_mut_view.lwe_dimension(), + /// LweDimension(glwe_dimension.0 * polynomial_size.0) + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector( + &mut self, + output: &mut LweCiphertextVectorMutView32, + input: &LweCiphertextVectorView32, + bsk: &FftFourierLweBootstrapKey32, + luts: &PlaintextVector32, + cbs_level_count: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + cbs_pfpksk: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + ) -> Result< + (), + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError, + > { + FftError::perform_fft_checks(bsk.polynomial_size())?; + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError:: + perform_generic_checks( + input, + output, + bsk, + luts, + cbs_level_count, + cbs_base_log, + cbs_pfpksk, + 32, + )?; + unsafe { + self.discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector_unchecked( + output, + input, + bsk, + luts, + cbs_level_count, + cbs_base_log, + cbs_pfpksk, + ); + } + Ok(()) + } + + unsafe fn discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector_unchecked( + &mut self, + output: &mut LweCiphertextVectorMutView32, + input: &LweCiphertextVectorView32, + bsk: &FftFourierLweBootstrapKey32, + luts: &PlaintextVector32, + cbs_level_count: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + cbs_pfpksk: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys32, + ) { + let lut_as_polynomial_list = + PolynomialList::from_container(luts.0.as_tensor().as_slice(), bsk.polynomial_size()); + + let fft = Fft::new(bsk.polynomial_size()); + let fft = fft.as_view(); + self.resize( + circuit_bootstrap_boolean_vertical_packing_scratch::( + CiphertextCount(input.lwe_ciphertext_count().0), + CiphertextCount(output.lwe_ciphertext_count().0), + input.lwe_dimension().to_lwe_size(), + PolynomialCount(luts.plaintext_count().0), + bsk.output_lwe_dimension().to_lwe_size(), + cbs_pfpksk.output_polynomial_size(), + bsk.glwe_dimension().to_glwe_size(), + cbs_level_count, + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + circuit_bootstrap_boolean_vertical_packing( + lut_as_polynomial_list.as_view(), + bsk.0.as_view(), + output.0.as_mut_view(), + input.0.as_view(), + cbs_pfpksk.0.as_view(), + cbs_level_count, + cbs_base_log, + fft, + self.stack(), + ) + } +} + +impl + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingEngine< + LweCiphertextVectorView64<'_>, + LweCiphertextVectorMutView64<'_>, + FftFourierLweBootstrapKey64, + PlaintextVector64, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + > for FftEngine +{ + /// # Example: + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// # use std::error::Error; + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let polynomial_size = PolynomialSize(1024); + /// let glwe_dimension = GlweDimension(1); + /// let lwe_dimension = LweDimension(481); + /// + /// let var_small = Variance::from_variance(2f64.powf(-80.0)); + /// let var_big = Variance::from_variance(2f64.powf(-70.0)); + /// + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut default_parallel_engine = + /// DefaultParallelEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dimension, polynomial_size)?; + /// let lwe_small_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dimension)?; + /// let lwe_big_sk: LweSecretKey64 = + /// default_engine.transform_glwe_secret_key_to_lwe_secret_key(glwe_sk.clone())?; + /// + /// let bsk_level_count = DecompositionLevelCount(9); + /// let bsk_base_log = DecompositionBaseLog(4); + /// + /// let std_bsk: LweBootstrapKey64 = default_parallel_engine.generate_new_lwe_bootstrap_key( + /// &lwe_small_sk, + /// &glwe_sk, + /// bsk_base_log, + /// bsk_level_count, + /// var_small, + /// )?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey64 = + /// fft_engine.convert_lwe_bootstrap_key(&std_bsk)?; + /// + /// let ksk_level_count = DecompositionLevelCount(9); + /// let ksk_base_log = DecompositionBaseLog(1); + /// + /// let ksk_big_to_small: LweKeyswitchKey64 = default_engine.generate_new_lwe_keyswitch_key( + /// &lwe_big_sk, + /// &lwe_small_sk, + /// ksk_level_count, + /// ksk_base_log, + /// var_big, + /// )?; + /// + /// let pfpksk_level_count = DecompositionLevelCount(9); + /// let pfpksk_base_log = DecompositionBaseLog(4); + /// + /// let cbs_pfpksk: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64 = default_engine + /// .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + /// &lwe_big_sk, + /// &glwe_sk, + /// pfpksk_base_log, + /// pfpksk_level_count, + /// var_small, + /// )?; + /// + /// // We will have a message with 10 bits of information + /// let message_bits = 10; + /// let bits_to_extract = ExtractedBitsCount(message_bits); + /// + /// // The value we encrypt is 42, we will extract the bits of this value and apply the + /// // circuit bootstrapping followed by the vertical packing on the extracted bits. + /// let cleartext = 42; + /// let delta_log_msg = DeltaLog(64 - message_bits); + /// + /// let encoded_message = default_engine.create_plaintext_from(&(cleartext << delta_log_msg.0))?; + /// let lwe_in = default_engine.encrypt_lwe_ciphertext(&lwe_big_sk, &encoded_message, var_big)?; + /// + /// // Bit extraction output, use the zero_encrypt engine to allocate a ciphertext vector + /// let mut bit_extraction_output = default_engine.zero_encrypt_lwe_ciphertext_vector( + /// &lwe_small_sk, + /// var_small, + /// LweCiphertextCount(bits_to_extract.0), + /// )?; + /// + /// fft_engine.discard_extract_bits_lwe_ciphertext( + /// &mut bit_extraction_output, + /// &lwe_in, + /// &fourier_bsk, + /// &ksk_big_to_small, + /// bits_to_extract, + /// delta_log_msg, + /// )?; + /// + /// // Though the delta log here is the same as the message delta log, in the general case they + /// // are different, so we create two DeltaLog parameters + /// let delta_log_lut = DeltaLog(64 - message_bits); + /// + /// // Create a look-up table we want to apply during vertical packing, here just the identity + /// // with the proper encoding. + /// // Note that this particular table will not trigger the cmux tree from the vertical packing, + /// // adapt the LUT generation to your usage. + /// // Here we apply a single look-up table as we output a single ciphertext. + /// let number_of_luts_and_output_vp_ciphertexts = 1; + /// let lut_size = 1 << bits_to_extract.0; + /// let mut lut: Vec = Vec::with_capacity(lut_size); + /// + /// for i in 0..lut_size { + /// lut.push((i as u64 % (1 << message_bits)) << delta_log_lut.0); + /// } + /// + /// let lut_as_plaintext_vector = default_engine.create_plaintext_vector_from(lut.as_slice())?; + /// + /// // We run on views, so we need a container for the output + /// let mut output_cbs_vp_ct_container = vec![ + /// 0u64; + /// lwe_big_sk.lwe_dimension().to_lwe_size().0 + /// * number_of_luts_and_output_vp_ciphertexts + /// ]; + /// + /// let mut output_cbs_vp_ct_mut_view: LweCiphertextVectorMutView64 = default_engine + /// .create_lwe_ciphertext_vector_from( + /// output_cbs_vp_ct_container.as_mut_slice(), + /// lwe_big_sk.lwe_dimension().to_lwe_size(), + /// )?; + /// // And we need to get a view on the bits extracted earlier that serve as inputs to the + /// // circuit bootstrap + vertical packing + /// let extracted_bits_lwe_size = bit_extraction_output.lwe_dimension().to_lwe_size(); + /// let extracted_bits_container = + /// default_engine.consume_retrieve_lwe_ciphertext_vector(bit_extraction_output)?; + /// let cbs_vp_input_vector_view: LweCiphertextVectorView64 = default_engine + /// .create_lwe_ciphertext_vector_from( + /// extracted_bits_container.as_slice(), + /// extracted_bits_lwe_size, + /// )?; + /// + /// let cbs_level_count = DecompositionLevelCount(4); + /// let cbs_base_log = DecompositionBaseLog(6); + /// + /// fft_engine.discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector( + /// &mut output_cbs_vp_ct_mut_view, + /// &cbs_vp_input_vector_view, + /// &fourier_bsk, + /// &lut_as_plaintext_vector, + /// cbs_level_count, + /// cbs_base_log, + /// &cbs_pfpksk, + /// )?; + /// + /// assert_eq!(output_cbs_vp_ct_mut_view.lwe_ciphertext_count().0, 1); + /// assert_eq!( + /// output_cbs_vp_ct_mut_view.lwe_dimension(), + /// LweDimension(glwe_dimension.0 * polynomial_size.0) + /// ); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector( + &mut self, + output: &mut LweCiphertextVectorMutView64, + input: &LweCiphertextVectorView64, + bsk: &FftFourierLweBootstrapKey64, + luts: &PlaintextVector64, + cbs_level_count: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + cbs_pfpksk: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + ) -> Result< + (), + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError, + > { + FftError::perform_fft_checks(bsk.polynomial_size())?; + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError:: + perform_generic_checks( + input, + output, + bsk, + luts, + cbs_level_count, + cbs_base_log, + cbs_pfpksk, + 64, + )?; + unsafe { + self.discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector_unchecked( + output, + input, + bsk, + luts, + cbs_level_count, + cbs_base_log, + cbs_pfpksk, + ); + } + Ok(()) + } + + unsafe fn discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector_unchecked( + &mut self, + output: &mut LweCiphertextVectorMutView64, + input: &LweCiphertextVectorView64, + bsk: &FftFourierLweBootstrapKey64, + luts: &PlaintextVector64, + cbs_level_count: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + cbs_pfpksk: &LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + ) { + let lut_as_polynomial_list = + PolynomialList::from_container(luts.0.as_tensor().as_slice(), bsk.polynomial_size()); + + let fft = Fft::new(bsk.polynomial_size()); + let fft = fft.as_view(); + self.resize( + circuit_bootstrap_boolean_vertical_packing_scratch::( + CiphertextCount(input.lwe_ciphertext_count().0), + CiphertextCount(output.lwe_ciphertext_count().0), + input.lwe_dimension().to_lwe_size(), + PolynomialCount(luts.plaintext_count().0), + bsk.output_lwe_dimension().to_lwe_size(), + cbs_pfpksk.output_polynomial_size(), + bsk.glwe_dimension().to_glwe_size(), + cbs_level_count, + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + circuit_bootstrap_boolean_vertical_packing( + lut_as_polynomial_list.as_view(), + bsk.0.as_view(), + output.0.as_mut_view(), + input.0.as_view(), + cbs_pfpksk.0.as_view(), + cbs_level_count, + cbs_base_log, + fft, + self.stack(), + ) + } +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/mod.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/mod.rs new file mode 100644 index 000000000..d0d915884 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_engine/mod.rs @@ -0,0 +1,66 @@ +use crate::core_crypto::prelude::PolynomialSize; +use dyn_stack::DynStack; + +use crate::core_crypto::specification::engines::sealed::AbstractEngineSeal; +use crate::core_crypto::specification::engines::AbstractEngine; +use core::mem::MaybeUninit; + +/// Error that can occur in the execution of FHE operations by the [`FftEngine`]. +#[derive(Debug)] +#[non_exhaustive] +pub enum FftError { + UnsupportedPolynomialSize, +} + +impl core::fmt::Display for FftError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FftError::UnsupportedPolynomialSize => f.write_str( + "The Concrete-FFT backend only supports polynomials of sizes that are powers of two \ + and greater than or equal to 32.", + ), + } + } +} + +impl std::error::Error for FftError {} + +impl FftError { + pub fn perform_fft_checks(polynomial_size: PolynomialSize) -> Result<(), FftError> { + if polynomial_size.0.is_power_of_two() && polynomial_size.0 >= 32 { + Ok(()) + } else { + Err(FftError::UnsupportedPolynomialSize) + } + } +} + +/// The main engine exposed by the Concrete-FFT backend. +pub struct FftEngine { + memory: Vec>, +} + +impl FftEngine { + pub(crate) fn resize(&mut self, capacity: usize) { + self.memory.resize_with(capacity, MaybeUninit::uninit); + } + + pub(crate) fn stack(&mut self) -> DynStack<'_> { + DynStack::new(&mut self.memory) + } +} + +impl AbstractEngineSeal for FftEngine {} +impl AbstractEngine for FftEngine { + type EngineError = FftError; + type Parameters = (); + + fn new(_parameter: Self::Parameters) -> Result { + Ok(FftEngine { memory: Vec::new() }) + } +} + +mod lwe_bootstrap_key_conversion; +mod lwe_ciphertext_discarding_bit_extraction; +mod lwe_ciphertext_discarding_bootstrap; +mod lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing; diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/deserialization.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/deserialization.rs new file mode 100644 index 000000000..ef0a18d34 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/deserialization.rs @@ -0,0 +1,156 @@ +#![allow(clippy::missing_safety_doc)] + +use super::{FftSerializationEngine, FftSerializationError}; +use crate::core_crypto::backends::fft::private::crypto::bootstrap::FourierLweBootstrapKey; +use crate::core_crypto::prelude::{ + EntityDeserializationEngine, EntityDeserializationError, FftFourierLweBootstrapKey32, + FftFourierLweBootstrapKey32Version, FftFourierLweBootstrapKey64, + FftFourierLweBootstrapKey64Version, +}; +use aligned_vec::ABox; +use concrete_fft::c64; +use serde::Deserialize; + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`FftSerializationEngine`] that operates +/// on 32 bits integers. It deserializes an LWE bootstrap key in the Fourier domain. +impl EntityDeserializationEngine<&[u8], FftFourierLweBootstrapKey32> for FftSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey32 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// + /// let mut serialization_engine = FftSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&fourier_bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(fourier_bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct SerializableFftFourierLweBootstrapKey32 { + version: FftFourierLweBootstrapKey32Version, + inner: FourierLweBootstrapKey>, + } + let deserialized: SerializableFftFourierLweBootstrapKey32 = + bincode::deserialize(serialized) + .map_err(FftSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + SerializableFftFourierLweBootstrapKey32 { + version: FftFourierLweBootstrapKey32Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + FftSerializationError::UnsupportedVersion, + )), + SerializableFftFourierLweBootstrapKey32 { + version: FftFourierLweBootstrapKey32Version::V0, + inner, + } => Ok(FftFourierLweBootstrapKey32(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> FftFourierLweBootstrapKey32 { + self.deserialize(serialized).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntityDeserializationEngine`] for [`FftSerializationEngine`] that operates +/// on 64 bits integers. It deserializes an LWE bootstrap key in the Fourier domain. +impl EntityDeserializationEngine<&[u8], FftFourierLweBootstrapKey64> for FftSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey64 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// + /// let mut serialization_engine = FftSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&fourier_bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(fourier_bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn deserialize( + &mut self, + serialized: &[u8], + ) -> Result> { + #[derive(Deserialize)] + struct SerializableFftFourierLweBootstrapKey64 { + version: FftFourierLweBootstrapKey64Version, + inner: FourierLweBootstrapKey>, + } + let deserialized: SerializableFftFourierLweBootstrapKey64 = + bincode::deserialize(serialized) + .map_err(FftSerializationError::Deserialization) + .map_err(EntityDeserializationError::Engine)?; + match deserialized { + SerializableFftFourierLweBootstrapKey64 { + version: FftFourierLweBootstrapKey64Version::Unsupported, + .. + } => Err(EntityDeserializationError::Engine( + FftSerializationError::UnsupportedVersion, + )), + SerializableFftFourierLweBootstrapKey64 { + version: FftFourierLweBootstrapKey64Version::V0, + inner, + } => Ok(FftFourierLweBootstrapKey64(inner)), + } + } + + unsafe fn deserialize_unchecked(&mut self, serialized: &[u8]) -> FftFourierLweBootstrapKey64 { + self.deserialize(serialized).unwrap() + } +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/mod.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/mod.rs new file mode 100644 index 000000000..6aa0aa8a1 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/mod.rs @@ -0,0 +1,50 @@ +use crate::core_crypto::specification::engines::sealed::AbstractEngineSeal; +use crate::core_crypto::specification::engines::AbstractEngine; +use std::error::Error; +use std::fmt::{Display, Formatter}; + +/// The error which can occur in the execution of FHE operations, due to the FFT implementation. +#[derive(Debug)] +pub enum FftSerializationError { + Serialization(bincode::Error), + Deserialization(bincode::Error), + UnsupportedVersion, +} + +impl Display for FftSerializationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + FftSerializationError::Serialization(bincode_error) => { + write!(f, "Failed to serialize entity: {bincode_error}") + } + FftSerializationError::Deserialization(bincode_error) => { + write!(f, "Failed to deserialize entity: {bincode_error}") + } + FftSerializationError::UnsupportedVersion => { + write!( + f, + "The version used to serialize the entity is not supported." + ) + } + } + } +} + +impl Error for FftSerializationError {} + +/// The serialization engine exposed by the fft backend. +pub struct FftSerializationEngine; + +impl AbstractEngineSeal for FftSerializationEngine {} + +impl AbstractEngine for FftSerializationEngine { + type EngineError = FftSerializationError; + type Parameters = (); + + fn new(_parameters: Self::Parameters) -> Result { + Ok(FftSerializationEngine) + } +} + +mod deserialization; +mod serialization; diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/serialization.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/serialization.rs new file mode 100644 index 000000000..a4ac73bd3 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/fft_serialization_engine/serialization.rs @@ -0,0 +1,138 @@ +#![allow(clippy::missing_safety_doc)] + +use super::{FftSerializationEngine, FftSerializationError}; +use crate::core_crypto::backends::fft::private::crypto::bootstrap::FourierLweBootstrapKeyView; +use crate::core_crypto::prelude::{ + EntitySerializationEngine, EntitySerializationError, FftFourierLweBootstrapKey32, + FftFourierLweBootstrapKey32Version, FftFourierLweBootstrapKey64, + FftFourierLweBootstrapKey64Version, +}; +use serde::Serialize; + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`FftSerializationEngine`] that operates on +/// 32 bits integers. It serializes an LWE bootstrap key in the Fourier domain. +impl EntitySerializationEngine> for FftSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey32 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey32 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey32 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey32 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// + /// let mut serialization_engine = FftSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&fourier_bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(fourier_bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &FftFourierLweBootstrapKey32, + ) -> Result, EntitySerializationError> { + let entity = entity.0.as_view(); + #[derive(Serialize)] + struct SerializableFftFourierLweBootstrapKey32<'a> { + version: FftFourierLweBootstrapKey32Version, + inner: FourierLweBootstrapKeyView<'a>, + } + let value = SerializableFftFourierLweBootstrapKey32 { + version: FftFourierLweBootstrapKey32Version::V0, + inner: entity, + }; + bincode::serialize(&value) + .map_err(FftSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &FftFourierLweBootstrapKey32) -> Vec { + self.serialize(entity).unwrap() + } +} + +/// # Description: +/// Implementation of [`EntitySerializationEngine`] for [`FftSerializationEngine`] that operates on +/// 64 bits integers. It serializes an LWE bootstrap key in the Fourier domain. +impl EntitySerializationEngine> for FftSerializationEngine { + /// # Example + /// ``` + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// Variance, *, + /// }; + /// # use std::error::Error; + /// + /// # fn main() -> Result<(), Box> { + /// // DISCLAIMER: the parameters used here are only for test purpose, and are not secure. + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(256)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let noise = Variance(2_f64.powf(-25.)); + /// + /// // Unix seeder must be given a secret input. + /// // Here we just give it 0, which is totally unsafe. + /// const UNSAFE_SECRET: u128 = 0; + /// let mut default_engine = DefaultEngine::new(Box::new(UnixSeeder::new(UNSAFE_SECRET)))?; + /// let mut fft_engine = FftEngine::new(())?; + /// let lwe_sk: LweSecretKey64 = default_engine.generate_new_lwe_secret_key(lwe_dim)?; + /// let glwe_sk: GlweSecretKey64 = + /// default_engine.generate_new_glwe_secret_key(glwe_dim, poly_size)?; + /// let bsk: LweBootstrapKey64 = + /// default_engine.generate_new_lwe_bootstrap_key(&lwe_sk, &glwe_sk, dec_bl, dec_lc, noise)?; + /// + /// let fourier_bsk: FftFourierLweBootstrapKey64 = fft_engine.convert_lwe_bootstrap_key(&bsk)?; + /// + /// let mut serialization_engine = FftSerializationEngine::new(())?; + /// let serialized = serialization_engine.serialize(&fourier_bsk)?; + /// let recovered = serialization_engine.deserialize(serialized.as_slice())?; + /// assert_eq!(fourier_bsk, recovered); + /// + /// # + /// # Ok(()) + /// # } + /// ``` + fn serialize( + &mut self, + entity: &FftFourierLweBootstrapKey64, + ) -> Result, EntitySerializationError> { + let entity = entity.0.as_view(); + #[derive(Serialize)] + struct SerializableFftFourierLweBootstrapKey64<'a> { + version: FftFourierLweBootstrapKey64Version, + inner: FourierLweBootstrapKeyView<'a>, + } + let value = SerializableFftFourierLweBootstrapKey64 { + version: FftFourierLweBootstrapKey64Version::V0, + inner: entity, + }; + bincode::serialize(&value) + .map_err(FftSerializationError::Serialization) + .map_err(EntitySerializationError::Engine) + } + + unsafe fn serialize_unchecked(&mut self, entity: &FftFourierLweBootstrapKey64) -> Vec { + self.serialize(entity).unwrap() + } +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/engines/mod.rs b/tfhe/src/core_crypto/backends/fft/implementation/engines/mod.rs new file mode 100644 index 000000000..c3de24701 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/engines/mod.rs @@ -0,0 +1,10 @@ +//! A module containing the [engines](crate::core_crypto::specification::engines) exposed by +//! the `Concrete-FFT` backend. + +mod fft_engine; +pub use fft_engine::*; + +#[cfg(feature = "backend_fft_serialization")] +mod fft_serialization_engine; +#[cfg(feature = "backend_fft_serialization")] +pub use fft_serialization_engine::*; diff --git a/tfhe/src/core_crypto/backends/fft/implementation/entities/lwe_bootstrap_key.rs b/tfhe/src/core_crypto/backends/fft/implementation/entities/lwe_bootstrap_key.rs new file mode 100644 index 000000000..77e8b5ba3 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/entities/lwe_bootstrap_key.rs @@ -0,0 +1,80 @@ +use super::super::super::private::crypto::bootstrap::FourierLweBootstrapKey; +use crate::core_crypto::specification::entities::markers::LweBootstrapKeyKind; +use crate::core_crypto::specification::entities::{AbstractEntity, LweBootstrapKeyEntity}; +use aligned_vec::ABox; +use concrete_fft::c64; +#[cfg(feature = "backend_fft_serialization")] +use serde::{Deserialize, Serialize}; + +/// A structure representing an LWE bootstrap key with 32 bits of precision, in the Fourier domain. +#[derive(Debug, Clone, PartialEq)] +pub struct FftFourierLweBootstrapKey32(pub(crate) FourierLweBootstrapKey>); + +/// A structure representing an LWE bootstrap key with 64 bits of precision, in the Fourier domain. +#[derive(Debug, Clone, PartialEq)] +pub struct FftFourierLweBootstrapKey64(pub(crate) FourierLweBootstrapKey>); + +impl AbstractEntity for FftFourierLweBootstrapKey32 { + type Kind = LweBootstrapKeyKind; +} +impl AbstractEntity for FftFourierLweBootstrapKey64 { + type Kind = LweBootstrapKeyKind; +} + +impl LweBootstrapKeyEntity for FftFourierLweBootstrapKey32 { + fn glwe_dimension(&self) -> crate::core_crypto::prelude::GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> crate::core_crypto::prelude::PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> crate::core_crypto::prelude::LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> crate::core_crypto::prelude::DecompositionBaseLog { + self.0.decomposition_base_log() + } + + fn decomposition_level_count(&self) -> crate::core_crypto::prelude::DecompositionLevelCount { + self.0.decomposition_level_count() + } +} +impl LweBootstrapKeyEntity for FftFourierLweBootstrapKey64 { + fn glwe_dimension(&self) -> crate::core_crypto::prelude::GlweDimension { + self.0.glwe_size().to_glwe_dimension() + } + + fn polynomial_size(&self) -> crate::core_crypto::prelude::PolynomialSize { + self.0.polynomial_size() + } + + fn input_lwe_dimension(&self) -> crate::core_crypto::prelude::LweDimension { + self.0.key_size() + } + + fn decomposition_base_log(&self) -> crate::core_crypto::prelude::DecompositionBaseLog { + self.0.decomposition_base_log() + } + + fn decomposition_level_count(&self) -> crate::core_crypto::prelude::DecompositionLevelCount { + self.0.decomposition_level_count() + } +} + +#[cfg(feature = "backend_fft_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum FftFourierLweBootstrapKey32Version { + V0, + #[serde(other)] + Unsupported, +} +#[cfg(feature = "backend_fft_serialization")] +#[derive(Serialize, Deserialize)] +pub(crate) enum FftFourierLweBootstrapKey64Version { + V0, + #[serde(other)] + Unsupported, +} diff --git a/tfhe/src/core_crypto/backends/fft/implementation/entities/mod.rs b/tfhe/src/core_crypto/backends/fft/implementation/entities/mod.rs new file mode 100644 index 000000000..ca4cc59f0 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/entities/mod.rs @@ -0,0 +1,6 @@ +//! A module containing all the [entities](crate::core_crypto::specification::entities) +//! exposed by the Concrete-FFT backend. + +mod lwe_bootstrap_key; + +pub use lwe_bootstrap_key::*; diff --git a/tfhe/src/core_crypto/backends/fft/implementation/mod.rs b/tfhe/src/core_crypto/backends/fft/implementation/mod.rs new file mode 100644 index 000000000..49169443f --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/implementation/mod.rs @@ -0,0 +1,2 @@ +pub mod engines; +pub mod entities; diff --git a/tfhe/src/core_crypto/backends/fft/mod.rs b/tfhe/src/core_crypto/backends/fft/mod.rs new file mode 100644 index 000000000..8abc32379 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/mod.rs @@ -0,0 +1,6 @@ +//! An accelerated backend using `Concrete-FFT`. + +mod implementation; +#[cfg_attr(not(feature = "__private_docs"), doc(hidden))] +pub mod private; +pub use implementation::{engines, entities}; diff --git a/tfhe/src/core_crypto/backends/fft/private/crypto/bootstrap.rs b/tfhe/src/core_crypto/backends/fft/private/crypto/bootstrap.rs new file mode 100644 index 000000000..876ed2421 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/crypto/bootstrap.rs @@ -0,0 +1,298 @@ +use super::super::math::fft::FftView; +use super::ggsw::{cmux, *}; +use crate::core_crypto::backends::fft::private::math::fft::FourierPolynomialList; +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext; +use crate::core_crypto::commons::crypto::lwe::LweCiphertext; +#[cfg(feature = "backend_fft_serialization")] +use crate::core_crypto::commons::math::tensor::ContainerOwned; +use crate::core_crypto::commons::math::tensor::{Container, Split}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::CastInto; +use crate::core_crypto::commons::utils::izip; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSize, LutCountLog, LweDimension, + ModulusSwitchOffset, MonomialDegree, PolynomialSize, +}; +use aligned_vec::CACHELINE_ALIGN; +use concrete_fft::c64; +use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr( + feature = "backend_fft_serialization", + derive(serde::Serialize, serde::Deserialize), + serde(bound(deserialize = "C: ContainerOwned")) +)] +pub struct FourierLweBootstrapKey> { + fourier: FourierPolynomialList, + key_size: LweDimension, + glwe_size: GlweSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, +} + +pub type FourierLweBootstrapKeyView<'a> = FourierLweBootstrapKey<&'a [c64]>; +pub type FourierLweBootstrapKeyMutView<'a> = FourierLweBootstrapKey<&'a mut [c64]>; + +impl> FourierLweBootstrapKey { + pub fn new( + data: C, + key_size: LweDimension, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + ) -> Self { + assert_eq!(polynomial_size.0 % 2, 0); + assert_eq!( + data.container_len(), + key_size.0 * polynomial_size.0 / 2 + * decomposition_level_count.0 + * glwe_size.0 + * glwe_size.0 + ); + Self { + fourier: FourierPolynomialList { + data, + polynomial_size, + }, + key_size, + glwe_size, + decomposition_base_log, + decomposition_level_count, + } + } + + /// Returns an iterator over the GGSW ciphertexts composing the key. + pub fn into_ggsw_iter(self) -> impl DoubleEndedIterator> + where + C: Split, + { + self.fourier + .data + .split_into(self.key_size.0) + .map(move |slice| { + FourierGgswCiphertext::new( + slice, + self.fourier.polynomial_size, + self.glwe_size, + self.decomposition_base_log, + self.decomposition_level_count, + ) + }) + } + + pub fn key_size(&self) -> LweDimension { + self.key_size + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.fourier.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomposition_base_log + } + + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomposition_level_count + } + + pub fn output_lwe_dimension(&self) -> LweDimension { + LweDimension((self.glwe_size.0 - 1) * self.polynomial_size().0) + } + + pub fn data(self) -> C { + self.fourier.data + } + + pub fn as_view(&self) -> FourierLweBootstrapKeyView<'_> { + FourierLweBootstrapKeyView { + fourier: FourierPolynomialList { + data: self.fourier.data.as_ref(), + polynomial_size: self.fourier.polynomial_size, + }, + key_size: self.key_size, + glwe_size: self.glwe_size, + decomposition_base_log: self.decomposition_base_log, + decomposition_level_count: self.decomposition_level_count, + } + } + + pub fn as_mut_view(&mut self) -> FourierLweBootstrapKeyMutView<'_> + where + C: AsMut<[c64]>, + { + FourierLweBootstrapKeyMutView { + fourier: FourierPolynomialList { + data: self.fourier.data.as_mut(), + polynomial_size: self.fourier.polynomial_size, + }, + key_size: self.key_size, + glwe_size: self.glwe_size, + decomposition_base_log: self.decomposition_base_log, + decomposition_level_count: self.decomposition_level_count, + } + } +} + +/// Returns the required memory for [`FourierLweBootstrapKeyMutView::fill_with_forward_fourier`]. +pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result { + fft.forward_scratch() +} + +impl<'a> FourierLweBootstrapKeyMutView<'a> { + /// Fills a bootstrapping key with the Fourier transform of a bootstrapping key in the standard + /// domain. + pub fn fill_with_forward_fourier>( + mut self, + coef_bsk: StandardBootstrapKey<&'_ [Scalar]>, + fft: FftView<'_>, + mut stack: DynStack<'_>, + ) { + for (fourier_ggsw, standard_ggsw) in + izip!(self.as_mut_view().into_ggsw_iter(), coef_bsk.ggsw_iter()) + { + fourier_ggsw.fill_with_forward_fourier(standard_ggsw, fft, stack.rb_mut()); + } + } +} + +/// Returns the required memory for [`FourierLweBootstrapKeyView::blind_rotate`]. +pub fn blind_rotate_scratch( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + fft: FftView<'_>, +) -> Result { + StackReq::try_new_aligned::(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)? + .try_and(cmux_scratch::(glwe_size, polynomial_size, fft)?) +} + +/// Returns the required memory for [`FourierLweBootstrapKeyView::bootstrap`]. +pub fn bootstrap_scratch( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + fft: FftView<'_>, +) -> Result { + blind_rotate_scratch::(glwe_size, polynomial_size, fft)?.try_and( + StackReq::try_new_aligned::(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN)?, + ) +} + +impl<'a> FourierLweBootstrapKeyView<'a> { + pub fn blind_rotate>( + self, + mut lut: GlweCiphertext<&'_ mut [Scalar]>, + lwe: &[Scalar], + fft: FftView<'_>, + mut stack: DynStack<'_>, + ) { + let (lwe_body, lwe_mask) = lwe.split_last().unwrap(); + + let lut_poly_size = lut.polynomial_size(); + let monomial_degree = pbs_modulus_switch( + *lwe_body, + lut_poly_size, + ModulusSwitchOffset(0), + LutCountLog(0), + ); + lut.as_mut_view() + .into_polynomial_list() + .into_polynomial_iter() + .for_each(|mut poly| { + poly.update_with_wrapping_unit_monomial_div(MonomialDegree(monomial_degree)); + }); + + // We initialize the ct_0 used for the successive cmuxes + let mut ct0 = lut; + + for (lwe_mask_element, bootstrap_key_ggsw) in izip!(lwe_mask.iter(), self.into_ggsw_iter()) + { + if *lwe_mask_element != Scalar::ZERO { + let stack = stack.rb_mut(); + // We copy ct_0 to ct_1 + let (mut ct1, stack) = stack.collect_aligned( + CACHELINE_ALIGN, + ct0.as_view().into_container().iter().copied(), + ); + let mut ct1 = GlweCiphertext::from_container(&mut *ct1, ct0.polynomial_size()); + + // We rotate ct_1 by performing ct_1 <- ct_1 * X^{a_hat} + for mut poly in ct1 + .as_mut_view() + .into_polynomial_list() + .into_polynomial_iter() + { + poly.update_with_wrapping_monic_monomial_mul(MonomialDegree( + pbs_modulus_switch( + *lwe_mask_element, + lut_poly_size, + ModulusSwitchOffset(0), + LutCountLog(0), + ), + )); + } + + cmux( + ct0.as_mut_view(), + ct1.as_mut_view(), + bootstrap_key_ggsw, + fft, + stack, + ); + } + } + } + + pub fn bootstrap>( + self, + lwe_out: &mut [Scalar], + lwe_in: &[Scalar], + accumulator: GlweCiphertext<&'_ [Scalar]>, + fft: FftView<'_>, + stack: DynStack<'_>, + ) { + let (mut local_accumulator_data, stack) = stack.collect_aligned( + CACHELINE_ALIGN, + accumulator.as_view().into_container().iter().copied(), + ); + let mut local_accumulator = GlweCiphertext::from_container( + &mut *local_accumulator_data, + accumulator.polynomial_size(), + ); + self.blind_rotate(local_accumulator.as_mut_view(), lwe_in, fft, stack); + local_accumulator.as_view().fill_lwe_with_sample_extraction( + &mut LweCiphertext::from_container(&mut *lwe_out), + MonomialDegree(0), + ); + } +} + +/// This function switches modulus for a single coefficient of a ciphertext, +/// only in the context of a PBS +/// +/// offset: the number of msb discarded +/// lut_count_log: the right padding +pub fn pbs_modulus_switch>( + input: Scalar, + poly_size: PolynomialSize, + offset: ModulusSwitchOffset, + lut_count_log: LutCountLog, +) -> usize { + // First, do the left shift (we discard the offset msb) + let mut output = input << offset.0; + // Start doing the right shift + output >>= Scalar::BITS - poly_size.log2().0 - 2 + lut_count_log.0; + // Do the rounding + output += output & Scalar::ONE; + // Finish the right shift + output >>= 1; + // Apply the lsb padding + output <<= lut_count_log.0; + >::cast_into(output) +} diff --git a/tfhe/src/core_crypto/backends/fft/private/crypto/ggsw.rs b/tfhe/src/core_crypto/backends/fft/private/crypto/ggsw.rs new file mode 100644 index 000000000..e43db8913 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/crypto/ggsw.rs @@ -0,0 +1,674 @@ +use core::mem::MaybeUninit; + +use super::super::math::decomposition::TensorSignedDecompositionLendingIter; +use super::super::math::fft::{FftView, FourierPolynomialList}; +use super::super::math::polynomial::{FourierPolynomialUninitMutView, FourierPolynomialView}; +use super::super::{as_mut_uninit, assume_init_mut}; +use crate::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext; +use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; +use crate::core_crypto::commons::math::polynomial::Polynomial; +#[cfg(feature = "backend_fft_serialization")] +use crate::core_crypto::commons::math::tensor::ContainerOwned; +use crate::core_crypto::commons::math::tensor::{Container, IntoTensor, Split}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::utils::izip; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, +}; +use aligned_vec::CACHELINE_ALIGN; +use concrete_fft::c64; +use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq}; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +/// A GGSW ciphertext in the Fourier domain. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr( + feature = "backend_fft_serialization", + derive(serde::Serialize, serde::Deserialize), + serde(bound(deserialize = "C: ContainerOwned")) +)] +pub struct FourierGgswCiphertext> { + fourier: FourierPolynomialList, + glwe_size: GlweSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, +} + +/// A matrix containing a single level of gadget decomposition, in the Fourier domain. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FourierGgswLevelMatrix> { + data: C, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + row_count: usize, + decomposition_level: DecompositionLevel, +} + +/// A row of a GGSW level matrix, in the Fourier domain. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FourierGgswLevelRow> { + data: C, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + decomposition_level: DecompositionLevel, +} + +pub type FourierGgswCiphertextView<'a> = FourierGgswCiphertext<&'a [c64]>; +pub type FourierGgswCiphertextMutView<'a> = FourierGgswCiphertext<&'a mut [c64]>; +pub type FourierGgswLevelMatrixView<'a> = FourierGgswLevelMatrix<&'a [c64]>; +pub type FourierGgswLevelMatrixMutView<'a> = FourierGgswLevelMatrix<&'a mut [c64]>; +pub type FourierGgswLevelRowView<'a> = FourierGgswLevelRow<&'a [c64]>; +pub type FourierGgswLevelRowMutView<'a> = FourierGgswLevelRow<&'a mut [c64]>; + +impl> FourierGgswCiphertext { + pub fn new( + data: C, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + ) -> Self { + assert_eq!(polynomial_size.0 % 2, 0); + assert_eq!( + data.container_len(), + polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0 + ); + + Self { + fourier: FourierPolynomialList { + data, + polynomial_size, + }, + glwe_size, + decomposition_base_log, + decomposition_level_count, + } + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.fourier.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomposition_base_log + } + + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomposition_level_count + } + + pub fn data(self) -> C { + self.fourier.data + } + + pub fn as_view(&self) -> FourierGgswCiphertextView<'_> + where + C: AsRef<[c64]>, + { + FourierGgswCiphertextView { + fourier: FourierPolynomialList { + data: self.fourier.data.as_ref(), + polynomial_size: self.fourier.polynomial_size, + }, + glwe_size: self.glwe_size, + decomposition_base_log: self.decomposition_base_log, + decomposition_level_count: self.decomposition_level_count, + } + } + + pub fn as_mut_view(&mut self) -> FourierGgswCiphertextMutView<'_> + where + C: AsMut<[c64]>, + { + FourierGgswCiphertextMutView { + fourier: FourierPolynomialList { + data: self.fourier.data.as_mut(), + polynomial_size: self.fourier.polynomial_size, + }, + glwe_size: self.glwe_size, + decomposition_base_log: self.decomposition_base_log, + decomposition_level_count: self.decomposition_level_count, + } + } +} + +impl> FourierGgswLevelMatrix { + pub fn new( + data: C, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + row_count: usize, + decomposition_level: DecompositionLevel, + ) -> Self { + assert_eq!(polynomial_size.0 % 2, 0); + assert_eq!( + data.container_len(), + polynomial_size.0 / 2 * glwe_size.0 * row_count + ); + Self { + data, + polynomial_size, + glwe_size, + row_count, + decomposition_level, + } + } + + /// Returns an iterator over the rows of the level matrices. + pub fn into_rows(self) -> impl DoubleEndedIterator> + where + C: Split, + { + self.data + .split_into(self.row_count) + .map(move |slice| FourierGgswLevelRow { + data: slice, + polynomial_size: self.polynomial_size, + glwe_size: self.glwe_size, + decomposition_level: self.decomposition_level, + }) + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn row_count(&self) -> usize { + self.row_count + } + + pub fn decomposition_level(&self) -> DecompositionLevel { + self.decomposition_level + } + + pub fn data(self) -> C { + self.data + } +} + +impl> FourierGgswLevelRow { + pub fn new( + data: C, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + decomposition_level: DecompositionLevel, + ) -> Self { + assert_eq!(polynomial_size.0 % 2, 0); + assert_eq!(data.container_len(), polynomial_size.0 / 2 * glwe_size.0); + Self { + data, + polynomial_size, + glwe_size, + decomposition_level, + } + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn decomposition_level(&self) -> DecompositionLevel { + self.decomposition_level + } + + pub fn data(self) -> C { + self.data + } +} + +impl<'a> FourierGgswCiphertextView<'a> { + /// Returns an iterator over the level matrices. + pub fn into_levels(self) -> impl DoubleEndedIterator> { + self.fourier + .data + .split_into(self.decomposition_level_count.0) + .enumerate() + .map(move |(i, slice)| { + FourierGgswLevelMatrixView::new( + slice, + self.fourier.polynomial_size, + self.glwe_size, + self.glwe_size.0, + DecompositionLevel(i + 1), + ) + }) + } +} + +/// Returns the required memory for [`FourierGgswCiphertextMutView::fill_with_forward_fourier`]. +pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result { + fft.forward_scratch() +} + +impl<'a> FourierGgswCiphertextMutView<'a> { + /// Fills a GGSW ciphertext with the Fourier transform of a GGSW ciphertext in the standard + /// domain. + pub fn fill_with_forward_fourier( + self, + coef_ggsw: StandardGgswCiphertext<&'_ [Scalar]>, + fft: FftView<'_>, + mut stack: DynStack<'_>, + ) { + debug_assert_eq!(coef_ggsw.polynomial_size(), self.polynomial_size()); + let poly_size = coef_ggsw.polynomial_size().0; + + for (fourier_poly, coef_poly) in izip!( + self.data().into_chunks(poly_size / 2), + coef_ggsw.into_container().into_chunks(poly_size) + ) { + // SAFETY: forward_as_torus doesn't write any uninitialized values into its output + fft.forward_as_torus( + FourierPolynomialUninitMutView { + data: unsafe { as_mut_uninit(fourier_poly) }, + }, + Polynomial::from_container(coef_poly), + stack.rb_mut(), + ); + } + } +} + +/// Returns the required memory for [`external_product`]. +pub fn external_product_scratch( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + fft: FftView<'_>, +) -> Result { + let align = CACHELINE_ALIGN; + let standard_scratch = + StackReq::try_new_aligned::(glwe_size.0 * polynomial_size.0, align)?; + let fourier_scratch = + StackReq::try_new_aligned::(glwe_size.0 * polynomial_size.0 / 2, align)?; + let fourier_scratch_single = StackReq::try_new_aligned::(polynomial_size.0 / 2, align)?; + + let substack3 = fft.forward_scratch()?; + let substack2 = substack3.try_and(fourier_scratch_single)?; + let substack1 = substack2.try_and(standard_scratch)?; + let substack0 = StackReq::try_any_of([ + substack1.try_and(standard_scratch)?, + fft.backward_scratch()?, + ])?; + substack0.try_and(fourier_scratch) +} + +/// Performs the external product of `ggsw` and `glwe`, and stores the result in `out`. +#[cfg_attr(__profiling, inline(never))] +pub fn external_product( + mut out: GlweCiphertext<&'_ mut [Scalar]>, + ggsw: FourierGgswCiphertextView<'_>, + glwe: GlweCiphertext<&'_ [Scalar]>, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + // we check that the polynomial sizes match + debug_assert_eq!(ggsw.polynomial_size(), glwe.polynomial_size()); + debug_assert_eq!(ggsw.polynomial_size(), out.polynomial_size()); + // we check that the glwe sizes match + debug_assert_eq!(ggsw.glwe_size(), glwe.size()); + debug_assert_eq!(ggsw.glwe_size(), out.size()); + + let align = CACHELINE_ALIGN; + let poly_size = ggsw.polynomial_size().0; + + // we round the input mask and body + let decomposer = SignedDecomposer::::new( + ggsw.decomposition_base_log(), + ggsw.decomposition_level_count(), + ); + + let (mut output_fft_buffer, mut substack0) = + stack.make_aligned_uninit::(poly_size / 2 * ggsw.glwe_size().0, align); + // output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid + // the cost of filling it up with zeros. `is_output_uninit` is set to `false` once + // it has been fully initialized for the first time. + let output_fft_buffer = &mut *output_fft_buffer; + let mut is_output_uninit = true; + + { + // ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER DOMAIN + // In this section, we perform the external product in the fourier domain, and accumulate + // the result in the output_fft_buffer variable. + let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new( + glwe.into_container() + .iter() + .map(|s| decomposer.closest_representable(*s)), + DecompositionBaseLog(decomposer.base_log), + DecompositionLevelCount(decomposer.level_count), + substack0.rb_mut(), + ); + + // We loop through the levels (we reverse to match the order of the decomposition iterator.) + ggsw.into_levels().rev().for_each(|ggsw_decomp_matrix| { + // We retrieve the decomposition of this level. + let (glwe_level, glwe_decomp_term, mut substack2) = + collect_next_term(&mut decomposition, &mut substack1, align); + let glwe_decomp_term = + GlweCiphertext::from_container(&*glwe_decomp_term, ggsw.polynomial_size()); + debug_assert_eq!(ggsw_decomp_matrix.decomposition_level(), glwe_level); + + // For each level we have to add the result of the vector-matrix product between the + // decomposition of the glwe, and the ggsw level matrix to the output. To do so, we + // iteratively add to the output, the product between every line of the matrix, and + // the corresponding (scalar) polynomial in the glwe decomposition: + // + // ggsw_mat ggsw_mat + // glwe_dec | - - - - | < glwe_dec | - - - - | + // | - - - | x | - - - - | | - - - | x | - - - - | < + // ^ | - - - - | ^ | - - - - | + // + // t = 1 t = 2 ... + + izip!( + ggsw_decomp_matrix.into_rows(), + glwe_decomp_term + .into_polynomial_list() + .into_polynomial_iter() + ) + .for_each(|(ggsw_row, glwe_poly)| { + let (mut fourier, substack3) = substack2 + .rb_mut() + .make_aligned_uninit::(poly_size / 2, align); + // We perform the forward fft transform for the glwe polynomial + let fourier = fft + .forward_as_integer( + FourierPolynomialUninitMutView { data: &mut fourier }, + glwe_poly, + substack3, + ) + .data; + // Now we loop through the polynomials of the output, and add the + // corresponding product of polynomials. + + // SAFETY: see comment above definition of `output_fft_buffer` + unsafe { + update_with_fmadd( + output_fft_buffer, + ggsw_row, + fourier, + is_output_uninit, + poly_size, + ) + }; + + // we initialized `output_fft_buffer, so we can set this to false + is_output_uninit = false; + }); + }); + } + + // -------------------------------------------- TRANSFORMATION OF RESULT TO STANDARD DOMAIN + // In this section, we bring the result from the fourier domain, back to the standard + // domain, and add it to the output. + // + // We iterate over the polynomials in the output. + if !is_output_uninit { + // SAFETY: output_fft_buffer is initialized, since `is_output_uninit` is false + let output_fft_buffer = &*unsafe { assume_init_mut(output_fft_buffer) }; + izip!( + out.as_mut_view() + .into_polynomial_list() + .into_polynomial_iter(), + output_fft_buffer + .into_chunks(poly_size / 2) + .map(|slice| FourierPolynomialView { data: slice }), + ) + .for_each(|(out, fourier)| { + fft.add_backward_as_torus(out, fourier, substack0.rb_mut()); + }); + } +} + +#[cfg_attr(__profiling, inline(never))] +fn collect_next_term<'a, Scalar: UnsignedTorus>( + decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>, + substack1: &'a mut DynStack, + align: usize, +) -> ( + DecompositionLevel, + dyn_stack::DynArray<'a, Scalar>, + DynStack<'a>, +) { + let (glwe_level, _, glwe_decomp_term) = decomposition.next_term().unwrap(); + let (glwe_decomp_term, substack2) = substack1.rb_mut().collect_aligned(align, glwe_decomp_term); + (glwe_level, glwe_decomp_term, substack2) +} + +/// # Note +/// +/// this function leaves all the elements of `output_fourier` in an initialized state. +/// +/// # Safety +/// +/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values. +/// - `is_x86_feature_detected!("avx512f")` must be true. +#[cfg(all( + feature = "backend_fft_nightly_avx512", + any(target_arch = "x86_64", target_arch = "x86") +))] +#[target_feature(enable = "avx512f")] +unsafe fn update_with_fmadd_avx512( + output_fourier: &mut [MaybeUninit], + ggsw_poly: &[c64], + fourier: &[c64], + is_output_uninit: bool, +) { + let n = output_fourier.len(); + + debug_assert_eq!(n, ggsw_poly.len()); + debug_assert_eq!(n, fourier.len()); + debug_assert_eq!(n % 4, 0); + + let out = output_fourier.as_mut_ptr(); + let lhs = ggsw_poly.as_ptr(); + let rhs = fourier.as_ptr(); + + // 4×c64 per register + + if is_output_uninit { + for i in 0..n / 4 { + let i = 4 * i; + let ab = _mm512_loadu_pd(lhs.add(i) as _); + let xy = _mm512_loadu_pd(rhs.add(i) as _); + let aa = _mm512_unpacklo_pd(ab, ab); + let bb = _mm512_unpackhi_pd(ab, ab); + let yx = _mm512_permute_pd::<0b01010101>(xy); + _mm512_storeu_pd( + out.add(i) as _, + _mm512_fmaddsub_pd(aa, xy, _mm512_mul_pd(bb, yx)), + ); + } + } else { + for i in 0..n / 4 { + let i = 4 * i; + let ab = _mm512_loadu_pd(lhs.add(i) as _); + let xy = _mm512_loadu_pd(rhs.add(i) as _); + let aa = _mm512_unpacklo_pd(ab, ab); + let bb = _mm512_unpackhi_pd(ab, ab); + let yx = _mm512_permute_pd::<0b01010101>(xy); + _mm512_storeu_pd( + out.add(i) as _, + _mm512_fmaddsub_pd( + aa, + xy, + _mm512_fmaddsub_pd(bb, yx, _mm512_loadu_pd(out.add(i) as _)), + ), + ); + } + } +} + +/// # Note +/// +/// this function leaves all the elements of `output_fourier` in an initialized state. +/// +/// # Safety +/// +/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values. +/// - `is_x86_feature_detected!("fma")` must be true. +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +#[target_feature(enable = "fma")] +unsafe fn update_with_fmadd_fma( + output_fourier: &mut [MaybeUninit], + ggsw_poly: &[c64], + fourier: &[c64], + is_output_uninit: bool, +) { + let n = output_fourier.len(); + + debug_assert_eq!(n, ggsw_poly.len()); + debug_assert_eq!(n, fourier.len()); + debug_assert_eq!(n % 4, 0); + + let out = output_fourier.as_mut_ptr(); + let lhs = ggsw_poly.as_ptr(); + let rhs = fourier.as_ptr(); + + // 2×c64 per register + + if is_output_uninit { + for i in 0..n / 2 { + let i = 2 * i; + let ab = _mm256_loadu_pd(lhs.add(i) as _); + let xy = _mm256_loadu_pd(rhs.add(i) as _); + let aa = _mm256_unpacklo_pd(ab, ab); + let bb = _mm256_unpackhi_pd(ab, ab); + let yx = _mm256_permute_pd::<0b0101>(xy); + _mm256_storeu_pd( + out.add(i) as _, + _mm256_fmaddsub_pd(aa, xy, _mm256_mul_pd(bb, yx)), + ); + } + } else { + for i in 0..n / 2 { + let i = 2 * i; + let ab = _mm256_loadu_pd(lhs.add(i) as _); + let xy = _mm256_loadu_pd(rhs.add(i) as _); + let aa = _mm256_unpacklo_pd(ab, ab); + let bb = _mm256_unpackhi_pd(ab, ab); + let yx = _mm256_permute_pd::<0b0101>(xy); + _mm256_storeu_pd( + out.add(i) as _, + _mm256_fmaddsub_pd( + aa, + xy, + _mm256_fmaddsub_pd(bb, yx, _mm256_loadu_pd(out.add(i) as _)), + ), + ); + } + } +} + +/// # Note +/// +/// this function leaves all the elements of `output_fourier` in an initialized state. +/// +/// # Safety +/// +/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values. +unsafe fn update_with_fmadd_scalar( + output_fourier: &mut [MaybeUninit], + ggsw_poly: &[c64], + fourier: &[c64], + is_output_uninit: bool, +) { + if is_output_uninit { + // we're writing to output_fft_buffer for the first time + // so its contents are uninitialized + izip!(output_fourier, ggsw_poly, fourier).for_each(|(out_fourier, lhs, rhs)| { + out_fourier.write(lhs * rhs); + }); + } else { + // we already wrote to output_fft_buffer, so we can assume its contents are + // initialized. + izip!(output_fourier, ggsw_poly, fourier).for_each(|(out_fourier, lhs, rhs)| { + *{ out_fourier.assume_init_mut() } += lhs * rhs; + }); + } +} + +/// # Note +/// +/// this function leaves all the elements of `output_fourier` in an initialized state. +/// +/// # Safety +/// +/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values. +#[cfg_attr(__profiling, inline(never))] +unsafe fn update_with_fmadd( + output_fft_buffer: &mut [MaybeUninit], + ggsw_row: FourierGgswLevelRowView, + fourier: &[c64], + is_output_uninit: bool, + poly_size: usize, +) { + #[allow(clippy::type_complexity)] + let ptr_fn = || -> unsafe fn(&mut [MaybeUninit], &[c64], &[c64], bool) { + #[cfg(all( + feature = "backend_fft_nightly_avx512", + any(target_arch = "x86_64", target_arch = "x86") + ))] + if is_x86_feature_detected!("avx512f") { + return update_with_fmadd_avx512; + } + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + if is_x86_feature_detected!("fma") { + return update_with_fmadd_fma; + } + + update_with_fmadd_scalar + }; + + let ptr = ptr_fn(); + + izip!( + output_fft_buffer.into_chunks(poly_size / 2), + ggsw_row.data.into_chunks(poly_size / 2) + ) + .for_each(|(output_fourier, ggsw_poly)| { + ptr(output_fourier, ggsw_poly, fourier, is_output_uninit); + }); +} + +/// Returns the required memory for [`cmux`]. +pub fn cmux_scratch( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + fft: FftView<'_>, +) -> Result { + external_product_scratch::(glwe_size, polynomial_size, fft) +} + +/// This cmux mutates both ct1 and ct0. The result is in ct0 after the method was called. +pub fn cmux( + ct0: GlweCiphertext<&'_ mut [Scalar]>, + mut ct1: GlweCiphertext<&'_ mut [Scalar]>, + ggsw: FourierGgswCiphertextView<'_>, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + izip!( + ct1.as_mut_view().into_tensor().into_container(), + ct0.as_view().into_tensor().into_container(), + ) + .for_each(|(c1, c0)| { + *c1 = c1.wrapping_sub(*c0); + }); + external_product(ct0, ggsw, ct1.as_view(), fft, stack); +} diff --git a/tfhe/src/core_crypto/backends/fft/private/crypto/glwe.rs b/tfhe/src/core_crypto/backends/fft/private/crypto/glwe.rs new file mode 100644 index 000000000..22890f80f --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/crypto/glwe.rs @@ -0,0 +1,93 @@ +use super::super::math::polynomial::*; +use crate::core_crypto::commons::math::tensor::{Container, IntoChunks}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::prelude::{GlweSize, PolynomialSize}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr( + feature = "backend_fft_serialization", + derive(serde::Serialize, serde::Deserialize) +)] +pub struct GlweCiphertext { + data: C, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, +} + +pub type GlweCiphertextView<'a, Scalar> = GlweCiphertext<&'a [Scalar]>; +pub type GlweCiphertextMutView<'a, Scalar> = GlweCiphertext<&'a mut [Scalar]>; + +impl GlweCiphertext { + pub fn new(data: C, polynomial_size: PolynomialSize, glwe_size: GlweSize) -> Self + where + C: Container, + { + assert_eq!(data.container_len(), polynomial_size.0 * glwe_size.0); + + Self { + data, + polynomial_size, + glwe_size, + } + } + + /// Returns an iterator over the polynomials in `self`. + pub fn into_polynomials(self) -> impl DoubleEndedIterator> + where + C: IntoChunks, + { + self.data + .split_into(self.glwe_size.0) + .map(|chunk| Polynomial { data: chunk }) + } + + pub fn data(self) -> C { + self.data + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn as_view(&self) -> GlweCiphertextView<'_, C::Element> { + GlweCiphertext { + data: self.data.as_ref(), + polynomial_size: self.polynomial_size, + glwe_size: self.glwe_size, + } + } + + pub fn as_mut_view(&mut self) -> GlweCiphertextMutView<'_, C::Element> + where + C: AsMut<[C::Element]>, + { + GlweCiphertext { + data: self.data.as_mut(), + polynomial_size: self.polynomial_size, + glwe_size: self.glwe_size, + } + } +} + +impl<'a, Scalar> GlweCiphertextView<'a, Scalar> { + /// Fills an LWE ciphertext with the extraction of one coefficient of the current GLWE + /// ciphertext. + pub fn fill_lwe_with_sample_extraction(self, lwe: &mut [Scalar], nth: usize) + where + Scalar: UnsignedTorus, + { + let this = crate::core_crypto::commons::crypto::glwe::GlweCiphertext::from_container( + self.data, + self.polynomial_size, + ); + let mut lwe = crate::core_crypto::commons::crypto::lwe::LweCiphertext::from_container(lwe); + #[allow(deprecated)] + let n_th = crate::core_crypto::prelude::MonomialDegree(nth); + + this.fill_lwe_with_sample_extraction(&mut lwe, n_th); + } +} diff --git a/tfhe/src/core_crypto/backends/fft/private/crypto/mod.rs b/tfhe/src/core_crypto/backends/fft/private/crypto/mod.rs new file mode 100644 index 000000000..3b150e37e --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/crypto/mod.rs @@ -0,0 +1,3 @@ +pub mod bootstrap; +pub mod ggsw; +pub mod wop_pbs; diff --git a/tfhe/src/core_crypto/backends/fft/private/crypto/wop_pbs/mod.rs b/tfhe/src/core_crypto/backends/fft/private/crypto/wop_pbs/mod.rs new file mode 100644 index 000000000..cb77ed0d3 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/crypto/wop_pbs/mod.rs @@ -0,0 +1,1056 @@ +#![allow(clippy::too_many_arguments)] + +use aligned_vec::CACHELINE_ALIGN; +use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq}; + +use super::super::math::fft::FftView; +use super::bootstrap::{bootstrap_scratch, FourierLweBootstrapKeyView}; +use super::ggsw::{ + cmux, cmux_scratch, external_product, external_product_scratch, + fill_with_forward_fourier_scratch, FourierGgswCiphertext, +}; +use crate::core_crypto::backends::fft::private::math::fft::FourierPolynomialList; +use crate::core_crypto::commons::crypto::encoding::Cleartext; +use crate::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; +use crate::core_crypto::commons::crypto::glwe::{ + GlweCiphertext, LwePrivateFunctionalPackingKeyswitchKeyList, +}; +use crate::core_crypto::commons::crypto::lwe::{LweCiphertext, LweKeyswitchKey, LweList}; +use crate::core_crypto::commons::math::polynomial::PolynomialList; +use crate::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor, Container, Split}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::CastInto; +use crate::core_crypto::commons::utils::izip; +use crate::core_crypto::prelude::{ + CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, DeltaLog, ExtractedBitsCount, + GlweSize, LweDimension, LweSize, MonomialDegree, PolynomialCount, PolynomialSize, +}; + +use concrete_fft::c64; + +pub fn extract_bits_scratch( + lwe_dimension: LweDimension, + ksk_after_key_size: LweDimension, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + fft: FftView<'_>, +) -> Result { + let align = CACHELINE_ALIGN; + + let lwe_in_buffer = StackReq::try_new_aligned::(lwe_dimension.to_lwe_size().0, align)?; + let lwe_out_ks_buffer = + StackReq::try_new_aligned::(ksk_after_key_size.to_lwe_size().0, align)?; + let pbs_accumulator = + StackReq::try_new_aligned::(glwe_size.0 * polynomial_size.0, align)?; + let lwe_out_pbs_buffer = StackReq::try_new_aligned::( + glwe_size.to_glwe_dimension().0 * polynomial_size.0 + 1, + align, + )?; + let lwe_bit_left_shift_buffer = lwe_in_buffer; + let bootstrap_scratch = bootstrap_scratch::(glwe_size, polynomial_size, fft)?; + + lwe_in_buffer + .try_and(lwe_out_ks_buffer)? + .try_and(pbs_accumulator)? + .try_and(lwe_out_pbs_buffer)? + .try_and(StackReq::try_any_of([ + lwe_bit_left_shift_buffer, + bootstrap_scratch, + ])?) +} + +/// Function to extract `number_of_bits_to_extract` from an [`LweCiphertext`] starting at the bit +/// number `delta_log` (0-indexed) included. +/// +/// Output bits are ordered from the MSB to the LSB. Each one of them is output in a distinct LWE +/// ciphertext, containing the encryption of the bit scaled by q/2 (i.e., the most significant bit +/// in the plaintext representation). +pub fn extract_bits>( + mut lwe_list_out: LweList<&'_ mut [Scalar]>, + lwe_in: LweCiphertext<&'_ [Scalar]>, + ksk: LweKeyswitchKey<&'_ [Scalar]>, + fourier_bsk: FourierLweBootstrapKeyView<'_>, + delta_log: DeltaLog, + number_of_bits_to_extract: ExtractedBitsCount, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + let ciphertext_n_bits = Scalar::BITS; + let number_of_bits_to_extract = number_of_bits_to_extract.0; + + debug_assert!( + ciphertext_n_bits >= number_of_bits_to_extract + delta_log.0, + "Tried to extract {} bits, while the maximum number of extractable bits for {} bits + ciphertexts and a scaling factor of 2^{} is {}", + number_of_bits_to_extract, + ciphertext_n_bits, + delta_log.0, + ciphertext_n_bits - delta_log.0, + ); + debug_assert!( + lwe_list_out.lwe_size().to_lwe_dimension() == ksk.after_key_size(), + "lwe_list_out needs to have an lwe_size of {}, got {}", + ksk.after_key_size().0, + lwe_list_out.lwe_size().to_lwe_dimension().0, + ); + debug_assert!( + lwe_list_out.count().0 == number_of_bits_to_extract, + "lwe_list_out needs to have a ciphertext count of {}, got {}", + number_of_bits_to_extract, + lwe_list_out.count().0, + ); + debug_assert!( + lwe_in.lwe_size() == fourier_bsk.output_lwe_dimension().to_lwe_size(), + "lwe_in needs to have an LWE dimension of {}, got {}", + fourier_bsk.output_lwe_dimension().to_lwe_size().0, + lwe_in.lwe_size().0, + ); + debug_assert!( + ksk.after_key_size() == fourier_bsk.key_size(), + "ksk needs to have an output LWE dimension of {}, got {}", + fourier_bsk.key_size().0, + ksk.after_key_size().0, + ); + + let polynomial_size = fourier_bsk.polynomial_size(); + let glwe_size = fourier_bsk.glwe_size(); + let glwe_dimension = glwe_size.to_glwe_dimension(); + + let align = CACHELINE_ALIGN; + + let (mut lwe_in_buffer_data, stack) = + stack.collect_aligned(align, lwe_in.into_container().iter().copied()); + let mut lwe_in_buffer = LweCiphertext::from_container(&mut *lwe_in_buffer_data); + + let (mut lwe_out_ks_buffer_data, stack) = + stack.make_aligned_with(ksk.lwe_size().0, align, |_| Scalar::ZERO); + let mut lwe_out_ks_buffer = LweCiphertext::from_container(&mut *lwe_out_ks_buffer_data); + + let (mut pbs_accumulator_data, stack) = + stack.make_aligned_with(glwe_size.0 * polynomial_size.0, align, |_| Scalar::ZERO); + let mut pbs_accumulator = + GlweCiphertext::from_container(&mut *pbs_accumulator_data, polynomial_size); + + let lwe_size = LweSize(glwe_dimension.0 * polynomial_size.0 + 1); + let (mut lwe_out_pbs_buffer_data, mut stack) = + stack.make_aligned_with(lwe_size.0, align, |_| Scalar::ZERO); + let mut lwe_out_pbs_buffer = LweCiphertext::from_container(&mut *lwe_out_pbs_buffer_data); + + // We iterate on the list in reverse as we want to store the extracted MSB at index 0 + for (bit_idx, output_ct) in lwe_list_out.ciphertext_iter_mut().rev().enumerate() { + // Shift on padding bit + let (lwe_bit_left_shift_buffer_data, _) = stack.rb_mut().collect_aligned( + align, + lwe_in_buffer + .as_view() + .into_container() + .iter() + .map(|s| *s << (ciphertext_n_bits - delta_log.0 - bit_idx - 1)), + ); + + // Key switch to input PBS key + ksk.keyswitch_ciphertext( + &mut lwe_out_ks_buffer.as_mut_view(), + &LweCiphertext::from_container(&*lwe_bit_left_shift_buffer_data), + ); + + drop(lwe_bit_left_shift_buffer_data); + + // Store the keyswitch output unmodified to the output list (as we need to to do other + // computations on the output of the keyswitch) + output_ct + .into_container() + .copy_from_slice(lwe_out_ks_buffer.as_view().into_container()); + + // If this was the last extracted bit, break + // we subtract 1 because if the number_of_bits_to_extract is 1 we want to stop right away + if bit_idx == number_of_bits_to_extract - 1 { + break; + } + + // Add q/4 to center the error while computing a negacyclic LUT + let out_ks_body = &mut lwe_out_ks_buffer.get_mut_body().0; + *out_ks_body = out_ks_body.wrapping_add(Scalar::ONE << (ciphertext_n_bits - 2)); + + // Fill lut for the current bit (equivalent to trivial encryption as mask is 0s) + // The LUT is filled with -alpha in each coefficient where alpha = delta*2^{bit_idx-1} + for poly_coeff in &mut pbs_accumulator + .as_mut_view() + .get_mut_body() + .into_polynomial() + .coefficient_iter_mut() + { + *poly_coeff = Scalar::ZERO.wrapping_sub(Scalar::ONE << (delta_log.0 - 1 + bit_idx)); + } + + fourier_bsk.bootstrap( + lwe_out_pbs_buffer.as_mut_view().into_container(), + lwe_out_ks_buffer.as_view().into_container(), + pbs_accumulator.as_view(), + fft, + stack.rb_mut(), + ); + + // Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the + // extracted bit was 0 and 1 in the other case + let out_pbs_body = &mut lwe_out_pbs_buffer.get_mut_body().0; + + *out_pbs_body = out_pbs_body.wrapping_add(Scalar::ONE << (delta_log.0 + bit_idx - 1)); + + // Remove the extracted bit from the initial LWE to get a 0 at the extracted bit location. + izip!( + lwe_in_buffer.as_mut_view().into_container(), + lwe_out_pbs_buffer.as_view().into_container() + ) + .for_each(|(out, inp)| *out = out.wrapping_sub(*inp)); + } +} + +pub fn circuit_bootstrap_boolean_scratch( + lwe_in_size: LweSize, + bsk_output_lwe_size: LweSize, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + fft: FftView<'_>, +) -> Result { + StackReq::try_new_aligned::(bsk_output_lwe_size.0, CACHELINE_ALIGN)?.try_and( + homomorphic_shift_boolean_scratch::(lwe_in_size, polynomial_size, glwe_size, fft)?, + ) +} + +/// Circuit bootstrapping for boolean messages, i.e. containing only one bit of message +/// +/// The output GGSW ciphertext `ggsw_out` decomposition base log and level count are used as the +/// circuit_bootstrap_boolean decomposition base log and level count. +pub fn circuit_bootstrap_boolean>( + fourier_bsk: FourierLweBootstrapKeyView<'_>, + lwe_in: LweCiphertext<&[Scalar]>, + mut ggsw_out: StandardGgswCiphertext<&mut [Scalar]>, + delta_log: DeltaLog, + fpksk_list: LwePrivateFunctionalPackingKeyswitchKeyList<&[Scalar]>, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + let level_cbs = ggsw_out.decomposition_level_count(); + let base_log_cbs = ggsw_out.decomposition_base_log(); + + debug_assert!( + level_cbs.0 >= 1, + "level_cbs needs to be >= 1, got {}", + level_cbs.0 + ); + debug_assert!( + base_log_cbs.0 >= 1, + "base_log_cbs needs to be >= 1, got {}", + base_log_cbs.0 + ); + + let fpksk_input_lwe_key_dimension = fpksk_list.input_lwe_key_dimension(); + let fourier_bsk_output_lwe_dimension = fourier_bsk.output_lwe_dimension(); + + debug_assert!( + fpksk_input_lwe_key_dimension == fourier_bsk_output_lwe_dimension, + "The fourier_bsk output_lwe_dimension, got {}, must be equal to the fpksk \ + input_lwe_key_dimension, got {}", + fourier_bsk_output_lwe_dimension.0, + fpksk_input_lwe_key_dimension.0 + ); + + let fpksk_output_polynomial_size = fpksk_list.output_polynomial_size(); + let fpksk_output_glwe_key_dimension = fpksk_list.output_glwe_key_dimension(); + + debug_assert!( + ggsw_out.polynomial_size() == fpksk_output_polynomial_size, + "The output GGSW ciphertext needs to have the same polynomial size as the fpksks, \ + got {}, expeceted {}", + ggsw_out.polynomial_size().0, + fpksk_output_polynomial_size.0 + ); + + debug_assert!( + ggsw_out.glwe_size().to_glwe_dimension() == fpksk_output_glwe_key_dimension, + "The output GGSW ciphertext needs to have the same GLWE dimension as the fpksks, \ + got {}, expeceted {}", + ggsw_out.glwe_size().to_glwe_dimension().0, + fpksk_output_glwe_key_dimension.0 + ); + + debug_assert!( + ggsw_out.glwe_size().0 == fpksk_list.fpksk_count().0, + "The input vector of fpksk needs to have {} (ggsw.glwe_size * \ + ggsw.decomposition_level_count) elements got {}", + ggsw_out.glwe_size().0, + fpksk_list.fpksk_count().0, + ); + + // Output for every bootstrapping + let (mut lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with( + fourier_bsk_output_lwe_dimension.to_lwe_size().0, + CACHELINE_ALIGN, + |_| Scalar::ZERO, + ); + let mut lwe_out_bs_buffer = LweCiphertext::from_container(&mut *lwe_out_bs_buffer_data); + + // Output for every pfksk that that come from the output GGSW + let mut glwe_out_pfksk_buffer = ggsw_out.as_mut_glwe_list(); + + let mut out_pfksk_buffer_iter = glwe_out_pfksk_buffer.ciphertext_iter_mut(); + + for decomposition_level in (1..=level_cbs.0).map(DecompositionLevelCount) { + homomorphic_shift_boolean( + fourier_bsk, + lwe_out_bs_buffer.as_mut_view(), + lwe_in, + decomposition_level, + base_log_cbs, + delta_log, + fft, + stack.rb_mut(), + ); + + for pfksk in fpksk_list.fpksk_iter() { + let mut glwe_out = out_pfksk_buffer_iter.next().unwrap(); + pfksk.private_functional_keyswitch_ciphertext(&mut glwe_out, &lwe_out_bs_buffer); + } + } +} + +pub fn homomorphic_shift_boolean_scratch( + lwe_in_size: LweSize, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + fft: FftView<'_>, +) -> Result { + let align = CACHELINE_ALIGN; + StackReq::try_new_aligned::(lwe_in_size.0, align)? + .try_and(StackReq::try_new_aligned::( + polynomial_size.0 * glwe_size.0, + align, + )?)? + .try_and(bootstrap_scratch::( + glwe_size, + polynomial_size, + fft, + )?) +} + +/// Homomorphic shift for LWE without padding bit +/// +/// Starts by shifting the message bit at bit #delta_log to the padding bit and then shifts it to +/// the right by base_log * level. +pub fn homomorphic_shift_boolean>( + fourier_bsk: FourierLweBootstrapKeyView<'_>, + mut lwe_out: LweCiphertext<&mut [Scalar]>, + lwe_in: LweCiphertext<&[Scalar]>, + level_count_cbs: DecompositionLevelCount, + base_log_cbs: DecompositionBaseLog, + delta_log: DeltaLog, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + let ciphertext_n_bits = Scalar::BITS; + let lwe_in_size = lwe_in.lwe_size(); + let polynomial_size = fourier_bsk.polynomial_size(); + + let (mut lwe_left_shift_buffer_data, stack) = + stack.make_aligned_with(lwe_in_size.0, CACHELINE_ALIGN, |_| Scalar::ZERO); + let mut lwe_left_shift_buffer = LweCiphertext::from_container(&mut *lwe_left_shift_buffer_data); + // Shift message LSB on padding bit, at this point we expect to have messages with only 1 bit + // of information + lwe_left_shift_buffer.fill_with_scalar_mul( + &lwe_in, + &Cleartext(Scalar::ONE << (ciphertext_n_bits - delta_log.0 - 1)), + ); + + // Add q/4 to center the error while computing a negacyclic LUT + let mut shift_buffer_body = lwe_left_shift_buffer.get_mut_body(); + shift_buffer_body.0 = shift_buffer_body + .0 + .wrapping_add(Scalar::ONE << (ciphertext_n_bits - 2)); + + let (mut pbs_accumulator_data, stack) = stack.make_aligned_with( + polynomial_size.0 * fourier_bsk.glwe_size().0, + CACHELINE_ALIGN, + |_| Scalar::ZERO, + ); + let mut pbs_accumulator = + GlweCiphertext::from_container(&mut *pbs_accumulator_data, polynomial_size); + + // Fill lut (equivalent to trivial encryption as mask is 0s) + // The LUT is filled with -alpha in each coefficient where + // alpha = 2^{log(q) - 1 - base_log * level} + pbs_accumulator + .get_mut_body() + .as_mut_tensor() + .fill_with_element(Scalar::ZERO.wrapping_sub( + Scalar::ONE << (ciphertext_n_bits - 1 - base_log_cbs.0 * level_count_cbs.0), + )); + + // Applying a negacyclic LUT on a ciphertext with one bit of message in the MSB and no bit + // of padding + fourier_bsk.bootstrap( + lwe_out.as_mut_view().into_container(), + lwe_left_shift_buffer.as_view().into_container(), + pbs_accumulator.as_view(), + fft, + stack, + ); + + // Add alpha where alpha = 2^{log(q) - 1 - base_log * level} + // To end up with an encryption of 0 if the message bit was 0 and 1 in the other case + let out_body = lwe_out.get_mut_body(); + out_body.0 = out_body + .0 + .wrapping_add(Scalar::ONE << (ciphertext_n_bits - 1 - base_log_cbs.0 * level_count_cbs.0)); +} + +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub struct GlweCiphertextList { + data: C, + count: usize, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, +} + +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub struct FourierGgswCiphertextList> { + fourier: FourierPolynomialList, + count: usize, + glwe_size: GlweSize, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, +} + +pub type FourierGgswCiphertextListView<'a> = FourierGgswCiphertextList<&'a [c64]>; +pub type FourierGgswCiphertextListMutView<'a> = FourierGgswCiphertextList<&'a mut [c64]>; +pub type GlweCiphertextListView<'a, Scalar> = GlweCiphertextList<&'a [Scalar]>; +pub type GlweCiphertextListMutView<'a, Scalar> = GlweCiphertextList<&'a mut [Scalar]>; + +impl GlweCiphertextList { + pub fn new( + data: C, + count: usize, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + ) -> Self { + assert_eq!( + data.container_len(), + count * polynomial_size.0 * glwe_size.0, + ); + Self { + data, + count, + polynomial_size, + glwe_size, + } + } + + pub fn data(self) -> C { + self.data + } + + pub fn count(&self) -> usize { + self.count + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn as_view(&self) -> GlweCiphertextListView<'_, C::Element> { + GlweCiphertextListView { + data: self.data.as_ref(), + count: self.count, + polynomial_size: self.polynomial_size, + glwe_size: self.glwe_size, + } + } + + pub fn as_mut_view(&mut self) -> GlweCiphertextListMutView<'_, C::Element> + where + C: AsMut<[C::Element]>, + { + GlweCiphertextListMutView { + data: self.data.as_mut(), + count: self.count, + polynomial_size: self.polynomial_size, + glwe_size: self.glwe_size, + } + } + + pub fn into_glwe_iter(self) -> impl DoubleEndedIterator> + where + C: Split, + { + self.data + .split_into(self.count) + .map(move |slice| GlweCiphertext::from_container(slice, self.polynomial_size)) + } +} + +impl> FourierGgswCiphertextList { + pub fn new( + data: C, + count: usize, + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + ) -> Self { + assert_eq!(polynomial_size.0 % 2, 0); + assert_eq!( + data.container_len(), + count * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0 + ); + + Self { + fourier: FourierPolynomialList { + data, + polynomial_size, + }, + count, + glwe_size, + decomposition_level_count, + decomposition_base_log, + } + } + + pub fn data(self) -> C { + self.fourier.data + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.fourier.polynomial_size + } + + pub fn count(&self) -> usize { + self.count + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomposition_level_count + } + + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomposition_base_log + } + + pub fn as_view(&self) -> FourierGgswCiphertextListView<'_> { + let fourier = FourierPolynomialList { + data: self.fourier.data.as_ref(), + polynomial_size: self.fourier.polynomial_size, + }; + FourierGgswCiphertextListView { + fourier, + count: self.count, + glwe_size: self.glwe_size, + decomposition_level_count: self.decomposition_level_count, + decomposition_base_log: self.decomposition_base_log, + } + } + + pub fn as_mut_view(&mut self) -> FourierGgswCiphertextListMutView<'_> + where + C: AsMut<[c64]>, + { + let fourier = FourierPolynomialList { + data: self.fourier.data.as_mut(), + polynomial_size: self.fourier.polynomial_size, + }; + FourierGgswCiphertextListMutView { + fourier, + count: self.count, + glwe_size: self.glwe_size, + decomposition_level_count: self.decomposition_level_count, + decomposition_base_log: self.decomposition_base_log, + } + } + + pub fn into_ggsw_iter(self) -> impl DoubleEndedIterator> + where + C: Split, + { + self.fourier.data.split_into(self.count).map(move |slice| { + FourierGgswCiphertext::new( + slice, + self.fourier.polynomial_size, + self.glwe_size, + self.decomposition_base_log, + self.decomposition_level_count, + ) + }) + } + + pub fn split_at(self, mid: usize) -> (Self, Self) + where + C: Split, + { + let polynomial_size = self.fourier.polynomial_size; + let glwe_size = self.glwe_size; + let decomposition_level_count = self.decomposition_level_count; + let decomposition_base_log = self.decomposition_base_log; + + let (left, right) = self.fourier.data.split_at( + mid * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * decomposition_level_count.0, + ); + ( + Self::new( + left, + mid, + polynomial_size, + glwe_size, + decomposition_base_log, + decomposition_level_count, + ), + Self::new( + right, + self.count - mid, + polynomial_size, + glwe_size, + decomposition_base_log, + decomposition_level_count, + ), + ) + } +} + +pub fn cmux_tree_memory_optimized_scratch( + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + nb_layer: usize, + fft: FftView<'_>, +) -> Result { + let t_scratch = StackReq::try_new_aligned::( + polynomial_size.0 * glwe_size.0 * nb_layer, + CACHELINE_ALIGN, + )?; + + StackReq::try_all_of([ + t_scratch, // t_0 + t_scratch, // t_1 + StackReq::try_new::(nb_layer)?, // t_fill + t_scratch, // diff + external_product_scratch::(glwe_size, polynomial_size, fft)?, + ]) +} + +/// Performs a tree of cmux in a way that limits the total allocated memory to avoid issues for +/// bigger trees. +pub fn cmux_tree_memory_optimized>( + mut output_glwe: GlweCiphertext<&mut [Scalar]>, + lut_per_layer: PolynomialList<&[Scalar]>, + ggsw_list: FourierGgswCiphertextListView<'_>, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + debug_assert!(lut_per_layer.polynomial_count().0 == 1 << ggsw_list.count()); + + if ggsw_list.count() > 0 { + let polynomial_size = ggsw_list.polynomial_size(); + let glwe_size = output_glwe.size(); + let nb_layer = ggsw_list.count(); + + debug_assert!(stack.can_hold( + cmux_tree_memory_optimized_scratch::(polynomial_size, glwe_size, nb_layer, fft) + .unwrap() + )); + + // These are accumulator that will be used to propagate the result from layer to layer + // At index 0 you have the lut that will be loaded, and then the result for each layer gets + // computed at the next index, last layer result gets stored in `result`. + // This allow to use memory space in C * nb_layer instead of C' * 2 ^ nb_layer + let (mut t_0_data, stack) = stack.make_aligned_with( + polynomial_size.0 * glwe_size.0 * nb_layer, + CACHELINE_ALIGN, + |_| Scalar::ZERO, + ); + let (mut t_1_data, stack) = stack.make_aligned_with( + polynomial_size.0 * glwe_size.0 * nb_layer, + CACHELINE_ALIGN, + |_| Scalar::ZERO, + ); + + let mut t_0 = + GlweCiphertextList::new(t_0_data.as_mut(), nb_layer, polynomial_size, glwe_size); + let mut t_1 = + GlweCiphertextList::new(t_1_data.as_mut(), nb_layer, polynomial_size, glwe_size); + + let (mut t_fill, mut stack) = stack.make_with(nb_layer, |_| 0_usize); + + let mut lut_polynomial_iter = lut_per_layer.into_polynomial_iter(); + loop { + let even = lut_polynomial_iter.next(); + let odd = lut_polynomial_iter.next(); + + let (lut_2i, lut_2i_plus_1) = match (even, odd) { + (Some(even), Some(odd)) => (even, odd), + _ => break, + }; + + let mut t_iter = izip!( + t_0.as_mut_view().into_glwe_iter(), + t_1.as_mut_view().into_glwe_iter(), + ) + .enumerate(); + + let (mut j_counter, (mut t0_j, mut t1_j)) = t_iter.next().unwrap(); + + t0_j.get_mut_body() + .as_mut_tensor() + .fill_with_copy(lut_2i.as_tensor()); + + t1_j.get_mut_body() + .as_mut_tensor() + .fill_with_copy(lut_2i_plus_1.as_tensor()); + + t_fill[0] = 2; + + for (j, ggsw) in ggsw_list.into_ggsw_iter().rev().enumerate() { + if t_fill[j] == 2 { + let (diff_data, stack) = stack.rb_mut().collect_aligned( + CACHELINE_ALIGN, + izip!( + t1_j.as_view().into_container(), + t0_j.as_view().into_container() + ) + .map(|(a, b)| a.wrapping_sub(*b)), + ); + let diff = GlweCiphertext::from_container(&*diff_data, polynomial_size); + + if j != nb_layer - 1 { + let (j_counter_plus_1, (mut t_0_j_plus_1, mut t_1_j_plus_1)) = + t_iter.next().unwrap(); + + assert_eq!(j_counter, j); + assert_eq!(j_counter_plus_1, j + 1); + + let mut output = if t_fill[j + 1] == 0 { + t_0_j_plus_1.as_mut_view() + } else { + t_1_j_plus_1.as_mut_view() + }; + + output + .as_mut_view() + .into_container() + .copy_from_slice(t0_j.as_view().into_container()); + external_product(output, ggsw, diff, fft, stack); + t_fill[j + 1] += 1; + t_fill[j] = 0; + + drop(diff_data); + + (j_counter, t0_j, t1_j) = (j_counter_plus_1, t_0_j_plus_1, t_1_j_plus_1); + } else { + let mut output = output_glwe.as_mut_view(); + output + .as_mut_view() + .into_container() + .copy_from_slice(t0_j.as_view().into_container()); + external_product(output, ggsw, diff, fft, stack); + } + } else { + break; + } + } + } + } else { + output_glwe + .get_mut_mask() + .as_mut_tensor() + .fill_with(|| Scalar::ZERO); + output_glwe + .get_mut_body() + .as_mut_tensor() + .fill_with_copy(lut_per_layer.as_tensor()); + } +} + +pub fn circuit_bootstrap_boolean_vertical_packing_scratch( + lwe_list_in_count: CiphertextCount, + lwe_list_out_count: CiphertextCount, + lwe_in_size: LweSize, + big_lut_polynomial_count: PolynomialCount, + bsk_output_lwe_size: LweSize, + fpksk_output_polynomial_size: PolynomialSize, + glwe_size: GlweSize, + level_cbs: DecompositionLevelCount, + fft: FftView<'_>, +) -> Result { + // We deduce the number of luts in the vec_lut from the number of cipherxtexts in lwe_list_out + let number_of_luts = lwe_list_out_count.0; + let small_lut_size = PolynomialCount(big_lut_polynomial_count.0 / number_of_luts); + + StackReq::try_all_of([ + StackReq::try_new_aligned::( + lwe_list_in_count.0 * fpksk_output_polynomial_size.0 / 2 + * glwe_size.0 + * glwe_size.0 + * level_cbs.0, + CACHELINE_ALIGN, + )?, + StackReq::try_new_aligned::( + fpksk_output_polynomial_size.0 * glwe_size.0 * glwe_size.0 * level_cbs.0, + CACHELINE_ALIGN, + )?, + StackReq::try_any_of([ + circuit_bootstrap_boolean_scratch::( + lwe_in_size, + bsk_output_lwe_size, + fpksk_output_polynomial_size, + glwe_size, + fft, + )?, + fill_with_forward_fourier_scratch(fft)?, + vertical_packing_scratch::( + glwe_size, + fpksk_output_polynomial_size, + small_lut_size, + lwe_list_in_count.0, + fft, + )?, + ])?, + ]) +} + +/// Perform a circuit bootstrap followed by a vertical packing on ciphertexts encrypting boolean +/// messages. +/// +/// The circuit bootstrapping uses the private functional packing key switch. +/// +/// This is supposed to be used only with boolean (1 bit of message) LWE ciphertexts. +pub fn circuit_bootstrap_boolean_vertical_packing>( + big_lut_as_polynomial_list: PolynomialList<&[Scalar]>, + fourier_bsk: FourierLweBootstrapKeyView<'_>, + mut lwe_list_out: LweList<&mut [Scalar]>, + lwe_list_in: LweList<&[Scalar]>, + fpksk_list: LwePrivateFunctionalPackingKeyswitchKeyList<&[Scalar]>, + level_cbs: DecompositionLevelCount, + base_log_cbs: DecompositionBaseLog, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + debug_assert!(stack.can_hold( + circuit_bootstrap_boolean_vertical_packing_scratch::( + lwe_list_in.count(), + lwe_list_out.count(), + lwe_list_in.lwe_size(), + big_lut_as_polynomial_list.polynomial_count(), + fourier_bsk.output_lwe_dimension().to_lwe_size(), + fpksk_list.output_polynomial_size(), + fourier_bsk.glwe_size(), + level_cbs, + fft + ) + .unwrap() + )); + debug_assert!(lwe_list_in.count().0 != 0, "Got empty `lwe_list_in`"); + debug_assert!( + lwe_list_out.lwe_size().to_lwe_dimension() == fourier_bsk.output_lwe_dimension(), + "Output LWE ciphertext needs to have an LweDimension of {}, got {}", + lwe_list_out.lwe_size().to_lwe_dimension().0, + fourier_bsk.output_lwe_dimension().0 + ); + + let glwe_size = fpksk_list.output_glwe_key_dimension().to_glwe_size(); + let (mut ggsw_list_data, stack) = stack.make_aligned_with( + lwe_list_in.count().0 * fpksk_list.output_polynomial_size().0 / 2 + * glwe_size.0 + * glwe_size.0 + * level_cbs.0, + CACHELINE_ALIGN, + |_| c64::default(), + ); + let (mut ggsw_res_data, mut stack) = stack.make_aligned_with( + fpksk_list.output_polynomial_size().0 * glwe_size.0 * glwe_size.0 * level_cbs.0, + CACHELINE_ALIGN, + |_| Scalar::ZERO, + ); + + let mut ggsw_list = FourierGgswCiphertextListMutView::new( + &mut ggsw_list_data, + lwe_list_in.count().0, + fpksk_list.output_polynomial_size(), + glwe_size, + base_log_cbs, + level_cbs, + ); + + let mut ggsw_res = StandardGgswCiphertext::from_container( + &mut *ggsw_res_data, + glwe_size, + fpksk_list.output_polynomial_size(), + base_log_cbs, + ); + + for (lwe_in, ggsw) in izip!( + lwe_list_in.ciphertext_iter(), + ggsw_list.as_mut_view().into_ggsw_iter(), + ) { + circuit_bootstrap_boolean( + fourier_bsk, + lwe_in, + ggsw_res.as_mut_view(), + DeltaLog(Scalar::BITS - 1), + fpksk_list, + fft, + stack.rb_mut(), + ); + + ggsw.fill_with_forward_fourier(ggsw_res.as_view(), fft, stack.rb_mut()); + } + + // We deduce the number of luts in the vec_lut from the number of cipherxtexts in lwe_list_out + let number_of_luts = lwe_list_out.count().0; + + let small_lut_size = + PolynomialCount(big_lut_as_polynomial_list.polynomial_count().0 / number_of_luts); + + for (lut, lwe_out) in izip!( + big_lut_as_polynomial_list.sublist_iter(small_lut_size), + lwe_list_out.ciphertext_iter_mut(), + ) { + vertical_packing(lut, lwe_out, ggsw_list.as_view(), fft, stack.rb_mut()); + } +} + +pub fn vertical_packing_scratch( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + lut_polynomial_count: PolynomialCount, + ggsw_list_count: usize, + fft: FftView<'_>, +) -> Result { + let bits = core::mem::size_of::() * 8; + + // Get the base 2 logarithm (rounded down) of the number of polynomials in the list i.e. if + // there is one polynomial, the number will be 0 + let log_lut_number: usize = bits - 1 - lut_polynomial_count.0.leading_zeros() as usize; + + let log_number_of_luts_for_cmux_tree = if log_lut_number > ggsw_list_count { + // this means that we dont have enough GGSW to perform the CMux tree, we can only do the + // Blind rotation + 0 + } else { + log_lut_number + }; + + StackReq::try_all_of([ + // cmux_tree_lut_res + StackReq::try_new_aligned::(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN)?, + StackReq::try_any_of([ + blind_rotate_scratch::(glwe_size, polynomial_size, fft)?, + cmux_tree_memory_optimized_scratch::( + polynomial_size, + glwe_size, + log_number_of_luts_for_cmux_tree, + fft, + )?, + ])?, + ]) +} + +// GGSW ciphertexts are stored from the msb (vec_ggsw[0]) to the lsb (vec_ggsw[last]) +pub fn vertical_packing>( + lut: PolynomialList<&[Scalar]>, + mut lwe_out: LweCiphertext<&mut [Scalar]>, + ggsw_list: FourierGgswCiphertextListView<'_>, + fft: FftView<'_>, + stack: DynStack<'_>, +) { + let polynomial_size = ggsw_list.polynomial_size(); + let glwe_size = ggsw_list.glwe_size(); + let glwe_dimension = glwe_size.to_glwe_dimension(); + + debug_assert!( + lwe_out.lwe_size().to_lwe_dimension().0 == polynomial_size.0 * glwe_dimension.0, + "Output LWE ciphertext needs to have an LweDimension of {}, got {}", + polynomial_size.0 * glwe_dimension.0, + lwe_out.lwe_size().to_lwe_dimension().0, + ); + + // Get the base 2 logarithm (rounded down) of the number of polynomials in the list i.e. if + // there is one polynomial, the number will be 0 + let log_lut_number: usize = + Scalar::BITS - 1 - lut.polynomial_count().0.leading_zeros() as usize; + + let log_number_of_luts_for_cmux_tree = if log_lut_number > ggsw_list.count() { + // this means that we dont have enough GGSW to perform the CMux tree, we can only do the + // Blind rotation + 0 + } else { + log_lut_number + }; + + // split the vec of GGSW in two, the msb GGSW is for the CMux tree and the lsb GGSW is for + // the last blind rotation. + let (cmux_ggsw, br_ggsw) = ggsw_list.split_at(log_number_of_luts_for_cmux_tree); + + let (mut cmux_tree_lut_res_data, mut stack) = + stack.make_aligned_with(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN, |_| { + Scalar::ZERO + }); + let mut cmux_tree_lut_res = + GlweCiphertext::from_container(&mut *cmux_tree_lut_res_data, polynomial_size); + + cmux_tree_memory_optimized( + cmux_tree_lut_res.as_mut_view(), + lut, + cmux_ggsw, + fft, + stack.rb_mut(), + ); + blind_rotate( + cmux_tree_lut_res.as_mut_view(), + br_ggsw, + fft, + stack.rb_mut(), + ); + + // sample extract of the RLWE of the Vertical packing + cmux_tree_lut_res.fill_lwe_with_sample_extraction(&mut lwe_out, MonomialDegree(0)); +} + +pub fn blind_rotate_scratch( + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + fft: FftView<'_>, +) -> Result { + StackReq::try_all_of([ + StackReq::try_new_aligned::(polynomial_size.0 * glwe_size.0, CACHELINE_ALIGN)?, + cmux_scratch::(glwe_size, polynomial_size, fft)?, + ]) +} + +pub fn blind_rotate>( + mut lut: GlweCiphertext<&mut [Scalar]>, + ggsw_list: FourierGgswCiphertextListView<'_>, + fft: FftView<'_>, + mut stack: DynStack<'_>, +) { + let mut monomial_degree = MonomialDegree(1); + + for ggsw in ggsw_list.into_ggsw_iter().rev() { + let ct_0 = lut.as_mut_view(); + let (mut ct1_data, stack) = stack.rb_mut().collect_aligned( + CACHELINE_ALIGN, + ct_0.as_view().into_container().iter().copied(), + ); + let mut ct_1 = GlweCiphertext::from_container(&mut *ct1_data, ct_0.polynomial_size()); + ct_1.as_mut_polynomial_list() + .update_with_wrapping_monic_monomial_div(monomial_degree); + monomial_degree.0 <<= 1; + cmux(ct_0, ct_1, ggsw, fft, stack); + } +} + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/core_crypto/backends/fft/private/crypto/wop_pbs/tests.rs b/tfhe/src/core_crypto/backends/fft/private/crypto/wop_pbs/tests.rs new file mode 100644 index 000000000..836dc1f88 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/crypto/wop_pbs/tests.rs @@ -0,0 +1,780 @@ +use super::*; +use crate::core_crypto::backends::fft::private::crypto::bootstrap::{ + fill_with_forward_fourier_scratch, FourierLweBootstrapKey, +}; +use crate::core_crypto::backends::fft::private::math::fft::Fft; +use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; +use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList; +use crate::core_crypto::commons::crypto::lwe::{LweCiphertext, LweKeyswitchKey, LweList}; +use crate::core_crypto::commons::crypto::secret::generators::{ + EncryptionRandomGenerator, SecretRandomGenerator, +}; +use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; +use crate::core_crypto::commons::math::decomposition::SignedDecomposer; +use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; +use crate::core_crypto::commons::test_tools; +use crate::core_crypto::prelude::{ + CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, DeltaLog, DispersionParameter, + ExtractedBitsCount, FunctionalPackingKeyswitchKeyCount, GlweDimension, LogStandardDev, + LweDimension, LweSize, PlaintextCount, PolynomialCount, PolynomialSize, StandardDev, Variance, +}; +use concrete_csprng::generators::SoftwareRandomGenerator; +use concrete_csprng::seeders::{Seeder, UnixSeeder}; +use concrete_fft::c64; +use dyn_stack::{DynStack, GlobalMemBuffer, ReborrowMut, StackReq}; + +// Extract all the bits of a LWE +#[test] +pub fn test_extract_bits() { + // Define settings for an insecure toy example + let polynomial_size = PolynomialSize(1024); + let rlwe_dimension = GlweDimension(1); + let lwe_dimension = LweDimension(585); + + let level_bsk = DecompositionLevelCount(2); + let base_log_bsk = DecompositionBaseLog(10); + + let level_ksk = DecompositionLevelCount(7); + let base_log_ksk = DecompositionBaseLog(4); + + let std = LogStandardDev::from_log_standard_dev(-60.); + + let number_of_bits_of_message_including_padding = 5_usize; + // Tests take about 2-3 seconds on a laptop with this number + let number_of_test_runs = 32; + + const UNSAFE_SECRET: u128 = 0; + let mut seeder = UnixSeeder::new(UNSAFE_SECRET); + + let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), &mut seeder); + + // allocation and generation of the key in coef domain: + let rlwe_sk: GlweSecretKey<_, Vec> = + GlweSecretKey::generate_binary(rlwe_dimension, polynomial_size, &mut secret_generator); + let lwe_small_sk: LweSecretKey<_, Vec> = + LweSecretKey::generate_binary(lwe_dimension, &mut secret_generator); + + let mut coef_bsk = StandardBootstrapKey::allocate( + 0_u64, + rlwe_dimension.to_glwe_size(), + polynomial_size, + level_bsk, + base_log_bsk, + lwe_dimension, + ); + coef_bsk.fill_with_new_key(&lwe_small_sk, &rlwe_sk, std, &mut encryption_generator); + + let mut fourier_bsk = FourierLweBootstrapKey::new( + vec![c64::default(); coef_bsk.as_view().into_container().len() / 2], + lwe_dimension, + polynomial_size, + rlwe_dimension.to_glwe_size(), + base_log_bsk, + level_bsk, + ); + + let fft = Fft::new(polynomial_size); + let fft = fft.as_view(); + + let lwe_big_sk = LweSecretKey::binary_from_container(rlwe_sk.as_tensor().as_slice()); + let mut ksk_lwe_big_to_small = LweKeyswitchKey::allocate( + 0_u64, + level_ksk, + base_log_ksk, + lwe_big_sk.key_size(), + lwe_small_sk.key_size(), + ); + ksk_lwe_big_to_small.fill_with_keyswitch_key( + &lwe_big_sk, + &lwe_small_sk, + std, + &mut encryption_generator, + ); + + let req = || { + StackReq::try_any_of([ + fill_with_forward_fourier_scratch(fft)?, + extract_bits_scratch::( + lwe_dimension, + ksk_lwe_big_to_small.after_key_size(), + rlwe_dimension.to_glwe_size(), + polynomial_size, + fft, + )?, + ]) + }; + let req = req().unwrap(); + let mut mem = GlobalMemBuffer::new(req); + let mut stack = DynStack::new(&mut mem); + + fourier_bsk + .as_mut_view() + .fill_with_forward_fourier(coef_bsk.as_view(), fft, stack.rb_mut()); + + let delta_log = DeltaLog(64 - number_of_bits_of_message_including_padding); + // Decomposer to manage the rounding after decrypting the extracted bit + let decomposer = SignedDecomposer::new(DecompositionBaseLog(1), DecompositionLevelCount(1)); + + //////////////////////////////////////////////////////////////////////////////////////////////// + + for _ in 0..number_of_test_runs { + // Generate a random plaintext in [0; 2^{number_of_bits_of_message_including_padding}[ + let val = test_tools::random_uint_between( + 0..2u64.pow(number_of_bits_of_message_including_padding as u32), + ); + + // Encryption + let message = Plaintext(val << delta_log.0); + println!("{:?}", message); + let mut lwe_in = LweCiphertext::allocate(0u64, LweSize(polynomial_size.0 + 1)); + lwe_big_sk.encrypt_lwe(&mut lwe_in, &message, std, &mut encryption_generator); + + // Bit extraction + // Extract all the bits + let number_values_to_extract = ExtractedBitsCount(64 - delta_log.0); + + let mut lwe_out_list = LweList::allocate( + 0u64, + ksk_lwe_big_to_small.lwe_size(), + CiphertextCount(number_values_to_extract.0), + ); + + extract_bits( + lwe_out_list.as_mut_view(), + lwe_in.as_view(), + ksk_lwe_big_to_small.as_view(), + fourier_bsk.as_view(), + delta_log, + number_values_to_extract, + fft, + stack.rb_mut(), + ); + + // Decryption of extracted bit + for (i, result_ct) in lwe_out_list.ciphertext_iter().rev().enumerate() { + let mut decrypted_message = Plaintext(0_u64); + lwe_small_sk.decrypt_lwe(&mut decrypted_message, &result_ct); + // Round after decryption using decomposer + let decrypted_rounded = decomposer.closest_representable(decrypted_message.0); + // Bring back the extracted bit found in the MSB in the LSB + let decrypted_extract_bit = decrypted_rounded >> 63; + println!("extracted bit : {:?}", decrypted_extract_bit); + println!("{:?}", decrypted_message); + assert_eq!( + ((message.0 >> delta_log.0) >> i) & 1, + decrypted_extract_bit, + "Bit #{}, for plaintext {:#066b}", + delta_log.0 + i, + message.0 + ) + } + } +} + +// Test the circuit bootstrapping with private functional ks +// Verify the decryption has the expected content +#[test] +fn test_circuit_bootstrapping_binary() { + // Define settings for an insecure toy example + let polynomial_size = PolynomialSize(512); + let glwe_dimension = GlweDimension(2); + let lwe_dimension = LweDimension(10); + + let level_bsk = DecompositionLevelCount(2); + let base_log_bsk = DecompositionBaseLog(15); + + let level_pksk = DecompositionLevelCount(2); + let base_log_pksk = DecompositionBaseLog(15); + + let level_count_cbs = DecompositionLevelCount(1); + let base_log_cbs = DecompositionBaseLog(10); + + let std = LogStandardDev::from_log_standard_dev(-60.); + + const UNSAFE_SECRET: u128 = 0; + let mut seeder = UnixSeeder::new(UNSAFE_SECRET); + + let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), &mut seeder); + + // Create GLWE and LWE secret key + let glwe_sk: GlweSecretKey<_, Vec> = + GlweSecretKey::generate_binary(glwe_dimension, polynomial_size, &mut secret_generator); + let lwe_sk: LweSecretKey<_, Vec> = + LweSecretKey::generate_binary(lwe_dimension, &mut secret_generator); + + // Allocation and generation of the bootstrap key in standard domain: + let mut std_bsk = StandardBootstrapKey::allocate( + 0u64, + glwe_dimension.to_glwe_size(), + polynomial_size, + level_bsk, + base_log_bsk, + lwe_dimension, + ); + std_bsk.fill_with_new_key(&lwe_sk, &glwe_sk, std, &mut encryption_generator); + + let mut fourier_bsk = FourierLweBootstrapKey::new( + vec![ + c64::default(); + lwe_dimension.0 * polynomial_size.0 / 2 + * level_bsk.0 + * glwe_dimension.to_glwe_size().0 + * glwe_dimension.to_glwe_size().0 + ], + lwe_dimension, + polynomial_size, + glwe_dimension.to_glwe_size(), + base_log_bsk, + level_bsk, + ); + + let fft = Fft::new(polynomial_size); + let fft = fft.as_view(); + + let mut mem = GlobalMemBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap()); + let stack = DynStack::new(&mut mem); + fourier_bsk + .as_mut_view() + .fill_with_forward_fourier(std_bsk.as_view(), fft, stack); + + let lwe_sk_bs_output = LweSecretKey::binary_from_container(glwe_sk.as_tensor().as_slice()); + + // Creation of all the pfksk for the circuit bootstrapping + let mut vec_pfksk = LwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u64, + level_pksk, + base_log_pksk, + lwe_sk_bs_output.key_size(), + glwe_sk.key_size(), + glwe_sk.polynomial_size(), + FunctionalPackingKeyswitchKeyCount(glwe_dimension.to_glwe_size().0), + ); + + vec_pfksk.par_fill_with_fpksk_for_circuit_bootstrap( + &lwe_sk_bs_output, + &glwe_sk, + std, + &mut encryption_generator, + ); + + let delta_log = DeltaLog(60); + + // value is 0 or 1 as CBS works on messages expected to contain 1 bit of information + let value: u64 = test_tools::random_uint_between(0..2u64); + // Encryption of an LWE with the value 'message' + let message = Plaintext((value) << delta_log.0); + let mut lwe_in = LweCiphertext::allocate(0u64, lwe_dimension.to_lwe_size()); + lwe_sk.encrypt_lwe(&mut lwe_in, &message, std, &mut encryption_generator); + + let mut cbs_res = StandardGgswCiphertext::allocate( + 0u64, + polynomial_size, + glwe_dimension.to_glwe_size(), + level_count_cbs, + base_log_cbs, + ); + + let mut mem = GlobalMemBuffer::new( + circuit_bootstrap_boolean_scratch::( + lwe_in.lwe_size(), + fourier_bsk.output_lwe_dimension().to_lwe_size(), + polynomial_size, + glwe_dimension.to_glwe_size(), + fft, + ) + .unwrap(), + ); + let stack = DynStack::new(&mut mem); + // Execute the CBS + circuit_bootstrap_boolean( + fourier_bsk.as_view(), + lwe_in.as_view(), + cbs_res.as_mut_view(), + delta_log, + vec_pfksk.as_view(), + fft, + stack, + ); + + let glwe_size = glwe_dimension.to_glwe_size(); + + //print the key to check if the RLWE in the GGSW seem to be well created + println!("RLWE secret key:\n{:?}", glwe_sk); + let mut decrypted = PlaintextList::allocate( + 0_u64, + PlaintextCount(polynomial_size.0 * level_count_cbs.0 * glwe_size.0), + ); + glwe_sk.decrypt_glwe_list(&mut decrypted, &cbs_res.as_glwe_list()); + + let level_size = polynomial_size.0 * glwe_size.0; + + println!("\nGGSW decryption:"); + for (level_idx, level_decrypted_glwe) in decrypted + .sublist_iter(PlaintextCount(level_size)) + .enumerate() + { + for (decrypted_glwe, original_polynomial_from_glwe_sk) in level_decrypted_glwe + .sublist_iter(PlaintextCount(polynomial_size.0)) + .take(glwe_dimension.0) + .zip(glwe_sk.as_polynomial_list().polynomial_iter()) + { + let current_level = level_idx + 1; + let mut expected_decryption = PlaintextList::allocate( + 0u64, + PlaintextCount(original_polynomial_from_glwe_sk.polynomial_size().0), + ); + expected_decryption + .as_mut_tensor() + .fill_with_copy(original_polynomial_from_glwe_sk.as_tensor()); + + let multiplying_factor = 0u64.wrapping_sub(value); + + expected_decryption + .as_mut_tensor() + .update_with_wrapping_scalar_mul(&multiplying_factor); + + let decomposer = + SignedDecomposer::new(base_log_cbs, DecompositionLevelCount(current_level)); + + expected_decryption + .as_mut_tensor() + .update_with(|coeff| *coeff >>= 64 - base_log_cbs.0 * current_level); + + let mut decoded_glwe = + PlaintextList::from_container(decrypted_glwe.as_tensor().as_container().to_vec()); + + decoded_glwe.as_mut_tensor().update_with(|coeff| { + *coeff = decomposer.closest_representable(*coeff) + >> (64 - base_log_cbs.0 * current_level) + }); + + assert_eq!( + expected_decryption.as_tensor().as_slice(), + decoded_glwe.as_tensor().as_slice() + ); + } + let last_decrypted_glwe = level_decrypted_glwe + .sublist_iter(PlaintextCount(polynomial_size.0)) + .rev() + .next() + .unwrap(); + + let mut last_decoded_glwe = + PlaintextList::from_container(last_decrypted_glwe.as_tensor().as_container().to_vec()); + + let decomposer = SignedDecomposer::new(base_log_cbs, level_count_cbs); + + last_decoded_glwe.as_mut_tensor().update_with(|coeff| { + *coeff = decomposer.closest_representable(*coeff) + >> (64 - base_log_cbs.0 * level_count_cbs.0) + }); + + let mut expected_decryption = PlaintextList::allocate(0u64, last_decoded_glwe.count()); + + *expected_decryption.as_mut_tensor().first_mut() = value; + + assert_eq!( + expected_decryption.as_tensor().as_slice(), + last_decoded_glwe.as_tensor().as_slice() + ); + } +} + +#[test] +pub fn test_cmux_tree() { + // Define settings for an insecure toy example + const UNSAFE_SECRET: u128 = 0; + let mut seeder = UnixSeeder::new(UNSAFE_SECRET); + + let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), &mut seeder); + let polynomial_size = PolynomialSize(512); + let glwe_dimension = GlweDimension(1); + let std = LogStandardDev::from_log_standard_dev(-60.); + let level = DecompositionLevelCount(3); + let base_log = DecompositionBaseLog(6); + // We need (1 << nb_ggsw) > polynomial_size to have an actual CMUX tree and not just a blind + // rotation + let nb_ggsw = 10; + let delta_log = 60; + + // Allocation and generation of the key in coef domain: + let glwe_sk: GlweSecretKey<_, Vec> = + GlweSecretKey::generate_binary(glwe_dimension, polynomial_size, &mut secret_generator); + let glwe_size = glwe_sk.key_size().to_glwe_size(); + + // Creation of the 'big' lut + // lut = [[0...0][1...1][2...2] ...] where [X...X] is a lut + // The values in the lut are taken mod 2 ^ {64 - delta_log} and shifted by delta_log to the left + let mut lut = PolynomialList::allocate(0u64, PolynomialCount(1 << nb_ggsw), polynomial_size); + for (i, mut polynomial) in lut.polynomial_iter_mut().enumerate() { + polynomial + .as_mut_tensor() + .fill_with_element((i as u64 % (1 << (64 - delta_log))) << delta_log); + } + + // Values between [0; 1023] + // Note that we use a delta log which does not handle more than 4 bits of message + let number_of_bits_for_payload = nb_ggsw; + + // Decomposer to manage the rounding after decrypting + let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); + + let number_of_test_runs = 32; + + for _ in 0..number_of_test_runs { + let mut value = + test_tools::random_uint_between(0..2u64.pow(number_of_bits_for_payload as u32)); + println!("value: {}", value); + let witness = value % (1 << (64 - delta_log)); + + // Bit decomposition of the value from MSB to LSB + let mut vec_message = vec![Plaintext(0); nb_ggsw]; + for i in (0..nb_ggsw).rev() { + vec_message[i] = Plaintext(value & 1); + value >>= 1; + } + + let fft = Fft::new(polynomial_size); + let fft = fft.as_view(); + + // Encrypt all bits in fourier GGSW ciphertexts from MSB to LSB, store them in a vec + let mut ggsw_list = FourierGgswCiphertextList::new( + vec![ + c64::default(); + nb_ggsw * polynomial_size.0 / 2 * glwe_size.0 * glwe_size.0 * level.0 + ], + nb_ggsw, + polynomial_size, + glwe_size, + base_log, + level, + ); + for (single_bit_msg, mut fourier_ggsw) in + izip!(vec_message.iter(), ggsw_list.as_mut_view().into_ggsw_iter()) + { + let mut ggsw = StandardGgswCiphertext::allocate( + 0_u64, + polynomial_size, + glwe_dimension.to_glwe_size(), + level, + base_log, + ); + glwe_sk.encrypt_constant_ggsw( + &mut ggsw, + single_bit_msg, + std, + &mut encryption_generator, + ); + + let mut mem = GlobalMemBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap()); + let stack = DynStack::new(&mut mem); + fourier_ggsw + .as_mut_view() + .fill_with_forward_fourier(ggsw.as_view(), fft, stack); + } + + let mut result_cmux_tree = GlweCiphertext::allocate(0_u64, polynomial_size, glwe_size); + let mut mem = GlobalMemBuffer::new( + cmux_tree_memory_optimized_scratch::(polynomial_size, glwe_size, nb_ggsw, fft) + .unwrap(), + ); + cmux_tree_memory_optimized( + result_cmux_tree.as_mut_view(), + lut.as_view(), + ggsw_list.as_view(), + fft, + DynStack::new(&mut mem), + ); + let mut decrypted_result = + PlaintextList::allocate(0u64, PlaintextCount(glwe_sk.polynomial_size().0)); + glwe_sk.decrypt_glwe(&mut decrypted_result, &result_cmux_tree); + + let decoded_result = + decomposer.closest_representable(*decrypted_result.as_tensor().first()) >> delta_log; + + // The recovered lut_number must be equal to the value stored in the lut at index + // witness % 2 ^ {64 - delta_log} + println!("result : {:?}", decoded_result); + println!("witness : {:?}", witness); + assert_eq!(decoded_result, witness) + } +} + +// Circuit bootstrap + vecrtical packing applying an identity lut +#[test] +pub fn test_extract_bit_circuit_bootstrapping_vertical_packing() { + // define settings + let polynomial_size = PolynomialSize(1024); + let glwe_dimension = GlweDimension(1); + let lwe_dimension = LweDimension(481); + + let level_bsk = DecompositionLevelCount(9); + let base_log_bsk = DecompositionBaseLog(4); + + let level_pksk = DecompositionLevelCount(9); + let base_log_pksk = DecompositionBaseLog(4); + + let level_ksk = DecompositionLevelCount(9); + let base_log_ksk = DecompositionBaseLog(1); + + let level_cbs = DecompositionLevelCount(4); + let base_log_cbs = DecompositionBaseLog(6); + + // Value was 0.000_000_000_000_000_221_486_881_160_055_68_513645324585951 + // But rust indicates it gets truncated anyways to + // 0.000_000_000_000_000_221_486_881_160_055_68 + let std_small = StandardDev::from_standard_dev(0.000_000_000_000_000_221_486_881_160_055_68); + // Value was 0.000_061_200_133_780_220_371_345 + // But rust indicates it gets truncated anyways to + // 0.000_061_200_133_780_220_36 + let std_big = StandardDev::from_standard_dev(0.000_061_200_133_780_220_36); + + const UNSAFE_SECRET: u128 = 0; + let mut seeder = UnixSeeder::new(UNSAFE_SECRET); + + let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), &mut seeder); + + //create GLWE and LWE secret key + let glwe_sk: GlweSecretKey<_, Vec> = + GlweSecretKey::generate_binary(glwe_dimension, polynomial_size, &mut secret_generator); + let lwe_small_sk: LweSecretKey<_, Vec> = + LweSecretKey::generate_binary(lwe_dimension, &mut secret_generator); + + let lwe_big_sk = LweSecretKey::binary_from_container(glwe_sk.as_tensor().as_slice()); + + // allocation and generation of the key in coef domain: + let mut coef_bsk = StandardBootstrapKey::allocate( + 0u64, + glwe_dimension.to_glwe_size(), + polynomial_size, + level_bsk, + base_log_bsk, + lwe_dimension, + ); + coef_bsk.fill_with_new_key( + &lwe_small_sk, + &glwe_sk, + Variance(std_small.get_variance()), + &mut encryption_generator, + ); + // allocation for the bootstrapping key + let mut fourier_bsk = FourierLweBootstrapKey::new( + vec![ + c64::default(); + lwe_dimension.0 * polynomial_size.0 / 2 + * level_bsk.0 + * glwe_dimension.to_glwe_size().0 + * glwe_dimension.to_glwe_size().0 + ], + lwe_dimension, + polynomial_size, + glwe_dimension.to_glwe_size(), + base_log_bsk, + level_bsk, + ); + + let fft = Fft::new(polynomial_size); + let fft = fft.as_view(); + + let mut mem = GlobalMemBuffer::new(fill_with_forward_fourier_scratch(fft).unwrap()); + fourier_bsk.as_mut_view().fill_with_forward_fourier( + coef_bsk.as_view(), + fft, + DynStack::new(&mut mem), + ); + + let mut ksk_lwe_big_to_small = LweKeyswitchKey::allocate( + 0u64, + level_ksk, + base_log_ksk, + lwe_big_sk.key_size(), + lwe_small_sk.key_size(), + ); + ksk_lwe_big_to_small.fill_with_keyswitch_key( + &lwe_big_sk, + &lwe_small_sk, + Variance(std_big.get_variance()), + &mut encryption_generator, + ); + + // Creation of all the pfksk for the circuit bootstrapping + let mut vec_fpksk = LwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u64, + level_pksk, + base_log_pksk, + lwe_big_sk.key_size(), + glwe_sk.key_size(), + glwe_sk.polynomial_size(), + FunctionalPackingKeyswitchKeyCount(glwe_dimension.to_glwe_size().0), + ); + + vec_fpksk.par_fill_with_fpksk_for_circuit_bootstrap( + &lwe_big_sk, + &glwe_sk, + std_small, + &mut encryption_generator, + ); + + let number_of_bits_in_input_lwe = 10; + let number_of_values_to_extract = ExtractedBitsCount(number_of_bits_in_input_lwe); + + let decomposer = SignedDecomposer::new(DecompositionBaseLog(10), DecompositionLevelCount(1)); + + // Here even thought the deltas have the same value, they can differ between ciphertexts and lut + // so keeping both separate + let delta_log = DeltaLog(64 - number_of_values_to_extract.0); + let delta_lut = DeltaLog(64 - number_of_values_to_extract.0); + + let number_of_test_runs = 10; + + for run_number in 0..number_of_test_runs { + let cleartext = + test_tools::random_uint_between(0..2u64.pow(number_of_bits_in_input_lwe as u32)); + + println!("{}", cleartext); + + let message = Plaintext(cleartext << delta_log.0); + let mut lwe_in = + LweCiphertext::allocate(0u64, LweSize(glwe_dimension.0 * polynomial_size.0 + 1)); + lwe_big_sk.encrypt_lwe( + &mut lwe_in, + &message, + Variance(std_big.get_variance()), + &mut encryption_generator, + ); + let mut extracted_bits_lwe_list = LweList::allocate( + 0u64, + ksk_lwe_big_to_small.lwe_size(), + CiphertextCount(number_of_values_to_extract.0), + ); + + let mut mem = GlobalMemBuffer::new( + extract_bits_scratch::( + lwe_dimension, + ksk_lwe_big_to_small.after_key_size(), + fourier_bsk.glwe_size(), + polynomial_size, + fft, + ) + .unwrap(), + ); + extract_bits( + extracted_bits_lwe_list.as_mut_view(), + lwe_in.as_view(), + ksk_lwe_big_to_small.as_view(), + fourier_bsk.as_view(), + delta_log, + number_of_values_to_extract, + fft, + DynStack::new(&mut mem), + ); + + // Decrypt all extracted bit for checking purposes in case of problems + for ct in extracted_bits_lwe_list.ciphertext_iter() { + let mut decrypted_message = Plaintext(0u64); + lwe_small_sk.decrypt_lwe(&mut decrypted_message, &ct); + let extract_bit_result = + (((decrypted_message.0 as f64) / (1u64 << (63)) as f64).round()) as u64; + println!("{:?}", extract_bit_result); + println!("{:?}", decrypted_message); + } + + // LUT creation + let number_of_luts_and_output_vp_ciphertexts = 1; + let mut lut_size = polynomial_size.0; + + let lut_poly_list = if run_number % 2 == 0 { + // Test with a small lut, only triggering a blind rotate + if lut_size < (1 << extracted_bits_lwe_list.count().0) { + lut_size = 1 << extracted_bits_lwe_list.count().0; + } + let mut lut = Vec::with_capacity(lut_size); + + for i in 0..lut_size { + lut.push((i as u64 % (1 << (64 - delta_log.0))) << delta_lut.0); + } + + // Here we have a single lut, so store it directly in the polynomial list + PolynomialList::from_container(lut, PolynomialSize(lut_size)) + } else { + // Test with a big lut, triggering an actual cmux tree + let mut lut_poly_list = PolynomialList::allocate( + 0u64, + PolynomialCount(1 << number_of_bits_in_input_lwe), + polynomial_size, + ); + for (i, mut polynomial) in lut_poly_list.polynomial_iter_mut().enumerate() { + polynomial + .as_mut_tensor() + .fill_with_element((i as u64 % (1 << (64 - delta_log.0))) << delta_lut.0); + } + lut_poly_list + }; + + // We need as many output ciphertexts as we have input luts + let mut vertical_packing_lwe_list_out = LweList::allocate( + 0u64, + LweDimension(polynomial_size.0 * glwe_dimension.0).to_lwe_size(), + CiphertextCount(number_of_luts_and_output_vp_ciphertexts), + ); + + // Perform circuit bootstrap + vertical packing + let mut mem = GlobalMemBuffer::new( + circuit_bootstrap_boolean_vertical_packing_scratch::( + extracted_bits_lwe_list.count(), + vertical_packing_lwe_list_out.count(), + extracted_bits_lwe_list.lwe_size(), + lut_poly_list.polynomial_count(), + fourier_bsk.output_lwe_dimension().to_lwe_size(), + vec_fpksk.output_polynomial_size(), + fourier_bsk.glwe_size(), + level_cbs, + fft, + ) + .unwrap(), + ); + circuit_bootstrap_boolean_vertical_packing( + lut_poly_list.as_view(), + fourier_bsk.as_view(), + vertical_packing_lwe_list_out.as_mut_view(), + extracted_bits_lwe_list.as_view(), + vec_fpksk.as_view(), + level_cbs, + base_log_cbs, + fft, + DynStack::new(&mut mem), + ); + + // We have a single output ct + let result_ct = vertical_packing_lwe_list_out + .ciphertext_iter() + .next() + .unwrap(); + + // decrypt result + let mut decrypted_message = Plaintext(0u64); + let lwe_sk = LweSecretKey::binary_from_container(glwe_sk.as_tensor().as_slice()); + lwe_sk.decrypt_lwe(&mut decrypted_message, &result_ct); + let decoded_message = decomposer.closest_representable(decrypted_message.0) >> delta_log.0; + + // print information if the result is wrong + if decoded_message != cleartext { + panic!( + "decoded_message ({:?}) != cleartext ({:?})\n\ + decrypted_message: {:?}, decoded_message: {:?}", + decoded_message, cleartext, decrypted_message, decoded_message + ); + } + println!("{:?}", decoded_message); + } +} diff --git a/tfhe/src/core_crypto/backends/fft/private/math/decomposition.rs b/tfhe/src/core_crypto/backends/fft/private/math/decomposition.rs new file mode 100644 index 000000000..ee86e914f --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/math/decomposition.rs @@ -0,0 +1,86 @@ +pub use crate::core_crypto::commons::math::decomposition::DecompositionLevel; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; +use dyn_stack::{DynArray, DynStack}; +use std::iter::Map; +use std::slice::IterMut; + +// copied from src/commons/math/decomposition/*.rs +// in order to avoid allocations + +pub struct TensorSignedDecompositionLendingIter<'buffers, Scalar: UnsignedInteger> { + // The base log of the decomposition + base_log: usize, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: Scalar, + // The internal states of each decomposition + states: DynArray<'buffers, Scalar>, + // A flag which stores whether the iterator is a fresh one (for the recompose method). + fresh: bool, +} + +impl<'buffers, Scalar: UnsignedInteger> TensorSignedDecompositionLendingIter<'buffers, Scalar> { + #[inline] + pub(crate) fn new( + input: impl Iterator, + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + stack: DynStack<'buffers>, + ) -> (Self, DynStack<'buffers>) { + let shift = Scalar::BITS - base_log.0 * level.0; + let (states, stack) = + stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift)); + ( + TensorSignedDecompositionLendingIter { + base_log: base_log.0, + current_level: level.0, + mod_b_mask: (Scalar::ONE << base_log.0) - Scalar::ONE, + states, + fresh: true, + }, + stack, + ) + } + + // inlining this improves perf of external product by about 25%, even in LTO builds + #[inline] + pub fn next_term<'short>( + &'short mut self, + ) -> Option<( + DecompositionLevel, + DecompositionBaseLog, + Map, impl FnMut(&'short mut Scalar) -> Scalar>, + )> { + // The iterator is not fresh anymore. + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + let current_level = self.current_level; + let base_log = self.base_log; + let mod_b_mask = self.mod_b_mask; + self.current_level -= 1; + + Some(( + DecompositionLevel(current_level), + DecompositionBaseLog(self.base_log), + self.states + .iter_mut() + .map(move |state| decompose_one_level(base_log, state, mod_b_mask)), + )) + } +} + +#[inline] +fn decompose_one_level(base_log: usize, state: &mut S, mod_b_mask: S) -> S { + let res = *state & mod_b_mask; + *state >>= base_log; + let mut carry = (res.wrapping_sub(S::ONE) | *state) & res; + carry >>= base_log - 1; + *state += carry; + res.wrapping_sub(carry << base_log) +} diff --git a/tfhe/src/core_crypto/backends/fft/private/math/fft/mod.rs b/tfhe/src/core_crypto/backends/fft/private/math/fft/mod.rs new file mode 100644 index 000000000..f555986c5 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/math/fft/mod.rs @@ -0,0 +1,661 @@ +use super::super::assume_init_mut; +use super::polynomial::{ + FourierPolynomialMutView, FourierPolynomialUninitMutView, FourierPolynomialView, + PolynomialUninitMutView, +}; +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::math::tensor::Container; +#[cfg(feature = "backend_fft_serialization")] +use crate::core_crypto::commons::math::tensor::ContainerOwned; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::CastInto; +use crate::core_crypto::commons::utils::izip; +use crate::core_crypto::prelude::PolynomialSize; +use aligned_vec::{avec, ABox}; +use concrete_fft::c64; +use concrete_fft::unordered::{Method, Plan}; +use dyn_stack::{DynStack, SizeOverflow, StackReq}; +use once_cell::sync::OnceCell; +use std::any::TypeId; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::mem::{align_of, size_of, MaybeUninit}; +use std::sync::{Arc, RwLock}; +use std::time::Duration; + +#[cfg(any(target_arch = "x86_64", target_arch = "x86"))] +mod x86; + +/// Twisting factors from the paper: +/// [Fast and Error-Free Negacyclic Integer Convolution using Extended Fourier Transform][paper] +/// +/// The real and imaginary parts form (the first `N/2`) `2N`-th roots of unity. +/// +/// [paper]: https://eprint.iacr.org/2021/480 +#[derive(Clone, Debug, PartialEq)] +pub struct Twisties { + re: ABox<[f64]>, + im: ABox<[f64]>, +} + +/// View type for [`Twisties`]. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct TwistiesView<'a> { + re: &'a [f64], + im: &'a [f64], +} + +impl Twisties { + pub fn as_view(&self) -> TwistiesView<'_> { + TwistiesView { + re: &self.re, + im: &self.im, + } + } +} + +impl Twisties { + /// Creates a new [`Twisties`] containing the `2N`-th roots of unity with `n = N/2`. + /// + /// # Panics + /// + /// Panics if `n` is not a power of two. + pub fn new(n: usize) -> Self { + debug_assert!(n.is_power_of_two()); + let mut re = avec![0.0; n].into_boxed_slice(); + let mut im = avec![0.0; n].into_boxed_slice(); + + let unit = core::f64::consts::PI / (2.0 * n as f64); + for (i, (re, im)) in izip!(&mut *re, &mut *im).enumerate() { + (*im, *re) = (i as f64 * unit).sin_cos(); + } + + Twisties { re, im } + } +} + +/// Negacyclic Fast Fourier Transform. See [`FftView`] for transform functions. +/// +/// This structure contains the twisting factors as well as the +/// FFT plan needed for the negacyclic convolution over the reals. +#[derive(Clone, Debug)] +pub struct Fft { + plan: Arc<(Twisties, Plan)>, +} + +/// View type for [`Fft`]. +#[derive(Clone, Copy, Debug)] +pub struct FftView<'a> { + plan: &'a Plan, + twisties: TwistiesView<'a>, +} + +impl Fft { + #[inline] + pub fn as_view(&self) -> FftView<'_> { + FftView { + plan: &self.plan.1, + twisties: self.plan.0.as_view(), + } + } +} + +type PlanMap = RwLock>>>>; +static PLANS: OnceCell = OnceCell::new(); +fn plans() -> &'static PlanMap { + PLANS.get_or_init(|| RwLock::new(HashMap::new())) +} + +/// Returns the input slice, cast to the same type. +/// +/// This is useful when the fact that `From` and `To` are the same type cannot be proven in the +/// type system, but is known to be true at runtime. +/// +/// # Panics +/// +/// Panics if `From` and `To` are not the same type +#[inline] +#[allow(dead_code)] +fn id_mut(slice: &mut [From]) -> &mut [To] { + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + assert_eq!(TypeId::of::(), TypeId::of::()); + + let len = slice.len(); + let ptr = slice.as_mut_ptr(); + unsafe { core::slice::from_raw_parts_mut(ptr as *mut To, len) } +} + +/// Returns the input slice, cast to the same type. +/// +/// This is useful when the fact that `From` and `To` are the same type cannot be proven in the +/// type system, but is known to be true at runtime. +/// +/// # Panics +/// +/// Panics if `From` and `To` are not the same type +#[inline] +#[allow(dead_code)] +fn id(slice: &[From]) -> &[To] { + assert_eq!(size_of::(), size_of::()); + assert_eq!(align_of::(), align_of::()); + assert_eq!(TypeId::of::(), TypeId::of::()); + + let len = slice.len(); + let ptr = slice.as_ptr(); + unsafe { core::slice::from_raw_parts(ptr as *const To, len) } +} + +impl Fft { + /// Real polynomial of size `size`. + pub fn new(size: PolynomialSize) -> Self { + let global_plans = plans(); + + let n = size.0; + let get_plan = || { + let plans = global_plans.read().unwrap(); + let plan = plans.get(&n).cloned(); + drop(plans); + + plan.map(|p| { + p.get_or_init(|| { + Arc::new(( + Twisties::new(n / 2), + Plan::new(n / 2, Method::Measure(Duration::from_millis(10))), + )) + }) + .clone() + }) + }; + + // could not find a plan of the given size, we lock the map again and try to insert it + let mut plans = global_plans.write().unwrap(); + if let Entry::Vacant(v) = plans.entry(n) { + v.insert(Arc::new(OnceCell::new())); + } + + drop(plans); + + Self { + plan: get_plan().unwrap(), + } + } +} + +#[cfg_attr(__profiling, inline(never))] +fn convert_forward_torus( + out: &mut [MaybeUninit], + in_re: &[Scalar], + in_im: &[Scalar], + twisties: TwistiesView<'_>, +) { + let normalization = 2.0_f64.powi(-(Scalar::BITS as i32)); + + izip!(out, in_re, in_im, twisties.re, twisties.im).for_each( + |(out, in_re, in_im, w_re, w_im)| { + let in_re: f64 = in_re.into_signed().cast_into() * normalization; + let in_im: f64 = in_im.into_signed().cast_into() * normalization; + out.write( + c64 { + re: in_re, + im: in_im, + } * c64 { + re: *w_re, + im: *w_im, + }, + ); + }, + ); +} + +fn convert_forward_integer_scalar( + out: &mut [MaybeUninit], + in_re: &[Scalar], + in_im: &[Scalar], + twisties: TwistiesView<'_>, +) { + izip!(out, in_re, in_im, twisties.re, twisties.im).for_each( + |(out, in_re, in_im, w_re, w_im)| { + let in_re: f64 = in_re.into_signed().cast_into(); + let in_im: f64 = in_im.into_signed().cast_into(); + out.write( + c64 { + re: in_re, + im: in_im, + } * c64 { + re: *w_re, + im: *w_im, + }, + ); + }, + ); +} + +#[cfg_attr(__profiling, inline(never))] +fn convert_forward_integer( + out: &mut [MaybeUninit], + in_re: &[Scalar], + in_im: &[Scalar], + twisties: TwistiesView<'_>, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + { + if Scalar::BITS == 32 { + x86::convert_forward_integer_u32(out, id(in_re), id(in_im), twisties); + } else if Scalar::BITS == 64 { + x86::convert_forward_integer_u64(out, id(in_re), id(in_im), twisties); + } else { + unreachable!(); + } + } + + // SAFETY: same as above + #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] + convert_forward_integer_scalar::(out, in_re, in_im, twisties) +} + +#[cfg_attr(__profiling, inline(never))] +fn convert_backward_torus( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + let normalization = 1.0 / inp.len() as f64; + izip!(out_re, out_im, inp, twisties.re, twisties.im).for_each( + |(out_re, out_im, inp, w_re, w_im)| { + let tmp = inp + * (c64 { + re: *w_re, + im: -*w_im, + } * normalization); + + out_re.write(Scalar::from_torus(tmp.re)); + out_im.write(Scalar::from_torus(tmp.im)); + }, + ); +} + +/// See [`convert_add_backward_torus`]. +/// +/// # Safety +/// +/// - Same preconditions as [`convert_add_backward_torus`]. +unsafe fn convert_add_backward_torus_scalar( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + let normalization = 1.0 / inp.len() as f64; + izip!(out_re, out_im, inp, twisties.re, twisties.im).for_each( + |(out_re, out_im, inp, w_re, w_im)| { + let tmp = inp + * (c64 { + re: *w_re, + im: -*w_im, + } * normalization); + + let out_re = out_re.assume_init_mut(); + let out_im = out_im.assume_init_mut(); + + *out_re = Scalar::wrapping_add(*out_re, Scalar::from_torus(tmp.re)); + *out_im = Scalar::wrapping_add(*out_im, Scalar::from_torus(tmp.im)); + }, + ); +} + +/// # Warning +/// +/// This function is actually unsafe, but can't be marked as such since we need it to implement +/// `Fn(...)`, as there's no equivalent `unsafe Fn(...)` trait. +/// +/// # Safety +/// +/// - `out_re` and `out_im` must not hold any uninitialized values. +#[cfg_attr(__profiling, inline(never))] +fn convert_add_backward_torus( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] + { + if Scalar::BITS == 32 { + x86::convert_add_backward_torus_u32(id_mut(out_re), id_mut(out_im), inp, twisties); + } else if Scalar::BITS == 64 { + x86::convert_add_backward_torus_u64(id_mut(out_re), id_mut(out_im), inp, twisties); + } else { + unreachable!(); + } + } + + // SAFETY: same as above + #[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))] + unsafe { + convert_add_backward_torus_scalar::(out_re, out_im, inp, twisties) + }; +} + +impl<'a> FftView<'a> { + /// Returns the polynomial size that this FFT was made for. + pub fn polynomial_size(self) -> PolynomialSize { + PolynomialSize(2 * self.plan.fft_size()) + } + + /// Serializes data in the Fourier domain. + #[cfg(feature = "backend_fft_serialization")] + #[cfg_attr(docsrs, doc(cfg(feature = "backend_fft_serialization")))] + pub fn serialize_fourier_buffer( + self, + serializer: S, + buf: &[c64], + ) -> Result { + self.plan.serialize_fourier_buffer(serializer, buf) + } + + /// Deserializes data in the Fourier domain + #[cfg(feature = "backend_fft_serialization")] + #[cfg_attr(docsrs, doc(cfg(feature = "backend_fft_serialization")))] + pub fn deserialize_fourier_buffer<'de, D: serde::Deserializer<'de>>( + self, + deserializer: D, + buf: &mut [c64], + ) -> Result<(), D::Error> { + self.plan.deserialize_fourier_buffer(deserializer, buf) + } + + /// Returns the memory required for a forward negacyclic FFT. + pub fn forward_scratch(self) -> Result { + self.plan.fft_scratch() + } + + /// Returns the memory required for a backward negacyclic FFT. + pub fn backward_scratch(self) -> Result { + self.plan + .fft_scratch()? + .try_and(StackReq::try_new_aligned::( + self.polynomial_size().0 / 2, + aligned_vec::CACHELINE_ALIGN, + )?) + } + + /// Performs a negacyclic real FFT of `standard`, viewed as torus elements, and stores the + /// result in `fourier`. + /// + /// # Note + /// + /// this function leaves all the elements of `out` in an initialized state. + /// + /// # Panics + /// + /// Panics if `standard` and `self` have differing polynomial sizes, or if `fourier` doesn't + /// have size equal to that amount divided by two. + pub fn forward_as_torus<'out, Scalar: UnsignedTorus>( + self, + fourier: FourierPolynomialUninitMutView<'out>, + standard: Polynomial<&'_ [Scalar]>, + stack: DynStack<'_>, + ) -> FourierPolynomialMutView<'out> { + // SAFETY: `convert_forward_torus` initializes the output slice that is passed to it + unsafe { self.forward_with_conv(fourier, standard, convert_forward_torus, stack) } + } + + /// Performs a negacyclic real FFT of `standard`, viewed as integers, and stores the result in + /// `fourier`. + /// + /// # Note + /// + /// this function leaves all the elements of `out` in an initialized state. + /// + /// # Panics + /// + /// Panics if `standard` and `self` have differing polynomial sizes, or if `fourier` doesn't + /// have size equal to that amount divided by two. + pub fn forward_as_integer<'out, Scalar: UnsignedTorus>( + self, + fourier: FourierPolynomialUninitMutView<'out>, + standard: Polynomial<&'_ [Scalar]>, + stack: DynStack<'_>, + ) -> FourierPolynomialMutView<'out> { + // SAFETY: `convert_forward_integer` initializes the output slice that is passed to it + unsafe { self.forward_with_conv(fourier, standard, convert_forward_integer, stack) } + } + + /// Performs an inverse negacyclic real FFT of `fourier` and stores the result in `standard`, + /// viewed as torus elements. + /// + /// # Note + /// + /// this function leaves all the elements of `out_re` and `out_im` in an initialized state. + /// + /// # Panics + /// + /// See [`Self::forward_as_torus`] + pub fn backward_as_torus<'out, Scalar: UnsignedTorus>( + self, + standard: PolynomialUninitMutView<'out, Scalar>, + fourier: FourierPolynomialView<'_>, + stack: DynStack<'_>, + ) { + // SAFETY: `convert_backward_torus` initializes the output slices that are passed to it + unsafe { self.backward_with_conv(standard, fourier, convert_backward_torus, stack) } + } + + /// Performs an inverse negacyclic real FFT of `fourier` and adds the result to `standard`, + /// viewed as torus elements. + /// + /// # Note + /// + /// this function leaves all the elements of `out_re` and `out_im` in an initialized state. + /// + /// # Panics + /// + /// See [`Self::forward_as_torus`] + pub fn add_backward_as_torus<'out, Scalar: UnsignedTorus>( + self, + standard: Polynomial<&'out mut [Scalar]>, + fourier: FourierPolynomialView<'_>, + stack: DynStack<'_>, + ) { + // SAFETY: `convert_add_backward_torus` initializes the output slices that are passed to it + unsafe { + self.backward_with_conv( + standard.into_uninit(), + fourier, + convert_add_backward_torus, + stack, + ) + } + } + + /// # Safety + /// + /// `conv_fn` must initialize the entirety of the mutable slice that it receives. + unsafe fn forward_with_conv< + 'out, + Scalar: UnsignedTorus, + F: Fn(&mut [MaybeUninit], &[Scalar], &[Scalar], TwistiesView<'_>), + >( + self, + fourier: FourierPolynomialUninitMutView<'out>, + standard: Polynomial<&'_ [Scalar]>, + conv_fn: F, + stack: DynStack<'_>, + ) -> FourierPolynomialMutView<'out> { + let fourier = fourier.data; + let standard = standard.tensor.into_container(); + let n = standard.len(); + debug_assert_eq!(n, 2 * fourier.len()); + let (standard_re, standard_im) = standard.split_at(n / 2); + conv_fn(fourier, standard_re, standard_im, self.twisties); + let fourier = assume_init_mut(fourier); + self.plan.fwd(fourier, stack); + FourierPolynomialMutView { data: fourier } + } + + /// # Safety + /// + /// `conv_fn` must initialize the entirety of the mutable slices that it receives. + unsafe fn backward_with_conv< + 'out, + Scalar: UnsignedTorus, + F: Fn(&mut [MaybeUninit], &mut [MaybeUninit], &[c64], TwistiesView<'_>), + >( + self, + standard: PolynomialUninitMutView<'out, Scalar>, + fourier: FourierPolynomialView<'_>, + conv_fn: F, + stack: DynStack<'_>, + ) { + let fourier = fourier.data; + let standard = standard.tensor.into_container(); + let n = standard.len(); + debug_assert_eq!(n, 2 * fourier.len()); + let (mut tmp, stack) = + stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier.iter().copied()); + self.plan.inv(&mut tmp, stack); + + let (standard_re, standard_im) = standard.split_at_mut(n / 2); + conv_fn(standard_re, standard_im, &tmp, self.twisties); + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FourierPolynomialList> { + pub data: C, + pub polynomial_size: PolynomialSize, +} + +#[cfg(feature = "backend_fft_serialization")] +impl> serde::Serialize for FourierPolynomialList { + fn serialize(&self, serializer: S) -> Result { + fn serialize_impl( + data: &[c64], + polynomial_size: PolynomialSize, + serializer: S, + ) -> Result { + use crate::core_crypto::commons::math::tensor::Split; + + pub struct SingleFourierPolynomial<'a> { + fft: FftView<'a>, + buf: &'a [c64], + } + + impl<'a> serde::Serialize for SingleFourierPolynomial<'a> { + fn serialize( + &self, + serializer: S, + ) -> Result { + self.fft.serialize_fourier_buffer(serializer, self.buf) + } + } + + use serde::ser::SerializeSeq; + let chunk_count = if polynomial_size.0 == 0 { + 0 + } else { + data.len() / (polynomial_size.0 / 2) + }; + + let mut state = serializer.serialize_seq(Some(2 + chunk_count))?; + state.serialize_element(&polynomial_size)?; + state.serialize_element(&chunk_count)?; + if chunk_count != 0 { + let fft = Fft::new(polynomial_size); + for buf in data.split_into(chunk_count) { + state.serialize_element(&SingleFourierPolynomial { + fft: fft.as_view(), + buf, + })?; + } + } + state.end() + } + + serialize_impl(self.data.as_ref(), self.polynomial_size, serializer) + } +} + +#[cfg(feature = "backend_fft_serialization")] +impl<'de, C: ContainerOwned> serde::Deserialize<'de> for FourierPolynomialList { + fn deserialize>(deserializer: D) -> Result { + use std::marker::PhantomData; + struct SeqVisitor>(PhantomData C>); + + impl<'de, C: ContainerOwned> serde::de::Visitor<'de> for SeqVisitor { + type Value = FourierPolynomialList; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str( + "a sequence of two fields followed by polynomials in the Fourier domain", + ) + } + + fn visit_seq>( + self, + mut seq: A, + ) -> Result { + use crate::core_crypto::commons::math::tensor::Split; + + let str = "sequence of two fields and Fourier polynomials"; + let polynomial_size = match seq.next_element::()? { + Some(polynomial_size) => polynomial_size, + None => return Err(serde::de::Error::invalid_length(0, &str)), + }; + let chunk_count = match seq.next_element::()? { + Some(chunk_count) => chunk_count, + None => return Err(serde::de::Error::invalid_length(1, &str)), + }; + + struct FillFourier<'a> { + fft: FftView<'a>, + buf: &'a mut [c64], + } + + impl<'de, 'a> serde::de::DeserializeSeed<'de> for FillFourier<'a> { + type Value = (); + + fn deserialize>( + self, + deserializer: D, + ) -> Result { + self.fft.deserialize_fourier_buffer(deserializer, self.buf) + } + } + + let mut data = + C::collect((0..(polynomial_size.0 / 2 * chunk_count)).map(|_| c64::default())); + + if chunk_count != 0 { + let fft = Fft::new(polynomial_size); + for (i, buf) in data.as_mut().split_into(chunk_count).enumerate() { + match seq.next_element_seed(FillFourier { + fft: fft.as_view(), + buf, + })? { + Some(()) => (), + None => { + return Err(serde::de::Error::invalid_length( + i, + &&*format!("sequence of {chunk_count} Fourier polynomials"), + )) + } + }; + } + } + + Ok(FourierPolynomialList { + data, + polynomial_size, + }) + } + } + + deserializer.deserialize_seq(SeqVisitor::(PhantomData)) + } +} + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/core_crypto/backends/fft/private/math/fft/tests.rs b/tfhe/src/core_crypto/backends/fft/private/math/fft/tests.rs new file mode 100644 index 000000000..7004407a6 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/math/fft/tests.rs @@ -0,0 +1,250 @@ +use dyn_stack::{GlobalMemBuffer, ReborrowMut}; + +use super::super::super::super::private::math::polynomial::FourierPolynomial; +use super::*; +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::test_tools::new_random_generator; +use aligned_vec::avec; + +fn abs_diff(a: Scalar, b: Scalar) -> Scalar { + if a > b { + a - b + } else { + b - a + } +} + +fn test_roundtrip() { + let mut generator = new_random_generator(); + for i in 2..=14 { + let size = 1_usize << i; + + let fft = Fft::new(PolynomialSize(size)); + let fft = fft.as_view(); + + let mut poly = Polynomial::from_container(avec![Scalar::ZERO; size].into_boxed_slice()); + let mut roundtrip = + Polynomial::from_container(avec![Scalar::ZERO; size].into_boxed_slice()); + let mut fourier = FourierPolynomial { + data: avec![c64::default(); size / 2].into_boxed_slice(), + }; + + for x in poly.tensor.as_mut_container().iter_mut() { + *x = generator.random_uniform(); + } + + let mut mem = GlobalMemBuffer::new( + fft.forward_scratch() + .unwrap() + .and(fft.backward_scratch().unwrap()), + ); + let mut stack = DynStack::new(&mut mem); + + fft.forward_as_torus( + unsafe { fourier.as_mut_view().into_uninit() }, + poly.as_view(), + stack.rb_mut(), + ); + fft.backward_as_torus( + unsafe { roundtrip.as_mut_view().into_uninit() }, + fourier.as_view(), + stack.rb_mut(), + ); + + for (expected, actual) in izip!( + poly.tensor.as_container().iter(), + roundtrip.tensor.as_container().iter() + ) { + if Scalar::BITS == 32 { + assert!(abs_diff(*expected, *actual) == Scalar::ZERO); + } else { + assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (64 - 50))); + } + } + } +} + +fn test_product() { + fn convolution_naive( + out: &mut [Scalar], + lhs: &[Scalar], + rhs: &[Scalar], + ) { + assert_eq!(out.len(), lhs.len()); + assert_eq!(out.len(), rhs.len()); + let n = out.len(); + let mut full_prod = vec![Scalar::ZERO; 2 * n]; + for i in 0..n { + for j in 0..n { + full_prod[i + j] = full_prod[i + j].wrapping_add(lhs[i].wrapping_mul(rhs[j])); + } + } + for i in 0..n { + out[i] = full_prod[i].wrapping_sub(full_prod[i + n]); + } + } + + let mut generator = new_random_generator(); + for i in 5..=14 { + for _ in 0..100 { + let size = 1_usize << i; + + let fft = Fft::new(PolynomialSize(size)); + let fft = fft.as_view(); + + let mut poly0 = + Polynomial::from_container(avec![Scalar::ZERO; size].into_boxed_slice()); + let mut poly1 = + Polynomial::from_container(avec![Scalar::ZERO; size].into_boxed_slice()); + + let mut convolution_from_fft = + Polynomial::from_container(avec![Scalar::ZERO; size].into_boxed_slice()); + let mut convolution_from_naive = + Polynomial::from_container(avec![Scalar::ZERO; size].into_boxed_slice()); + + let mut fourier0 = FourierPolynomial { + data: avec![c64::default(); size / 2].into_boxed_slice(), + }; + let mut fourier1 = FourierPolynomial { + data: avec![c64::default(); size / 2 ].into_boxed_slice(), + }; + + for (x, y) in izip!( + poly0.tensor.as_mut_container().iter_mut(), + poly1.tensor.as_mut_container().iter_mut() + ) { + *x = generator.random_uniform(); + *y = generator.random_uniform(); + if Scalar::BITS == 64 { + *x >>= 32; + *y >>= 32; + } else { + *x >>= 16; + *y >>= 16; + } + } + + let mut mem = GlobalMemBuffer::new( + fft.forward_scratch() + .unwrap() + .and(fft.backward_scratch().unwrap()), + ); + let mut stack = DynStack::new(&mut mem); + + // SAFETY: forward_as_torus doesn't write any uninitialized values into its output + fft.forward_as_torus( + unsafe { fourier0.as_mut_view().into_uninit() }, + poly0.as_view(), + stack.rb_mut(), + ); + // SAFETY: forward_as_integer doesn't write any uninitialized values into its output + fft.forward_as_integer( + unsafe { fourier1.as_mut_view().into_uninit() }, + poly1.as_view(), + stack.rb_mut(), + ); + + for (f0, f1) in izip!(&mut *fourier0.data, &*fourier1.data) { + *f0 *= *f1; + } + + // SAFETY: backward_as_torus doesn't write any uninitialized values into its output + fft.backward_as_torus( + unsafe { convolution_from_fft.as_mut_view().into_uninit() }, + fourier0.as_view(), + stack.rb_mut(), + ); + convolution_naive( + convolution_from_naive.tensor.as_mut_container(), + poly0.tensor.as_container(), + poly1.tensor.as_container(), + ); + + for (expected, actual) in izip!( + convolution_from_naive.tensor.as_container().iter(), + convolution_from_fft.tensor.as_container().iter() + ) { + assert!(abs_diff(*expected, *actual) < (Scalar::ONE << (Scalar::BITS - 5))); + } + } + } +} + +#[test] +fn test_product_u32() { + test_product::(); +} + +#[test] +fn test_product_u64() { + test_product::(); +} + +#[test] +fn test_roundtrip_u32() { + test_roundtrip::(); +} +#[test] +fn test_roundtrip_u64() { + test_roundtrip::(); +} + +#[test] +fn f64_to_i64_bit_twiddles() { + for x in [ + 0.0, + -0.0, + 37.1242161_f64, + -37.1242161_f64, + 0.1, + -0.1, + 1.0, + -1.0, + 0.9, + -0.9, + 2.0, + -2.0, + 1e-310, + -1e-310, + 2.0_f64.powi(62), + -(2.0_f64.powi(62)), + 1.1 * 2.0_f64.powi(62), + 1.1 * -(2.0_f64.powi(62)), + -(2.0_f64.powi(63)), + ] { + // this test checks the correctness of converting from f64 to i64 by manipulating the bits + // of the ieee754 representation of the floating point values. + // + // if the value is not representable as an i64, the result is unspecified. + // + // https://en.wikipedia.org/wiki/Double-precision_floating-point_format + let bits = x.to_bits(); + let implicit_mantissa = bits & 0xFFFFFFFFFFFFF; + let explicit_mantissa = implicit_mantissa | 0x10000000000000; + let biased_exp = ((bits >> 52) & 0x7FF) as i64; + let sign = bits >> 63; + + let explicit_mantissa_lshift = explicit_mantissa << 11; + + // equivalent to: + // + // let exp = biased_exp - 1023; + // let explicit_mantissa_shift = explicit_mantissa_lshift >> (63 - exp.max(0)); + let right_shift_amount = (1086 - biased_exp) as u64; + + let explicit_mantissa_shift = if right_shift_amount < 64 { + explicit_mantissa_lshift >> right_shift_amount + } else { + 0 + }; + + let value = if sign == 0 { + explicit_mantissa_shift as i64 + } else { + -(explicit_mantissa_shift as i64) + }; + + let value = if biased_exp == 0 { 0 } else { value }; + assert_eq!(value as i64, x as i64); + } +} diff --git a/tfhe/src/core_crypto/backends/fft/private/math/fft/x86.rs b/tfhe/src/core_crypto/backends/fft/private/math/fft/x86.rs new file mode 100644 index 000000000..f548dbe48 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/math/fft/x86.rs @@ -0,0 +1,1063 @@ +//! For documentation on the various intrinsics used here, refer to Intel's intrinsics guide. +//! +//! +//! currently we dispatch based on the availability of +//! - avx+avx2(advanced vector extensions for 256 intrinsics)+fma(fused multiply add for complex +//! multiplication, usually comes with avx+avx2), +//! - or the availability of avx512f[+avx512dq(doubleword/quadword intrinsics for conversion of f64 +//! to/from i64. usually comes with avx512f on modern cpus)] +//! +//! more dispatch options may be added in the future + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +use super::super::super::c64; +use super::TwistiesView; +use std::mem::MaybeUninit; + +/// Converts a vector of f64 values to a vector of i64 values. +/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version. +/// +/// # Safety +/// +/// - `is_x86_feature_detected!("avx2")` must be true. +#[inline(always)] +pub unsafe fn mm256_cvtpd_epi64(x: __m256d) -> __m256i { + // reinterpret the bits as u64 values + let bits = _mm256_castpd_si256(x); + // mask that covers the first 52 bits + let mantissa_mask = _mm256_set1_epi64x(0xFFFFFFFFFFFFF_u64 as i64); + // mask that covers the 52nd bit + let explicit_mantissa_bit = _mm256_set1_epi64x(0x10000000000000_u64 as i64); + // mask that covers the first 11 bits + let exp_mask = _mm256_set1_epi64x(0x7FF_u64 as i64); + + // extract the first 52 bits and add the implicit bit + let mantissa = _mm256_or_si256(_mm256_and_si256(bits, mantissa_mask), explicit_mantissa_bit); + + // extract the 52nd to 63rd (excluded) bits for the biased exponent + let biased_exp = _mm256_and_si256(_mm256_srli_epi64::<52>(bits), exp_mask); + + // extract the 63rd sign bit + let sign_is_negative_mask = + _mm256_sub_epi64(_mm256_setzero_si256(), _mm256_srli_epi64::<63>(bits)); + + // we need to shift the mantissa by some value that may be negative, so we first shift it to + // the left by the maximum amount, then shift it to the right by our value plus the offset we + // just shifted by + // + // the 52nd bit is set to 1, so we shift to the left by 11 so the 63rd (last) bit is set. + let mantissa_lshift = _mm256_slli_epi64::<11>(mantissa); + + // shift to the right and apply the exponent bias + let mantissa_shift = _mm256_srlv_epi64( + mantissa_lshift, + _mm256_sub_epi64(_mm256_set1_epi64x(1086), biased_exp), + ); + + // if the sign bit is unset, we keep our result + let value_if_positive = mantissa_shift; + // otherwise, we negate it + let value_if_negative = _mm256_sub_epi64(_mm256_setzero_si256(), value_if_positive); + + // if the biased exponent is all zeros, we have a subnormal value (or zero) + + // if it is not subnormal, we keep our results + let value_if_non_subnormal = + _mm256_blendv_epi8(value_if_positive, value_if_negative, sign_is_negative_mask); + + // if it is subnormal, the conversion to i64 (rounding towards zero) returns zero + let value_if_subnormal = _mm256_setzero_si256(); + + // compare the biased exponent to a zero value + let is_subnormal = _mm256_cmpeq_epi64(biased_exp, _mm256_setzero_si256()); + + // choose the result depending on subnormalness + _mm256_blendv_epi8(value_if_non_subnormal, value_if_subnormal, is_subnormal) +} + +/// Converts a vector of f64 values to a vector of i64 values. +/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version. +/// +/// # Safety +/// +/// - `is_x86_feature_detected!("avx2")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[inline(always)] +pub unsafe fn mm512_cvtpd_epi64(x: __m512d) -> __m512i { + // reinterpret the bits as u64 values + let bits = _mm512_castpd_si512(x); + // mask that covers the first 52 bits + let mantissa_mask = _mm512_set1_epi64(0xFFFFFFFFFFFFF_u64 as i64); + // mask that covers the 53rd bit + let explicit_mantissa_bit = _mm512_set1_epi64(0x10000000000000_u64 as i64); + // mask that covers the first 11 bits + let exp_mask = _mm512_set1_epi64(0x7FF_u64 as i64); + + // extract the first 52 bits and add the implicit bit + let mantissa = _mm512_or_si512(_mm512_and_si512(bits, mantissa_mask), explicit_mantissa_bit); + + // extract the 52nd to 63rd (excluded) bits for the biased exponent + let biased_exp = _mm512_and_si512(_mm512_srli_epi64::<52>(bits), exp_mask); + + // extract the 63rd sign bit + let sign_is_negative_mask = + _mm512_cmpneq_epi64_mask(_mm512_srli_epi64::<63>(bits), _mm512_set1_epi64(1)); + + // we need to shift the mantissa by some value that may be negative, so we first shift it to + // the left by the maximum amount, then shift it to the right by our value plus the offset we + // just shifted by + // + // the 53rd bit is set to 1, so we shift to the left by 10 so the 63rd (last) bit is set. + let mantissa_lshift = _mm512_slli_epi64::<11>(mantissa); + + // shift to the right and apply the exponent bias + let mantissa_shift = _mm512_srlv_epi64( + mantissa_lshift, + _mm512_sub_epi64(_mm512_set1_epi64(1086), biased_exp), + ); + + // if the sign bit is unset, we keep our result + let value_if_positive = mantissa_shift; + // otherwise, we negate it + let value_if_negative = _mm512_sub_epi64(_mm512_setzero_si512(), value_if_positive); + + // if the biased exponent is all zeros, we have a subnormal value (or zero) + + // if it is not subnormal, we keep our results + let value_if_non_subnormal = + _mm512_mask_blend_epi64(sign_is_negative_mask, value_if_positive, value_if_negative); + + // if it is subnormal, the conversion to i64 (rounding towards zero) returns zero + let value_if_subnormal = _mm512_setzero_si512(); + + // compare the biased exponent to a zero value + let is_subnormal = _mm512_cmpeq_epi64_mask(biased_exp, _mm512_setzero_si512()); + + // choose the result depending on subnormalness + _mm512_mask_blend_epi64(is_subnormal, value_if_non_subnormal, value_if_subnormal) +} + +/// Converts a vector of i64 values to a vector of f64 values. Not sure how it works. +/// Ported from . +/// +/// # Safety +/// +/// - `is_x86_feature_detected!("avx2")` must be true. +#[inline(always)] +pub unsafe fn mm256_cvtepi64_pd(x: __m256i) -> __m256d { + let mut x_hi = _mm256_srai_epi32::<16>(x); + x_hi = _mm256_blend_epi16::<0x33>(x_hi, _mm256_setzero_si256()); + x_hi = _mm256_add_epi64( + x_hi, + _mm256_castpd_si256(_mm256_set1_pd(442721857769029238784.0)), // 3*2^67 + ); + let x_lo = + _mm256_blend_epi16::<0x88>(x, _mm256_castpd_si256(_mm256_set1_pd(4503599627370496.0))); // 2^52 + + let f = _mm256_sub_pd( + _mm256_castsi256_pd(x_hi), + _mm256_set1_pd(442726361368656609280.0), // 3*2^67 + 2^52 + ); + + _mm256_add_pd(f, _mm256_castsi256_pd(x_lo)) +} + +/// Converts a vector of i64 values to a vector of f64 values. +/// +/// # Safety +/// +/// - `is_x86_feature_detected!("avx512dq")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[target_feature(enable = "avx512dq")] +#[inline] +pub unsafe fn mm512_cvtepi64_pd(x: __m512i) -> __m512d { + // hopefully this compiles to vcvtqq2pd + let i64x8: [i64; 8] = core::mem::transmute(x); + let as_f64x8 = [ + i64x8[0] as f64, + i64x8[1] as f64, + i64x8[2] as f64, + i64x8[3] as f64, + i64x8[4] as f64, + i64x8[5] as f64, + i64x8[6] as f64, + i64x8[7] as f64, + ]; + core::mem::transmute(as_f64x8) +} + +/// # Safety +/// +/// - `is_x86_feature_detected!("avx512f")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[target_feature(enable = "avx512f")] +pub unsafe fn convert_forward_integer_u32_avx512f( + out: &mut [MaybeUninit], + in_re: &[u32], + in_im: &[u32], + twisties: TwistiesView<'_>, +) { + let n = out.len(); + debug_assert_eq!(n % 8, 0); + debug_assert_eq!(n, out.len()); + debug_assert_eq!(n, in_re.len()); + debug_assert_eq!(n, in_im.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let out = out.as_mut_ptr() as *mut f64; + let in_re = in_re.as_ptr(); + let in_im = in_im.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 8 { + let i = i * 8; + // load i32 values and convert to f64 + let in_re = _mm512_cvtepi32_pd(_mm256_loadu_si256(in_re.add(i) as _)); + // load i32 values and convert to f64 + let in_im = _mm512_cvtepi32_pd(_mm256_loadu_si256(in_im.add(i) as _)); + // load f64 values + let w_re = _mm512_loadu_pd(w_re.add(i)); + // load f64 values + let w_im = _mm512_loadu_pd(w_im.add(i)); + let out = out.add(2 * i); + + // perform complex multiplication + let out_re = _mm512_fmsub_pd(in_re, w_re, _mm512_mul_pd(in_im, w_im)); + let out_im = _mm512_fmadd_pd(in_re, w_im, _mm512_mul_pd(in_im, w_re)); + + // we have + // x0 x1 x2 x3 x4 x5 x6 x7 + // y0 y1 y2 y3 y4 y5 y6 y7 + // + // we want + // x0 y0 x1 y1 x2 y2 x3 y3 + // x4 y4 x5 y5 x6 y6 x7 y7 + + // interleave real part and imaginary part + { + let idx0 = _mm512_setr_epi64( + 0b0000, 0b1000, 0b0001, 0b1001, 0b0010, 0b1010, 0b0011, 0b1011, + ); + let idx1 = _mm512_setr_epi64( + 0b0100, 0b1100, 0b0101, 0b1101, 0b0110, 0b1110, 0b0111, 0b1111, + ); + + let out0 = _mm512_permutex2var_pd(out_re, idx0, out_im); + let out1 = _mm512_permutex2var_pd(out_re, idx1, out_im); + + // store c64 values + _mm512_storeu_pd(out, out0); + _mm512_storeu_pd(out.add(8), out1); + } + } +} + +/// # Safety +/// +/// - `is_x86_feature_detected!("avx512f")` must be true. +/// - `is_x86_feature_detected!("avx512dq")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[target_feature(enable = "avx512f,avx512dq")] +pub unsafe fn convert_forward_integer_u64_avx512f_avx512dq( + out: &mut [MaybeUninit], + in_re: &[u64], + in_im: &[u64], + twisties: TwistiesView<'_>, +) { + let n = out.len(); + debug_assert_eq!(n % 8, 0); + debug_assert_eq!(n, out.len()); + debug_assert_eq!(n, in_re.len()); + debug_assert_eq!(n, in_im.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let out = out.as_mut_ptr() as *mut f64; + let in_re = in_re.as_ptr(); + let in_im = in_im.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 8 { + let i = i * 8; + // load i64 values and convert to f64 + let in_re = mm512_cvtepi64_pd(_mm512_loadu_si512(in_re.add(i) as _)); + // load i64 values and convert to f64 + let in_im = mm512_cvtepi64_pd(_mm512_loadu_si512(in_im.add(i) as _)); + // load f64 values + let w_re = _mm512_loadu_pd(w_re.add(i)); + // load f64 values + let w_im = _mm512_loadu_pd(w_im.add(i)); + let out = out.add(2 * i); + + // perform complex multiplication + let out_re = _mm512_fmsub_pd(in_re, w_re, _mm512_mul_pd(in_im, w_im)); + let out_im = _mm512_fmadd_pd(in_re, w_im, _mm512_mul_pd(in_im, w_re)); + + // we have + // x0 x1 x2 x3 x4 x5 x6 x7 + // y0 y1 y2 y3 y4 y5 y6 y7 + // + // we want + // x0 y0 x1 y1 x2 y2 x3 y3 + // x4 y4 x5 y5 x6 y6 x7 y7 + + // interleave real part and imaginary part + { + let idx0 = _mm512_setr_epi64( + 0b0000, 0b1000, 0b0001, 0b1001, 0b0010, 0b1010, 0b0011, 0b1011, + ); + let idx1 = _mm512_setr_epi64( + 0b0100, 0b1100, 0b0101, 0b1101, 0b0110, 0b1110, 0b0111, 0b1111, + ); + + let out0 = _mm512_permutex2var_pd(out_re, idx0, out_im); + let out1 = _mm512_permutex2var_pd(out_re, idx1, out_im); + + // store c64 values + _mm512_storeu_pd(out, out0); + _mm512_storeu_pd(out.add(8), out1); + } + } +} + +/// # Safety +/// +/// - `is_x86_feature_detected!("fma")` must be true. +#[target_feature(enable = "avx,fma")] +pub unsafe fn convert_forward_integer_u32_fma( + out: &mut [MaybeUninit], + in_re: &[u32], + in_im: &[u32], + twisties: TwistiesView<'_>, +) { + let n = out.len(); + debug_assert_eq!(n % 4, 0); + debug_assert_eq!(n, out.len()); + debug_assert_eq!(n, in_re.len()); + debug_assert_eq!(n, in_im.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let out = out.as_mut_ptr() as *mut f64; + let in_re = in_re.as_ptr(); + let in_im = in_im.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 4 { + let i = i * 4; + // load i32 values and convert to f64 + let in_re = _mm256_cvtepi32_pd(_mm_loadu_si128(in_re.add(i) as _)); + // load i32 values and convert to f64 + let in_im = _mm256_cvtepi32_pd(_mm_loadu_si128(in_im.add(i) as _)); + // load f64 values + let w_re = _mm256_loadu_pd(w_re.add(i)); + // load f64 values + let w_im = _mm256_loadu_pd(w_im.add(i)); + let out = out.add(2 * i); + + // perform complex multiplication + let out_re = _mm256_fmsub_pd(in_re, w_re, _mm256_mul_pd(in_im, w_im)); + let out_im = _mm256_fmadd_pd(in_re, w_im, _mm256_mul_pd(in_im, w_re)); + + // we have + // x0 x1 x2 x3 + // y0 y1 y2 y3 + // + // we want + // x0 y0 x1 y1 + // x2 y2 x3 y3 + + // interleave real part and imaginary part + { + // unpacklo/unpackhi + // x0 y0 x2 y2 + // x1 y1 x3 y3 + let lo = _mm256_unpacklo_pd(out_re, out_im); + let hi = _mm256_unpackhi_pd(out_re, out_im); + + let out0 = _mm256_permute2f128_pd::<0b00100000>(lo, hi); + let out1 = _mm256_permute2f128_pd::<0b00110001>(lo, hi); + + // store c64 values + _mm256_storeu_pd(out, out0); + _mm256_storeu_pd(out.add(4), out1); + } + } +} + +/// # Safety +/// +/// - `is_x86_feature_detected!("avx2")` must be true. +/// - `is_x86_feature_detected!("fma")` must be true. +#[target_feature(enable = "avx,avx2,fma")] +pub unsafe fn convert_forward_integer_u64_avx2_fma( + out: &mut [MaybeUninit], + in_re: &[u64], + in_im: &[u64], + twisties: TwistiesView<'_>, +) { + let n = out.len(); + debug_assert_eq!(n % 4, 0); + debug_assert_eq!(n, out.len()); + debug_assert_eq!(n, in_re.len()); + debug_assert_eq!(n, in_im.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let out = out.as_mut_ptr() as *mut f64; + let in_re = in_re.as_ptr(); + let in_im = in_im.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 4 { + let i = i * 4; + // load i64 values and convert to f64 + let in_re = mm256_cvtepi64_pd(_mm256_loadu_si256(in_re.add(i) as _)); + // load i64 values and convert to f64 + let in_im = mm256_cvtepi64_pd(_mm256_loadu_si256(in_im.add(i) as _)); + // load f64 values + let w_re = _mm256_loadu_pd(w_re.add(i)); + // load f64 values + let w_im = _mm256_loadu_pd(w_im.add(i)); + let out = out.add(2 * i); + + // perform complex multiplication + let out_re = _mm256_fmsub_pd(in_re, w_re, _mm256_mul_pd(in_im, w_im)); + let out_im = _mm256_fmadd_pd(in_re, w_im, _mm256_mul_pd(in_im, w_re)); + + // we have + // x0 x1 x2 x3 + // y0 y1 y2 y3 + // + // we want + // x0 y0 x1 y1 + // x2 y2 x3 y3 + + // interleave real part and imaginary part + { + // unpacklo/unpackhi + // x0 y0 x2 y2 + // x1 y1 x3 y3 + let lo = _mm256_unpacklo_pd(out_re, out_im); + let hi = _mm256_unpackhi_pd(out_re, out_im); + + let out0 = _mm256_permute2f128_pd::<0b00100000>(lo, hi); + let out1 = _mm256_permute2f128_pd::<0b00110001>(lo, hi); + + // store c64 values + _mm256_storeu_pd(out, out0); + _mm256_storeu_pd(out.add(4), out1); + } + } +} + +/// Performs common work for `u32` and `u64`, used by the backward torus transformation. +/// +/// This deinterleaves two vectors of c64 values into two vectors of real part and imaginary part, +/// then rounds to the nearest integer. +/// +/// # Safety +/// +/// - `w_re.add(i)`, `w_im.add(i)`, and `inp.add(i)` must point to an array of at least 8 +/// elements. +/// - `is_x86_feature_detected!("avx512f")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[inline(always)] +pub unsafe fn convert_torus_prologue_avx512f( + normalization: __m512d, + w_re: *const f64, + i: usize, + w_im: *const f64, + inp: *const c64, + scaling: __m512d, +) -> (__m512d, __m512d) { + let w_re = _mm512_mul_pd(normalization, _mm512_loadu_pd(w_re.add(i))); + let w_im = _mm512_mul_pd(normalization, _mm512_loadu_pd(w_im.add(i))); + + // re0 im0 re1 im1 re2 im2 re3 im3 + let inp0 = _mm512_loadu_pd(inp.add(i) as _); + // re4 im4 re5 im5 re6 im6 re7 im7 + let inp1 = _mm512_loadu_pd(inp.add(i + 4) as _); + + // real indices + let idx0 = _mm512_setr_epi64( + 0b0000, 0b0010, 0b0100, 0b0110, 0b1000, 0b1010, 0b1100, 0b1110, + ); + // imaginary indices + let idx1 = _mm512_setr_epi64( + 0b0001, 0b0011, 0b0101, 0b0111, 0b1001, 0b1011, 0b1101, 0b1111, + ); + + // re0 re1 re2 re3 re4 re5 re6 re7 + let inp_re = _mm512_permutex2var_pd(inp0, idx0, inp1); + // im0 im1 im2 im3 im4 im5 im6 im7 + let inp_im = _mm512_permutex2var_pd(inp0, idx1, inp1); + + // perform complex multiplication with conj(w) + let mul_re = _mm512_fmadd_pd(inp_re, w_re, _mm512_mul_pd(inp_im, w_im)); + let mul_im = _mm512_fnmadd_pd(inp_re, w_im, _mm512_mul_pd(inp_im, w_re)); + + // round to nearest integer and suppress exceptions + const ROUNDING: i32 = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC; + + // get the fractional part (centered around zero) by subtracting rounded value + let fract_re = _mm512_sub_pd(mul_re, _mm512_roundscale_pd::(mul_re)); + let fract_im = _mm512_sub_pd(mul_im, _mm512_roundscale_pd::(mul_im)); + // scale fractional part and round + let fract_re = _mm512_roundscale_pd::(_mm512_mul_pd(scaling, fract_re)); + let fract_im = _mm512_roundscale_pd::(_mm512_mul_pd(scaling, fract_im)); + + (fract_re, fract_im) +} + +/// See [`convert_add_backward_torus`]. +/// +/// # Safety +/// +/// - Same preconditions as [`convert_add_backward_torus`]. +/// - `is_x86_feature_detected!("avx512f")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[target_feature(enable = "avx512f")] +pub unsafe fn convert_add_backward_torus_u32_avx512f( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + let n = out_re.len(); + debug_assert_eq!(n % 8, 0); + debug_assert_eq!(n, out_re.len()); + debug_assert_eq!(n, out_im.len()); + debug_assert_eq!(n, inp.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let normalization = _mm512_set1_pd(1.0 / n as f64); + let scaling = _mm512_set1_pd(2.0_f64.powi(u32::BITS as i32)); + let out_re = out_re.as_mut_ptr() as *mut u32; + let out_im = out_im.as_mut_ptr() as *mut u32; + let inp = inp.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 8 { + let i = i * 8; + + let (fract_re, fract_im) = + convert_torus_prologue_avx512f(normalization, w_re, i, w_im, inp, scaling); + + // convert f64 to i32 + let fract_re = _mm512_cvtpd_epi32(fract_re); + // convert f64 to i32 + let fract_im = _mm512_cvtpd_epi32(fract_im); + + // add to input and store + _mm256_storeu_si256( + out_re.add(i) as _, + _mm256_add_epi32(fract_re, _mm256_loadu_si256(out_re.add(i) as _)), + ); + // add to input and store + _mm256_storeu_si256( + out_im.add(i) as _, + _mm256_add_epi32(fract_im, _mm256_loadu_si256(out_im.add(i) as _)), + ); + } +} + +/// See [`convert_add_backward_torus`]. +/// +/// # Safety +/// +/// - Same preconditions as [`convert_add_backward_torus`]. +/// - `is_x86_feature_detected!("avx512f")` must be true. +#[cfg(feature = "backend_fft_nightly_avx512")] +#[target_feature(enable = "avx512f")] +pub unsafe fn convert_add_backward_torus_u64_avx512f( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + let n = out_re.len(); + debug_assert_eq!(n % 8, 0); + debug_assert_eq!(n, out_re.len()); + debug_assert_eq!(n, out_im.len()); + debug_assert_eq!(n, inp.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let normalization = _mm512_set1_pd(1.0 / n as f64); + let scaling = _mm512_set1_pd(2.0_f64.powi(u64::BITS as i32)); + let out_re = out_re.as_mut_ptr() as *mut u64; + let out_im = out_im.as_mut_ptr() as *mut u64; + let inp = inp.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 8 { + let i = i * 8; + + let (fract_re, fract_im) = + convert_torus_prologue_avx512f(normalization, w_re, i, w_im, inp, scaling); + + // convert f64 to i64 + let fract_re = mm512_cvtpd_epi64(fract_re); + // convert f64 to i64 + let fract_im = mm512_cvtpd_epi64(fract_im); + + // add to input and store + _mm512_storeu_si512( + out_re.add(i) as _, + _mm512_add_epi64(fract_re, _mm512_loadu_si512(out_re.add(i) as _)), + ); + // add to input and store + _mm512_storeu_si512( + out_im.add(i) as _, + _mm512_add_epi64(fract_im, _mm512_loadu_si512(out_im.add(i) as _)), + ); + } +} + +/// Performs common work for `u32` and `u64`, used by the backward torus transformation. +/// +/// This deinterleaves two vectors of c64 values into two vectors of real part and imaginary part, +/// then rounds to the nearest integer. +/// +/// # Safety +/// +/// - `w_re.add(i)`, `w_im.add(i)`, and `inp.add(i)` must point to an array of at least 4 +/// elements. +/// - `is_x86_feature_detected!("fma")` must be true. +#[inline(always)] +pub unsafe fn convert_torus_prologue_fma( + normalization: __m256d, + w_re: *const f64, + i: usize, + w_im: *const f64, + inp: *const c64, + scaling: __m256d, +) -> (__m256d, __m256d) { + let w_re = _mm256_mul_pd(normalization, _mm256_loadu_pd(w_re.add(i))); + let w_im = _mm256_mul_pd(normalization, _mm256_loadu_pd(w_im.add(i))); + + // re0 im0 + let inp0 = _mm_loadu_pd(inp.add(i) as _); + // re1 im1 + let inp1 = _mm_loadu_pd(inp.add(i + 1) as _); + // re2 im2 + let inp2 = _mm_loadu_pd(inp.add(i + 2) as _); + // re3 im3 + let inp3 = _mm_loadu_pd(inp.add(i + 3) as _); + + // re0 re1 + let inp_re01 = _mm_unpacklo_pd(inp0, inp1); + // im0 im1 + let inp_im01 = _mm_unpackhi_pd(inp0, inp1); + // re2 re3 + let inp_re23 = _mm_unpacklo_pd(inp2, inp3); + // im2 im3 + let inp_im23 = _mm_unpackhi_pd(inp2, inp3); + + // re0 re1 re2 re3 + let inp_re = _mm256_insertf128_pd::<0b1>(_mm256_castpd128_pd256(inp_re01), inp_re23); + // im0 im1 im2 im3 + let inp_im = _mm256_insertf128_pd::<0b1>(_mm256_castpd128_pd256(inp_im01), inp_im23); + + // perform complex multiplication with conj(w) + let mul_re = _mm256_fmadd_pd(inp_re, w_re, _mm256_mul_pd(inp_im, w_im)); + let mul_im = _mm256_fnmadd_pd(inp_re, w_im, _mm256_mul_pd(inp_im, w_re)); + + // round to nearest integer and suppress exceptions + const ROUNDING: i32 = _MM_FROUND_NINT | _MM_FROUND_NO_EXC; + + // get the fractional part (centered around zero) by subtracting rounded value + let fract_re = _mm256_sub_pd(mul_re, _mm256_round_pd::(mul_re)); + let fract_im = _mm256_sub_pd(mul_im, _mm256_round_pd::(mul_im)); + // scale fractional part and round + let fract_re = _mm256_round_pd::(_mm256_mul_pd(scaling, fract_re)); + let fract_im = _mm256_round_pd::(_mm256_mul_pd(scaling, fract_im)); + + (fract_re, fract_im) +} + +/// See [`convert_add_backward_torus`]. +/// +/// # Safety +/// +/// - Same preconditions as [`convert_add_backward_torus`]. +/// - `is_x86_feature_detected!("fma")` must be true. +#[target_feature(enable = "avx,fma")] +pub unsafe fn convert_add_backward_torus_u32_fma( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + let n = out_re.len(); + debug_assert_eq!(n % 4, 0); + debug_assert_eq!(n, out_re.len()); + debug_assert_eq!(n, out_im.len()); + debug_assert_eq!(n, inp.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let normalization = _mm256_set1_pd(1.0 / n as f64); + let scaling = _mm256_set1_pd(2.0_f64.powi(u32::BITS as i32)); + let out_re = out_re.as_mut_ptr() as *mut u32; + let out_im = out_im.as_mut_ptr() as *mut u32; + let inp = inp.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 4 { + let i = i * 4; + + let (fract_re, fract_im) = + convert_torus_prologue_fma(normalization, w_re, i, w_im, inp, scaling); + + // convert f64 to i32 + let fract_re = _mm256_cvtpd_epi32(fract_re); + // convert f64 to i32 + let fract_im = _mm256_cvtpd_epi32(fract_im); + + // add to input and store + _mm_storeu_si128( + out_re.add(i) as _, + _mm_add_epi32(fract_re, _mm_loadu_si128(out_re.add(i) as _)), + ); + // add to input and store + _mm_storeu_si128( + out_im.add(i) as _, + _mm_add_epi32(fract_im, _mm_loadu_si128(out_im.add(i) as _)), + ); + } +} + +/// See [`convert_add_backward_torus`]. +/// +/// # Safety +/// +/// - Same preconditions as [`convert_add_backward_torus`]. +/// - `is_x86_feature_detected!("avx2")` must be true. +/// - `is_x86_feature_detected!("fma")` must be true. +#[target_feature(enable = "avx2,fma")] +pub unsafe fn convert_add_backward_torus_u64_fma( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + let n = out_re.len(); + debug_assert_eq!(n % 4, 0); + debug_assert_eq!(n, out_re.len()); + debug_assert_eq!(n, out_im.len()); + debug_assert_eq!(n, inp.len()); + debug_assert_eq!(n, twisties.re.len()); + debug_assert_eq!(n, twisties.im.len()); + + let normalization = _mm256_set1_pd(1.0 / n as f64); + let scaling = _mm256_set1_pd(2.0_f64.powi(u64::BITS as i32)); + let out_re = out_re.as_mut_ptr() as *mut u64; + let out_im = out_im.as_mut_ptr() as *mut u64; + let inp = inp.as_ptr(); + let w_re = twisties.re.as_ptr(); + let w_im = twisties.im.as_ptr(); + + for i in 0..n / 4 { + let i = i * 4; + + let (fract_re, fract_im) = + convert_torus_prologue_fma(normalization, w_re, i, w_im, inp, scaling); + + // convert f64 to i64 + let fract_re = mm256_cvtpd_epi64(fract_re); + // convert f64 to i64 + let fract_im = mm256_cvtpd_epi64(fract_im); + + // add to input and store + _mm256_storeu_si256( + out_re.add(i) as _, + _mm256_add_epi64(fract_re, _mm256_loadu_si256(out_re.add(i) as _)), + ); + // add to input and store + _mm256_storeu_si256( + out_im.add(i) as _, + _mm256_add_epi64(fract_im, _mm256_loadu_si256(out_im.add(i) as _)), + ); + } +} + +pub fn convert_forward_integer_u32( + out: &mut [MaybeUninit], + in_re: &[u32], + in_im: &[u32], + twisties: TwistiesView<'_>, +) { + // this is a function that returns a function pointer to the right simd function + #[allow(clippy::type_complexity)] + let ptr_fn = || -> unsafe fn(&mut [MaybeUninit], &[u32], &[u32], TwistiesView<'_>) { + #[cfg(feature = "backend_fft_nightly_avx512")] + if is_x86_feature_detected!("avx512f") { + return convert_forward_integer_u32_avx512f; + } + + if is_x86_feature_detected!("fma") { + convert_forward_integer_u32_fma + } else { + super::convert_forward_integer_scalar:: + } + }; + // we call it to get the function pointer to the right simd function + let ptr = ptr_fn(); + + // SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im` + // do not hold any uninitialized values since that is a precondition of calling this + // function + unsafe { ptr(out, in_re, in_im, twisties) } +} + +pub fn convert_forward_integer_u64( + out: &mut [MaybeUninit], + in_re: &[u64], + in_im: &[u64], + twisties: TwistiesView<'_>, +) { + #[allow(clippy::type_complexity)] + // this is a function that returns a function pointer to the right simd function + let ptr_fn = || -> unsafe fn(&mut [MaybeUninit], &[u64], &[u64], TwistiesView<'_>) { + #[cfg(feature = "backend_fft_nightly_avx512")] + if is_x86_feature_detected!("avx512f") & is_x86_feature_detected!("avx512dq") { + return convert_forward_integer_u64_avx512f_avx512dq; + } + + if is_x86_feature_detected!("avx2") & is_x86_feature_detected!("fma") { + convert_forward_integer_u64_avx2_fma + } else { + super::convert_forward_integer_scalar:: + } + }; + // we call it to get the function pointer to the right simd function + let ptr = ptr_fn(); + + // SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im` + // do not hold any uninitialized values since that is a precondition of calling this + // function + unsafe { ptr(out, in_re, in_im, twisties) } +} + +/// # Warning +/// +/// This function is actually unsafe, but can't be marked as such since we need it to implement +/// `Fn(...)`, as there's no equivalent `unsafe Fn(...)` trait. +/// +/// # Safety +/// +/// - `out_re` and `out_im` must not hold any uninitialized values. +// TODO: revert when backwards as torus intrisics are fixed +#[allow(dead_code)] +pub fn convert_add_backward_torus_u32( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + // this is a function that returns a function pointer to the right simd function + #[allow(clippy::type_complexity)] + let ptr_fn = || -> unsafe fn ( + &mut [MaybeUninit], + &mut [MaybeUninit], + &[c64], + TwistiesView<'_>, + ) { + #[cfg(feature = "backend_fft_nightly_avx512")] + if is_x86_feature_detected!("avx512f") { + return convert_add_backward_torus_u32_avx512f; + } + + if is_x86_feature_detected!("fma") { + convert_add_backward_torus_u32_fma + } else { + super::convert_add_backward_torus_scalar:: + } + }; + // we call it to get the function pointer to the right simd function + let ptr = ptr_fn(); + + // SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im` + // do not hold any uninitialized values since that is a precondition of calling this + // function + unsafe { ptr(out_re, out_im, inp, twisties) } +} + +/// # Warning +/// +/// This function is actually unsafe, but can't be marked as such since we need it to implement +/// `Fn(...)`, as there's no equivalent `unsafe Fn(...)` trait. +/// +/// # Safety +/// +/// - `out_re` and `out_im` must not hold any uninitialized values. +// TODO: revert when backwards as torus intrisics are fixed +#[allow(dead_code)] +pub fn convert_add_backward_torus_u64( + out_re: &mut [MaybeUninit], + out_im: &mut [MaybeUninit], + inp: &[c64], + twisties: TwistiesView<'_>, +) { + // this is a function that returns a function pointer to the right simd function + #[allow(clippy::type_complexity)] + let ptr_fn = || -> unsafe fn ( + &mut [MaybeUninit], + &mut [MaybeUninit], + &[c64], + TwistiesView<'_>, + ) { + #[cfg(feature = "backend_fft_nightly_avx512")] + if is_x86_feature_detected!("avx512f") { + return convert_add_backward_torus_u64_avx512f; + } + + if is_x86_feature_detected!("avx2") & is_x86_feature_detected!("fma") { + convert_add_backward_torus_u64_fma + } else { + super::convert_add_backward_torus_scalar:: + } + }; + // we call it to get the function pointer to the right simd function + let ptr = ptr_fn(); + + // SAFETY: the target x86 feature availability was checked, and `out_re` and `out_im` + // do not hold any uninitialized values since that is a precondition of calling this + // function + unsafe { ptr(out_re, out_im, inp, twisties) } +} + +#[cfg(test)] +mod tests { + use std::mem::transmute; + + use crate::core_crypto::backends::fft::private::as_mut_uninit; + use crate::core_crypto::backends::fft::private::math::fft::{ + convert_add_backward_torus_scalar, Twisties, + }; + + use super::*; + + #[test] + fn convert_f64_i64() { + if is_x86_feature_detected!("avx2") { + for v in [ + [ + -(2.0_f64.powi(63)), + -(2.0_f64.powi(63)), + -(2.0_f64.powi(63)), + -(2.0_f64.powi(63)), + ], + [0.0, -0.0, 37.1242161_f64, -37.1242161_f64], + [0.1, -0.1, 1.0, -1.0], + [0.9, -0.9, 2.0, -2.0], + [2.0, -2.0, 1e-310, -1e-310], + [ + 2.0_f64.powi(62), + -(2.0_f64.powi(62)), + 1.1 * 2.0_f64.powi(62), + 1.1 * -(2.0_f64.powi(62)), + ], + [ + 0.9 * 2.0_f64.powi(63), + -(0.9 * 2.0_f64.powi(63)), + 0.1 * 2.0_f64.powi(63), + 0.1 * -(2.0_f64.powi(63)), + ], + ] { + let target = v.map(|x| x as i64); + + let computed: [i64; 4] = unsafe { transmute(mm256_cvtpd_epi64(transmute(v))) }; + assert_eq!(target, computed); + } + } + } + + #[test] + fn add_backward_torus_fma() { + let n = 1024; + let z = c64 { + re: -34384521907.303154, + im: 19013399110.689323, + }; + let input = vec![z; n]; + let mut out_fma_re = vec![0_u64; n]; + let mut out_fma_im = vec![0_u64; n]; + let mut out_scalar_re = vec![0_u64; n]; + let mut out_scalar_im = vec![0_u64; n]; + let twisties = Twisties::new(n); + + unsafe { + convert_add_backward_torus_u64_fma( + as_mut_uninit(&mut out_fma_re), + as_mut_uninit(&mut out_fma_im), + &input, + twisties.as_view(), + ); + + convert_add_backward_torus_scalar( + as_mut_uninit(&mut out_scalar_re), + as_mut_uninit(&mut out_scalar_im), + &input, + twisties.as_view(), + ); + } + + for i in 0..n { + assert!(out_fma_re[i].abs_diff(out_scalar_re[i]) < (1 << 38)); + assert!(out_fma_im[i].abs_diff(out_scalar_im[i]) < (1 << 38)); + } + } + + #[cfg(feature = "backend_fft_nightly_avx512")] + #[test] + fn add_backward_torus_avx512() { + let n = 1024; + let z = c64 { + re: -34384521907.303154, + im: 19013399110.689323, + }; + let input = vec![z; n]; + let mut out_avx_re = vec![0_u64; n]; + let mut out_avx_im = vec![0_u64; n]; + let mut out_scalar_re = vec![0_u64; n]; + let mut out_scalar_im = vec![0_u64; n]; + let twisties = Twisties::new(n); + + unsafe { + convert_add_backward_torus_u64_avx512f( + as_mut_uninit(&mut out_avx_re), + as_mut_uninit(&mut out_avx_im), + &input, + twisties.as_view(), + ); + + convert_add_backward_torus_scalar( + as_mut_uninit(&mut out_scalar_re), + as_mut_uninit(&mut out_scalar_im), + &input, + twisties.as_view(), + ); + } + + for i in 0..n { + assert!(out_avx_re[i].abs_diff(out_scalar_re[i]) < (1 << 38)); + assert!(out_avx_im[i].abs_diff(out_scalar_im[i]) < (1 << 38)); + } + } +} diff --git a/tfhe/src/core_crypto/backends/fft/private/math/mod.rs b/tfhe/src/core_crypto/backends/fft/private/math/mod.rs new file mode 100644 index 000000000..730a210b9 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/math/mod.rs @@ -0,0 +1,3 @@ +pub mod decomposition; +pub mod fft; +pub mod polynomial; diff --git a/tfhe/src/core_crypto/backends/fft/private/math/polynomial.rs b/tfhe/src/core_crypto/backends/fft/private/math/polynomial.rs new file mode 100644 index 000000000..6480a962b --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/math/polynomial.rs @@ -0,0 +1,77 @@ +use super::super::as_mut_uninit; +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::math::tensor::Container; +use concrete_fft::c64; + +//-------------------------------------------------------------------------------- +// Structure definitions +//-------------------------------------------------------------------------------- + +/// Polynomial in the Fourier domain. +/// +/// # Note +/// +/// Polynomials in the Fourier domain have half the size of the corresponding polynomials in +/// the standard domain. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct FourierPolynomial { + pub data: C, +} + +pub type FourierPolynomialView<'a> = FourierPolynomial<&'a [c64]>; +pub type FourierPolynomialMutView<'a> = FourierPolynomial<&'a mut [c64]>; + +/// Polynomial in the standard domain, with possibly uninitialized coefficients. +/// +/// This is used for the Fourier transforms to avoid the cost of initializing the output buffer, +/// which can be non negligible. +pub type PolynomialUninitMutView<'a, Scalar> = Polynomial<&'a mut [core::mem::MaybeUninit]>; + +/// Polynomial in the Fourier domain, with possibly uninitialized coefficients. +/// +/// This is used for the Fourier transforms to avoid the cost of initializing the output buffer, +/// which can be non negligible. +/// +/// # Note +/// +/// Polynomials in the Fourier domain have half the size of the corresponding polynomials in +/// the standard domain. +pub type FourierPolynomialUninitMutView<'a> = + FourierPolynomial<&'a mut [core::mem::MaybeUninit]>; + +impl> FourierPolynomial { + pub fn as_view(&self) -> FourierPolynomialView<'_> { + FourierPolynomial { + data: self.data.as_ref(), + } + } + + pub fn as_mut_view(&mut self) -> FourierPolynomialMutView<'_> + where + C: AsMut<[c64]>, + { + FourierPolynomial { + data: self.data.as_mut(), + } + } +} + +impl<'a, Scalar> Polynomial<&'a mut [Scalar]> { + /// # Safety + /// + /// No uninitialized values must be written into the output buffer when the borrow ends + pub unsafe fn into_uninit(self) -> PolynomialUninitMutView<'a, Scalar> { + PolynomialUninitMutView::from_container(as_mut_uninit(self.tensor.into_container())) + } +} + +impl<'a> FourierPolynomialMutView<'a> { + /// # Safety + /// + /// No uninitialized values must be written into the output buffer when the borrow ends + pub unsafe fn into_uninit(self) -> FourierPolynomialUninitMutView<'a> { + FourierPolynomialUninitMutView { + data: as_mut_uninit(self.data), + } + } +} diff --git a/tfhe/src/core_crypto/backends/fft/private/mod.rs b/tfhe/src/core_crypto/backends/fft/private/mod.rs new file mode 100644 index 000000000..808d08593 --- /dev/null +++ b/tfhe/src/core_crypto/backends/fft/private/mod.rs @@ -0,0 +1,33 @@ +#![allow(deprecated)] // For the time being + +pub use concrete_fft::c64; +use core::mem::MaybeUninit; + +pub mod crypto; +pub mod math; + +/// Convert a mutable slice reference to an uninitialized mutable slice reference. +/// +/// # Safety +/// +/// No uninitialized values must be written into the output slice by the time the borrow ends +#[inline] +pub unsafe fn as_mut_uninit(slice: &mut [T]) -> &mut [MaybeUninit] { + let len = slice.len(); + let ptr = slice.as_mut_ptr(); + // SAFETY: T and MaybeUninit have the same layout + core::slice::from_raw_parts_mut(ptr as *mut _, len) +} + +/// Convert an uninitialized mutable slice reference to an initialized mutable slice reference. +/// +/// # Safety +/// +/// All the elements of the input slice must be initialized and in a valid state. +#[inline] +pub unsafe fn assume_init_mut(slice: &mut [MaybeUninit]) -> &mut [T] { + let len = slice.len(); + let ptr = slice.as_mut_ptr(); + // SAFETY: T and MaybeUninit have the same layout + core::slice::from_raw_parts_mut(ptr as *mut _, len) +} diff --git a/tfhe/src/core_crypto/backends/mod.rs b/tfhe/src/core_crypto/backends/mod.rs new file mode 100644 index 000000000..e60ecd9be --- /dev/null +++ b/tfhe/src/core_crypto/backends/mod.rs @@ -0,0 +1,10 @@ +//! A module containing various backends implementing various FHE cryptographic primitives. + +#[cfg(feature = "backend_default")] +pub mod default; + +#[cfg(feature = "backend_fft")] +pub mod fft; + +#[cfg(feature = "backend_cuda")] +pub mod cuda; diff --git a/tfhe/src/core_crypto/commons/crypto/bootstrap/mod.rs b/tfhe/src/core_crypto/commons/crypto/bootstrap/mod.rs new file mode 100644 index 000000000..080c249e8 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/bootstrap/mod.rs @@ -0,0 +1,13 @@ +//! Bootstrapping keys. +//! +//! The bootstrapping operation allows to reduce the level of noise in an LWE ciphertext, while +//! evaluating an univariate function. + +mod seeded_standard; +mod standard; + +pub use seeded_standard::StandardSeededBootstrapKey; +pub use standard::StandardBootstrapKey; + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/core_crypto/commons/crypto/bootstrap/seeded_standard/mod.rs b/tfhe/src/core_crypto/commons/crypto/bootstrap/seeded_standard/mod.rs new file mode 100644 index 000000000..927a2b58c --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/bootstrap/seeded_standard/mod.rs @@ -0,0 +1,719 @@ +use super::StandardBootstrapKey; +use crate::core_crypto::commons::crypto::encoding::Plaintext; +use crate::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; +use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Seeder, Uniform, +}; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, ck_dim_eq, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::commons::utils::{zip, zip_args}; +use crate::core_crypto::prelude::{ + BinaryKeyKind, DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweSize, + LweDimension, PolynomialSize, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A seeded bootstrapping key represented in the standard domain. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StandardSeededBootstrapKey { + tensor: Tensor, + poly_size: PolynomialSize, + glwe_size: GlweSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + compression_seed: CompressionSeed, +} + +tensor_traits!(StandardSeededBootstrapKey); + +impl StandardSeededBootstrapKey> { + /// Allocates a new seeded bootstrapping key in the standard domain whose polynomials + /// coefficients are all `value`. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// assert_eq!(bsk.polynomial_size(), PolynomialSize(9)); + /// assert_eq!(bsk.glwe_size(), GlweSize(7)); + /// assert_eq!(bsk.level_count(), DecompositionLevelCount(3)); + /// assert_eq!(bsk.base_log(), DecompositionBaseLog(5)); + /// assert_eq!(bsk.key_size(), LweDimension(4)); + /// assert_eq!(bsk.compression_seed(), CompressionSeed { seed: Seed(42) }); + /// ``` + pub fn allocate( + glwe_size: GlweSize, + poly_size: PolynomialSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + key_size: LweDimension, + compression_seed: CompressionSeed, + ) -> Self + where + Scalar: UnsignedTorus, + { + StandardSeededBootstrapKey { + tensor: Tensor::from_container(vec![ + Scalar::ZERO; + key_size.0 + * decomp_level.0 + * glwe_size.0 + * poly_size.0 + ]), + decomp_level, + decomp_base_log, + glwe_size, + poly_size, + compression_seed, + } + } +} + +impl StandardSeededBootstrapKey { + /// Creates a seeded bootstrapping key from an existing container of values. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// + /// let key_size = LweDimension(4); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(5); + /// let glwe_size = GlweSize(7); + /// let poly_size = PolynomialSize(9); + /// + /// let container = vec![0u32; key_size.0 * decomp_level.0 * glwe_size.0 * poly_size.0]; + /// + /// let bsk = StandardSeededBootstrapKey::from_container( + /// container, + /// glwe_size, + /// poly_size, + /// decomp_level, + /// decomp_base_log, + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// assert_eq!(bsk.polynomial_size(), PolynomialSize(9)); + /// assert_eq!(bsk.glwe_size(), GlweSize(7)); + /// assert_eq!(bsk.level_count(), DecompositionLevelCount(3)); + /// assert_eq!(bsk.base_log(), DecompositionBaseLog(5)); + /// assert_eq!(bsk.key_size(), LweDimension(4)); + /// assert_eq!(bsk.compression_seed(), CompressionSeed { seed: Seed(42) }); + /// ``` + pub fn from_container( + cont: Cont, + glwe_size: GlweSize, + poly_size: PolynomialSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + compression_seed: CompressionSeed, + ) -> Self + where + Cont: AsRefSlice, + Coef: UnsignedTorus, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => + decomp_level.0, + glwe_size.0, + poly_size.0 + ); + StandardSeededBootstrapKey { + tensor, + glwe_size, + poly_size, + decomp_level, + decomp_base_log, + compression_seed, + } + } + + /// Generate a new seeded bootstrap key from the input parameters, and fills the current + /// container with it. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut seeder = UnixSeeder::new(0); + /// + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(9)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let mut bsk = StandardSeededBootstrapKey::>::allocate( + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + /// let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// bsk.fill_with_new_key::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &lwe_sk, + /// &glwe_sk, + /// LogStandardDev::from_log_standard_dev(-15.), + /// &mut seeder, + /// ); + /// ``` + pub fn fill_with_new_key( + &mut self, + lwe_secret_key: &LweSecretKey, + glwe_secret_key: &GlweSecretKey, + noise_parameters: NoiseParameters, + seeder: &mut NoiseSeeder, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Scalar: UnsignedTorus, + NoiseParameters: DispersionParameter, + NoiseSeeder: Seeder, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(self.key_size().0 => lwe_secret_key.key_size().0); + self.as_mut_tensor() + .fill_with_element(::ZERO); + + let mut generator = + EncryptionRandomGenerator::::new(self.compression_seed().seed, seeder); + + let gen_iter = generator + .fork_bsk_to_ggsw::( + lwe_secret_key.key_size(), + self.decomp_level, + glwe_secret_key.key_size().to_glwe_size(), + self.poly_size, + ) + .unwrap(); + + for zip_args!(mut ggsw, sk_scalar, mut generator) in zip!( + self.ggsw_iter_mut(), + lwe_secret_key.as_tensor().iter(), + gen_iter + ) { + let encoded = Plaintext(*sk_scalar); + glwe_secret_key.encrypt_constant_seeded_ggsw_with_existing_generator( + &mut ggsw, + &encoded, + noise_parameters, + &mut generator, + ); + } + } + + /// Generate a new bootstrap key from the input parameters, and fills the current container + /// with it, using all the available threads. + /// + /// # Note + /// + /// This method uses _rayon_ internally, and is hidden behind the "__commons_parallel" feature + /// gate. + /// + /// # Example + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut seeder = UnixSeeder::new(0); + /// + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(9)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let mut bsk = StandardSeededBootstrapKey::>::allocate( + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + /// let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// bsk.par_fill_with_new_key::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &lwe_sk, + /// &glwe_sk, + /// LogStandardDev::from_log_standard_dev(-15.), + /// &mut seeder, + /// ); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_fill_with_new_key( + &mut self, + lwe_secret_key: &LweSecretKey, + glwe_secret_key: &GlweSecretKey, + noise_parameters: NoiseParameters, + seeder: &mut NoiseSeeder, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Scalar: UnsignedTorus + Sync + Send, + GlweCont: Sync + Send, + Cont: Sync + Send, + NoiseParameters: DispersionParameter + Sync + Send, + NoiseSeeder: Seeder + Sync + Send, + Gen: ParallelByteRandomGenerator, + { + ck_dim_eq!(self.key_size().0 => lwe_secret_key.key_size().0); + self.as_mut_tensor() + .fill_with_element(::ZERO); + + let mut generator = + EncryptionRandomGenerator::::new(self.compression_seed().seed, seeder); + + let gen_iter = generator + .par_fork_bsk_to_ggsw::( + lwe_secret_key.key_size(), + self.decomp_level, + glwe_secret_key.key_size().to_glwe_size(), + self.poly_size, + ) + .unwrap(); + + self.par_ggsw_iter_mut() + .zip(lwe_secret_key.as_tensor().par_iter()) + .zip(gen_iter) + .for_each(|((mut ggsw, sk_scalar), mut generator)| { + let encoded = Plaintext(*sk_scalar); + glwe_secret_key.par_encrypt_constant_seeded_ggsw_with_existing_generator( + &mut ggsw, + &encoded, + noise_parameters, + &mut generator, + ); + }); + } + + /// Returns the size of the polynomials used in the bootstrapping key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// assert_eq!(bsk.polynomial_size(), PolynomialSize(9)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns the size of the GLWE ciphertexts used in the bootstrapping key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// assert_eq!(bsk.glwe_size(), GlweSize(7)); + /// ``` + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + /// Returns the number of levels used to decompose the key bits. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// assert_eq!(bsk.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + self.decomp_level + } + + /// Returns the logarithm of the base used to decompose the key bits. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// assert_eq!(bsk.base_log(), DecompositionBaseLog(5)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Returns the size of the LWE encrypted key. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// assert_eq!(bsk.key_size(), LweDimension(4)); + /// ``` + pub fn key_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => + self.poly_size.0, + self.glwe_size.0, + self.decomp_level.0 + ); + LweDimension( + self.as_tensor().len() / (self.glwe_size.0 * self.poly_size.0 * self.decomp_level.0), + ) + } + + /// Returns the compression seed used for the seeded entity. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// assert_eq!(bsk.compression_seed(), CompressionSeed { seed: Seed(42) }); + /// ``` + pub fn compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// Returns an iterator over the borrowed seeded GGSW ciphertext composing the key. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// for ggsw in bsk.ggsw_iter() { + /// assert_eq!(ggsw.polynomial_size(), PolynomialSize(9)); + /// assert_eq!(ggsw.glwe_size(), GlweSize(7)); + /// assert_eq!(ggsw.decomposition_level_count(), DecompositionLevelCount(3)); + /// } + /// assert_eq!(bsk.ggsw_iter().count(), 4); + /// ``` + pub fn ggsw_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + let chunks_size = self.glwe_size.0 * self.poly_size.0 * self.decomp_level.0; + let glwe_size = self.glwe_size; + let poly_size = self.poly_size; + let base_log = self.decomp_base_log; + let compression_seed = self.compression_seed; + self.as_tensor() + .subtensor_iter(chunks_size) + .map(move |tensor| { + StandardGgswSeededCiphertext::from_container( + tensor.into_container(), + poly_size, + glwe_size, + base_log, + compression_seed, + ) + }) + } + + /// Returns a parallel iterator over the mutably borrowed seeded GGSW ciphertext composing the + /// key. + /// + /// # Notes + /// + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let mut bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// bsk.par_ggsw_iter_mut().for_each(|mut ggsw| { + /// ggsw.as_mut_tensor().fill_with_element(1); + /// }); + /// assert!(bsk.as_tensor().iter().all(|a| *a == 1)); + /// assert_eq!(bsk.ggsw_iter_mut().count(), 4); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_ggsw_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator< + Item = StandardGgswSeededCiphertext<&mut [::Element]>, + > + where + Self: AsMutTensor, + ::Element: Sync + Send, + Cont: Sync + Send, + { + let chunks_size = self.glwe_size.0 * self.poly_size.0 * self.decomp_level.0; + let glwe_size = self.glwe_size; + let poly_size = self.poly_size; + let base_log = self.decomp_base_log; + let compression_seed = self.compression_seed; + + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .map(move |tensor| { + StandardGgswSeededCiphertext::from_container( + tensor.into_container(), + poly_size, + glwe_size, + base_log, + compression_seed, + ) + }) + } + + /// Returns an iterator over the mutably borrowed seeded GGSW ciphertext composing the key. + /// + /// # Example + /// # Example + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardSeededBootstrapKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let mut bsk = StandardSeededBootstrapKey::>::allocate( + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// for mut ggsw in bsk.ggsw_iter_mut() { + /// ggsw.as_mut_tensor().fill_with_element(1); + /// } + /// assert!(bsk.as_tensor().iter().all(|a| *a == 1)); + /// assert_eq!(bsk.ggsw_iter_mut().count(), 4); + /// ``` + pub fn ggsw_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.glwe_size.0 * self.poly_size.0 * self.decomp_level.0; + let glwe_size = self.glwe_size; + let poly_size = self.poly_size; + let base_log = self.decomp_base_log; + let compression_seed = self.compression_seed; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |tensor| { + StandardGgswSeededCiphertext::from_container( + tensor.into_container(), + poly_size, + glwe_size, + base_log, + compression_seed, + ) + }) + } + + /// Returns the key as a full fledged StandardBootstrapKey + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::bootstrap::{ + /// StandardBootstrapKey, StandardSeededBootstrapKey, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// LweDimension, PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut seeder = UnixSeeder::new(0); + /// + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(9)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let mut seeded_bsk = StandardSeededBootstrapKey::>::allocate( + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + /// let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// seeded_bsk.fill_with_new_key::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &lwe_sk, + /// &glwe_sk, + /// LogStandardDev::from_log_standard_dev(-15.), + /// &mut seeder, + /// ); + /// + /// // expansion of the bootstrapping key + /// let mut coef_bsk_expanded = StandardBootstrapKey::allocate( + /// 0u32, + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// ); + /// seeded_bsk.expand_into::<_, _, SoftwareRandomGenerator>(&mut coef_bsk_expanded); + /// ``` + pub fn expand_into(self, output: &mut StandardBootstrapKey) + where + Scalar: Copy + RandomGenerable + Numeric, + StandardBootstrapKey: AsMutTensor, + Self: AsRefTensor, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed().seed); + + output + .ggsw_iter_mut() + .zip(self.ggsw_iter()) + .for_each(|(mut ggsw_out, ggsw_in)| { + ggsw_in.expand_into_with_existing_generator(&mut ggsw_out, &mut generator); + }); + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/bootstrap/standard/mod.rs b/tfhe/src/core_crypto/commons/crypto/bootstrap/standard/mod.rs new file mode 100644 index 000000000..545cb7f26 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/bootstrap/standard/mod.rs @@ -0,0 +1,749 @@ +use crate::core_crypto::commons::crypto::encoding::Plaintext; +use crate::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; +use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::math::random::ByteRandomGenerator; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, ck_dim_eq, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Container, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::commons::utils::{zip, zip_args}; +use crate::core_crypto::prelude::{ + BinaryKeyKind, DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweSize, + LweDimension, PolynomialSize, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A bootstrapping key represented in the standard domain. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct StandardBootstrapKey { + pub(crate) tensor: Tensor, + poly_size: PolynomialSize, + rlwe_size: GlweSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, +} + +tensor_traits!(StandardBootstrapKey); + +impl StandardBootstrapKey> { + /// Allocates a new bootstrapping key in the standard domain whose polynomials coefficients are + /// all `value`. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// assert_eq!(bsk.polynomial_size(), PolynomialSize(9)); + /// assert_eq!(bsk.glwe_size(), GlweSize(7)); + /// assert_eq!(bsk.level_count(), DecompositionLevelCount(3)); + /// assert_eq!(bsk.base_log(), DecompositionBaseLog(5)); + /// assert_eq!(bsk.key_size(), LweDimension(4)); + /// ``` + pub fn allocate( + value: Scalar, + rlwe_size: GlweSize, + poly_size: PolynomialSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + key_size: LweDimension, + ) -> StandardBootstrapKey> + where + Scalar: UnsignedTorus, + { + StandardBootstrapKey { + tensor: Tensor::from_container(vec![ + value; + key_size.0 + * decomp_level.0 + * rlwe_size.0 + * rlwe_size.0 + * poly_size.0 + ]), + decomp_level, + decomp_base_log, + rlwe_size, + poly_size, + } + } +} + +impl StandardBootstrapKey { + /// Creates a bootstrapping key from an existing container of values. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let vector = vec![0u32; 10 * 5 * 4 * 4 * 15]; + /// let bsk = StandardBootstrapKey::from_container( + /// vector.as_slice(), + /// GlweSize(4), + /// PolynomialSize(10), + /// DecompositionLevelCount(5), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(bsk.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(bsk.glwe_size(), GlweSize(4)); + /// assert_eq!(bsk.level_count(), DecompositionLevelCount(5)); + /// assert_eq!(bsk.base_log(), DecompositionBaseLog(4)); + /// assert_eq!(bsk.key_size(), LweDimension(15)); + /// ``` + pub fn from_container( + cont: Cont, + glwe_size: GlweSize, + poly_size: PolynomialSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + ) -> StandardBootstrapKey + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => + decomp_level.0, + glwe_size.0 * glwe_size.0, + poly_size.0 + ); + StandardBootstrapKey { + tensor, + rlwe_size: glwe_size, + poly_size, + decomp_level, + decomp_base_log, + } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> StandardBootstrapKey<&'_ [Cont::Element]> + where + Cont: Container, + { + StandardBootstrapKey { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + rlwe_size: self.rlwe_size, + poly_size: self.poly_size, + decomp_level: self.decomp_level, + decomp_base_log: self.decomp_base_log, + } + } + + pub fn as_mut_view(&mut self) -> StandardBootstrapKey<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + StandardBootstrapKey { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + rlwe_size: self.rlwe_size, + poly_size: self.poly_size, + decomp_level: self.decomp_level, + decomp_base_log: self.decomp_base_log, + } + } + + /// Generate a new bootstrap key from the input parameters, and fills the current container + /// with it. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(9)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let mut bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// ); + /// let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + /// let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// bsk.fill_with_new_key( + /// &lwe_sk, + /// &glwe_sk, + /// LogStandardDev::from_log_standard_dev(-15.), + /// &mut encryption_generator, + /// ); + /// ``` + pub fn fill_with_new_key( + &mut self, + lwe_secret_key: &LweSecretKey, + glwe_secret_key: &GlweSecretKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(self.key_size().0 => lwe_secret_key.key_size().0); + self.as_mut_tensor() + .fill_with_element(::ZERO); + + let gen_iter = generator + .fork_bsk_to_ggsw::( + lwe_secret_key.key_size(), + self.decomp_level, + glwe_secret_key.key_size().to_glwe_size(), + self.poly_size, + ) + .unwrap(); + + for zip_args!(mut rgsw, sk_scalar, mut generator) in zip!( + self.ggsw_iter_mut(), + lwe_secret_key.as_tensor().iter(), + gen_iter + ) { + let encoded = Plaintext(*sk_scalar); + glwe_secret_key.encrypt_constant_ggsw( + &mut rgsw, + &encoded, + noise_parameters, + &mut generator, + ); + } + } + + /// Generate a new bootstrap key from the input parameters, and fills the current container + /// with it, using all the available threads. + /// + /// # Note + /// + /// This method uses _rayon_ internally, and is hidden behind the "__commons_parallel" feature + /// gate. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// PolynomialSize, + /// }; + /// + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(9)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let mut bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// ); + /// let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + /// let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// let mut secret_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// bsk.par_fill_with_new_key( + /// &lwe_sk, + /// &glwe_sk, + /// LogStandardDev::from_log_standard_dev(-15.), + /// &mut secret_generator, + /// ); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_fill_with_new_key( + &mut self, + lwe_secret_key: &LweSecretKey, + glwe_secret_key: &GlweSecretKey, + noise_parameters: impl DispersionParameter + Sync + Send, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Scalar: UnsignedTorus + Sync + Send, + RlweCont: Sync, + Gen: ParallelByteRandomGenerator, + { + ck_dim_eq!(self.key_size().0 => lwe_secret_key.key_size().0); + self.as_mut_tensor() + .fill_with_element(::ZERO); + let gen_iter = generator + .par_fork_bsk_to_ggsw::( + lwe_secret_key.key_size(), + self.decomp_level, + glwe_secret_key.key_size().to_glwe_size(), + self.poly_size, + ) + .unwrap(); + self.par_ggsw_iter_mut() + .zip(lwe_secret_key.as_tensor().par_iter()) + .zip(gen_iter) + .for_each(|((mut rgsw, sk_scalar), mut generator)| { + let encoded = Plaintext(*sk_scalar); + glwe_secret_key.par_encrypt_constant_ggsw( + &mut rgsw, + &encoded, + noise_parameters, + &mut generator, + ); + }); + } + + /// Generate a new bootstrap key from the input parameters, and fills the current container + /// with it. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// PolynomialSize, + /// }; + /// + /// let (lwe_dim, glwe_dim, poly_size) = (LweDimension(4), GlweDimension(6), PolynomialSize(9)); + /// let (dec_lc, dec_bl) = (DecompositionLevelCount(3), DecompositionBaseLog(5)); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let mut bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// glwe_dim.to_glwe_size(), + /// poly_size, + /// dec_lc, + /// dec_bl, + /// lwe_dim, + /// ); + /// let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + /// let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// bsk.fill_with_new_trivial_key( + /// &lwe_sk, + /// &glwe_sk, + /// LogStandardDev::from_log_standard_dev(-15.), + /// &mut encryption_generator, + /// ); + /// ``` + pub fn fill_with_new_trivial_key( + &mut self, + lwe_secret_key: &LweSecretKey, + rlwe_secret_key: &GlweSecretKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(self.key_size().0 => lwe_secret_key.key_size().0); + for (mut rgsw, sk_scalar) in self.ggsw_iter_mut().zip(lwe_secret_key.as_tensor().iter()) { + let encoded = Plaintext(*sk_scalar); + rlwe_secret_key.trivial_encrypt_constant_ggsw( + &mut rgsw, + &encoded, + noise_parameters, + generator, + ); + } + } + + /// Returns the size of the polynomials used in the bootstrapping key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// assert_eq!(bsk.polynomial_size(), PolynomialSize(9)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns the size of the GLWE ciphertexts used in the bootstrapping key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// assert_eq!(bsk.glwe_size(), GlweSize(7)); + /// ``` + pub fn glwe_size(&self) -> GlweSize { + self.rlwe_size + } + + /// Returns the number of levels used to decompose the key bits. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// assert_eq!(bsk.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + self.decomp_level + } + + /// Returns the logarithm of the base used to decompose the key bits. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// assert_eq!(bsk.base_log(), DecompositionBaseLog(5)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Returns the size of the LWE encrypted key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// assert_eq!(bsk.key_size(), LweDimension(4)); + /// ``` + pub fn key_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => + self.poly_size.0, + self.rlwe_size.0 * self.rlwe_size.0, + self.decomp_level.0 + ); + LweDimension( + self.as_tensor().len() + / (self.rlwe_size.0 * self.rlwe_size.0 * self.poly_size.0 * self.decomp_level.0), + ) + } + + /// Returns an iterator over the borrowed GGSW ciphertext composing the key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// for ggsw in bsk.ggsw_iter() { + /// assert_eq!(ggsw.polynomial_size(), PolynomialSize(9)); + /// assert_eq!(ggsw.glwe_size(), GlweSize(7)); + /// assert_eq!(ggsw.decomposition_level_count(), DecompositionLevelCount(3)); + /// } + /// assert_eq!(bsk.ggsw_iter().count(), 4); + /// ``` + pub fn ggsw_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + let chunks_size = + self.rlwe_size.0 * self.rlwe_size.0 * self.poly_size.0 * self.decomp_level.0; + let rlwe_size = self.rlwe_size; + let poly_size = self.poly_size; + let base_log = self.decomp_base_log; + self.as_tensor() + .subtensor_iter(chunks_size) + .map(move |tensor| { + StandardGgswCiphertext::from_container( + tensor.into_container(), + rlwe_size, + poly_size, + base_log, + ) + }) + } + + /// Returns a parallel iterator over the mutably borrowed GGSW ciphertext composing the + /// key. + /// + /// # Notes + /// + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example + /// + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let mut bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// bsk.par_ggsw_iter_mut().for_each(|mut ggsw| { + /// ggsw.as_mut_tensor().fill_with_element(0); + /// }); + /// assert!(bsk.as_tensor().iter().all(|a| *a == 0)); + /// assert_eq!(bsk.ggsw_iter_mut().count(), 4); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_ggsw_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Sync + Send, + { + let chunks_size = + self.rlwe_size.0 * self.rlwe_size.0 * self.poly_size.0 * self.decomp_level.0; + let rlwe_size = self.rlwe_size; + let poly_size = self.poly_size; + let base_log = self.decomp_base_log; + + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .map(move |tensor| { + StandardGgswCiphertext::from_container( + tensor.into_container(), + rlwe_size, + poly_size, + base_log, + ) + }) + } + + /// Returns an iterator over the mutably borrowed GGSW ciphertext composing the key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let mut bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(9), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// for mut ggsw in bsk.ggsw_iter_mut() { + /// ggsw.as_mut_tensor().fill_with_element(0); + /// } + /// assert!(bsk.as_tensor().iter().all(|a| *a == 0)); + /// assert_eq!(bsk.ggsw_iter_mut().count(), 4); + /// ``` + pub fn ggsw_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = + self.rlwe_size.0 * self.rlwe_size.0 * self.poly_size.0 * self.decomp_level.0; + let rlwe_size = self.rlwe_size; + let poly_size = self.poly_size; + let base_log = self.decomp_base_log; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |tensor| { + StandardGgswCiphertext::from_container( + tensor.into_container(), + rlwe_size, + poly_size, + base_log, + ) + }) + } + + /// Returns an iterator over borrowed polynomials composing the key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(256), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// for poly in bsk.poly_iter() { + /// assert_eq!(poly.polynomial_size(), PolynomialSize(256)); + /// } + /// assert_eq!(bsk.poly_iter().count(), 7 * 7 * 3 * 4) + /// ``` + pub fn poly_iter(&self) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + ::Element: UnsignedTorus, + { + let poly_size = self.poly_size.0; + self.as_tensor() + .subtensor_iter(poly_size) + .map(|chunk| Polynomial::from_container(chunk.into_container())) + } + + /// Returns an iterator over mutably borrowed polynomials composing the key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + /// }; + /// let mut bsk = StandardBootstrapKey::allocate( + /// 9u32, + /// GlweSize(7), + /// PolynomialSize(256), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(5), + /// LweDimension(4), + /// ); + /// for mut poly in bsk.poly_iter_mut() { + /// poly.as_mut_tensor().fill_with_element(0u32); + /// } + /// assert!(bsk.as_tensor().iter().all(|a| *a == 0)); + /// assert_eq!(bsk.poly_iter_mut().count(), 7 * 7 * 3 * 4) + /// ``` + pub fn poly_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + ::Element: UnsignedTorus, + { + let poly_size = self.poly_size.0; + self.as_mut_tensor() + .subtensor_iter_mut(poly_size) + .map(|chunk| Polynomial::from_container(chunk.into_container())) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/bootstrap/tests.rs b/tfhe/src/core_crypto/commons/crypto/bootstrap/tests.rs new file mode 100644 index 000000000..7029b718d --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/bootstrap/tests.rs @@ -0,0 +1,295 @@ +use crate::core_crypto::commons::crypto::bootstrap::{ + StandardBootstrapKey, StandardSeededBootstrapKey, +}; +use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, +}; +use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; +use crate::core_crypto::commons::math::random::CompressionSeed; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::test_tools::new_secret_random_generator; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + StandardDev, +}; +use concrete_csprng::generators::SoftwareRandomGenerator; +use concrete_csprng::seeders::Seed; + +fn test_bsk_seeded_gen_equivalence() { + for _ in 0..10 { + let lwe_dim = + LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let glwe_dim = + GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let poly_size = + PolynomialSize(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let level = DecompositionLevelCount( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let base_log = DecompositionBaseLog( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + let deterministic_seeder_seed = + Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + + let compression_seed = CompressionSeed { seed: mask_seed }; + + let mut secret_generator = new_secret_random_generator(); + let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + let glwe_sk = GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + + let mut bsk = StandardBootstrapKey::allocate( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + ); + + let mut encryption_generator = EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new(deterministic_seeder_seed), + ); + + bsk.fill_with_new_key( + &lwe_sk, + &glwe_sk, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + let mut seeded_bsk = StandardSeededBootstrapKey::allocate( + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + compression_seed, + ); + + seeded_bsk.fill_with_new_key::<_, _, _, _, _, SoftwareRandomGenerator>( + &lwe_sk, + &glwe_sk, + StandardDev::from_standard_dev(10.), + &mut DeterministicSeeder::::new(deterministic_seeder_seed), + ); + + let mut expanded_bsk = StandardBootstrapKey::allocate( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + ); + + seeded_bsk.expand_into::<_, _, SoftwareRandomGenerator>(&mut expanded_bsk); + + assert_eq!(bsk, expanded_bsk); + } +} + +#[test] +fn test_bsk_seeded_gen_equivalence_u32() { + test_bsk_seeded_gen_equivalence::() +} + +#[test] +fn test_bsk_seeded_gen_equivalence_u64() { + test_bsk_seeded_gen_equivalence::() +} + +#[cfg(all(test, feature = "__commons_parallel"))] +mod parallel { + use crate::core_crypto::commons::crypto::bootstrap::{ + StandardBootstrapKey, StandardSeededBootstrapKey, + }; + use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, + }; + use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + use crate::core_crypto::commons::math::random::CompressionSeed; + use crate::core_crypto::commons::math::torus::UnsignedTorus; + use crate::core_crypto::commons::test_tools::{new_secret_random_generator, UnsafeRandSeeder}; + use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + StandardDev, + }; + use concrete_csprng::generators::SoftwareRandomGenerator; + use concrete_csprng::seeders::Seed; + + fn test_bsk_gen_equivalence() { + for _ in 0..10 { + let lwe_dim = + LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let glwe_dim = + GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let poly_size = PolynomialSize( + crate::core_crypto::commons::test_tools::random_usize_between(5..10), + ); + let level = DecompositionLevelCount( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let base_log = DecompositionBaseLog( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let mask_seed = crate::core_crypto::commons::test_tools::any_usize() as u128; + let noise_seed = crate::core_crypto::commons::test_tools::any_usize() as u128; + + let mut secret_generator = new_secret_random_generator(); + let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + let glwe_sk = + GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + + let mut mono_bsk = StandardBootstrapKey::allocate( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + ); + let mut encryption_generator = + EncryptionRandomGenerator::::new( + Seed(mask_seed), + &mut UnsafeRandSeeder, + ); + encryption_generator.seed_noise_generator(Seed(noise_seed)); + mono_bsk.fill_with_new_key( + &lwe_sk, + &glwe_sk, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + let mut multi_bsk = StandardBootstrapKey::allocate( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + ); + let mut encryption_generator = + EncryptionRandomGenerator::::new( + Seed(mask_seed), + &mut UnsafeRandSeeder, + ); + encryption_generator.seed_noise_generator(Seed(noise_seed)); + multi_bsk.par_fill_with_new_key( + &lwe_sk, + &glwe_sk, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + assert_eq!(mono_bsk, multi_bsk); + } + } + + fn test_bsk_par_seeded_gen_equivalence() { + for _ in 0..10 { + let lwe_dim = + LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let glwe_dim = + GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let poly_size = PolynomialSize( + crate::core_crypto::commons::test_tools::random_usize_between(5..10), + ); + let level = DecompositionLevelCount( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let base_log = DecompositionBaseLog( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + let deterministic_seeder_seed = + Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + + let compression_seed = CompressionSeed { seed: mask_seed }; + + let mut secret_generator = new_secret_random_generator(); + let lwe_sk = LweSecretKey::generate_binary(lwe_dim, &mut secret_generator); + let glwe_sk = + GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + + let mut bsk = StandardBootstrapKey::allocate( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + ); + + let mut encryption_generator = + EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new( + deterministic_seeder_seed, + ), + ); + + // To mitigate current issues with forking SoftwareRandomGenerator generators + // We know parallel and sequential generation of bsk are the same thanks to the + // test_bsk_gen_equivalence based tests + bsk.fill_with_new_key( + &lwe_sk, + &glwe_sk, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + let mut par_seeded_bsk = StandardSeededBootstrapKey::allocate( + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + compression_seed, + ); + + par_seeded_bsk.par_fill_with_new_key::<_, _, _, _, _, SoftwareRandomGenerator>( + &lwe_sk, + &glwe_sk, + StandardDev::from_standard_dev(10.), + &mut DeterministicSeeder::::new(deterministic_seeder_seed), + ); + + let mut expanded_bsk = StandardBootstrapKey::allocate( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + level, + base_log, + lwe_dim, + ); + + par_seeded_bsk.expand_into::<_, _, SoftwareRandomGenerator>(&mut expanded_bsk); + + assert_eq!(bsk, expanded_bsk); + } + } + + #[test] + fn test_bsk_gen_equivalence_u32() { + test_bsk_gen_equivalence::() + } + + #[test] + fn test_bsk_gen_equivalence_u64() { + test_bsk_gen_equivalence::() + } + + #[test] + fn test_bsk_par_seeded_gen_equivalence_u32() { + test_bsk_par_seeded_gen_equivalence::() + } + + #[test] + fn test_bsk_par_seeded_gen_equivalence_u64() { + test_bsk_par_seeded_gen_equivalence::() + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/encoding/cleartext.rs b/tfhe/src/core_crypto/commons/crypto/encoding/cleartext.rs new file mode 100644 index 000000000..72f40bf7d --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/encoding/cleartext.rs @@ -0,0 +1,181 @@ +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::CleartextCount; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A clear, non-encoded, value. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct Cleartext(pub T); + +/// A list of clear, non-encoded, values. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CleartextList { + tensor: Tensor, +} + +tensor_traits!(CleartextList); + +impl CleartextList> +where + Scalar: Copy, +{ + /// Allocates a new list of cleartexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::CleartextCount; + /// let clear_list = CleartextList::allocate(1 as u8, CleartextCount(100)); + /// assert_eq!(clear_list.count(), CleartextCount(100)); + /// ``` + pub fn allocate(value: Scalar, count: CleartextCount) -> CleartextList> { + CleartextList::from_container(vec![value; count.0]) + } +} + +impl CleartextList { + /// Creates a cleartext list from a container of values. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::CleartextList; + /// use tfhe::core_crypto::prelude::CleartextCount; + /// let clear_list = CleartextList::from_container(vec![1 as u8; 100]); + /// assert_eq!(clear_list.count(), CleartextCount(100)); + /// ``` + pub fn from_container(cont: Cont) -> CleartextList { + CleartextList { + tensor: Tensor::from_container(cont), + } + } + + /// Returns the number of elements in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::CleartextList; + /// use tfhe::core_crypto::prelude::CleartextCount; + /// let clear_list = CleartextList::from_container(vec![1 as u8; 100]); + /// assert_eq!(clear_list.count(), CleartextCount(100)); + /// ``` + pub fn count(&self) -> CleartextCount + where + Self: AsRefTensor, + { + CleartextCount(self.as_tensor().len()) + } + + /// Creates an iterator over borrowed cleartexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::CleartextList; + /// let clear_list = CleartextList::from_container(vec![1 as u8; 100]); + /// clear_list.cleartext_iter().for_each(|a| assert_eq!(a.0, 1)); + /// assert_eq!(clear_list.cleartext_iter().count(), 100); + /// ``` + pub fn cleartext_iter(&self) -> impl Iterator::Element>> + where + Self: AsRefTensor, + ::Element: Numeric, + { + self.as_tensor().iter().map(|refe| unsafe { + &*{ + refe as *const ::Element + as *const Cleartext<::Element> + } + }) + } + + /// Creates an iterator over mutably borrowed cleartexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::{Cleartext, CleartextList}; + /// let mut clear_list = CleartextList::from_container(vec![1 as u8; 100]); + /// clear_list + /// .cleartext_iter_mut() + /// .for_each(|a| *a = Cleartext(2)); + /// clear_list.cleartext_iter().for_each(|a| assert_eq!(a.0, 2)); + /// assert_eq!(clear_list.cleartext_iter_mut().count(), 100); + /// ``` + pub fn cleartext_iter_mut( + &mut self, + ) -> impl Iterator::Element>> + where + Self: AsMutTensor, + ::Element: Numeric, + { + self.as_mut_tensor().iter_mut().map(|refe| unsafe { + &mut *{ + refe as *mut ::Element + as *mut Cleartext<::Element> + } + }) + } + + /// Creates an iterator over borrowed sub-lists. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::encoding::CleartextList; + /// use tfhe::core_crypto::prelude::CleartextCount; + /// let clear_list = CleartextList::from_container(vec![1 as u8; 100]); + /// clear_list + /// .sublist_iter(CleartextCount(10)) + /// .for_each(|a| assert_eq!(a.count(), CleartextCount(10))); + /// assert_eq!(clear_list.sublist_iter(CleartextCount(10)).count(), 10); + /// ``` + pub fn sublist_iter( + &self, + sub_len: CleartextCount, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => sub_len.0); + self.as_tensor() + .subtensor_iter(sub_len.0) + .map(|sub| CleartextList::from_container(sub.into_container())) + } + + /// Creates an iterator over mutably borrowed sub-lists. + /// + /// #Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::encoding::{Cleartext, CleartextList}; + /// use tfhe::core_crypto::prelude::CleartextCount; + /// let mut clear_list = CleartextList::from_container(vec![1 as u8; 100]); + /// clear_list + /// .sublist_iter_mut(CleartextCount(10)) + /// .for_each(|mut a| a.cleartext_iter_mut().for_each(|b| *b = Cleartext(3))); + /// clear_list + /// .cleartext_iter() + /// .for_each(|a| assert_eq!(*a, Cleartext(3))); + /// assert_eq!(clear_list.sublist_iter_mut(CleartextCount(10)).count(), 10); + /// ``` + pub fn sublist_iter_mut( + &mut self, + sub_len: CleartextCount, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => sub_len.0); + self.as_mut_tensor() + .subtensor_iter_mut(sub_len.0) + .map(|sub| CleartextList::from_container(sub.into_container())) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/encoding/mod.rs b/tfhe/src/core_crypto/commons/crypto/encoding/mod.rs new file mode 100644 index 000000000..c8a66a423 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/encoding/mod.rs @@ -0,0 +1,7 @@ +//! Encoding cleartexts into plaintexts + +mod cleartext; +pub use cleartext::*; + +mod plaintext; +pub use plaintext::*; diff --git a/tfhe/src/core_crypto/commons/crypto/encoding/plaintext.rs b/tfhe/src/core_crypto/commons/crypto/encoding/plaintext.rs new file mode 100644 index 000000000..811bc064e --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/encoding/plaintext.rs @@ -0,0 +1,256 @@ +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::PlaintextCount; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// An plaintext (encoded) value. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct Plaintext(pub T); + +/// A list of plaintexts +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PlaintextList { + pub(crate) tensor: Tensor, +} + +tensor_traits!(PlaintextList); + +impl PlaintextList> +where + Scalar: Copy, +{ + /// Allocates a new list of plaintexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PlaintextCount; + /// let plain_list = PlaintextList::allocate(1 as u8, PlaintextCount(100)); + /// assert_eq!(plain_list.count(), PlaintextCount(100)); + /// ``` + pub fn allocate(value: Scalar, count: PlaintextCount) -> PlaintextList> { + PlaintextList::from_container(vec![value; count.0]) + } +} + +impl PlaintextList { + /// Creates a plaintext list from a container of values. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PlaintextCount; + /// let plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// assert_eq!(plain_list.count(), PlaintextCount(100)); + /// ``` + pub fn from_container(cont: Cont) -> PlaintextList { + PlaintextList { + tensor: Tensor::from_container(cont), + } + } + + pub fn from_tensor(tensor: Tensor) -> PlaintextList { + PlaintextList { tensor } + } + + /// Returns the number of elements in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PlaintextCount; + /// let plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// assert_eq!(plain_list.count(), PlaintextCount(100)); + /// ``` + pub fn count(&self) -> PlaintextCount + where + Self: AsRefTensor, + { + PlaintextCount(self.as_tensor().len()) + } + + /// Creates an iterator over borrowed plaintexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// let plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// plain_list.plaintext_iter().for_each(|a| assert_eq!(a.0, 1)); + /// assert_eq!(plain_list.plaintext_iter().count(), 100); + /// ``` + pub fn plaintext_iter(&self) -> impl Iterator::Element>> + where + Self: AsRefTensor, + ::Element: Numeric, + { + self.as_tensor().iter().map(|refe| unsafe { + &*{ + refe as *const ::Element + as *const Plaintext<::Element> + } + }) + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_plaintext_iter( + &self, + ) -> impl IndexedParallelIterator::Element>> + where + Self: AsRefTensor, + ::Element: Numeric + Sync, + { + self.as_tensor().par_iter().map(|refe| unsafe { + &*{ + refe as *const ::Element + as *const Plaintext<::Element> + } + }) + } + + /// Creates an iterator over mutably borrowed plaintexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// let mut plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// plain_list + /// .plaintext_iter_mut() + /// .for_each(|a| *a = Plaintext(2)); + /// plain_list.plaintext_iter().for_each(|a| assert_eq!(a.0, 2)); + /// assert_eq!(plain_list.plaintext_iter_mut().count(), 100); + /// ``` + pub fn plaintext_iter_mut( + &mut self, + ) -> impl Iterator::Element>> + where + Self: AsMutTensor, + ::Element: Numeric, + { + self.as_mut_tensor().iter_mut().map(|refe| unsafe { + &mut *{ + refe as *mut ::Element + as *mut Plaintext<::Element> + } + }) + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_plaintext_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element>> + where + Self: AsMutTensor, + ::Element: Numeric + Send + Sync, + { + self.as_mut_tensor().par_iter_mut().map(|refe| unsafe { + &mut *{ + refe as *mut ::Element + as *mut Plaintext<::Element> + } + }) + } + + /// Creates an iterator over borrowed sub-lists. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PlaintextCount; + /// let mut plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// plain_list + /// .sublist_iter(PlaintextCount(10)) + /// .for_each(|a| assert_eq!(a.count(), PlaintextCount(10))); + /// assert_eq!(plain_list.sublist_iter(PlaintextCount(10)).count(), 10); + /// ``` + pub fn sublist_iter( + &self, + count: PlaintextCount, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(count.0) + .map(|sub| PlaintextList::from_container(sub.into_container())) + } + + /// Creates an iterator over mutably borrowed sub-lists. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PlaintextCount; + /// let mut plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// plain_list + /// .sublist_iter_mut(PlaintextCount(10)) + /// .for_each(|mut a| a.plaintext_iter_mut().for_each(|b| *b = Plaintext(2))); + /// plain_list + /// .plaintext_iter() + /// .for_each(|a| assert_eq!(*a, Plaintext(2))); + /// assert_eq!(plain_list.sublist_iter_mut(PlaintextCount(10)).count(), 10); + /// ``` + pub fn sublist_iter_mut( + &mut self, + count: PlaintextCount, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.count().0 => count.0); + self.as_mut_tensor() + .subtensor_iter_mut(count.0) + .map(|sub| PlaintextList::from_container(sub.into_container())) + } + + /// Return a borrowed polynomial whose coefficients are the plaintexts of this list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// let plain_poly = plain_list.as_polynomial(); + /// assert_eq!(plain_poly.polynomial_size(), PolynomialSize(100)); + /// ``` + pub fn as_polynomial(&self) -> Polynomial<&[::Element]> + where + Self: AsRefTensor, + { + Polynomial::from_container(self.as_tensor().as_slice()) + } + + /// Return a mutably borrowed polynomial whose coefficients are the plaintexts of this list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut plain_list = PlaintextList::from_container(vec![1 as u8; 100]); + /// let mut plain_poly = plain_list.as_mut_polynomial(); + /// assert_eq!(plain_poly.polynomial_size(), PolynomialSize(100)); + /// ``` + pub fn as_mut_polynomial(&mut self) -> Polynomial<&mut [::Element]> + where + Self: AsMutTensor, + { + Polynomial::from_container(self.as_mut_tensor().as_mut_slice()) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/ggsw/levels.rs b/tfhe/src/core_crypto/commons/crypto/ggsw/levels.rs new file mode 100644 index 000000000..a2182fdb8 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/ggsw/levels.rs @@ -0,0 +1,353 @@ +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext; +use crate::core_crypto::commons::math::decomposition::DecompositionLevel; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::prelude::{GlweSize, PolynomialSize}; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; + +/// A matrix containing a single level of gadget decomposition. +pub struct GgswLevelMatrix { + tensor: Tensor, + poly_size: PolynomialSize, + glwe_size: GlweSize, + level: DecompositionLevel, +} + +tensor_traits!(GgswLevelMatrix); + +impl GgswLevelMatrix { + /// Creates a GGSW level matrix from an arbitrary container. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_matrix.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(level_matrix.glwe_size(), GlweSize(7)); + /// assert_eq!(level_matrix.decomposition_level(), DecompositionLevel(1)); + /// ``` + pub fn from_container( + cont: Cont, + poly_size: PolynomialSize, + rlwe_size: GlweSize, + level: DecompositionLevel, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => rlwe_size.0, poly_size.0); + GgswLevelMatrix { + tensor, + poly_size, + glwe_size: rlwe_size, + level, + } + } + + /// Returns the size of the GLWE ciphertexts composing the GGSW level matrix. + /// + /// This is also the number of columns of the matrix (assuming it is a matrix of + /// polynomials), as well as its number of rows. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_matrix.glwe_size(), GlweSize(7)); + /// ``` + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + /// Returns the index of the level corresponding to this matrix. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_matrix.decomposition_level(), DecompositionLevel(1)); + /// ``` + pub fn decomposition_level(&self) -> DecompositionLevel { + self.level + } + + /// Returns the size of the polynomials of the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_matrix.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns an iterator over the borrowed rows of the matrix. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// for row in level_matrix.row_iter() { + /// assert_eq!(row.glwe_size(), GlweSize(7)); + /// assert_eq!(row.polynomial_size(), PolynomialSize(10)); + /// } + /// assert_eq!(level_matrix.row_iter().count(), 7); + /// ``` + pub fn row_iter(&self) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(self.poly_size.0 * self.glwe_size.0) + .map(move |tens| { + GgswLevelRow::from_container(tens.into_container(), self.poly_size, self.level) + }) + } + + /// Returns an iterator over the mutably borrowed rows of the matrix. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// for mut row in level_matrix.row_iter_mut() { + /// row.as_mut_tensor().fill_with_element(9); + /// } + /// assert!(level_matrix.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(level_matrix.row_iter_mut().count(), 7); + /// ``` + pub fn row_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.poly_size.0 * self.glwe_size.0; + let poly_size = self.poly_size; + let level = self.level; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |tens| GgswLevelRow::from_container(tens.into_container(), poly_size, level)) + } + + /// Returns a parallel iterator over the mutably borrowed rows of the matrix. + /// + /// # Note + /// + /// This method uses _rayon_ internally, and is hidden behind the "__commons_parallel" feature + /// gate. + /// + /// # Example + /// + /// ```rust + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelMatrix; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut level_matrix = GgswLevelMatrix::from_container( + /// vec![0 as u8; 10 * 7 * 7], + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevel(1), + /// ); + /// level_matrix.par_row_iter_mut().for_each(|mut row| { + /// row.as_mut_tensor().fill_with_element(9); + /// }); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_row_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Send + Sync, + { + let chunks_size = self.poly_size.0 * self.glwe_size.0; + let poly_size = self.poly_size; + let level = self.level; + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .map(move |tens| GgswLevelRow::from_container(tens.into_container(), poly_size, level)) + } +} + +/// A row of a GGSW level matrix. +pub struct GgswLevelRow { + tensor: Tensor, + poly_size: PolynomialSize, + level: DecompositionLevel, +} + +tensor_traits!(GgswLevelRow); + +impl GgswLevelRow { + /// Creates an Rgsw level row from an arbitrary container. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelRow; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_row = GgswLevelRow::from_container( + /// vec![0 as u8; 10 * 7], + /// PolynomialSize(10), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_row.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(level_row.glwe_size(), GlweSize(7)); + /// assert_eq!(level_row.decomposition_level(), DecompositionLevel(1)); + /// ``` + pub fn from_container(cont: Cont, poly_size: PolynomialSize, level: DecompositionLevel) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => poly_size.0); + GgswLevelRow { + tensor, + poly_size, + level, + } + } + + /// Returns the size of the glwe ciphertext composing this level row. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelRow; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_row = GgswLevelRow::from_container( + /// vec![0 as u8; 10 * 7], + /// PolynomialSize(10), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_row.glwe_size(), GlweSize(7)); + /// ``` + pub fn glwe_size(&self) -> GlweSize + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.poly_size.0); + GlweSize(self.as_tensor().len() / self.poly_size.0) + } + + /// Returns the index of the level corresponding to this row. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelRow; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let level_row = GgswLevelRow::from_container( + /// vec![0 as u8; 10 * 7], + /// PolynomialSize(10), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_row.decomposition_level(), DecompositionLevel(1)); + /// ``` + pub fn decomposition_level(&self) -> DecompositionLevel { + self.level + } + + /// Returns the size of the polynomials used in the row. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelRow; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let level_row = GgswLevelRow::from_container( + /// vec![0 as u8; 10 * 7], + /// PolynomialSize(10), + /// DecompositionLevel(1), + /// ); + /// assert_eq!(level_row.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Consumes the row and returns its container wrapped into an `GlweCiphertext`. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::ggsw::GgswLevelRow; + /// use tfhe::core_crypto::commons::math::decomposition::DecompositionLevel; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let level_row = GgswLevelRow::from_container( + /// vec![0 as u8; 10 * 7], + /// PolynomialSize(10), + /// DecompositionLevel(1), + /// ); + /// let glwe = level_row.into_glwe(); + /// assert_eq!(glwe.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(glwe.size(), GlweSize(7)); + /// ``` + pub fn into_glwe(self) -> GlweCiphertext { + GlweCiphertext { + tensor: self.tensor, + poly_size: self.poly_size, + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/ggsw/mod.rs b/tfhe/src/core_crypto/commons/crypto/ggsw/mod.rs new file mode 100644 index 000000000..f20c11dde --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/ggsw/mod.rs @@ -0,0 +1,14 @@ +//! GGSW encryption scheme. + +mod levels; +mod seeded_levels; +mod seeded_standard; +mod standard; + +pub use levels::*; +pub use seeded_levels::*; +pub use seeded_standard::*; +pub use standard::*; + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/core_crypto/commons/crypto/ggsw/seeded_levels.rs b/tfhe/src/core_crypto/commons/crypto/ggsw/seeded_levels.rs new file mode 100644 index 000000000..4f1c25899 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/ggsw/seeded_levels.rs @@ -0,0 +1,205 @@ +use crate::core_crypto::commons::crypto::glwe::GlweSeededCiphertext; +use crate::core_crypto::commons::math::decomposition::DecompositionLevel; +use crate::core_crypto::commons::math::random::CompressionSeed; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A matrix containing a single level of gadget decomposition. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct GgswSeededLevelMatrix { + tensor: Tensor, + poly_size: PolynomialSize, + glwe_size: GlweSize, + level: DecompositionLevel, + compression_seed: CompressionSeed, +} + +tensor_traits!(GgswSeededLevelMatrix); + +impl GgswSeededLevelMatrix { + /// Creates a GGSW seeded level matrix from an arbitrary container. + pub fn from_container( + cont: Cont, + poly_size: PolynomialSize, + glwe_size: GlweSize, + level: DecompositionLevel, + compression_seed: CompressionSeed, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => poly_size.0); + Self { + tensor, + poly_size, + glwe_size, + level, + compression_seed, + } + } + + /// Returns the size of the GLWE ciphertexts composing the GGSW level matrix. + /// + /// This is also the number of columns of the expanded matrix (assuming it is a matrix of + /// polynomials), as well as the number of rows of the matrix. + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + /// Returns the index of the level corresponding to this matrix. + pub fn decomposition_level(&self) -> DecompositionLevel { + self.level + } + + /// Returns the size of the polynomials of the current ciphertext. + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns an iterator over the borrowed rows of the matrix. + pub fn row_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(self.poly_size.0) + .map(move |sub| { + GgswSeededLevelRow::from_container( + sub.into_container(), + self.poly_size, + self.level, + self.glwe_size.to_glwe_dimension(), + self.compression_seed, + ) + }) + } + + /// Returns an iterator over the mutably borrowed rows of the matrix. + pub fn row_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.poly_size.0; + let poly_size = self.poly_size; + let glwe_dimension = self.glwe_size.to_glwe_dimension(); + let level = self.level; + let compression_seed = self.compression_seed; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| { + GgswSeededLevelRow::from_container( + sub.into_container(), + poly_size, + level, + glwe_dimension, + compression_seed, + ) + }) + } + + /// Returns a parallel iterator over the mutably borrowed rows of the matrix. + /// + /// # Note + /// + /// This method uses _rayon_ internally, and is hidden behind the "multithread" feature + /// gate. + #[cfg(feature = "__commons_parallel")] + pub fn par_row_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Send + Sync, + { + let chunks_size = self.poly_size.0; + let poly_size = self.poly_size; + let glwe_dimension = self.glwe_size.to_glwe_dimension(); + let level = self.level; + let compression_seed = self.compression_seed; + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .map(move |sub| { + GgswSeededLevelRow::from_container( + sub.into_container(), + poly_size, + level, + glwe_dimension, + compression_seed, + ) + }) + } +} + +/// A row of a GGSW level matrix. +pub struct GgswSeededLevelRow { + tensor: Tensor, + poly_size: PolynomialSize, + level: DecompositionLevel, + glwe_dimension: GlweDimension, + compression_seed: CompressionSeed, +} + +tensor_traits!(GgswSeededLevelRow); + +impl GgswSeededLevelRow { + /// Creates an Rgsw seeded level row from an arbitrary container. + pub fn from_container( + cont: Cont, + poly_size: PolynomialSize, + level: DecompositionLevel, + glwe_dimension: GlweDimension, + compression_seed: CompressionSeed, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.as_slice().len() => poly_size.0); + Self { + tensor, + poly_size, + level, + glwe_dimension, + compression_seed, + } + } + + /// Returns the size of the glwe ciphertext composing this level row. + pub fn glwe_size(&self) -> GlweSize { + self.glwe_dimension.to_glwe_size() + } + + /// Returns the index of the level corresponding to this row. + pub fn decomposition_level(&self) -> DecompositionLevel { + self.level + } + + /// Returns the size of the polynomials used in the row. + pub fn polynomial_size(&self) -> PolynomialSize + where + Cont: AsRefSlice, + { + self.poly_size + } + + /// Consumes the row and returns its container wrapped into an `GlweCiphertext`. + pub fn into_seeded_glwe(self) -> GlweSeededCiphertext { + GlweSeededCiphertext::from_container( + self.tensor.into_container(), + self.glwe_dimension, + self.compression_seed, + ) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/ggsw/seeded_standard.rs b/tfhe/src/core_crypto/commons/crypto/ggsw/seeded_standard.rs new file mode 100644 index 000000000..58e6639a2 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/ggsw/seeded_standard.rs @@ -0,0 +1,588 @@ +use super::{GgswSeededLevelMatrix, StandardGgswCiphertext}; +use crate::core_crypto::commons::math::decomposition::DecompositionLevel; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Uniform, +}; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A GGSW seeded ciphertext. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct StandardGgswSeededCiphertext { + tensor: Tensor, + poly_size: PolynomialSize, + glwe_size: GlweSize, + decomp_base_log: DecompositionBaseLog, + compression_seed: CompressionSeed, +} + +tensor_traits!(StandardGgswSeededCiphertext); + +impl StandardGgswSeededCiphertext> { + /// Allocates a new GGSW ciphertext whose coefficients are all `value`. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.glwe_size(), glwe_size); + /// assert_eq!(seeded_ggsw.decomposition_level_count(), decomp_level); + /// assert_eq!(seeded_ggsw.decomposition_base_log(), decomp_base_log); + /// assert_eq!(seeded_ggsw.polynomial_size(), polynomial_size); + /// assert_eq!(seeded_ggsw.compression_seed(), compression_seed); + /// ``` + pub fn allocate( + poly_size: PolynomialSize, + glwe_size: GlweSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + compression_seed: CompressionSeed, + ) -> Self + where + Scalar: Numeric, + { + Self::from_container( + vec![Scalar::ZERO; decomp_level.0 * glwe_size.0 * poly_size.0], + poly_size, + glwe_size, + decomp_base_log, + compression_seed, + ) + } +} + +impl StandardGgswSeededCiphertext { + /// Creates a ggsw seeded ciphertext from an existing container. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let container = vec![0u8; decomp_level.0 * glwe_size.0 * polynomial_size.0]; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::from_container( + /// container, + /// polynomial_size, + /// glwe_size, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.glwe_size(), glwe_size); + /// assert_eq!(seeded_ggsw.decomposition_level_count(), decomp_level); + /// assert_eq!(seeded_ggsw.decomposition_base_log(), decomp_base_log); + /// assert_eq!(seeded_ggsw.polynomial_size(), polynomial_size); + /// assert_eq!(seeded_ggsw.compression_seed(), compression_seed); + /// ``` + pub fn from_container( + cont: Cont, + poly_size: PolynomialSize, + glwe_size: GlweSize, + decomp_base_log: DecompositionBaseLog, + compression_seed: CompressionSeed, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => poly_size.0, glwe_size.0); + Self { + tensor, + glwe_size, + poly_size, + decomp_base_log, + compression_seed, + } + } + + /// Returns the size of the glwe ciphertexts composing the ggsw ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.glwe_size(), glwe_size); + /// ``` + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + /// Returns the compression seed used to fill masks of the GLWE ciphertext making up the GGSW + /// ciphertext. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.compression_seed(), compression_seed); + /// ``` + pub fn compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// Returns the number of decomposition levels used in the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.decomposition_level_count(), decomp_level); + /// ``` + pub fn decomposition_level_count(&self) -> DecompositionLevelCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => + self.glwe_size.0, + self.poly_size.0 + ); + DecompositionLevelCount(self.as_tensor().len() / (self.glwe_size.0 * self.poly_size.0)) + } + + /// Returns the size of the polynomials used in the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.polynomial_size(), polynomial_size); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns the logarithm of the base used for the gadget decomposition. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// assert_eq!(seeded_ggsw.decomposition_base_log(), decomp_base_log); + /// ``` + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Returns an iterator over borrowed seeded level matrices. + /// + /// # Note + /// + /// This iterator iterates over the levels from the lower to the higher level in the usual + /// order. To iterate in the reverse order, you can use `rev()` on the iterator. + /// + /// # Example + /// + /// Returns an iterator over mutably borrowed seeded level matrices. + /// + /// # Note + /// + /// This iterator iterates over the levels from the lower to the higher level in the usual + /// order. To iterate in the reverse order, you can use `rev()` on the iterator. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// for level_matrix in seeded_ggsw.level_matrix_iter() { + /// assert_eq!(level_matrix.row_iter().count(), glwe_size.0); + /// assert_eq!(level_matrix.polynomial_size(), polynomial_size); + /// for rlwe in level_matrix.row_iter() { + /// assert_eq!(rlwe.glwe_size(), glwe_size); + /// assert_eq!(rlwe.polynomial_size(), polynomial_size); + /// } + /// } + /// + /// assert_eq!(seeded_ggsw.level_matrix_iter().count(), decomp_level.0); + /// ``` + pub fn level_matrix_iter( + &self, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsRefTensor, + { + // The factor two is to get the coefficient with the message and the body with unpredictable + // noise + let chunks_size = self.poly_size.0 * self.glwe_size.0; + let poly_size = self.poly_size; + let glwe_size = self.glwe_size; + self.as_tensor() + .subtensor_iter(chunks_size) + .enumerate() + .map(move |(index, tensor)| { + GgswSeededLevelMatrix::from_container( + tensor.into_container(), + poly_size, + glwe_size, + DecompositionLevel(index + 1), + self.compression_seed, + ) + }) + } + + /// Returns an iterator over mutably borrowed seeded level matrices. + /// + /// # Note + /// + /// This iterator iterates over the levels from the lower to the higher level in the usual + /// order. To iterate in the reverse order, you can use `rev()` on the iterator. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// for mut level_matrix in seeded_ggsw.level_matrix_iter_mut() { + /// for mut rlwe in level_matrix.row_iter_mut() { + /// rlwe.as_mut_tensor().fill_with_element(9); + /// } + /// } + /// + /// assert!(seeded_ggsw.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(seeded_ggsw.level_matrix_iter_mut().count(), 3); + /// ``` + pub fn level_matrix_iter_mut( + &mut self, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsMutTensor, + { + // The factor two is to get the coefficient with the message and the body with unpredictable + // noise + let chunks_size = self.poly_size.0 * self.glwe_size.0; + let poly_size = self.poly_size; + let glwe_size = self.glwe_size; + let compression_seed = self.compression_seed; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .enumerate() + .map(move |(index, tensor)| { + GgswSeededLevelMatrix::from_container( + tensor.into_container(), + poly_size, + glwe_size, + DecompositionLevel(index + 1), + compression_seed, + ) + }) + } + + /// Returns a parallel iterator over mutably borrowed level seeded matrices. + /// + /// # Notes + /// This iterator is hidden behind the "multithread" feature gate. + /// + /// # Example + /// + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// seeded_ggsw + /// .par_level_matrix_iter_mut() + /// .for_each(|mut level_matrix| { + /// for mut rlwe in level_matrix.row_iter_mut() { + /// rlwe.as_mut_tensor().fill_with_element(9); + /// } + /// }); + /// + /// assert!(seeded_ggsw.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(seeded_ggsw.level_matrix_iter_mut().count(), 3); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_level_matrix_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Sync + Send, + { + let chunks_size = self.poly_size.0 * self.glwe_size.0; + let poly_size = self.poly_size; + let glwe_size = self.glwe_size; + let compression_seed = self.compression_seed; + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .enumerate() + .map(move |(index, tensor)| { + GgswSeededLevelMatrix::from_container( + tensor.into_container(), + poly_size, + glwe_size, + DecompositionLevel(index + 1), + compression_seed, + ) + }) + } + + pub fn expand_into_with_existing_generator( + self, + output: &mut StandardGgswCiphertext, + generator: &mut RandomGenerator, + ) where + Scalar: Copy + RandomGenerable + Numeric, + StandardGgswCiphertext: AsMutTensor, + Self: AsRefTensor, + Gen: ByteRandomGenerator, + { + for (matrix_in, mut matrix_out) in + self.level_matrix_iter().zip(output.level_matrix_iter_mut()) + { + for (row_in, row_out) in matrix_in.row_iter().zip(matrix_out.row_iter_mut()) { + let mut glwe_out = row_out.into_glwe(); + + let glwe_seeded = row_in.into_seeded_glwe(); + + glwe_seeded + .expand_into_with_existing_generator::<_, _, Gen>(&mut glwe_out, generator); + } + } + } + + /// Returns the ciphertext as a full fledged GgswCiphertext + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::ggsw::{ + /// StandardGgswCiphertext, StandardGgswSeededCiphertext, + /// }; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_size = GlweSize(7); + /// let decomp_level = DecompositionLevelCount(3); + /// let decomp_base_log = DecompositionBaseLog(4); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ggsw = StandardGgswSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// compression_seed, + /// ); + /// + /// let mut ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// polynomial_size, + /// glwe_size, + /// decomp_level, + /// decomp_base_log, + /// ); + /// + /// seeded_ggsw.expand_into::<_, _, SoftwareRandomGenerator>(&mut ggsw); + /// + /// assert_eq!(ggsw.polynomial_size(), polynomial_size); + /// assert_eq!(ggsw.glwe_size(), glwe_size); + /// assert_eq!(ggsw.decomposition_base_log(), decomp_base_log); + /// assert_eq!(ggsw.decomposition_level_count(), decomp_level); + /// ``` + pub fn expand_into(self, output: &mut StandardGgswCiphertext) + where + Scalar: Copy + RandomGenerable + Numeric, + StandardGgswCiphertext: AsMutTensor, + Self: AsRefTensor, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed().seed); + + self.expand_into_with_existing_generator(output, &mut generator); + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/ggsw/standard.rs b/tfhe/src/core_crypto/commons/crypto/ggsw/standard.rs new file mode 100644 index 000000000..eaf9443e4 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/ggsw/standard.rs @@ -0,0 +1,537 @@ +use crate::core_crypto::commons::crypto::encoding::Plaintext; +use crate::core_crypto::commons::math::tensor::Container; + +use crate::core_crypto::commons::crypto::glwe::GlweList; +use crate::core_crypto::commons::math::decomposition::DecompositionLevel; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; + +use super::GgswLevelMatrix; + +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A GGSW ciphertext. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct StandardGgswCiphertext { + pub(crate) tensor: Tensor, + poly_size: PolynomialSize, + rlwe_size: GlweSize, + decomp_base_log: DecompositionBaseLog, +} + +tensor_traits!(StandardGgswCiphertext); + +impl StandardGgswCiphertext> { + /// Allocates a new GGSW ciphertext whose coefficients are all `value`. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(ggsw.glwe_size(), GlweSize(7)); + /// assert_eq!(ggsw.decomposition_level_count(), DecompositionLevelCount(3)); + /// assert_eq!(ggsw.decomposition_base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn allocate( + value: Scalar, + poly_size: PolynomialSize, + rlwe_size: GlweSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + ) -> Self + where + Scalar: Copy, + { + StandardGgswCiphertext { + tensor: Tensor::from_container(vec![ + value; + decomp_level.0 + * rlwe_size.0 + * rlwe_size.0 + * poly_size.0 + ]), + poly_size, + rlwe_size, + decomp_base_log, + } + } +} + +impl StandardGgswCiphertext> +where + Scalar: UnsignedTorus, +{ + pub fn new_trivial_encryption( + poly_size: PolynomialSize, + glwe_size: GlweSize, + decomp_level: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + plaintext: &Plaintext, + ) -> Self { + let mut ciphertext = Self::allocate( + Scalar::ZERO, + poly_size, + glwe_size, + decomp_level, + decomp_base_log, + ); + ciphertext.fill_with_trivial_encryption(plaintext); + ciphertext + } +} + +impl StandardGgswCiphertext { + /// Creates an Rgsw ciphertext from an existing container. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::from_container( + /// vec![9 as u8; 7 * 7 * 10 * 3], + /// GlweSize(7), + /// PolynomialSize(10), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(ggsw.glwe_size(), GlweSize(7)); + /// assert_eq!(ggsw.decomposition_level_count(), DecompositionLevelCount(3)); + /// assert_eq!(ggsw.decomposition_base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn from_container( + cont: Cont, + rlwe_size: GlweSize, + poly_size: PolynomialSize, + decomp_base_log: DecompositionBaseLog, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => rlwe_size.0, poly_size.0, rlwe_size.0 * rlwe_size.0); + StandardGgswCiphertext { + tensor, + poly_size, + rlwe_size, + decomp_base_log, + } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> StandardGgswCiphertext<&'_ [Cont::Element]> + where + Cont: Container, + { + StandardGgswCiphertext { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + poly_size: self.poly_size, + rlwe_size: self.rlwe_size, + decomp_base_log: self.decomp_base_log, + } + } + + pub fn as_mut_view(&mut self) -> StandardGgswCiphertext<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + StandardGgswCiphertext { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + poly_size: self.poly_size, + rlwe_size: self.rlwe_size, + decomp_base_log: self.decomp_base_log, + } + } + + /// Returns the size of the glwe ciphertexts composing the ggsw ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(ggsw.glwe_size(), GlweSize(7)); + /// ``` + pub fn glwe_size(&self) -> GlweSize { + self.rlwe_size + } + + /// Returns the number of decomposition levels used in the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(ggsw.decomposition_level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn decomposition_level_count(&self) -> DecompositionLevelCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => + self.rlwe_size.0, + self.poly_size.0, + self.rlwe_size.0 * self.rlwe_size.0 + ); + DecompositionLevelCount( + self.as_tensor().len() / (self.rlwe_size.0 * self.rlwe_size.0 * self.poly_size.0), + ) + } + + /// Returns the size of the polynomials used in the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(ggsw.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns a borrowed list composed of all the GLWE ciphertext composing current ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, + /// PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// let list = ggsw.as_glwe_list(); + /// assert_eq!(list.glwe_dimension(), GlweDimension(6)); + /// assert_eq!(list.ciphertext_count(), CiphertextCount(3 * 7)); + /// ``` + pub fn as_glwe_list(&self) -> GlweList<&[Scalar]> + where + Self: AsRefTensor, + { + GlweList::from_container( + self.as_tensor().as_slice(), + self.rlwe_size.to_glwe_dimension(), + self.poly_size, + ) + } + + /// Returns a mutably borrowed `GlweList` composed of all the GLWE ciphertext composing + /// current ciphertext. + /// + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, + /// PolynomialSize, + /// }; + /// + /// let mut ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// let mut list = ggsw.as_mut_glwe_list(); + /// list.as_mut_tensor().fill_with_element(0); + /// assert_eq!(list.glwe_dimension(), GlweDimension(6)); + /// assert_eq!(list.ciphertext_count(), CiphertextCount(3 * 7)); + /// ggsw.as_tensor().iter().for_each(|a| assert_eq!(*a, 0)); + /// ``` + pub fn as_mut_glwe_list(&mut self) -> GlweList<&mut [Scalar]> + where + Self: AsMutTensor, + { + let dimension = self.rlwe_size.to_glwe_dimension(); + let size = self.poly_size; + GlweList::from_container(self.as_mut_tensor().as_mut_slice(), dimension, size) + } + + /// Returns the logarithm of the base used for the gadget decomposition. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(10), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// assert_eq!(ggsw.decomposition_base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Returns an iterator over borrowed level matrices. + /// + /// # Note + /// + /// This iterator iterates over the levels from the lower to the higher level in the usual + /// order. To iterate in the reverse order, you can use `rev()` on the iterator. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(9), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// for level_matrix in ggsw.level_matrix_iter() { + /// assert_eq!(level_matrix.row_iter().count(), 7); + /// assert_eq!(level_matrix.polynomial_size(), PolynomialSize(9)); + /// for rlwe in level_matrix.row_iter() { + /// assert_eq!(rlwe.glwe_size(), GlweSize(7)); + /// assert_eq!(rlwe.polynomial_size(), PolynomialSize(9)); + /// } + /// } + /// assert_eq!(ggsw.level_matrix_iter().count(), 3); + /// ``` + pub fn level_matrix_iter( + &self, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsRefTensor, + { + let chunks_size = self.poly_size.0 * self.rlwe_size.0 * self.rlwe_size.0; + let poly_size = self.poly_size; + let rlwe_size = self.rlwe_size; + self.as_tensor() + .subtensor_iter(chunks_size) + .enumerate() + .map(move |(index, tensor)| { + GgswLevelMatrix::from_container( + tensor.into_container(), + poly_size, + rlwe_size, + DecompositionLevel(index + 1), + ) + }) + } + + /// Returns an iterator over mutably borrowed level matrices. + /// + /// # Note + /// + /// This iterator iterates over the levels from the lower to the higher level in the usual + /// order. To iterate in the reverse order, you can use `rev()` on the iterator. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let mut ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(9), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// for mut level_matrix in ggsw.level_matrix_iter_mut() { + /// for mut rlwe in level_matrix.row_iter_mut() { + /// rlwe.as_mut_tensor().fill_with_element(9); + /// } + /// } + /// assert!(ggsw.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(ggsw.level_matrix_iter_mut().count(), 3); + /// ``` + pub fn level_matrix_iter_mut( + &mut self, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.poly_size.0 * self.rlwe_size.0 * self.rlwe_size.0; + let poly_size = self.poly_size; + let rlwe_size = self.rlwe_size; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .enumerate() + .map(move |(index, tensor)| { + GgswLevelMatrix::from_container( + tensor.into_container(), + poly_size, + rlwe_size, + DecompositionLevel(index + 1), + ) + }) + } + + /// Returns a parallel iterator over mutably borrowed level matrices. + /// + /// # Notes + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example + /// + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweSize, PolynomialSize, + /// }; + /// + /// let mut ggsw = StandardGgswCiphertext::allocate( + /// 9 as u8, + /// PolynomialSize(9), + /// GlweSize(7), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(4), + /// ); + /// ggsw.par_level_matrix_iter_mut() + /// .for_each(|mut level_matrix| { + /// for mut rlwe in level_matrix.row_iter_mut() { + /// rlwe.as_mut_tensor().fill_with_element(9); + /// } + /// }); + /// assert!(ggsw.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(ggsw.level_matrix_iter_mut().count(), 3); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_level_matrix_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Sync + Send, + { + let chunks_size = self.poly_size.0 * self.rlwe_size.0 * self.rlwe_size.0; + let poly_size = self.poly_size; + let rlwe_size = self.rlwe_size; + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .enumerate() + .map(move |(index, tensor)| { + GgswLevelMatrix::from_container( + tensor.into_container(), + poly_size, + rlwe_size, + DecompositionLevel(index + 1), + ) + }) + } + + pub fn fill_with_trivial_encryption(&mut self, plaintext: &Plaintext) + where + Self: AsMutTensor, + Scalar: UnsignedTorus, + { + // We fill the ggsw with trivial glwe encryptions of zero: + for mut glwe in self.as_mut_glwe_list().ciphertext_iter_mut() { + let mut mask = glwe.get_mut_mask(); + mask.as_mut_tensor().fill_with_element(Scalar::ZERO); + } + let base_log = self.decomposition_base_log(); + for mut matrix in self.level_matrix_iter_mut() { + let decomposition = plaintext.0.wrapping_mul( + Scalar::ONE + << (::BITS + - (base_log.0 * (matrix.decomposition_level().0))), + ); + // We iterate over the rows of the level matrix + for (index, row) in matrix.row_iter_mut().enumerate() { + let rlwe_ct = row.into_glwe(); + // We retrieve the row as a polynomial list + let mut polynomial_list = rlwe_ct.into_polynomial_list(); + // We retrieve the polynomial in the diagonal + let mut level_polynomial = polynomial_list.get_mut_polynomial(index); + // We get the first coefficient + let first_coef = level_polynomial.as_mut_tensor().first_mut(); + // We update the first coefficient + *first_coef = first_coef.wrapping_add(decomposition); + } + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/ggsw/tests.rs b/tfhe/src/core_crypto/commons/crypto/ggsw/tests.rs new file mode 100644 index 000000000..93903de36 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/ggsw/tests.rs @@ -0,0 +1,207 @@ +use crate::core_crypto::commons::crypto::encoding::PlaintextList; +use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, +}; +use crate::core_crypto::commons::crypto::secret::GlweSecretKey; +use crate::core_crypto::commons::math::random::{CompressionSeed, Seeder}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::test_tools; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LogStandardDev}; +use concrete_csprng::generators::SoftwareRandomGenerator; + +use super::{StandardGgswCiphertext, StandardGgswSeededCiphertext}; + +fn test_seeded_ggsw() { + // random settings + let nb_ct = test_tools::random_ciphertext_count(10); + let dimension = test_tools::random_glwe_dimension(5); + let polynomial_size = test_tools::random_polynomial_size(200); + let noise_parameters = LogStandardDev::from_log_standard_dev(-50.); + let decomp_level = DecompositionLevelCount(3); + let decomp_base_log = DecompositionBaseLog(7); + let mut secret_generator = test_tools::new_secret_random_generator(); + + // generates a secret key + let sk = GlweSecretKey::generate_binary(dimension, polynomial_size, &mut secret_generator); + + // generates random plaintexts + let plaintext_vector = + PlaintextList::from_tensor(secret_generator.random_uniform_tensor(nb_ct.0)); + + for plaintext in plaintext_vector.plaintext_iter() { + let main_seed = test_tools::random_seed(); + + // Use a deterministic seeder to get the seeds that will be used during the tests + let mut deterministic_seeder = + DeterministicSeeder::::new(main_seed); + let noise_seed = deterministic_seeder.seed(); + let mask_seed = deterministic_seeder.seed(); + + // encrypts + let mut seeded_ggsw = StandardGgswSeededCiphertext::allocate( + polynomial_size, + dimension.to_glwe_size(), + decomp_level, + decomp_base_log, + CompressionSeed { seed: mask_seed }, + ); + + // Recreate a second deterministic seeder to control the behavior of the seeded encryption + let mut seeder = DeterministicSeeder::::new(main_seed); + + sk.encrypt_constant_seeded_ggsw::<_, _, _, _, SoftwareRandomGenerator>( + &mut seeded_ggsw, + plaintext, + noise_parameters, + &mut seeder, + ); + + // expands + let mut ggsw_expanded = StandardGgswCiphertext::allocate( + T::ZERO, + polynomial_size, + dimension.to_glwe_size(), + decomp_level, + decomp_base_log, + ); + seeded_ggsw.expand_into::<_, _, SoftwareRandomGenerator>(&mut ggsw_expanded); + + // control encryption + let mut ggsw = StandardGgswCiphertext::allocate( + T::ZERO, + polynomial_size, + dimension.to_glwe_size(), + decomp_level, + decomp_base_log, + ); + + // Recreate a generator with the known mask seed + let mut generator = EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new(main_seed), + ); + // And force the noise seed (only available in tests) to the noise seed we know was used + generator.seed_noise_generator(noise_seed); + + sk.encrypt_constant_ggsw(&mut ggsw, plaintext, noise_parameters, &mut generator); + + assert_eq!(ggsw_expanded, ggsw); + } +} + +#[test] +fn test_seeded_ggsw_u32() { + test_seeded_ggsw::() +} + +#[test] +fn test_seeded_ggsw_u64() { + test_seeded_ggsw::() +} + +#[cfg(feature = "__commons_parallel")] +mod parallel { + use crate::core_crypto::commons::crypto::encoding::PlaintextList; + use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, + }; + use crate::core_crypto::commons::crypto::secret::GlweSecretKey; + use crate::core_crypto::commons::math::random::{CompressionSeed, Seeder}; + use crate::core_crypto::commons::math::torus::UnsignedTorus; + use crate::core_crypto::commons::test_tools; + use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, LogStandardDev, + }; + use concrete_csprng::generators::SoftwareRandomGenerator; + + use super::{StandardGgswCiphertext, StandardGgswSeededCiphertext}; + + fn test_par_seeded_ggsw() { + // random settings + let nb_ct = test_tools::random_ciphertext_count(10); + let dimension = test_tools::random_glwe_dimension(5); + let polynomial_size = test_tools::random_polynomial_size(200); + let noise_parameters = LogStandardDev::from_log_standard_dev(-50.); + let decomp_level = DecompositionLevelCount(3); + let decomp_base_log = DecompositionBaseLog(7); + let mut secret_generator = test_tools::new_secret_random_generator(); + + // generates a secret key + let sk = GlweSecretKey::generate_binary(dimension, polynomial_size, &mut secret_generator); + + // generates random plaintexts + let plaintext_vector = + PlaintextList::from_tensor(secret_generator.random_uniform_tensor(nb_ct.0)); + + for plaintext in plaintext_vector.plaintext_iter() { + let main_seed = test_tools::random_seed(); + + // Use a deterministic seeder to get the seeds that will be used during the tests + let mut deterministic_seeder = + DeterministicSeeder::::new(main_seed); + let noise_seed = deterministic_seeder.seed(); + let mask_seed = deterministic_seeder.seed(); + + // encrypts + let mut seeded_ggsw = StandardGgswSeededCiphertext::allocate( + polynomial_size, + dimension.to_glwe_size(), + decomp_level, + decomp_base_log, + CompressionSeed { seed: mask_seed }, + ); + + // Recreate a second deterministic seeder to control the behavior of the seeded + // encryption + let mut seeder = DeterministicSeeder::::new(main_seed); + + sk.par_encrypt_constant_seeded_ggsw::<_, _, _, _, SoftwareRandomGenerator>( + &mut seeded_ggsw, + plaintext, + noise_parameters, + &mut seeder, + ); + + // expands + let mut ggsw_expanded = StandardGgswCiphertext::allocate( + T::ZERO, + polynomial_size, + dimension.to_glwe_size(), + decomp_level, + decomp_base_log, + ); + seeded_ggsw.expand_into::<_, _, SoftwareRandomGenerator>(&mut ggsw_expanded); + + // control encryption + let mut ggsw = StandardGgswCiphertext::allocate( + T::ZERO, + polynomial_size, + dimension.to_glwe_size(), + decomp_level, + decomp_base_log, + ); + + // Recreate a generator with the known mask seed + let mut generator = EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new(main_seed), + ); + // And force the noise seed (only available in tests) to the noise seed we know was used + generator.seed_noise_generator(noise_seed); + + sk.par_encrypt_constant_ggsw(&mut ggsw, plaintext, noise_parameters, &mut generator); + + assert_eq!(ggsw_expanded, ggsw); + } + } + + #[test] + fn test_par_seeded_ggsw_u32() { + test_par_seeded_ggsw::() + } + + #[test] + fn test_par_seeded_ggsw_u64() { + test_par_seeded_ggsw::() + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/body.rs b/tfhe/src/core_crypto/commons/crypto/glwe/body.rs new file mode 100644 index 000000000..9d3c60f66 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/body.rs @@ -0,0 +1,75 @@ +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, IntoTensor, Tensor, +}; + +/// The body of a GLWE ciphertext. +pub struct GlweBody { + pub(crate) tensor: Tensor, +} + +tensor_traits!(GlweBody); + +impl GlweBody { + /// Consumes the current ciphertext body, and return a polynomial over the original container. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let body = glwe.get_body(); + /// let poly = body.into_polynomial(); + /// assert_eq!(poly.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn into_polynomial(self) -> Polynomial + where + Self: IntoTensor, + { + Polynomial::from_container(self.into_tensor().into_container()) + } + + /// Returns a borrowed polynomial from the current body. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let body = glwe.get_body(); + /// let poly = body.as_polynomial(); + /// assert_eq!(poly.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn as_polynomial(&self) -> Polynomial<&[::Element]> + where + Self: AsRefTensor, + { + Polynomial::from_container(self.as_tensor().as_slice()) + } + + /// Returns a mutably borrowed polynomial from the current body. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let mut body = glwe.get_mut_body(); + /// let mut poly = body.as_mut_polynomial(); + /// poly.as_mut_tensor().fill_with_element(9); + /// assert!(body.as_tensor().iter().all(|a| *a == 9)); + /// ``` + pub fn as_mut_polynomial(&mut self) -> Polynomial<&mut [::Element]> + where + Self: AsMutTensor, + { + Polynomial::from_container(self.as_mut_tensor().as_mut_slice()) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/ciphertext.rs b/tfhe/src/core_crypto/commons/crypto/glwe/ciphertext.rs new file mode 100644 index 000000000..906c147b7 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/ciphertext.rs @@ -0,0 +1,534 @@ +use super::{GlweBody, GlweMask}; +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::lwe::LweCiphertext; +use crate::core_crypto::commons::math::polynomial::PolynomialList; +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Container, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{GlweDimension, GlweSize, MonomialDegree, PolynomialSize}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// An GLWE ciphertext. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweCiphertext { + pub(crate) tensor: Tensor, + pub(crate) poly_size: PolynomialSize, +} + +tensor_traits!(GlweCiphertext); + +impl GlweCiphertext> { + /// Allocates a new GLWE ciphertext, whose body and masks coefficients are all `value`. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// let glwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// assert_eq!(glwe_ciphertext.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(glwe_ciphertext.mask_size(), GlweDimension(99)); + /// assert_eq!(glwe_ciphertext.size(), GlweSize(100)); + /// ``` + pub fn allocate( + value: Scalar, + poly_size: PolynomialSize, + size: GlweSize, + ) -> GlweCiphertext> + where + GlweCiphertext>: AsMutTensor, + Scalar: Copy, + { + GlweCiphertext::from_container(vec![value; poly_size.0 * size.0], poly_size) + } +} + +impl GlweCiphertext> +where + Scalar: Numeric, +{ + pub fn new_trivial_encryption( + glwe_size: GlweSize, + plaintexts: &PlaintextList, + ) -> Self + where + PlaintextList: AsRefTensor, + { + let poly_size = PolynomialSize(plaintexts.count().0); + let mut ciphertext = Self::allocate(Scalar::ZERO, poly_size, glwe_size); + ciphertext.fill_with_trivial_encryption(plaintexts); + ciphertext + } +} + +impl GlweCiphertext { + /// Creates a new GLWE ciphertext from an existing container. + /// + /// # Note + /// + /// This method does not perform any transformation of the container data. Those are assumed to + /// represent a valid glwe ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// let glwe = GlweCiphertext::from_container(vec![0 as u8; 1100], PolynomialSize(10)); + /// assert_eq!(glwe.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(glwe.mask_size(), GlweDimension(109)); + /// assert_eq!(glwe.size(), GlweSize(110)); + /// ``` + pub fn from_container(cont: Cont, poly_size: PolynomialSize) -> GlweCiphertext { + GlweCiphertext { + tensor: Tensor::from_container(cont), + poly_size, + } + } + + pub fn as_view(&self) -> GlweCiphertext<&'_ [Cont::Element]> + where + Cont: Container, + { + GlweCiphertext { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + poly_size: self.poly_size, + } + } + + pub fn as_mut_view(&mut self) -> GlweCiphertext<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + GlweCiphertext { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + poly_size: self.poly_size, + } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + /// Returns the size of the ciphertext, e.g. the number of masks + 1. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// assert_eq!(glwe.size(), GlweSize(100)); + /// ``` + pub fn size(&self) -> GlweSize + where + Self: AsRefTensor, + { + GlweSize(self.as_tensor().len() / self.poly_size.0) + } + + /// Returns the number of masks of the ciphertext, e.g. its size - 1. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// let glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// assert_eq!(glwe.mask_size(), GlweDimension(99)); + /// ``` + pub fn mask_size(&self) -> GlweDimension + where + Self: AsRefTensor, + { + GlweDimension(self.size().0 - 1) + } + + /// Returns the number of coefficients of the polynomials of the ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let rlwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// assert_eq!(rlwe_ciphertext.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns a borrowed [`GlweBody`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let rlwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let body = rlwe_ciphertext.get_body(); + /// assert_eq!(body.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn get_body(&self) -> GlweBody<&[::Element]> + where + Self: AsRefTensor, + { + GlweBody { + tensor: self + .as_tensor() + .get_sub((self.mask_size().0 * self.polynomial_size().0)..), + } + } + + /// Returns a borrowed [`GlweMask`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let rlwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let mask = rlwe_ciphertext.get_mask(); + /// assert_eq!(mask.mask_element_iter().count(), 99); + /// ``` + pub fn get_mask(&self) -> GlweMask<&[::Element]> + where + Self: AsRefTensor, + { + GlweMask { + tensor: self + .as_tensor() + .get_sub(..(self.mask_size().0 * self.polynomial_size().0)), + poly_size: self.poly_size, + } + } + + /// Returns a mutably borrowed [`GlweBody`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let mut body = glwe.get_mut_body(); + /// body.as_mut_tensor().fill_with_element(9); + /// let body = glwe.get_body(); + /// assert!(body.as_tensor().iter().all(|a| *a == 9)); + /// ``` + pub fn get_mut_body(&mut self) -> GlweBody<&mut [::Element]> + where + Self: AsMutTensor, + { + let body_index = self.mask_size().0 * self.polynomial_size().0; + GlweBody { + tensor: self.as_mut_tensor().get_sub_mut(body_index..), + } + } + + /// Returns a mutably borrowed [`GlweMask`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let mut masks = glwe.get_mut_mask(); + /// for mut mask in masks.mask_element_iter_mut() { + /// mask.as_mut_tensor().fill_with_element(9); + /// } + /// assert_eq!(masks.mask_element_iter_mut().count(), 99); + /// assert!(!glwe.as_tensor().iter().all(|a| *a == 9)); + /// ``` + pub fn get_mut_mask(&mut self) -> GlweMask<&mut [::Element]> + where + Self: AsMutTensor, + { + let body_index = self.mask_size().0 * self.polynomial_size().0; + let poly_size = self.poly_size; + GlweMask { + tensor: self.as_mut_tensor().get_sub_mut(..body_index), + poly_size, + } + } + + /// Returns borrowed [`GlweBody`] and [`GlweMask`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let (body, masks) = glwe.get_body_and_mask(); + /// assert_eq!(body.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// assert_eq!(masks.mask_element_iter().count(), 99); + /// ``` + #[allow(clippy::type_complexity)] + pub fn get_body_and_mask( + &self, + ) -> ( + GlweBody<&[::Element]>, + GlweMask<&[::Element]>, + ) + where + Self: AsRefTensor, + { + let index = self.mask_size().0 * self.polynomial_size().0; + ( + GlweBody { + tensor: self.as_tensor().get_sub(index..), + }, + GlweMask { + tensor: self.as_tensor().get_sub(..index), + poly_size: self.poly_size, + }, + ) + } + + /// Returns borrowed [`GlweBody`] and [`GlweMask`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let (mut body, mut masks) = glwe.get_mut_body_and_mask(); + /// body.as_mut_tensor().fill_with_element(9); + /// for mut mask in masks.mask_element_iter_mut() { + /// mask.as_mut_tensor().fill_with_element(9); + /// } + /// assert_eq!(body.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// assert!(glwe.as_tensor().iter().all(|a| *a == 9)); + /// ``` + #[allow(clippy::type_complexity)] + pub fn get_mut_body_and_mask( + &mut self, + ) -> ( + GlweBody<&mut [::Element]>, + GlweMask<&mut [::Element]>, + ) + where + Self: AsMutTensor, + { + let body_index = self.mask_size().0 * self.polynomial_size().0; + let poly_size = self.poly_size; + let (masks, body) = self.as_mut_tensor().as_mut_slice().split_at_mut(body_index); + ( + GlweBody { + tensor: Tensor::from_container(body), + }, + GlweMask { + tensor: Tensor::from_container(masks), + poly_size, + }, + ) + } + + /// Consumes the current ciphertext and turn it to a list of polynomial. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialCount, PolynomialSize}; + /// let rlwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let poly_list = rlwe_ciphertext.into_polynomial_list(); + /// assert_eq!(poly_list.polynomial_count(), PolynomialCount(100)); + /// assert_eq!(poly_list.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn into_polynomial_list(self) -> PolynomialList { + PolynomialList { + tensor: self.tensor, + poly_size: self.poly_size, + } + } + + /// Returns a borrowed polynomial list from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialCount, PolynomialSize}; + /// let rlwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let poly_list = rlwe_ciphertext.as_polynomial_list(); + /// assert_eq!(poly_list.polynomial_count(), PolynomialCount(100)); + /// assert_eq!(poly_list.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn as_polynomial_list(&self) -> PolynomialList<&[::Element]> + where + Self: AsRefTensor, + { + PolynomialList { + tensor: Tensor::from_container(self.as_tensor().as_slice()), + poly_size: self.poly_size, + } + } + + /// Returns a mutably borrowed polynomial list from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut glwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let mut poly_list = glwe.as_mut_polynomial_list(); + /// for mut poly in poly_list.polynomial_iter_mut() { + /// poly.as_mut_tensor().fill_with_element(9); + /// } + /// assert!(glwe.as_tensor().iter().all(|a| *a == 9)); + /// ``` + pub fn as_mut_polynomial_list( + &mut self, + ) -> PolynomialList<&mut [::Element]> + where + Self: AsMutTensor, + { + let poly_size = self.poly_size; + PolynomialList { + tensor: Tensor::from_container(self.as_mut_tensor().as_mut_slice()), + poly_size, + } + } + + /// Fills an LWE ciphertext with the extraction of one coefficient of the current GLWE + /// ciphertext. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::commons::math::polynomial::MonomialDegree; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{GlweDimension, LogStandardDev, LweDimension, PolynomialSize}; + /// + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let poly_size = PolynomialSize(4); + /// let glwe_dim = GlweDimension(2); + /// let glwe_secret_key = + /// GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// let mut plaintext_list = + /// PlaintextList::from_container(vec![100000 as u32, 200000, 300000, 400000]); + /// let mut glwe_ct = GlweCiphertext::allocate(0u32, poly_size, glwe_dim.to_glwe_size()); + /// let mut lwe_ct = + /// LweCiphertext::allocate(0u32, LweDimension(poly_size.0 * glwe_dim.0).to_lwe_size()); + /// glwe_secret_key.encrypt_glwe( + /// &mut glwe_ct, + /// &plaintext_list, + /// LogStandardDev(-60.), + /// &mut encryption_generator, + /// ); + /// let lwe_secret_key = glwe_secret_key.into_lwe_secret_key(); + /// + /// // Check for the first + /// for i in 0..4 { + /// // We sample extract + /// glwe_ct.fill_lwe_with_sample_extraction(&mut lwe_ct, MonomialDegree(i)); + /// // We decrypt + /// let mut output = Plaintext(0u32); + /// lwe_secret_key.decrypt_lwe(&mut output, &lwe_ct); + /// // We check that the decryption is correct + /// let plain = plaintext_list.as_tensor().get_element(i); + /// let d0 = output.0.wrapping_sub(*plain); + /// let d1 = plain.wrapping_sub(output.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400); + /// } + /// ``` + pub fn fill_lwe_with_sample_extraction( + &self, + lwe: &mut LweCiphertext, + n_th: MonomialDegree, + ) where + Self: AsRefTensor, + LweCiphertext: AsMutTensor, + Element: UnsignedTorus, + { + // We retrieve the bodies and masks of the two ciphertexts. + let (lwe_body, mut lwe_mask) = lwe.get_mut_body_and_mask(); + let (glwe_body, glwe_mask) = self.get_body_and_mask(); + + // We copy the body + lwe_body.0 = *glwe_body + .as_polynomial() + .get_monomial(n_th) + .get_coefficient(); + + // We copy the mask (each polynomial is in the wrong order) + lwe_mask + .as_mut_tensor() + .fill_with_copy(glwe_mask.as_tensor()); + + // We compute the number of elements which must be + // turned into their opposite + let opposite_count = self.poly_size.0 - n_th.0 - 1; + + // We loop through the polynomials (as mut tensors) + for mut lwe_mask_poly in lwe_mask + .as_mut_tensor() + .subtensor_iter_mut(self.poly_size.0) + { + // We reverse the polynomial + lwe_mask_poly.reverse(); + // We compute the opposite of the proper coefficients + lwe_mask_poly + .get_sub_mut(0..opposite_count) + .update_with_wrapping_neg(); + // We rotate the polynomial properly + lwe_mask_poly.rotate_left(opposite_count); + } + } + + pub fn fill_with_trivial_encryption( + &mut self, + plaintexts: &PlaintextList, + ) where + PlaintextList: AsRefTensor, + Self: AsMutTensor, + Scalar: Numeric, + { + debug_assert_eq!(plaintexts.count().0, self.poly_size.0); + let (mut body, mut mask) = self.get_mut_body_and_mask(); + + mask.as_mut_polynomial_list() + .polynomial_iter_mut() + .for_each(|mut polynomial| { + polynomial + .coefficient_iter_mut() + .for_each(|mask_coeff| *mask_coeff = ::ZERO) + }); + + body.as_mut_polynomial() + .coefficient_iter_mut() + .zip(plaintexts.plaintext_iter()) + .for_each( + |(body_coeff, plaintext): (&mut Scalar, &Plaintext)| { + *body_coeff = plaintext.0; + }, + ); + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/keyswitch.rs b/tfhe/src/core_crypto/commons/crypto/glwe/keyswitch.rs new file mode 100644 index 000000000..85c3c967a --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/keyswitch.rs @@ -0,0 +1,2251 @@ +use super::{GlweCiphertext, GlweList}; +use crate::core_crypto::commons::crypto::encoding::PlaintextList; +use crate::core_crypto::commons::crypto::lwe::{LweCiphertext, LweList}; +use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; +use crate::core_crypto::commons::math::decomposition::{ + DecompositionLevel, DecompositionTerm, SignedDecomposer, +}; +use crate::core_crypto::commons::math::polynomial::{Polynomial, PolynomialList}; +use crate::core_crypto::commons::math::random::ByteRandomGenerator; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, ck_dim_eq, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Container, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::prelude::{ + BinaryKeyKind, CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, + DispersionParameter, FunctionalPackingKeyswitchKeyCount, GlweDimension, GlweSize, LweDimension, + LweSize, MonomialDegree, PlaintextCount, PolynomialCount, PolynomialSize, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A packing keyswitching key. +/// +/// A packing keyswitching key allows to pack several LWE ciphertexts into a single GLWE +/// ciphertext. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LwePackingKeyswitchKey { + tensor: Tensor, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + output_polynomial_size: PolynomialSize, +} + +tensor_traits!(LwePackingKeyswitchKey); + +impl LwePackingKeyswitchKey> +where + Scalar: Copy, +{ + /// Allocates a packing keyswitching key whose masks and bodies are all `value`. + /// + /// # Note + /// + /// This function does *not* generate a keyswitch key, but merely allocates a container of the + /// right size. See [`LwePackingKeyswitchKey::fill_with_keyswitch_key`] to fill the container + /// with a proper keyswitching key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!( + /// pksk.decomposition_level_count(), + /// DecompositionLevelCount(10) + /// ); + /// assert_eq!(pksk.decomposition_base_log(), DecompositionBaseLog(16)); + /// assert_eq!(pksk.output_glwe_key_dimension(), GlweDimension(2)); + /// assert_eq!(pksk.input_lwe_key_dimension(), LweDimension(10)); + /// ``` + pub fn allocate( + value: Scalar, + decomp_size: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + input_dimension: LweDimension, + output_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + ) -> Self { + LwePackingKeyswitchKey { + tensor: Tensor::from_container(vec![ + value; + decomp_size.0 + * output_dimension.to_glwe_size().0 + * output_polynomial_size.0 + * input_dimension.0 + ]), + decomp_base_log, + decomp_level_count: decomp_size, + output_glwe_size: output_dimension.to_glwe_size(), + output_polynomial_size, + } + } +} + +impl LwePackingKeyswitchKey { + /// Creates a packing keyswitching key from a container. + /// + /// # Notes + /// + /// This method does not create a packing keyswitch key, but merely wraps the container in + /// the proper type. It assumes that either the container already contains a proper keyswitching + /// key, or that [`LwePackingKeyswitchKey::fill_with_keyswitch_key`] will be called right + /// after. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// let input_size = LweDimension(200); + /// let output_size = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(7); + /// let decomp_level_count = DecompositionLevelCount(4); + /// + /// let pksk = LwePackingKeyswitchKey::from_container( + /// vec![ + /// 0 as u8; + /// input_size.0 * (output_size.0 + 1) * polynomial_size.0 * decomp_level_count.0 + /// ], + /// decomp_base_log, + /// decomp_level_count, + /// output_size, + /// polynomial_size, + /// ); + /// + /// assert_eq!(pksk.decomposition_level_count(), DecompositionLevelCount(4)); + /// assert_eq!(pksk.decomposition_base_log(), DecompositionBaseLog(7)); + /// assert_eq!(pksk.output_glwe_key_dimension(), GlweDimension(2)); + /// assert_eq!(pksk.input_lwe_key_dimension(), LweDimension(200)); + /// ``` + pub fn from_container( + cont: Cont, + decomp_base_log: DecompositionBaseLog, + decomp_size: DecompositionLevelCount, + output_glwe_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + ) -> LwePackingKeyswitchKey + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => output_glwe_dimension.to_glwe_size().0 * output_polynomial_size.0, decomp_size.0); + LwePackingKeyswitchKey { + tensor, + decomp_base_log, + decomp_level_count: decomp_size, + output_glwe_size: output_glwe_dimension.to_glwe_size(), + output_polynomial_size, + } + } + + /// Returns the dimension of the output GLWE key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pksk.output_glwe_key_dimension(), GlweDimension(2)); + /// ``` + pub fn output_glwe_key_dimension(&self) -> GlweDimension { + self.output_glwe_size.to_glwe_dimension() + } + + /// Returns the size of the polynomials composing the GLWE ciphertext + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, LweSize, + /// PolynomialSize, + /// }; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pksk.output_polynomial_size(), PolynomialSize(256)); + /// ``` + pub fn output_polynomial_size(&self) -> PolynomialSize { + self.output_polynomial_size + } + + /// Returns the dimension of the input LWE key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pksk.input_lwe_key_dimension(), LweDimension(10)); + /// ``` + pub fn input_lwe_key_dimension(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension( + self.as_tensor().len() + / (self.output_glwe_size.0 + * self.output_polynomial_size.0 + * self.decomp_level_count.0), + ) + } + + /// Returns the number of levels used for the decomposition of the input key bits. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!( + /// pksk.decomposition_level_count(), + /// DecompositionLevelCount(10) + /// ); + /// ``` + pub fn decomposition_level_count(&self) -> DecompositionLevelCount + where + Self: AsRefTensor, + { + self.decomp_level_count + } + + /// Returns the logarithm of the base used for the decomposition of the input key bits. + /// + /// Indeed, the basis used is always of the form $2^b$. This function returns $b$. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pksk.decomposition_base_log(), DecompositionBaseLog(16)); + /// ``` + pub fn decomposition_base_log(&self) -> DecompositionBaseLog + where + Self: AsRefTensor, + { + self.decomp_base_log + } + + /// Fills the current keyswitch key container with an actual keyswitching key constructed from + /// an input and an output key. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::glwe::LwePackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// + /// let input_size = LweDimension(10); + /// let output_size = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let cipher_size = LweSize(55); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = + /// GlweSecretKey::generate_binary(output_size, polynomial_size, &mut secret_generator); + /// + /// let mut pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u32, + /// decomp_level_count, + /// decomp_base_log, + /// input_size, + /// output_size, + /// polynomial_size, + /// ); + /// pksk.fill_with_packing_keyswitch_key(&input_key, &output_key, noise, &mut encryption_generator); + /// + /// assert!(!pksk.as_tensor().iter().all(|a| *a == 0)); + /// ``` + pub fn fill_with_packing_keyswitch_key( + &mut self, + input_lwe_key: &LweSecretKey, + output_glwe_key: &GlweSecretKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + // We instantiate a buffer + let mut messages = PlaintextList::from_container(vec![ + ::Element::ZERO; + self.decomp_level_count.0 + * self.output_polynomial_size.0 + ]); + + // We retrieve decomposition arguments + let decomp_level_count = self.decomp_level_count; + let decomp_base_log = self.decomp_base_log; + let polynomial_size = self.output_polynomial_size; + + // loop over the before key blocks + for (input_key_bit, keyswitch_key_block) in input_lwe_key + .as_tensor() + .iter() + .zip(self.bit_decomp_iter_mut()) + { + // We reset the buffer + messages + .as_mut_tensor() + .fill_with_element(::Element::ZERO); + + // We fill the buffer with the powers of the key bits + for (level, mut message) in (1..=decomp_level_count.0) + .map(DecompositionLevel) + .zip(messages.sublist_iter_mut(PlaintextCount(polynomial_size.0))) + { + *message.as_mut_tensor().first_mut() = + DecompositionTerm::new(level, decomp_base_log, *input_key_bit) + .to_recomposition_summand(); + } + + // We encrypt the buffer + output_glwe_key.encrypt_glwe_list( + &mut keyswitch_key_block.into_glwe_list(), + &messages, + noise_parameters, + generator, + ); + } + } + + /// Iterates over borrowed `LweKeyBitDecomposition` elements. + /// + /// One `LweKeyBitDecomposition` being a set of LWE ciphertexts, encrypting under the output + /// key, the $l$ levels of the signed decomposition of a single bit of the input key. + /// + /// # Example + /// + /// ```ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LwePackingKeyswitchKey}; + /// use tfhe::core_crypto::backends::default::private::math::decomposition::{DecompositionLevelCount, DecompositionBaseLog}; + /// let pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(15), + /// LweDimension(20) + /// ); + /// for decomp in pksk.bit_decomp_iter() { + /// assert_eq!(decomp.lwe_size(), pksk.lwe_size()); + /// assert_eq!(decomp.count().0, 10); + /// } + /// assert_eq!(pksk.bit_decomp_iter().count(), 15); + /// ``` + pub(crate) fn bit_decomp_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.output_glwe_size.0 * self.output_polynomial_size.0, self.decomp_level_count.0); + let size = + self.decomp_level_count.0 * self.output_glwe_size.0 * self.output_polynomial_size.0; + let glwe_size = self.output_glwe_size; + let poly_size = self.output_polynomial_size; + self.as_tensor().subtensor_iter(size).map(move |sub| { + LweKeyBitDecomposition::from_container(sub.into_container(), glwe_size, poly_size) + }) + } + + /// Iterates over mutably borrowed `LweKeyBitDecomposition` elements. + /// + /// One `LweKeyBitDecomposition` being a set of LWE ciphertexts, encrypting under the output + /// key, the $l$ levels of the signed decomposition of a single bit of the input key. + /// + /// # Example + /// + /// ```ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LwePackingKeyswitchKey}; + /// use tfhe::core_crypto::backends::default::private::math::tensor::{AsRefTensor, AsMutTensor}; + /// use tfhe::core_crypto::backends::default::private::math::decomposition::{DecompositionLevelCount, DecompositionBaseLog}; + /// let mut pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(15), + /// LweDimension(20) + /// ); + /// for mut decomp in pksk.bit_decomp_iter_mut() { + /// for mut ciphertext in decomp.ciphertext_iter_mut() { + /// ciphertext.as_mut_tensor().fill_with_element(0); + /// } + /// } + /// assert!(pksk.as_tensor().iter().all(|a| *a == 0)); + /// assert_eq!(pksk.bit_decomp_iter_mut().count(), 15); + /// ``` + pub(crate) fn bit_decomp_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.output_glwe_size.0 * self.output_polynomial_size.0, self.decomp_level_count.0); + let chunks_size = + self.decomp_level_count.0 * self.output_glwe_size.0 * self.output_polynomial_size.0; + let glwe_size = self.output_glwe_size; + let poly_size = self.output_polynomial_size; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| { + LweKeyBitDecomposition::from_container(sub.into_container(), glwe_size, poly_size) + }) + } + + /// Keyswitches a single LWE ciphertext into a GLWE + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::commons::crypto::glwe::*; + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// LweDimension, LweSize, PolynomialSize, + /// }; + /// + /// let input_size = LweDimension(1024); + /// let output_size = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(8); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = + /// GlweSecretKey::generate_binary(output_size, polynomial_size, &mut secret_generator); + /// + /// let mut pksk = LwePackingKeyswitchKey::allocate( + /// 0 as u64, + /// decomp_level_count, + /// decomp_base_log, + /// input_size, + /// output_size, + /// polynomial_size, + /// ); + /// pksk.fill_with_packing_keyswitch_key(&input_key, &output_key, noise, &mut encryption_generator); + /// + /// let plaintext: Plaintext = Plaintext(1432154329994324); + /// let mut ciphertext = LweCiphertext::allocate(0. as u64, LweSize(1025)); + /// let mut switched_ciphertext = + /// GlweCiphertext::allocate(0. as u64, PolynomialSize(256), GlweSize(3)); + /// input_key.encrypt_lwe( + /// &mut ciphertext, + /// &plaintext, + /// noise, + /// &mut encryption_generator, + /// ); + /// + /// pksk.keyswitch_ciphertext(&mut switched_ciphertext, &ciphertext); + /// + /// let mut decrypted = PlaintextList::from_container(vec![0 as u64; 256]); + /// output_key.decrypt_glwe(&mut decrypted, &switched_ciphertext); + /// ``` + pub fn keyswitch_ciphertext( + &self, + after: &mut GlweCiphertext, + before: &LweCiphertext, + ) where + Self: AsRefTensor, + GlweCiphertext: AsMutTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedTorus, + { + ck_dim_eq!(self.input_lwe_key_dimension().0 => before.lwe_size().to_lwe_dimension().0); + ck_dim_eq!(self.output_glwe_key_dimension().0 => after.size().to_glwe_dimension().0); + + // We reset the output + after.as_mut_tensor().fill_with(|| Scalar::ZERO); + + // We copy the body + *after.get_mut_body().tensor.as_mut_tensor().first_mut() = before.get_body().0; + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new(self.decomp_base_log, self.decomp_level_count); + + // Loop over the number of levels: + // We compute the multiplication of a ciphertext from the keyswitching key with a + // piece of the decomposition and subtract it to the buffer + for (block, input_lwe_mask) in self + .bit_decomp_iter() + .zip(before.get_mask().mask_element_iter()) + { + // We decompose + let mask_rounded = decomposer.closest_representable(*input_lwe_mask); + let decomp = decomposer.decompose(mask_rounded); + + // Loop over the number of levels: + // We compute the multiplication of a ciphertext from the keyswitching key with a + // piece of the decomposition and subtract it to the buffer + for (level_key_cipher, decomposed) in block + .as_tensor() + .subtensor_iter(self.output_glwe_size.0 * self.output_polynomial_size.0) + .rev() + .zip(decomp) + { + after + .as_mut_tensor() + .update_with_wrapping_sub_element_mul(&level_key_cipher, decomposed.value()); + } + } + } + + /// Packs several LweCiphertext into a single GlweCiphertext + /// with a keyswitch technique + pub fn packing_keyswitch( + &self, + output: &mut GlweCiphertext, + input: &LweList, + ) where + Self: AsRefTensor, + LweList: AsRefTensor, + GlweCiphertext: AsMutTensor, + OutCont: Clone, + Scalar: UnsignedTorus, + { + debug_assert!(input.count().0 <= output.polynomial_size().0); + output.as_mut_tensor().fill_with_element(Scalar::ZERO); + let mut buffer = output.clone(); + // for each ciphertext, call mono_key_switch + for (degree, input_cipher) in input.ciphertext_iter().enumerate() { + self.keyswitch_ciphertext(&mut buffer, &input_cipher); + buffer + .as_mut_polynomial_list() + .polynomial_iter_mut() + .for_each(|mut poly| { + poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(degree)) + }); + output + .as_mut_tensor() + .update_with_wrapping_add(buffer.as_tensor()); + } + } +} + +/// A private functional packing keyswitching key. +/// +/// A private functional packing keyswitching key allows to pack several LWE ciphertexts +/// into a single GLWE ciphertext while performing a private function on each +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LwePrivateFunctionalPackingKeyswitchKey { + tensor: Tensor, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + output_glwe_size: GlweSize, + output_polynomial_size: PolynomialSize, +} + +tensor_traits!(LwePrivateFunctionalPackingKeyswitchKey); + +impl LwePrivateFunctionalPackingKeyswitchKey> +where + Scalar: Copy, +{ + /// Allocates a private functional packing keyswitching key whose masks and bodies are all + /// `value`. + /// + /// # Note + /// + /// This function does *not* generate a private functional packing keyswitching key , but + /// merely allocates a container of the right size. + /// See [`LwePrivateFunctionalPackingKeyswitchKey::fill_with_private_functional_keyswitch_key`] + /// to fill the container with a proper functional keyswitching key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!( + /// pfpksk.decomposition_level_count(), + /// DecompositionLevelCount(10) + /// ); + /// assert_eq!(pfpksk.decomposition_base_log(), DecompositionBaseLog(16)); + /// assert_eq!(pfpksk.output_glwe_key_dimension(), GlweDimension(2)); + /// assert_eq!(pfpksk.input_lwe_key_dimension(), LweDimension(10)); + /// ``` + pub fn allocate( + value: Scalar, + decomp_size: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + input_dimension: LweDimension, + output_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + ) -> Self { + LwePrivateFunctionalPackingKeyswitchKey { + tensor: Tensor::from_container(vec![ + value; + decomp_size.0 + * output_dimension.to_glwe_size().0 + * output_polynomial_size.0 + * input_dimension.to_lwe_size().0 + ]), + decomp_base_log, + decomp_level_count: decomp_size, + output_glwe_size: output_dimension.to_glwe_size(), + output_polynomial_size, + } + } +} + +impl LwePrivateFunctionalPackingKeyswitchKey { + /// Creates a private functional packing keyswitching key from a container. + /// + /// # Notes + /// + /// This method does not create a private functional packing keyswitch key, but merely wraps + /// the container in the proper type. It assumes that either the container already contains a + /// proper functional keyswitching key, or that + /// [`LwePrivateFunctionalPackingKeyswitchKey::fill_with_private_functional_keyswitch_key`] will + /// be called right after. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// let input_lwe_dim = LweDimension(200); + /// let output_glwe_dim = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(7); + /// let decomp_level_count = DecompositionLevelCount(4); + /// + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::from_container( + /// vec![ + /// 0 as u8; + /// input_lwe_dim.to_lwe_size().0 + /// * output_glwe_dim.to_glwe_size().0 + /// * polynomial_size.0 + /// * decomp_level_count.0 + /// ], + /// decomp_base_log, + /// decomp_level_count, + /// output_glwe_dim, + /// polynomial_size, + /// ); + /// + /// assert_eq!(pfpksk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(pfpksk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(pfpksk.output_glwe_key_dimension(), output_glwe_dim); + /// assert_eq!(pfpksk.input_lwe_key_dimension(), input_lwe_dim); + /// ``` + pub fn from_container( + cont: Cont, + decomp_base_log: DecompositionBaseLog, + decomp_size: DecompositionLevelCount, + output_glwe_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + ) -> LwePrivateFunctionalPackingKeyswitchKey + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => output_glwe_dimension.to_glwe_size().0 * output_polynomial_size.0, decomp_size.0); + LwePrivateFunctionalPackingKeyswitchKey { + tensor, + decomp_base_log, + decomp_level_count: decomp_size, + output_glwe_size: output_glwe_dimension.to_glwe_size(), + output_polynomial_size, + } + } + + /// Returns the dimension of the output GLWE key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pfpksk.output_glwe_key_dimension(), GlweDimension(2)); + /// ``` + pub fn output_glwe_key_dimension(&self) -> GlweDimension { + self.output_glwe_size.to_glwe_dimension() + } + + /// Returns the size of the polynomials composing the GLWE ciphertext + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, LweSize, + /// PolynomialSize, + /// }; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pfpksk.output_polynomial_size(), PolynomialSize(256)); + /// ``` + pub fn output_polynomial_size(&self) -> PolynomialSize { + self.output_polynomial_size + } + + /// Returns the dimension of the input LWE key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pfpksk.input_lwe_key_dimension(), LweDimension(10)); + /// ``` + pub fn input_lwe_key_dimension(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension( + self.as_tensor().len() + / (self.output_glwe_size.0 + * self.output_polynomial_size.0 + * self.decomp_level_count.0) + - 1, + ) + } + + /// Returns the number of levels used for the decomposition of the input key bits. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!( + /// pfpksk.decomposition_level_count(), + /// DecompositionLevelCount(10) + /// ); + /// ``` + pub fn decomposition_level_count(&self) -> DecompositionLevelCount + where + Self: AsRefTensor, + { + self.decomp_level_count + } + + /// Returns the logarithm of the base used for the decomposition of the input key bits. + /// + /// Indeed, the basis used is always of the form $2^b$. This function returns $b$. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// GlweDimension(2), + /// PolynomialSize(256), + /// ); + /// assert_eq!(pfpksk.decomposition_base_log(), DecompositionBaseLog(16)); + /// ``` + pub fn decomposition_base_log(&self) -> DecompositionBaseLog + where + Self: AsRefTensor, + { + self.decomp_base_log + } + + /// Fills the current private functional keyswitch key container with an actual private + /// functional keyswitching key constructed from an input and an output key. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// + /// let input_size = LweDimension(10); + /// let output_size = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let cipher_size = LweSize(55); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = + /// GlweSecretKey::generate_binary(output_size, polynomial_size, &mut secret_generator); + /// + /// let mut pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u32, + /// decomp_level_count, + /// decomp_base_log, + /// input_size, + /// output_size, + /// polynomial_size, + /// ); + /// pfpksk.fill_with_private_functional_packing_keyswitch_key( + /// &input_key, + /// &output_key, + /// noise, + /// &mut encryption_generator, + /// &|x| x, + /// &Polynomial::allocate(1 as u32, output_key.polynomial_size()), + /// ); + /// + /// assert!(!pfpksk.as_tensor().iter().all(|a| *a == 0)); + /// ``` + pub fn fill_with_private_functional_packing_keyswitch_key< + InKeyCont, + OutKeyCont, + PolyCont, + Scalar, + Gen, + ScalarFunc, + >( + &mut self, + input_lwe_key: &LweSecretKey, + output_glwe_key: &GlweSecretKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + f: &ScalarFunc, + polynomial: &Polynomial, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Polynomial: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + ScalarFunc: Fn(Scalar) -> Scalar + ?Sized, + { + // We instantiate a buffer + let mut messages = PlaintextList::from_container(vec![ + ::Element::ZERO; + self.decomp_level_count.0 + * self.output_polynomial_size.0 + ]); + // We retrieve decomposition arguments + let decomp_level_count = self.decomp_level_count; + let decomp_base_log = self.decomp_base_log; + let polynomial_size = self.output_polynomial_size; + + let last_key_iter_bit = [Scalar::MAX]; + // add minus one for the function which will be applied to the decomposed body + // ( Scalar::MAX = -Scalar::ONE ) + let input_key_bit_iter = input_lwe_key + .as_tensor() + .as_slice() + .iter() + .chain(last_key_iter_bit.iter()); + + let gen_iter = generator + .fork_pfpksk_to_pfpksk_chunks::( + decomp_level_count, + output_glwe_key.key_size().to_glwe_size(), + output_glwe_key.polynomial_size(), + input_lwe_key.key_size().to_lwe_size(), + ) + .unwrap(); + + // loop over the before key blocks + for ((&input_key_bit, keyswitch_key_block), mut loop_generator) in input_key_bit_iter + .zip(self.bit_decomp_iter_mut()) + .zip(gen_iter) + { + // We reset the buffer + messages + .as_mut_tensor() + .fill_with_element(::Element::ZERO); + + // We fill the buffer with the powers of the key bits + for (level, mut message) in (1..=decomp_level_count.0) + .map(DecompositionLevel) + .zip(messages.sublist_iter_mut(PlaintextCount(polynomial_size.0))) + { + message + .as_mut_tensor() + .update_with_wrapping_add_element_mul( + polynomial.as_tensor(), + DecompositionTerm::new( + level, + decomp_base_log, + f(Scalar::ONE).wrapping_mul(input_key_bit), + ) + .to_recomposition_summand(), + ); + } + + // We encrypt the buffer + output_glwe_key.encrypt_glwe_list( + &mut keyswitch_key_block.into_glwe_list(), + &messages, + noise_parameters, + &mut loop_generator, + ); + } + } + + /// Fills the current private functional keyswitch key container with an actual private + /// functional keyswitching key constructed from an input and an output key using multiple + /// threads . + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LogStandardDev, LweDimension, + /// LweSize, PolynomialSize, + /// }; + /// + /// let input_size = LweDimension(10); + /// let output_size = GlweDimension(3); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let cipher_size = LweSize(55); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = + /// GlweSecretKey::generate_binary(output_size, polynomial_size, &mut secret_generator); + /// + /// let mut pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u32, + /// decomp_level_count, + /// decomp_base_log, + /// input_size, + /// output_size, + /// polynomial_size, + /// ); + /// pfpksk.par_fill_with_private_functional_packing_keyswitch_key( + /// &input_key, + /// &output_key, + /// noise, + /// &mut encryption_generator, + /// &|x| x, + /// &Polynomial::allocate(1 as u32, output_key.polynomial_size()), + /// ); + /// + /// assert!(!pfpksk.as_tensor().iter().all(|a| *a == 0)); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_fill_with_private_functional_packing_keyswitch_key< + InKeyCont, + OutKeyCont, + PolyCont, + Scalar, + Gen, + Noise, + ScalarFunc, + >( + &mut self, + input_lwe_key: &LweSecretKey, + output_glwe_key: &GlweSecretKey, + noise_parameters: Noise, + generator: &mut EncryptionRandomGenerator, + f: &ScalarFunc, + polynomial: &Polynomial, + ) where + Self: AsMutTensor, + ::Element: Sync + Send, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + OutKeyCont: AsRefSlice + Sync + Send, + Polynomial: AsRefTensor, + PolyCont: AsRefSlice + Sync + Send, + Scalar: UnsignedTorus + Sync + Send, + Gen: ParallelByteRandomGenerator, + Noise: DispersionParameter + Sync + Send, + ScalarFunc: Fn(Scalar) -> Scalar + Sync + Send + ?Sized, + Cont: Sync + Send, + { + // We retrieve decomposition arguments + let decomp_level_count = self.decomp_level_count; + let decomp_base_log = self.decomp_base_log; + let polynomial_size = self.output_polynomial_size; + + let last_key_iter_bit = [Scalar::MAX]; + // add minus one for the function which will be applied to the decomposed body + // ( Scalar::MAX = -Scalar::ONE ) + let input_key_bit_iter = input_lwe_key + .as_tensor() + .as_slice() + .par_iter() + .chain(last_key_iter_bit.par_iter()); + + let gen_iter = generator + .par_fork_pfpksk_to_pfpksk_chunks::( + decomp_level_count, + output_glwe_key.key_size().to_glwe_size(), + output_glwe_key.polynomial_size(), + input_lwe_key.key_size().to_lwe_size(), + ) + .unwrap(); + + // loop over the before key blocks + input_key_bit_iter + .zip(self.par_bit_decomp_iter_mut()) + .zip(gen_iter) + .for_each( + move |((&input_key_bit, keyswitch_key_block), mut loop_generator)| { + // We instantiate a buffer + let mut messages = PlaintextList::from_container(vec![ + ::Element::ZERO; + decomp_level_count.0 + * polynomial_size.0 + ]); + + // We fill the buffer with the powers of the key bits + for (level, mut message) in (1..=decomp_level_count.0) + .map(DecompositionLevel) + .zip(messages.sublist_iter_mut(PlaintextCount(polynomial_size.0))) + { + message + .as_mut_tensor() + .update_with_wrapping_add_element_mul( + polynomial.as_tensor(), + DecompositionTerm::new( + level, + decomp_base_log, + f(Scalar::ONE).wrapping_mul(input_key_bit), + ) + .to_recomposition_summand(), + ); + } + + // We encrypt the buffer + output_glwe_key.encrypt_glwe_list( + &mut keyswitch_key_block.into_glwe_list(), + &messages, + noise_parameters, + &mut loop_generator, + ); + }, + ) + } + + /// Iterates over borrowed `LweKeyBitDecomposition` elements. + /// + /// One `LweKeyBitDecomposition` being a set of LWE ciphertexts, encrypting under the output + /// key, the $l$ levels of the signed decomposition of a single bit of the input key. + /// + /// # Example + /// + /// ```ignore + /// use tfhe::core_crypto::commons::crypto::{*, glwe::LwePrivateFunctionalPackingKeyswitchKey}; + /// use tfhe::core_crypto::prelude::{DecompositionLevelCount, DecompositionBaseLog, + /// GlweDimension, LweDimension, PolynomialSize}; + /// let pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(15), + /// GlweDimension(20), + /// PolynomialSize(256) + /// ); + /// for decomp in pfpksk.bit_decomp_iter() { + /// assert_eq!(decomp.glwe_size(), pfpksk.output_glwe_size()); + /// assert_eq!(decomp.count().0, 10); + /// } + /// assert_eq!(pfpksk.bit_decomp_iter().count(), 15 + 1); + /// ``` + pub(crate) fn bit_decomp_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.output_glwe_size.0 * self.output_polynomial_size.0, self.decomp_level_count.0); + let size = + self.decomp_level_count.0 * self.output_glwe_size.0 * self.output_polynomial_size.0; + let glwe_size = self.output_glwe_size; + let poly_size = self.output_polynomial_size; + self.as_tensor().subtensor_iter(size).map(move |sub| { + LweKeyBitDecomposition::from_container(sub.into_container(), glwe_size, poly_size) + }) + } + + /// Iterates over mutably borrowed `LweKeyBitDecomposition` elements. + /// + /// One `LweKeyBitDecomposition` being a set of LWE ciphertexts, encrypting under the output + /// key, the $l$ levels of the signed decomposition of a single bit of the input key. + /// + /// # Example + /// + /// ```ignore + /// use tfhe::core_crypto::commons::crypto::{*, glwe::LwePrivateFunctionalPackingKeyswitchKey}; + /// use tfhe::core_crypto::commons::math::tensor::{AsRefTensor, AsMutTensor}; + /// use tfhe::core_crypto::prelude::{DecompositionLevelCount, DecompositionBaseLog, + /// GlweDimension, LweDimension, PolynomialSize}; + /// let mut pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(15), + /// GlweDimension(20), + /// PolynomialSize(256) + /// ); + /// for mut decomp in pfpksk.bit_decomp_iter_mut() { + /// for mut ciphertext in decomp.ciphertext_iter_mut() { + /// ciphertext.as_mut_tensor().fill_with_element(0); + /// } + /// } + /// assert!(pfpksk.as_tensor().iter().all(|a| *a == 0)); + /// assert_eq!(pfpksk.bit_decomp_iter_mut().count(), 15 + 1); + /// ``` + pub(crate) fn bit_decomp_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.output_glwe_size.0 * self.output_polynomial_size.0, self.decomp_level_count.0); + let chunks_size = + self.decomp_level_count.0 * self.output_glwe_size.0 * self.output_polynomial_size.0; + let glwe_size = self.output_glwe_size; + let poly_size = self.output_polynomial_size; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| { + LweKeyBitDecomposition::from_container(sub.into_container(), glwe_size, poly_size) + }) + } + + #[cfg(feature = "__commons_parallel")] + pub(crate) fn par_bit_decomp_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Sync + Send, + { + ck_dim_div!(self.as_tensor().len() => self.output_glwe_size.0 * self.output_polynomial_size.0, self.decomp_level_count.0); + let chunks_size = + self.decomp_level_count.0 * self.output_glwe_size.0 * self.output_polynomial_size.0; + let glwe_size = self.output_glwe_size; + let poly_size = self.output_polynomial_size; + self.as_mut_tensor() + .par_subtensor_iter_mut(chunks_size) + .map(move |sub| { + LweKeyBitDecomposition::from_container(sub.into_container(), glwe_size, poly_size) + }) + } + + /// Keyswitches a single LWE ciphertext into a GLWE using a + /// private functional packing keyswitch key + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::commons::crypto::glwe::*; + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// LweDimension, LweSize, PolynomialSize, + /// }; + /// + /// let input_lwe_dim = LweDimension(1024); + /// let output_glwe_dim = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(8); + /// let noise = LogStandardDev::from_log_standard_dev(-60.); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let input_key = LweSecretKey::generate_binary(input_lwe_dim, &mut secret_generator); + /// let output_key = + /// GlweSecretKey::generate_binary(output_glwe_dim, polynomial_size, &mut secret_generator); + /// + /// let mut pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u64, + /// decomp_level_count, + /// decomp_base_log, + /// input_lwe_dim, + /// output_glwe_dim, + /// polynomial_size, + /// ); + /// pfpksk.fill_with_private_functional_packing_keyswitch_key( + /// &input_key, + /// &output_key, + /// noise, + /// &mut encryption_generator, + /// &|x| x, + /// &Polynomial::allocate(1 as u64, polynomial_size), + /// ); + /// + /// let plaintext: Plaintext = Plaintext(5 << 60); + /// let mut ciphertext = LweCiphertext::allocate(0. as u64, input_lwe_dim.to_lwe_size()); + /// let mut switched_ciphertext = + /// GlweCiphertext::allocate(0. as u64, polynomial_size, output_glwe_dim.to_glwe_size()); + /// input_key.encrypt_lwe( + /// &mut ciphertext, + /// &plaintext, + /// noise, + /// &mut encryption_generator, + /// ); + /// + /// pfpksk.private_functional_keyswitch_ciphertext(&mut switched_ciphertext, &ciphertext); + /// + /// let mut decrypted = PlaintextList::from_container(vec![0 as u64; polynomial_size.0]); + /// output_key.decrypt_glwe(&mut decrypted, &switched_ciphertext); + /// ``` + pub fn private_functional_keyswitch_ciphertext( + &self, + after: &mut GlweCiphertext, + before: &LweCiphertext, + ) where + Self: AsRefTensor, + GlweCiphertext: AsMutTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedTorus, + { + ck_dim_eq!(self.input_lwe_key_dimension().0 => before.lwe_size().to_lwe_dimension().0 ); + ck_dim_eq!(self.output_glwe_key_dimension().0 => after.size().to_glwe_dimension().0); + + // We reset the output + after.as_mut_tensor().fill_with(|| Scalar::ZERO); + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new(self.decomp_base_log, self.decomp_level_count); + + for (block, input_lwe) in self.bit_decomp_iter().zip(before.as_tensor().iter()) { + // We decompose + let rounded = decomposer.closest_representable(*input_lwe); + let decomp = decomposer.decompose(rounded); + + // Loop over the number of levels: + // We compute the multiplication of a ciphertext from the private functional + // keyswitching key with a piece of the decomposition and subtract it to the buffer + for (level_key_cipher, decomposed) in block + .as_tensor() + .subtensor_iter(self.output_glwe_size.0 * self.output_polynomial_size.0) + .rev() + .zip(decomp) + { + after + .as_mut_tensor() + .update_with_wrapping_sub_element_mul(&level_key_cipher, decomposed.value()); + } + } + } + + /// Packs several LweCiphertext into a single GlweCiphertext + /// with a private functional keyswitch technique + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::commons::crypto::glwe::*; + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// LweDimension, LweSize, PlaintextCount, PolynomialSize, + /// }; + /// + /// let input_lwe_dim = LweDimension(1024); + /// let output_glwe_dim = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(8); + /// let noise = LogStandardDev::from_log_standard_dev(-60.); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let input_key = LweSecretKey::generate_binary(input_lwe_dim, &mut secret_generator); + /// let output_key = + /// GlweSecretKey::generate_binary(output_glwe_dim, polynomial_size, &mut secret_generator); + /// + /// let mut pfpksk = LwePrivateFunctionalPackingKeyswitchKey::allocate( + /// 0 as u64, + /// decomp_level_count, + /// decomp_base_log, + /// input_lwe_dim, + /// output_glwe_dim, + /// polynomial_size, + /// ); + /// let mut vec = vec![0u64; polynomial_size.0]; + /// vec[0] = 1; + /// + /// pfpksk.fill_with_private_functional_packing_keyswitch_key( + /// &input_key, + /// &output_key, + /// noise, + /// &mut encryption_generator, + /// &|x| x, + /// &Polynomial::from_container(vec), + /// ); + /// + /// let plaintext_list = PlaintextList::allocate(1 << 60 as u64, PlaintextCount(10)); + /// let ciphertext_list = + /// LweList::new_trivial_encryption(input_key.key_size().to_lwe_size(), &plaintext_list); + /// let mut switched_ciphertext = + /// GlweCiphertext::allocate(0 as u64, polynomial_size, output_glwe_dim.to_glwe_size()); + /// + /// pfpksk.private_functional_packing_keyswitch(&mut switched_ciphertext, &ciphertext_list); + /// + /// let mut decrypted = PlaintextList::from_container(vec![0 as u64; polynomial_size.0]); + /// output_key.decrypt_glwe(&mut decrypted, &switched_ciphertext); + /// ``` + pub fn private_functional_packing_keyswitch( + &self, + output: &mut GlweCiphertext, + input: &LweList, + ) where + Self: AsRefTensor, + LweList: AsRefTensor, + GlweCiphertext: AsMutTensor, + OutCont: Clone, + Scalar: UnsignedTorus, + { + debug_assert!(input.count().0 <= output.polynomial_size().0); + output.as_mut_tensor().fill_with_element(Scalar::ZERO); + let mut buffer = output.clone(); + // for each ciphertext, call mono_key_switch + for (degree, input_cipher) in input.ciphertext_iter().enumerate() { + self.private_functional_keyswitch_ciphertext(&mut buffer, &input_cipher); + buffer + .as_mut_polynomial_list() + .polynomial_iter_mut() + .for_each(|mut poly| { + poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(degree)) + }); + output + .as_mut_tensor() + .update_with_wrapping_add(buffer.as_tensor()); + } + } +} + +/// The encryption of a single bit of the output key. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] +pub(crate) struct LweKeyBitDecomposition { + pub(crate) tensor: Tensor, + pub(crate) glwe_size: GlweSize, + pub(crate) poly_size: PolynomialSize, +} + +tensor_traits!(LweKeyBitDecomposition); + +impl LweKeyBitDecomposition { + /// Creates a key bit decomposition from a container. + /// + /// # Notes + /// + /// This method does not decompose a key bit in a basis, but merely wraps a container in the + /// right structure. See [`LwePackingKeyswitchKey::bit_decomp_iter`] for an iterator that + /// returns key bit decompositions. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 1500], GlweSize(10), + /// PolynomialSize(10); + /// assert_eq!(kbd.count(), CiphertextCount(15)); + /// assert_eq!(kbd.glwe_size(), GlweSize(10)); + /// assert_eq!(kbd.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn from_container(cont: Cont, glwe_size: GlweSize, poly_size: PolynomialSize) -> Self + where + Tensor: AsRefSlice, + { + LweKeyBitDecomposition { + tensor: Tensor::from_container(cont), + glwe_size, + poly_size, + } + } + + /// Returns the size of the GLWE ciphertexts encoding each level of the key bit decomposition. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// assert_eq!(kbd.lwe_size(), LweSize(10)); + /// ``` + #[allow(dead_code)] + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + /// Returns the size of the lwe ciphertexts encoding each level of the key bit decomposition. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// assert_eq!(kbd.lwe_size(), LweSize(10)); + /// ``` + #[allow(dead_code)] + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns the number of ciphertexts in the decomposition. + /// + /// Note that this is actually equals to the number of levels in the decomposition. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// assert_eq!(kbd.count(), CiphertextCount(15)); + /// ``` + #[allow(dead_code)] + pub fn count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.glwe_size.0 * self.poly_size.0); + CiphertextCount(self.as_tensor().len() / (self.glwe_size.0 * self.poly_size.0)) + } + + /// Returns an iterator over borrowed `GlweCiphertext`. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// for ciphertext in kbd.ciphertext_iter(){ + /// assert_eq!(ciphertext.lwe_size(), LweSize(10)); + /// } + /// assert_eq!(kbd.ciphertext_iter().count(), 15); + /// ``` + #[allow(dead_code)] + pub fn ciphertext_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(self.glwe_size.0 * self.poly_size.0) + .map(move |sub| GlweCiphertext::from_container(sub.into_container(), self.poly_size)) + } + + /// Returns an iterator over mutably borrowed `GlweCiphertext`. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// use tfhe::core_crypto::backends::default::private::math::tensor::{AsRefTensor, AsMutTensor}; + /// let mut kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// for mut ciphertext in kbd.ciphertext_iter_mut(){ + /// ciphertext.as_mut_tensor().fill_with_element(9); + /// } + /// assert!(kbd.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(kbd.ciphertext_iter().count(), 15); + /// ``` + #[allow(dead_code)] + pub fn ciphertext_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.glwe_size.0 * self.poly_size.0; + let poly_size = self.poly_size; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| GlweCiphertext::from_container(sub.into_container(), poly_size)) + } + + /// Consumes the current key bit decomposition and returns a GLWE. + /// + /// Note that this operation is super cheap, as it merely rewraps the current container in + /// a GLWE structure. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, glwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// let glwe = kbd.into_glwe_list(); + /// assert_eq!(list.count(), CiphertextCount(15)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// ``` + pub fn into_glwe_list(self) -> GlweList { + GlweList { + tensor: self.tensor, + rlwe_size: self.glwe_size, + poly_size: self.poly_size, + } + } +} + +/// A private functional packing keyswitching key list. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LwePrivateFunctionalPackingKeyswitchKeyList { + tensor: Tensor, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_lwe_size: LweSize, + output_glwe_size: GlweSize, + output_polynomial_size: PolynomialSize, +} + +tensor_traits!(LwePrivateFunctionalPackingKeyswitchKeyList); + +impl LwePrivateFunctionalPackingKeyswitchKeyList> +where + Scalar: Copy, +{ + /// Allocates storage for an owned [`LwePrivateFunctionalPackingKeyswitchKeyList`]. + /// + /// # Note + /// + /// This function does *not* generate a private functional packing keyswitch key list, but + /// merely allocates a container of the right size. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + /// GlweDimension, GlweSize, LweDimension, LweSize, PolynomialSize, + /// }; + /// let input_lwe_dim = LweDimension(200); + /// let output_glwe_dim = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(7); + /// let decomp_level_count = DecompositionLevelCount(4); + /// let fpksk_count = FunctionalPackingKeyswitchKeyCount(3); + /// + /// let pfpksk_list = LwePrivateFunctionalPackingKeyswitchKeyList::allocate( + /// 0u8, + /// decomp_level_count, + /// decomp_base_log, + /// input_lwe_dim, + /// output_glwe_dim, + /// polynomial_size, + /// fpksk_count, + /// ); + /// + /// assert_eq!(pfpksk_list.decomposition_level_count(), decomp_level_count); + /// assert_eq!(pfpksk_list.decomposition_base_log(), decomp_base_log); + /// assert_eq!(pfpksk_list.output_glwe_key_dimension(), output_glwe_dim); + /// assert_eq!(pfpksk_list.input_lwe_key_dimension(), input_lwe_dim); + /// assert_eq!(pfpksk_list.fpksk_count(), fpksk_count); + /// ``` + pub fn allocate( + value: Scalar, + decomp_size: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + input_dimension: LweDimension, + output_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + fpksk_count: FunctionalPackingKeyswitchKeyCount, + ) -> Self { + LwePrivateFunctionalPackingKeyswitchKeyList { + tensor: Tensor::from_container(vec![ + value; + decomp_size.0 + * output_dimension.to_glwe_size().0 + * output_polynomial_size.0 + * input_dimension.to_lwe_size().0 + * fpksk_count.0 + ]), + decomp_base_log, + decomp_level_count: decomp_size, + input_lwe_size: input_dimension.to_lwe_size(), + output_glwe_size: output_dimension.to_glwe_size(), + output_polynomial_size, + } + } +} + +impl LwePrivateFunctionalPackingKeyswitchKeyList { + /// Creates a list from a container of values. + /// + /// # Notes + /// + /// This method does not create a private functional packing keyswitch key list, but merely + /// wraps the container in the proper type. It assumes that either the container already + /// contains a proper functional keyswitching key list, or that it will be filled right after. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + /// GlweDimension, GlweSize, LweDimension, LweSize, PolynomialSize, + /// }; + /// let input_lwe_dim = LweDimension(200); + /// let output_glwe_dim = GlweDimension(2); + /// let polynomial_size = PolynomialSize(256); + /// let decomp_base_log = DecompositionBaseLog(7); + /// let decomp_level_count = DecompositionLevelCount(4); + /// let fpksk_count = FunctionalPackingKeyswitchKeyCount(3); + /// + /// let pfpksk_list = LwePrivateFunctionalPackingKeyswitchKeyList::from_container( + /// vec![ + /// 0 as u8; + /// input_lwe_dim.to_lwe_size().0 + /// * output_glwe_dim.to_glwe_size().0 + /// * polynomial_size.0 + /// * decomp_level_count.0 + /// * fpksk_count.0 + /// ], + /// decomp_base_log, + /// decomp_level_count, + /// input_lwe_dim, + /// output_glwe_dim, + /// polynomial_size, + /// fpksk_count, + /// ); + /// + /// assert_eq!(pfpksk_list.decomposition_level_count(), decomp_level_count); + /// assert_eq!(pfpksk_list.decomposition_base_log(), decomp_base_log); + /// assert_eq!(pfpksk_list.output_glwe_key_dimension(), output_glwe_dim); + /// assert_eq!(pfpksk_list.input_lwe_key_dimension(), input_lwe_dim); + /// assert_eq!(pfpksk_list.fpksk_count(), fpksk_count); + /// ``` + pub fn from_container( + cont: Cont, + decomp_base_log: DecompositionBaseLog, + decomp_size: DecompositionLevelCount, + input_dimension: LweDimension, + output_glwe_dimension: GlweDimension, + output_polynomial_size: PolynomialSize, + fpksk_count: FunctionalPackingKeyswitchKeyCount, + ) -> LwePrivateFunctionalPackingKeyswitchKeyList + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => + output_glwe_dimension.to_glwe_size().0 * output_polynomial_size.0, + decomp_size.0, + input_dimension.to_lwe_size().0, + fpksk_count.0); + LwePrivateFunctionalPackingKeyswitchKeyList { + tensor, + decomp_base_log, + decomp_level_count: decomp_size, + input_lwe_size: input_dimension.to_lwe_size(), + output_glwe_size: output_glwe_dimension.to_glwe_size(), + output_polynomial_size, + } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> LwePrivateFunctionalPackingKeyswitchKeyList<&'_ [Cont::Element]> + where + Cont: Container, + { + LwePrivateFunctionalPackingKeyswitchKeyList { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + decomp_base_log: self.decomp_base_log, + decomp_level_count: self.decomp_level_count, + input_lwe_size: self.input_lwe_size, + output_glwe_size: self.output_glwe_size, + output_polynomial_size: self.output_polynomial_size, + } + } + + pub fn as_mut_view( + &mut self, + ) -> LwePrivateFunctionalPackingKeyswitchKeyList<&'_ mut [Cont::Element]> + where + Cont: Container + AsMut<[Cont::Element]>, + { + LwePrivateFunctionalPackingKeyswitchKeyList { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + decomp_base_log: self.decomp_base_log, + decomp_level_count: self.decomp_level_count, + input_lwe_size: self.input_lwe_size, + output_glwe_size: self.output_glwe_size, + output_polynomial_size: self.output_polynomial_size, + } + } + + /// Returns the dimension of the output GLWE key. + pub fn output_glwe_key_dimension(&self) -> GlweDimension { + self.output_glwe_size.to_glwe_dimension() + } + + /// Returns the size of the polynomials composing the GLWE ciphertext + pub fn output_polynomial_size(&self) -> PolynomialSize { + self.output_polynomial_size + } + + /// Returns the dimension of the input LWE key. + pub fn input_lwe_key_dimension(&self) -> LweDimension { + self.input_lwe_size.to_lwe_dimension() + } + + /// Returns the number of levels used for the decomposition of the input key bits. + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomp_level_count + } + + /// Returns the logarithm of the base used for the decomposition of the input key bits. + /// + /// Indeed, the basis used is always of the form $2^b$. This function returns $b$. + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomp_base_log + } + + /// Returns the number of private functional packing keyswitch key in the list. + pub fn fpksk_count(&self) -> FunctionalPackingKeyswitchKeyCount + where + Self: AsRefTensor, + { + let single_ksk_size = self.output_glwe_size.0 + * self.output_polynomial_size.0 + * self.decomp_level_count.0 + * self.input_lwe_size.0; + ck_dim_div!(self.as_tensor().len() => single_ksk_size); + FunctionalPackingKeyswitchKeyCount(self.as_tensor().len() / single_ksk_size) + } + + /// Returns an iterator over keys borrowed from the list. + pub fn fpksk_iter( + &self, + ) -> impl DoubleEndedIterator< + Item = LwePrivateFunctionalPackingKeyswitchKey<&[::Element]>, + > + where + Self: AsRefTensor, + { + let single_ksk_size = self.output_glwe_size.0 + * self.output_polynomial_size.0 + * self.decomp_level_count.0 + * self.input_lwe_size.0; + ck_dim_div!(self.as_tensor().len() => single_ksk_size); + self.as_tensor() + .subtensor_iter(single_ksk_size) + .map(move |sub| { + LwePrivateFunctionalPackingKeyswitchKey::from_container( + sub.into_container(), + self.decomposition_base_log(), + self.decomposition_level_count(), + self.output_glwe_key_dimension(), + self.output_polynomial_size(), + ) + }) + } + + /// Returns an iterator over keys borrowed from the list. + pub fn fpksk_iter_mut( + &mut self, + ) -> impl DoubleEndedIterator< + Item = LwePrivateFunctionalPackingKeyswitchKey<&mut [::Element]>, + > + where + Self: AsMutTensor, + { + let single_ksk_size = self.output_glwe_size.0 + * self.output_polynomial_size.0 + * self.decomp_level_count.0 + * self.input_lwe_size.0; + ck_dim_div!(self.as_mut_tensor().len() => single_ksk_size); + + let decomposition_base_log = self.decomposition_base_log(); + let decomposition_level_count = self.decomposition_level_count(); + let output_glwe_key_dimension = self.output_glwe_key_dimension(); + let output_polynomial_size = self.output_polynomial_size(); + + self.as_mut_tensor() + .subtensor_iter_mut(single_ksk_size) + .map(move |sub| { + LwePrivateFunctionalPackingKeyswitchKey::from_container( + sub.into_container(), + decomposition_base_log, + decomposition_level_count, + output_glwe_key_dimension, + output_polynomial_size, + ) + }) + } + + /// Returns an iterator over keys borrowed from the list. + #[cfg(feature = "__commons_parallel")] + pub fn par_fpksk_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator< + Item = LwePrivateFunctionalPackingKeyswitchKey<&mut [::Element]>, + > + where + Self: AsMutTensor, + ::Element: Sync + Send, + { + let single_ksk_size = self.output_glwe_size.0 + * self.output_polynomial_size.0 + * self.decomp_level_count.0 + * self.input_lwe_size.0; + ck_dim_div!(self.as_mut_tensor().len() => single_ksk_size); + + let decomposition_base_log = self.decomposition_base_log(); + let decomposition_level_count = self.decomposition_level_count(); + let output_glwe_key_dimension = self.output_glwe_key_dimension(); + let output_polynomial_size = self.output_polynomial_size(); + + self.as_mut_tensor() + .par_subtensor_iter_mut(single_ksk_size) + .map(move |sub| { + LwePrivateFunctionalPackingKeyswitchKey::from_container( + sub.into_container(), + decomposition_base_log, + decomposition_level_count, + output_glwe_key_dimension, + output_polynomial_size, + ) + }) + } + + pub fn fill_with_fpksk_for_circuit_bootstrap( + &mut self, + input_lwe_key: &LweSecretKey, + output_glwe_key: &GlweSecretKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Scalar: UnsignedTorus, + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + Gen: ByteRandomGenerator, + { + debug_assert!( + self.fpksk_count().0 == output_glwe_key.key_size().to_glwe_size().0, + "Current list has {} fpksk, need to have {} \ + (encrypted_glwe_key.key_size().to_glwe_size())", + self.fpksk_count().0, + output_glwe_key.key_size().to_glwe_size().0 + ); + + let decomp_level_count = self.decomp_level_count; + + let gen_iter = generator + .fork_cbs_pfpksk_to_pfpksk::( + decomp_level_count, + output_glwe_key.key_size().to_glwe_size(), + output_glwe_key.polynomial_size(), + input_lwe_key.key_size().to_lwe_size(), + self.fpksk_count(), + ) + .unwrap(); + + let mut last_polynomial_as_list = PolynomialList::allocate( + Scalar::ZERO, + PolynomialCount(1), + output_glwe_key.polynomial_size(), + ); + // We apply the x -> -x function so instead of putting one in the first coeff of the + // polynomial, we put Scalar::MAX == - Sclar::One so that we can use a single function in + // the loop avoiding branching + *last_polynomial_as_list + .get_mut_polynomial(0) + .get_mut_monomial(MonomialDegree(0)) + .get_mut_coefficient() = Scalar::MAX; + + for ((mut fpksk, polynomial_to_encrypt), mut loop_generator) in self + .fpksk_iter_mut() + .zip( + output_glwe_key + .as_polynomial_list() + .polynomial_iter() + .chain(last_polynomial_as_list.polynomial_iter()), + ) + .zip(gen_iter) + { + fpksk.fill_with_private_functional_packing_keyswitch_key( + input_lwe_key, + output_glwe_key, + noise_parameters, + &mut loop_generator, + &|x| Scalar::ZERO.wrapping_sub(x), + &polynomial_to_encrypt, + ); + } + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_fill_with_fpksk_for_circuit_bootstrap( + &mut self, + input_lwe_key: &LweSecretKey, + output_glwe_key: &GlweSecretKey, + noise_parameters: Noise, + generator: &mut EncryptionRandomGenerator, + ) where + Scalar: UnsignedTorus + Sync + Send, + Self: AsMutTensor, + ::Element: Sync + Send, + LweSecretKey: AsRefTensor, + GlweSecretKey: AsRefTensor, + C1: AsRefSlice + Sync + Send, + C2: AsRefSlice + Sync + Send, + Gen: ParallelByteRandomGenerator, + Noise: DispersionParameter + Sync + Send, + { + debug_assert!( + self.fpksk_count().0 == output_glwe_key.key_size().to_glwe_size().0, + "Current list has {} fpksk, need to have {} \ + (encrypted_glwe_key.key_size().to_glwe_size())", + self.fpksk_count().0, + output_glwe_key.key_size().to_glwe_size().0 + ); + + let decomp_level_count = self.decomp_level_count; + + let gen_iter = generator + .par_fork_cbs_pfpksk_to_pfpksk::( + decomp_level_count, + output_glwe_key.key_size().to_glwe_size(), + output_glwe_key.polynomial_size(), + input_lwe_key.key_size().to_lwe_size(), + self.fpksk_count(), + ) + .unwrap(); + + let mut last_polynomial_as_list = PolynomialList::allocate( + Scalar::ZERO, + PolynomialCount(1), + output_glwe_key.polynomial_size(), + ); + // We apply the x -> -x function so instead of putting one in the first coeff of the + // polynomial, we put Scalar::MAX == - Sclar::One so that we can use a single function + // in the loop avoiding branching + *last_polynomial_as_list + .get_mut_polynomial(0) + .get_mut_monomial(MonomialDegree(0)) + .get_mut_coefficient() = Scalar::MAX; + + self.par_fpksk_iter_mut() + .zip( + output_glwe_key + .as_polynomial_list() + .par_polynomial_iter() + .chain(last_polynomial_as_list.par_polynomial_iter()), + ) + .zip(gen_iter) + .for_each(|((mut fpksk, polynomial_to_encrypt), mut loop_generator)| { + fpksk.par_fill_with_private_functional_packing_keyswitch_key( + input_lwe_key, + output_glwe_key, + noise_parameters, + &mut loop_generator, + &|x| Scalar::ZERO.wrapping_sub(x), + &polynomial_to_encrypt, + ); + }) + } +} + +#[cfg(feature = "__commons_parallel")] +#[cfg(test)] +mod test { + use crate::core_crypto::commons::crypto::glwe::LwePrivateFunctionalPackingKeyswitchKeyList; + use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator, + }; + use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + GlweDimension, LogStandardDev, LweDimension, PolynomialSize, + }; + use concrete_csprng::generators::SoftwareRandomGenerator; + use concrete_csprng::seeders::Seed; + + #[test] + fn check_equivalence_serial_parallel_pfpksk_gen() { + let input_lwe_dimension = LweDimension(10); + let output_glwe_dimension = GlweDimension(3); + let polynomial_size = PolynomialSize(256); + let decomp_base_log = DecompositionBaseLog(3); + let decomp_level_count = DecompositionLevelCount(5); + let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + let noise = LogStandardDev::from_log_standard_dev(-15.); + + let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + let deterministic_seeder_seed = + Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + let number_of_runs = 10_usize; + + for _ in 0..number_of_runs { + let input_key = + LweSecretKey::generate_binary(input_lwe_dimension, &mut secret_generator); + let output_key = GlweSecretKey::generate_binary( + output_glwe_dimension, + polynomial_size, + &mut secret_generator, + ); + + let mut encryption_generator = + EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new( + deterministic_seeder_seed, + ), + ); + + let mut pfpksk_serial = LwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u32, + decomp_level_count, + decomp_base_log, + input_lwe_dimension, + output_glwe_dimension, + polynomial_size, + FunctionalPackingKeyswitchKeyCount(output_glwe_dimension.to_glwe_size().0), + ); + pfpksk_serial.fill_with_fpksk_for_circuit_bootstrap( + &input_key, + &output_key, + noise, + &mut encryption_generator, + ); + + let mut encryption_generator = + EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new( + deterministic_seeder_seed, + ), + ); + let mut pfpksk_par = LwePrivateFunctionalPackingKeyswitchKeyList::allocate( + 0u32, + decomp_level_count, + decomp_base_log, + input_lwe_dimension, + output_glwe_dimension, + polynomial_size, + FunctionalPackingKeyswitchKeyCount(output_glwe_dimension.to_glwe_size().0), + ); + pfpksk_par.par_fill_with_fpksk_for_circuit_bootstrap( + &input_key, + &output_key, + noise, + &mut encryption_generator, + ); + + assert_eq!(pfpksk_par, pfpksk_serial); + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/list.rs b/tfhe/src/core_crypto/commons/crypto/glwe/list.rs new file mode 100644 index 000000000..cdb312bac --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/list.rs @@ -0,0 +1,277 @@ +use super::GlweCiphertext; +use crate::core_crypto::commons::crypto::encoding::PlaintextList; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{ + CiphertextCount, GlweDimension, GlweSize, PlaintextCount, PolynomialSize, +}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A list of ciphertexts encoded with the GLWE scheme. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweList { + pub(crate) tensor: Tensor, + pub(crate) rlwe_size: GlweSize, + pub(crate) poly_size: PolynomialSize, +} + +tensor_traits!(GlweList); + +impl GlweList> +where + Scalar: Copy, +{ + /// Allocates storage for an owned [`GlweList`]. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweList; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// let list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// assert_eq!(list.ciphertext_count(), CiphertextCount(30)); + /// assert_eq!(list.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(list.glwe_size(), GlweSize(21)); + /// assert_eq!(list.glwe_dimension(), GlweDimension(20)); + /// ``` + pub fn allocate( + value: Scalar, + poly_size: PolynomialSize, + glwe_dimension: GlweDimension, + ciphertext_number: CiphertextCount, + ) -> Self { + GlweList { + tensor: Tensor::from_container(vec![ + value; + poly_size.0 + * (glwe_dimension.0 + 1) + * ciphertext_number.0 + ]), + rlwe_size: GlweSize(glwe_dimension.0 + 1), + poly_size, + } + } +} + +impl GlweList { + /// Creates a list from a container of values. + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweList; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// let list = GlweList::from_container( + /// vec![0 as u8; 10 * 21 * 30], + /// GlweDimension(20), + /// PolynomialSize(10), + /// ); + /// assert_eq!(list.ciphertext_count(), CiphertextCount(30)); + /// assert_eq!(list.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(list.glwe_size(), GlweSize(21)); + /// assert_eq!(list.glwe_dimension(), GlweDimension(20)); + /// ``` + pub fn from_container( + cont: Cont, + rlwe_dimension: GlweDimension, + poly_size: PolynomialSize, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => rlwe_dimension.0 + 1, poly_size.0); + GlweList { + tensor, + rlwe_size: GlweSize(rlwe_dimension.0 + 1), + poly_size, + } + } + + /// Returns the number of ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweList; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, PolynomialSize}; + /// let list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// assert_eq!(list.ciphertext_count(), CiphertextCount(30)); + /// ``` + pub fn ciphertext_count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.rlwe_size.0, self.poly_size.0); + CiphertextCount(self.as_tensor().len() / (self.rlwe_size.0 * self.polynomial_size().0)) + } + + /// Returns the size of the glwe ciphertexts contained in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweList; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// let list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// assert_eq!(list.glwe_size(), GlweSize(21)); + /// ``` + pub fn glwe_size(&self) -> GlweSize + where + Self: AsRefTensor, + { + self.rlwe_size + } + + /// Returns the number of coefficients of the polynomials used for the list ciphertexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweList; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, PolynomialSize}; + /// let list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// assert_eq!(list.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns the number of masks of the ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweList; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, PolynomialSize}; + /// let list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// assert_eq!(list.glwe_dimension(), GlweDimension(20)); + /// ``` + pub fn glwe_dimension(&self) -> GlweDimension + where + Self: AsRefTensor, + { + GlweDimension(self.rlwe_size.0 - 1) + } + + /// Returns an iterator over ciphertexts borrowed from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweBody, GlweList}; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, PolynomialSize}; + /// let list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// for ciphertext in list.ciphertext_iter() { + /// let (body, masks) = ciphertext.get_body_and_mask(); + /// assert_eq!(body.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// } + /// assert_eq!(list.ciphertext_iter().count(), 30); + /// ``` + pub fn ciphertext_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.rlwe_size.0, self.poly_size.0); + let poly_size = self.poly_size; + let size = self.rlwe_size.0 * self.polynomial_size().0; + self.as_tensor() + .subtensor_iter(size) + .map(move |sub| GlweCiphertext::from_container(sub.into_container(), poly_size)) + } + + /// Returns an iterator over ciphertexts borrowed from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweBody, GlweList}; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, PolynomialSize}; + /// let mut list = GlweList::allocate( + /// 0 as u8, + /// PolynomialSize(10), + /// GlweDimension(20), + /// CiphertextCount(30), + /// ); + /// for mut ciphertext in list.ciphertext_iter_mut() { + /// let mut body = ciphertext.get_mut_body(); + /// body.as_mut_tensor().fill_with_element(9); + /// } + /// for ciphertext in list.ciphertext_iter() { + /// let body = ciphertext.get_body(); + /// assert!(body.as_tensor().iter().all(|a| *a == 9)); + /// } + /// assert_eq!(list.ciphertext_iter_mut().count(), 30); + /// ``` + pub fn ciphertext_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.rlwe_size.0, self.poly_size.0); + let poly_size = self.poly_size; + let chunks_size = self.rlwe_size.0 * self.polynomial_size().0; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| GlweCiphertext::from_container(sub.into_container(), poly_size)) + } + + pub fn fill_with_trivial_encryption( + &mut self, + plaintexts: &PlaintextList, + ) where + PlaintextList: AsRefTensor, + for<'a> PlaintextList<&'a [Scalar]>: AsRefTensor, + Self: AsMutTensor, + Scalar: Numeric, + { + debug_assert_eq!( + plaintexts.count().0, + self.poly_size.0 * self.ciphertext_count().0 + ); + let plaintext_count = PlaintextCount(self.poly_size.0); + for (mut ciphertext, plaintext) in self + .ciphertext_iter_mut() + .zip(plaintexts.sublist_iter(plaintext_count)) + { + ciphertext.fill_with_trivial_encryption(&plaintext); + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/mask.rs b/tfhe/src/core_crypto/commons/crypto/glwe/mask.rs new file mode 100644 index 000000000..b8f0ed84f --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/mask.rs @@ -0,0 +1,173 @@ +use crate::core_crypto::commons::math::polynomial::{Polynomial, PolynomialList}; +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::prelude::PolynomialSize; + +/// The mask of a GLWE ciphertext +pub struct GlweMask { + pub(crate) tensor: Tensor, + pub(crate) poly_size: PolynomialSize, +} + +tensor_traits!(GlweMask); + +impl GlweMask { + /// Returns an iterator over borrowed mask elements contained in the mask. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let rlwe_ciphertext = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// for mask in rlwe_ciphertext.get_mask().mask_element_iter() { + /// assert_eq!(mask.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// } + /// assert_eq!(rlwe_ciphertext.get_mask().mask_element_iter().count(), 99); + /// ``` + pub fn mask_element_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(self.poly_size.0) + .map(|sub| GlweMaskElement::from_container(sub.into_container())) + } + + /// Returns an iterator over mutably borrowed mask elements contained in the mask. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialSize}; + /// let mut rlwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// for mut mask in rlwe.get_mut_mask().mask_element_iter_mut() { + /// mask.as_mut_tensor().fill_with_element(9); + /// } + /// assert!(rlwe.get_mask().as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(rlwe.get_mask().mask_element_iter().count(), 99); + /// ``` + pub fn mask_element_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.poly_size.0; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(|sub| GlweMaskElement::from_container(sub.into_container())) + } + + /// Returns a borrowed polynomial list from the current mask. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialCount, PolynomialSize}; + /// let rlwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let masks = rlwe.get_mask(); + /// let list = masks.as_polynomial_list(); + /// assert_eq!(list.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(list.polynomial_count(), PolynomialCount(99)); + /// ``` + pub fn as_polynomial_list(&self) -> PolynomialList<&[::Element]> + where + Self: AsRefTensor, + { + PolynomialList::from_container(self.as_tensor().as_slice(), self.poly_size) + } + + /// Returns a mutably borrowed polynomial list from the current mask list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweSize, PolynomialCount, PolynomialSize}; + /// let mut rlwe = GlweCiphertext::allocate(0 as u8, PolynomialSize(10), GlweSize(100)); + /// let mut masks = rlwe.get_mut_mask(); + /// let mut tensor = masks.as_mut_polynomial_list(); + /// assert_eq!(tensor.polynomial_size(), PolynomialSize(10)); + /// assert_eq!(tensor.polynomial_count(), PolynomialCount(99)); + /// ``` + pub fn as_mut_polynomial_list( + &mut self, + ) -> PolynomialList<&mut [::Element]> + where + Self: AsMutTensor, + { + let poly_size = self.poly_size; + PolynomialList::from_container(self.as_mut_tensor().as_mut_slice(), poly_size) + } +} + +/// A mask of an GLWE ciphertext. +pub struct GlweMaskElement { + tensor: Tensor, +} + +tensor_traits!(GlweMaskElement); + +impl GlweMaskElement { + /// Creates a mask element from a container. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweMaskElement; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mask = GlweMaskElement::from_container(vec![0 as u8; 10]); + /// assert_eq!(mask.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn from_container(cont: Container) -> GlweMaskElement { + GlweMaskElement { + tensor: Tensor::from_container(cont), + } + } + + /// Returns a borrowed polynomial from the current mask element. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweMaskElement; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mask = GlweMaskElement::from_container(vec![0 as u8; 10]); + /// assert_eq!(mask.as_polynomial().polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn as_polynomial(&self) -> Polynomial<&[::Element]> + where + Self: AsRefTensor, + { + Polynomial::from_container(self.as_tensor().as_slice()) + } + + /// Returns a mutably borrowed polynomial from the current mask element. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweMaskElement; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// let mut mask = GlweMaskElement::from_container(vec![0 as u8; 10]); + /// mask.as_mut_polynomial() + /// .as_mut_tensor() + /// .fill_with_element(9); + /// assert!(mask.as_tensor().iter().all(|a| *a == 9)); + /// ``` + pub fn as_mut_polynomial(&mut self) -> Polynomial<&mut [::Element]> + where + Self: AsMutTensor, + { + Polynomial::from_container(self.as_mut_tensor().as_mut_slice()) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/mod.rs b/tfhe/src/core_crypto/commons/crypto/glwe/mod.rs new file mode 100644 index 000000000..0dc4c2079 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/mod.rs @@ -0,0 +1,17 @@ +//! GLWE encryption scheme + +mod body; +mod ciphertext; +mod keyswitch; +mod list; +mod mask; +mod seeded_ciphertext; +mod seeded_list; + +pub use body::*; +pub use ciphertext::*; +pub use keyswitch::*; +pub use list::*; +pub use mask::*; +pub use seeded_ciphertext::*; +pub use seeded_list::*; diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/seeded_ciphertext.rs b/tfhe/src/core_crypto/commons/crypto/glwe/seeded_ciphertext.rs new file mode 100644 index 000000000..f59257b81 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/seeded_ciphertext.rs @@ -0,0 +1,367 @@ +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Uniform, +}; +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, IntoTensor, Tensor, +}; + +use super::{GlweBody, GlweCiphertext}; + +/// An GLWE seeded ciphertext. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweSeededCiphertext { + tensor: Tensor, + glwe_dimension: GlweDimension, + compression_seed: CompressionSeed, +} + +tensor_traits!(GlweSeededCiphertext); + +impl GlweSeededCiphertext> { + /// Allocates a new GLWE seeded ciphertext, whose body coefficients are all 0. The underlying + /// container has a size of `poly_size`. This seeded version of the GLWE ciphertext stores the + /// coefficients of the body. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(glwe_seeded_ciphertext.polynomial_size(), polynomial_size); + /// assert_eq!(glwe_seeded_ciphertext.mask_size(), glwe_dimension); + /// assert_eq!(glwe_seeded_ciphertext.compression_seed(), compression_seed); + /// assert_eq!(glwe_seeded_ciphertext.size(), glwe_dimension.to_glwe_size()); + /// ``` + pub fn allocate( + poly_size: PolynomialSize, + dimension: GlweDimension, + compression_seed: CompressionSeed, + ) -> Self + where + Self: AsMutTensor, + Scalar: Numeric, + { + Self { + tensor: Tensor::from_container(vec![Scalar::ZERO; poly_size.0]), + glwe_dimension: dimension, + compression_seed, + } + } +} + +impl GlweSeededCiphertext { + /// Creates a new GLWE seeded ciphertext from an existing container. + /// + /// # Note + /// + /// This method does not perform any transformation of the container data. Those are assumed to + /// represent a valid glwe body. + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweBody, GlweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor, Tensor}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// let tensor_container = vec![0u8; polynomial_size.0]; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::from_container( + /// tensor_container, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(glwe_seeded_ciphertext.polynomial_size(), polynomial_size); + /// assert_eq!(glwe_seeded_ciphertext.mask_size(), glwe_dimension); + /// assert_eq!(glwe_seeded_ciphertext.compression_seed(), compression_seed); + /// assert_eq!(glwe_seeded_ciphertext.size(), glwe_dimension.to_glwe_size()); + /// ``` + pub fn from_container( + cont: Cont, + dimension: GlweDimension, + compression_seed: CompressionSeed, + ) -> Self { + Self { + tensor: Tensor::from_container(cont), + glwe_dimension: dimension, + compression_seed, + } + } + + /// Returns the size of the ciphertext, i.e. the number of masks + 1. + /// + /// # Example + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(glwe_seeded_ciphertext.size(), glwe_dimension.to_glwe_size()); + /// ``` + pub fn size(&self) -> GlweSize { + self.glwe_dimension.to_glwe_size() + } + + /// Returns the number of masks of the ciphertext, i.e. its size - 1. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(glwe_seeded_ciphertext.mask_size(), glwe_dimension); + /// ``` + pub fn mask_size(&self) -> GlweDimension { + self.glwe_dimension + } + + /// Returns the number of coefficients of the polynomials of the ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(glwe_seeded_ciphertext.polynomial_size(), polynomial_size); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize + where + Self: AsRefTensor, + { + PolynomialSize(self.as_tensor().len()) + } + + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(glwe_seeded_ciphertext.compression_seed(), compression_seed); + /// ``` + pub fn compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// Returns a borrowed [`GlweBody`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweBody, GlweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor, Tensor}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// let tensor_container = vec![0u8; polynomial_size.0]; + /// + /// assert_eq!( + /// glwe_seeded_ciphertext.get_body().as_tensor().as_slice(), + /// &tensor_container[..] + /// ); + /// ``` + pub fn get_body(&self) -> GlweBody<&[::Element]> + where + Self: AsRefTensor, + { + GlweBody { + tensor: self.as_tensor().get_sub(0..), + } + } + + /// Returns a mutably borrowed [`GlweBody`] from the current ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweBody, GlweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{ + /// AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, + /// }; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(99); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut glwe_seeded_ciphertext = GlweSeededCiphertext::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// compression_seed, + /// ); + /// + /// let mut tensor_container = vec![0u8; polynomial_size.0]; + /// + /// assert_eq!( + /// glwe_seeded_ciphertext.get_mut_body().as_tensor().as_slice(), + /// &tensor_container[..] + /// ); + /// + /// glwe_seeded_ciphertext + /// .get_mut_body() + /// .as_mut_tensor() + /// .as_mut_slice()[0] = 1; + /// + /// tensor_container[0] = 1; + /// + /// assert_eq!( + /// glwe_seeded_ciphertext.get_mut_body().as_tensor().as_slice(), + /// &tensor_container[..] + /// ); + /// ``` + pub fn get_mut_body(&mut self) -> GlweBody<&mut [::Element]> + where + Self: AsMutTensor, + { + GlweBody { + tensor: self.as_mut_tensor().get_sub_mut(0..), + } + } + + pub fn expand_into_with_existing_generator( + self, + output: &mut GlweCiphertext, + generator: &mut RandomGenerator, + ) where + Scalar: Copy + RandomGenerable + Numeric, + GlweCiphertext: AsMutTensor, + Self: IntoTensor + AsRefTensor, + Gen: ByteRandomGenerator, + { + let (mut output_body, mut output_mask) = output.get_mut_body_and_mask(); + + // generate a uniformly random mask + generator.fill_tensor_with_random_uniform(output_mask.as_mut_tensor()); + + output_body + .as_mut_tensor() + .as_mut_slice() + .clone_from_slice(self.into_tensor().as_slice()); + } + + /// Returns the ciphertext as a full fledged GlweCiphertext + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweCiphertext, GlweSeededCiphertext}; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(5); + /// let glwe_dimension = GlweDimension(256); + /// + /// let mut seeded_ciphertext = GlweSeededCiphertext::allocate( + /// polynomial_size, + /// glwe_dimension, + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// + /// let mut ciphertext = GlweCiphertext::allocate( + /// 0 as u32, + /// seeded_ciphertext.polynomial_size(), + /// seeded_ciphertext.size(), + /// ); + /// + /// seeded_ciphertext.expand_into::<_, _, SoftwareRandomGenerator>(&mut ciphertext); + /// + /// assert_eq!(ciphertext.mask_size(), glwe_dimension); + /// assert_eq!(ciphertext.polynomial_size(), polynomial_size); + /// ``` + pub fn expand_into(self, output: &mut GlweCiphertext) + where + Scalar: Copy + RandomGenerable + Numeric, + GlweCiphertext: AsMutTensor, + Self: IntoTensor + AsRefTensor, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed().seed); + + self.expand_into_with_existing_generator::<_, _, Gen>(output, &mut generator); + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/glwe/seeded_list.rs b/tfhe/src/core_crypto/commons/crypto/glwe/seeded_list.rs new file mode 100644 index 000000000..df980fdb5 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/glwe/seeded_list.rs @@ -0,0 +1,384 @@ +use crate::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +use super::{GlweBody, GlweList}; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Uniform, +}; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::numeric::Numeric; + +/// A list of ciphertexts encoded with the GLWE scheme. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweSeededList { + pub tensor: Tensor, + pub glwe_dimension: GlweDimension, + pub poly_size: PolynomialSize, + pub compression_seed: CompressionSeed, +} + +tensor_traits!(GlweSeededList); + +impl GlweSeededList> +where + Scalar: Numeric, +{ + /// Allocates storage for an owned [`GlweSeededList`]. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// assert_eq!(list.polynomial_size(), polynomial_size); + /// assert_eq!(list.glwe_dimension(), GlweDimension(20)); + /// assert_eq!(list.glwe_size(), GlweSize(21)); + /// assert_eq!(list.ciphertext_count(), ciphertext_count); + /// assert_eq!(list.compression_seed(), compression_seed); + /// ``` + pub fn allocate( + poly_size: PolynomialSize, + glwe_dimension: GlweDimension, + ciphertext_number: CiphertextCount, + compression_seed: CompressionSeed, + ) -> Self { + GlweSeededList { + tensor: Tensor::from_container(vec![Scalar::ZERO; poly_size.0 * ciphertext_number.0]), + glwe_dimension, + poly_size, + compression_seed, + } + } +} + +impl GlweSeededList { + /// Creates a list from a container of values. + /// + /// # Example + /// + /// TODO + pub fn from_container( + cont: Cont, + glwe_dimension: GlweDimension, + poly_size: PolynomialSize, + compression_seed: CompressionSeed, + ) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => poly_size.0); + GlweSeededList { + tensor, + glwe_dimension, + poly_size, + compression_seed, + } + } + + /// Returns the number of ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// assert_eq!(list.ciphertext_count(), ciphertext_count); + /// ``` + pub fn ciphertext_count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.poly_size.0); + CiphertextCount(self.as_tensor().len() / self.polynomial_size().0) + } + + /// Returns the size of the glwe ciphertexts contained in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// assert_eq!(list.glwe_size(), GlweSize(21)); + /// ``` + pub fn glwe_size(&self) -> GlweSize { + self.glwe_dimension.to_glwe_size() + } + + /// Returns the number of coefficients of the polynomials used for the list ciphertexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// assert_eq!(list.polynomial_size(), polynomial_size); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns the number of masks of the ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// assert_eq!(list.glwe_dimension(), GlweDimension(20)); + /// ``` + pub fn glwe_dimension(&self) -> GlweDimension { + self.glwe_dimension + } + + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// assert_eq!(list.compression_seed(), compression_seed); + /// ``` + pub fn compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// Returns an iterator over ciphertexts bodies from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor, Tensor}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// for body in list.body_iter() { + /// let tensor_container = vec![0u8; polynomial_size.0]; + /// + /// assert_eq!(body.as_tensor().as_slice(), &tensor_container[..]); + /// } + /// ``` + pub fn body_iter(&self) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .as_slice() + .chunks(self.poly_size.0) + .map(|body| GlweBody { + tensor: Tensor::from_container(body), + }) + } + + /// Returns an iterator over mutable ciphertexts bodies from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::glwe::GlweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::{ + /// AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, + /// }; + /// use tfhe::core_crypto::prelude::{CiphertextCount, GlweDimension, GlweSize, PolynomialSize}; + /// + /// let polynomial_size = PolynomialSize(10); + /// let glwe_dimension = GlweDimension(20); + /// let ciphertext_count = CiphertextCount(30); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut list = GlweSeededList::>::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// for mut body in list.body_iter_mut() { + /// let mut tensor_container = vec![0u8; polynomial_size.0]; + /// + /// assert_eq!(body.as_tensor().as_slice(), &tensor_container[..]); + /// + /// body.as_mut_tensor().as_mut_slice()[0] = 1; + /// + /// tensor_container[0] = 1; + /// + /// assert_eq!(body.as_tensor().as_slice(), &tensor_container[..]); + /// } + /// ``` + pub fn body_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let poly_size = self.poly_size.0; + + self.as_mut_tensor() + .as_mut_slice() + .chunks_mut(poly_size) + .map(|body| GlweBody { + tensor: Tensor::from_container(body), + }) + } + + /// Returns the ciphertext list as a full fledged GlweList + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweCiphertext, GlweList, GlweSeededList}; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, GlweDimension, LogStandardDev, PolynomialSize, + /// }; + /// + /// let polynomial_size = PolynomialSize(2); + /// let glwe_dimension = GlweDimension(256); + /// let ciphertext_count = CiphertextCount(2); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut seeded_ciphertexts = GlweSeededList::allocate( + /// polynomial_size, + /// glwe_dimension, + /// ciphertext_count, + /// compression_seed, + /// ); + /// + /// let mut ciphertexts = GlweList::allocate( + /// 0 as u32, + /// seeded_ciphertexts.polynomial_size(), + /// seeded_ciphertexts.glwe_size().to_glwe_dimension(), + /// seeded_ciphertexts.ciphertext_count(), + /// ); + /// + /// seeded_ciphertexts.expand_into::<_, _, SoftwareRandomGenerator>(&mut ciphertexts); + /// + /// assert_eq!(ciphertexts.polynomial_size(), polynomial_size); + /// assert_eq!(ciphertexts.glwe_dimension(), glwe_dimension); + /// assert_eq!(ciphertexts.ciphertext_count(), ciphertext_count); + /// ``` + pub fn expand_into(self, output: &mut GlweList) + where + Self: AsRefTensor, + GlweList: AsMutTensor, + Scalar: Numeric + RandomGenerable, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed().seed); + + for (mut glwe_out, body_in) in output.ciphertext_iter_mut().zip(self.body_iter()) { + let (mut body, mut mask) = glwe_out.get_mut_body_and_mask(); + generator.fill_tensor_with_random_uniform(mask.as_mut_tensor()); + body.as_mut_tensor().fill_with_copy(body_in.as_tensor()); + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/ciphertext.rs b/tfhe/src/core_crypto/commons/crypto/lwe/ciphertext.rs new file mode 100644 index 000000000..c21d65a74 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/ciphertext.rs @@ -0,0 +1,527 @@ +use super::LweList; +use crate::core_crypto::commons::crypto::encoding::{Cleartext, CleartextList, Plaintext}; +use crate::core_crypto::commons::crypto::glwe::GlweCiphertext; +use crate::core_crypto::commons::crypto::secret::LweSecretKey; +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutTensor, AsRefTensor, Container, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger}; +use crate::core_crypto::prelude::{KeyKind, LweDimension, LweSize, MonomialDegree}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A ciphertext encrypted using the LWE scheme. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct LweCiphertext { + pub(crate) tensor: Tensor, +} + +tensor_traits!(LweCiphertext); + +impl LweCiphertext> +where + Scalar: Copy, +{ + /// Allocates a new ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// let ct = LweCiphertext::allocate(0 as u8, LweSize(4)); + /// assert_eq!(ct.lwe_size(), LweSize(4)); + /// assert_eq!(ct.get_mask().mask_size(), LweDimension(3)); + /// ``` + pub fn allocate(value: Scalar, size: LweSize) -> Self { + LweCiphertext { + tensor: Tensor::from_container(vec![value; size.0]), + } + } +} + +impl LweCiphertext> +where + Scalar: Numeric, +{ + /// Creates a new ciphertext containing the trivial encryption of the plain text + pub fn new_trivial_encryption(lwe_size: LweSize, plaintext: &Plaintext) -> Self { + let mut ciphertext = Self::allocate(Scalar::ZERO, lwe_size); + ciphertext.fill_with_trivial_encryption(plaintext); + ciphertext + } +} + +impl LweCiphertext { + /// Creates a ciphertext from a container of values. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// let vector = vec![0 as u8; 10]; + /// let ct = LweCiphertext::from_container(vector.as_slice()); + /// assert_eq!(ct.lwe_size(), LweSize(10)); + /// assert_eq!(ct.get_mask().mask_size(), LweDimension(9)); + /// ``` + pub fn from_container(cont: Cont) -> LweCiphertext { + let tensor = Tensor::from_container(cont); + LweCiphertext { tensor } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> LweCiphertext<&'_ [Cont::Element]> + where + Cont: Container, + { + LweCiphertext::from_container(self.tensor.as_container().as_ref()) + } + + pub fn as_mut_view(&mut self) -> LweCiphertext<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + LweCiphertext::from_container(self.tensor.as_mut_container().as_mut()) + } + + /// Returns the size of the cipher, e.g. the size of the mask + 1 for the body. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::prelude::LweSize; + /// let ct = LweCiphertext::allocate(0 as u8, LweSize(4)); + /// assert_eq!(ct.lwe_size(), LweSize(4)); + /// ``` + pub fn lwe_size(&self) -> LweSize + where + Self: AsRefTensor, + { + LweSize(self.as_tensor().len()) + } + + /// Returns the body of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweCiphertext}; + /// let ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let body = ciphertext.get_body(); + /// assert_eq!(body, &LweBody(0 as u8)); + /// ``` + pub fn get_body(&self) -> &LweBody + where + Self: AsRefTensor, + { + unsafe { &*{ self.as_tensor().last() as *const Scalar as *const LweBody } } + } + + /// Returns the mask of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let mask = ciphertext.get_mask(); + /// assert_eq!(mask.mask_size(), LweDimension(9)); + /// ``` + pub fn get_mask(&self) -> LweMask<&[Scalar]> + where + Self: AsRefTensor, + { + let (_, mask) = self.as_tensor().split_last(); + LweMask { tensor: mask } + } + + /// Returns the body and the mask of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweCiphertext}; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let (body, mask) = ciphertext.get_body_and_mask(); + /// assert_eq!(body, &LweBody(0)); + /// assert_eq!(mask.mask_size(), LweDimension(9)); + /// ``` + pub fn get_body_and_mask(&self) -> (&LweBody, LweMask<&[Scalar]>) + where + Self: AsRefTensor, + { + let (body, mask) = self.as_tensor().split_last(); + let body = unsafe { &*{ body as *const Scalar as *const LweBody } }; + (body, LweMask { tensor: mask }) + } + + /// Returns the mutable body of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweCiphertext}; + /// let mut ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let mut body = ciphertext.get_mut_body(); + /// *body = LweBody(8); + /// let body = ciphertext.get_body(); + /// assert_eq!(body, &LweBody(8 as u8)); + /// ``` + pub fn get_mut_body(&mut self) -> &mut LweBody + where + Self: AsMutTensor, + { + unsafe { &mut *{ self.as_mut_tensor().last_mut() as *mut Scalar as *mut LweBody } } + } + + /// Returns the mutable mask of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// let mut ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let mut mask = ciphertext.get_mut_mask(); + /// for mut elt in mask.mask_element_iter_mut() { + /// *elt = 8; + /// } + /// let mask = ciphertext.get_mask(); + /// for elt in mask.mask_element_iter() { + /// assert_eq!(*elt, 8); + /// } + /// assert_eq!(mask.mask_element_iter().count(), 9); + /// ``` + pub fn get_mut_mask(&mut self) -> LweMask<&mut [Scalar]> + where + Self: AsMutTensor, + { + let (_, masks) = self.as_mut_tensor().split_last_mut(); + LweMask { tensor: masks } + } + + /// Returns the mutable body and mask of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let mut ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let (body, mask) = ciphertext.get_mut_body_and_mask(); + /// assert_eq!(body, &mut LweBody(0)); + /// assert_eq!(mask.mask_size(), LweDimension(9)); + /// ``` + pub fn get_mut_body_and_mask( + &mut self, + ) -> (&mut LweBody, LweMask<&mut [Scalar]>) + where + Self: AsMutTensor, + { + let (body, masks) = self.as_mut_tensor().split_last_mut(); + let body = unsafe { &mut *{ body as *mut Scalar as *mut LweBody } }; + (body, LweMask { tensor: masks }) + } + + /// Fills the ciphertext with the result of the multiplication of the `input` ciphertext by the + /// `scalar` cleartext. + pub fn fill_with_scalar_mul( + &mut self, + input: &LweCiphertext, + scalar: &Cleartext, + ) where + Self: AsMutTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedInteger, + { + self.as_mut_tensor() + .fill_with_one(input.as_tensor(), |o| o.wrapping_mul(scalar.0)); + } + + /// Fills the ciphertext with the result of the multisum of the `input_list` with the + /// `weights` values, and adds a bias. + /// + /// Said differently, this function fills `self` with: + /// $$ + /// bias + \sum\_i input\_list\[i\] * weights\[i\] + /// $$ + pub fn fill_with_multisum_with_bias( + &mut self, + input_list: &LweList, + weights: &CleartextList, + bias: &Plaintext, + ) where + Self: AsMutTensor, + LweList: AsRefTensor, + CleartextList: AsRefTensor, + Scalar: UnsignedInteger, + { + // loop over the ciphertexts and the weights + for (input_cipher, weight) in input_list.ciphertext_iter().zip(weights.cleartext_iter()) { + let cipher_tens = input_cipher.as_tensor(); + self.as_mut_tensor().update_with_one(cipher_tens, |o, c| { + *o = o.wrapping_add(c.wrapping_mul(weight.0)) + }); + } + + // add the bias + let new_body = (self.get_body().0).wrapping_add(bias.0); + *self.get_mut_body() = LweBody(new_body); + } + + /// Adds the `other` ciphertext to the current one. + pub fn update_with_add(&mut self, other: &LweCiphertext) + where + Self: AsMutTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedTorus, + { + self.as_mut_tensor() + .update_with_wrapping_add(other.as_tensor()) + } + + /// Subtracts the `other` ciphertext from the current one. + pub fn update_with_sub(&mut self, other: &LweCiphertext) + where + Self: AsMutTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedTorus, + { + self.as_mut_tensor() + .update_with_wrapping_sub(other.as_tensor()) + } + + /// Computes the opposite of the ciphertext. + pub fn update_with_neg(&mut self) + where + Self: AsMutTensor, + Scalar: UnsignedTorus, + { + self.as_mut_tensor().update_with_wrapping_neg() + } + + /// Multiplies the current ciphertext with a scalar value inplace. + pub fn update_with_scalar_mul(&mut self, scalar: Cleartext) + where + Self: AsMutTensor, + Scalar: UnsignedTorus, + { + self.as_mut_tensor() + .update_with_wrapping_scalar_mul(&scalar.0) + } + + /// Fills an LWE ciphertext with the sample extraction of one of the coefficients of a GLWE + /// ciphertext. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::commons::math::polynomial::MonomialDegree; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{GlweDimension, LogStandardDev, LweDimension, PolynomialSize}; + /// + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let poly_size = PolynomialSize(4); + /// let glwe_dim = GlweDimension(2); + /// let glwe_secret_key = + /// GlweSecretKey::generate_binary(glwe_dim, poly_size, &mut secret_generator); + /// let mut plaintext_list = + /// PlaintextList::from_container(vec![100000 as u32, 200000, 300000, 400000]); + /// let mut glwe_ct = GlweCiphertext::allocate(0u32, poly_size, glwe_dim.to_glwe_size()); + /// let mut lwe_ct = + /// LweCiphertext::allocate(0u32, LweDimension(poly_size.0 * glwe_dim.0).to_lwe_size()); + /// glwe_secret_key.encrypt_glwe( + /// &mut glwe_ct, + /// &plaintext_list, + /// LogStandardDev(-50.), + /// &mut encryption_generator, + /// ); + /// let lwe_secret_key = glwe_secret_key.into_lwe_secret_key(); + /// + /// // Check for the first + /// for i in 0..4 { + /// // We sample extract + /// lwe_ct.fill_with_glwe_sample_extraction(&glwe_ct, MonomialDegree(i)); + /// // We decrypt + /// let mut output = Plaintext(0u32); + /// lwe_secret_key.decrypt_lwe(&mut output, &lwe_ct); + /// // We check that the decryption is correct + /// let plain = plaintext_list.as_tensor().get_element(i); + /// let d0 = output.0.wrapping_sub(*plain); + /// let d1 = plain.wrapping_sub(output.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400); + /// } + /// ``` + pub fn fill_with_glwe_sample_extraction( + &mut self, + glwe: &GlweCiphertext, + n_th: MonomialDegree, + ) where + Self: AsMutTensor, + GlweCiphertext: AsRefTensor, + Element: UnsignedTorus, + { + glwe.fill_lwe_with_sample_extraction(self, n_th); + } + + pub fn fill_with_trivial_encryption(&mut self, plaintext: &Plaintext) + where + Scalar: Numeric, + Self: AsMutTensor, + { + let (output_body, mut output_mask) = self.get_mut_body_and_mask(); + + // generate a uniformly random mask + output_mask.as_mut_tensor().fill_with_element(Scalar::ZERO); + + // No need to do the multisum between the secret key and the mask + // as the mask only contains zeros + + // add the encoded message + output_body.0 = plaintext.0; + } +} + +/// The mask of an LWE encrypted ciphertext. +#[derive(Debug, PartialEq, Eq)] +pub struct LweMask { + tensor: Tensor, +} + +tensor_traits!(LweMask); + +impl LweMask { + /// Creates a mask from a scalar container. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let masks = LweMask::from_container(vec![0 as u8; 10]); + /// assert_eq!(masks.mask_size(), LweDimension(10)); + /// ``` + pub fn from_container(cont: Cont) -> LweMask { + LweMask { + tensor: Tensor::from_container(cont), + } + } + + /// Returns an iterator over the mask elements. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// let mut ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let masks = ciphertext.get_mask(); + /// for mask in masks.mask_element_iter() { + /// assert_eq!(mask, &0); + /// } + /// assert_eq!(masks.mask_element_iter().count(), 9); + /// ``` + pub fn mask_element_iter(&self) -> impl Iterator::Element> + where + Self: AsRefTensor, + { + self.as_tensor().iter() + } + + /// Returns an iterator over mutable mask elements. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// let mut ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// let mut masks = ciphertext.get_mut_mask(); + /// for mask in masks.mask_element_iter_mut() { + /// *mask = 9; + /// } + /// for mask in masks.mask_element_iter() { + /// assert_eq!(mask, &9); + /// } + /// assert_eq!(masks.mask_element_iter_mut().count(), 9); + /// ``` + pub fn mask_element_iter_mut( + &mut self, + ) -> impl Iterator::Element> + where + Self: AsMutTensor, + { + self.as_mut_tensor().iter_mut() + } + + /// Returns the number of masks. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let mut ciphertext = LweCiphertext::from_container(vec![0 as u8; 10]); + /// assert_eq!(ciphertext.get_mask().mask_size(), LweDimension(9)); + /// ``` + pub fn mask_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension(self.as_tensor().len()) + } + + /// Computes sum of the mask elements weighted by the key elements. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::LweSecretKey; + /// let ciphertext = LweCiphertext::from_container(vec![1u32, 2, 3, 4, 5]); + /// let mask = ciphertext.get_mask(); + /// let key = LweSecretKey::binary_from_container(vec![1, 1, 0, 1]); + /// let multisum = mask.compute_multisum(&key); + /// assert_eq!(multisum, 7); + /// ``` + pub fn compute_multisum(&self, key: &LweSecretKey) -> Scalar + where + Self: AsRefTensor, + LweSecretKey: AsRefTensor, + Kind: KeyKind, + Scalar: UnsignedTorus, + { + self.as_tensor().fold_with_one( + key.as_tensor(), + ::ZERO, + |ac, s_i, o_i| ac.wrapping_add(*s_i * *o_i), + ) + } +} + +/// The body of an Lwe ciphertext. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct LweBody(pub T); diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/keyswitch.rs b/tfhe/src/core_crypto/commons/crypto/lwe/keyswitch.rs new file mode 100644 index 000000000..7745ecd89 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/keyswitch.rs @@ -0,0 +1,737 @@ +use super::{LweCiphertext, LweList}; +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::crypto::secret::LweSecretKey; +use crate::core_crypto::commons::math::decomposition::{ + DecompositionLevel, DecompositionTerm, SignedDecomposer, +}; +use crate::core_crypto::commons::math::random::ByteRandomGenerator; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, ck_dim_eq, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Container, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::prelude::{ + BinaryKeyKind, CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, + DispersionParameter, LweDimension, LweSize, +}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// An Lwe Keyswithing key. +/// +/// A keyswitching key allows to change the key of a cipher text. Lets assume the following +/// elements: +/// +/// + The input key $s\_{in}$ is composed of $n$ bits +/// + The output key $s\_{out}$ is composed of $m$ bits +/// +/// The keyswitch key will be composed of $m$ encryptions of each bits of the $s\_{out}$ key, under +/// the key $s\_{in}$; encryptions which will be stored as their decomposition over a given basis +/// $B\_{ks}\in\mathbb{N}$, up to a level $l\_{ks}\in\mathbb{N}$. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweKeyswitchKey { + tensor: Tensor, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + lwe_size: LweSize, +} + +tensor_traits!(LweKeyswitchKey); + +impl LweKeyswitchKey> +where + Scalar: Copy, +{ + /// Allocates a keyswitching key whose masks and bodies are all `value`. + /// + /// # Note + /// + /// This function does *not* generate a keyswitch key, but merely allocates a container of the + /// right size. See [`LweKeyswitchKey::fill_with_keyswitch_key`] to fill the container with a + /// proper keyswitching key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, LweSize, + /// }; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// LweDimension(20), + /// ); + /// assert_eq!( + /// ksk.decomposition_levels_count(), + /// DecompositionLevelCount(10) + /// ); + /// assert_eq!(ksk.decomposition_base_log(), DecompositionBaseLog(16)); + /// assert_eq!(ksk.lwe_size(), LweSize(21)); + /// assert_eq!(ksk.before_key_size(), LweDimension(10)); + /// assert_eq!(ksk.after_key_size(), LweDimension(20)); + /// ``` + pub fn allocate( + value: Scalar, + decomp_size: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + input_size: LweDimension, + output_size: LweDimension, + ) -> Self { + LweKeyswitchKey { + tensor: Tensor::from_container(vec![ + value; + decomp_size.0 * (output_size.0 + 1) * input_size.0 + ]), + decomp_base_log, + decomp_level_count: decomp_size, + lwe_size: LweSize(output_size.0 + 1), + } + } +} + +impl LweKeyswitchKey { + /// Creates an LWE key switching key from a container. + /// + /// # Notes + /// + /// This method does not create a keyswitching key, but merely wrap the container in the proper + /// type. It assumes that either the container already contains a proper keyswitching key, or + /// that [`LweKeyswitchKey::fill_with_keyswitch_key`] will be called right after. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, LweSize, + /// }; + /// let input_size = LweDimension(256); + /// let output_size = LweDimension(35); + /// let decomp_log_base = DecompositionBaseLog(7); + /// let decomp_level_count = DecompositionLevelCount(4); + /// + /// let ksk = LweKeyswitchKey::from_container( + /// vec![0 as u8; input_size.0 * (output_size.0 + 1) * decomp_level_count.0], + /// decomp_log_base, + /// decomp_level_count, + /// output_size, + /// ); + /// + /// assert_eq!(ksk.decomposition_levels_count(), DecompositionLevelCount(4)); + /// assert_eq!(ksk.decomposition_base_log(), DecompositionBaseLog(7)); + /// assert_eq!(ksk.lwe_size(), LweSize(36)); + /// assert_eq!(ksk.before_key_size(), LweDimension(256)); + /// assert_eq!(ksk.after_key_size(), LweDimension(35)); + /// ``` + pub fn from_container( + cont: Cont, + decomp_base_log: DecompositionBaseLog, + decomp_size: DecompositionLevelCount, + output_size: LweDimension, + ) -> LweKeyswitchKey + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + ck_dim_div!(tensor.len() => output_size.0 + 1, decomp_size.0); + LweKeyswitchKey { + tensor, + decomp_base_log, + decomp_level_count: decomp_size, + lwe_size: LweSize(output_size.0 + 1), + } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> LweKeyswitchKey<&'_ [Cont::Element]> + where + Cont: Container, + { + LweKeyswitchKey { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + decomp_base_log: self.decomp_base_log, + decomp_level_count: self.decomp_level_count, + lwe_size: self.lwe_size, + } + } + + pub fn as_mut_view(&mut self) -> LweKeyswitchKey<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + LweKeyswitchKey { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + decomp_base_log: self.decomp_base_log, + decomp_level_count: self.decomp_level_count, + lwe_size: self.lwe_size, + } + } + + /// Return the size of the output key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// LweDimension(20), + /// ); + /// assert_eq!(ksk.after_key_size(), LweDimension(20)); + /// ``` + pub fn after_key_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension(self.lwe_size.0 - 1) + } + + /// Returns the size of the ciphertexts encoding each level of the decomposition of each bits + /// of the input key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LweDimension, LweSize, + /// }; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// LweDimension(20), + /// ); + /// assert_eq!(ksk.lwe_size(), LweSize(21)); + /// ``` + pub fn lwe_size(&self) -> LweSize + where + Self: AsRefTensor, + { + self.lwe_size + } + + /// Returns the size of the input key. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// LweDimension(20), + /// ); + /// assert_eq!(ksk.before_key_size(), LweDimension(10)); + /// ``` + pub fn before_key_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension(self.as_tensor().len() / (self.lwe_size.0 * self.decomp_level_count.0)) + } + + /// Returns the number of levels used for the decomposition of the input key bits. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// LweDimension(20), + /// ); + /// assert_eq!( + /// ksk.decomposition_levels_count(), + /// DecompositionLevelCount(10) + /// ); + /// ``` + pub fn decomposition_levels_count(&self) -> DecompositionLevelCount + where + Self: AsRefTensor, + { + self.decomp_level_count + } + + /// Returns the logarithm of the base used for the decomposition of the input key bits. + /// + /// Indeed, the basis used is always of the form $2^N$. This function returns $N$. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(10), + /// LweDimension(20), + /// ); + /// assert_eq!(ksk.decomposition_base_log(), DecompositionBaseLog(16)); + /// ``` + pub fn decomposition_base_log(&self) -> DecompositionBaseLog + where + Self: AsRefTensor, + { + self.decomp_base_log + } + + /// Fills the current keyswitch key container with an actual keyswitching key constructed from + /// an input and an output key. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::lwe::LweKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::LweSecretKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LogStandardDev, LweDimension, LweSize, + /// }; + /// + /// let input_size = LweDimension(10); + /// let output_size = LweDimension(20); + /// let decomp_log_base = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let cipher_size = LweSize(55); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = LweSecretKey::generate_binary(output_size, &mut secret_generator); + /// + /// let mut ksk = LweKeyswitchKey::allocate( + /// 0 as u32, + /// decomp_level_count, + /// decomp_log_base, + /// input_size, + /// output_size, + /// ); + /// ksk.fill_with_keyswitch_key(&input_key, &output_key, noise, &mut encryption_generator); + /// + /// assert!(!ksk.as_tensor().iter().all(|a| *a == 0)); + /// ``` + pub fn fill_with_keyswitch_key( + &mut self, + before_key: &LweSecretKey, + after_key: &LweSecretKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + LweSecretKey: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + // We instantiate a buffer + let mut messages = PlaintextList::from_container(vec![ + ::Element::ZERO; + self.decomp_level_count.0 + ]); + + // We retrieve decomposition arguments + let decomp_level_count = self.decomp_level_count; + let decomp_base_log = self.decomp_base_log; + + // loop over the before key blocks + for (input_key_bit, keyswitch_key_block) in before_key + .as_tensor() + .iter() + .zip(self.bit_decomp_iter_mut()) + { + // We reset the buffer + messages + .as_mut_tensor() + .fill_with_element(::Element::ZERO); + + // We fill the buffer with the powers of the key bits + for (level, message) in (1..=decomp_level_count.0) + .map(DecompositionLevel) + .zip(messages.plaintext_iter_mut()) + { + *message = Plaintext( + DecompositionTerm::new(level, decomp_base_log, *input_key_bit) + .to_recomposition_summand(), + ); + } + + // We encrypt the buffer + after_key.encrypt_lwe_list( + &mut keyswitch_key_block.into_lwe_list(), + &messages, + noise_parameters, + generator, + ); + } + } + + /// Iterates over borrowed `LweKeyBitDecomposition` elements. + /// + /// One `LweKeyBitDecomposition` being a set of lwe ciphertext, encrypting under the output + /// key, the $l$ levels of the signed decomposition of a single bit of the input key. + /// + /// # Example + /// + /// ```ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyswitchKey}; + /// use tfhe::core_crypto::backends::default::private::math::decomposition::{DecompositionLevelCount, DecompositionBaseLog}; + /// let ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(15), + /// LweDimension(20) + /// ); + /// for decomp in ksk.bit_decomp_iter() { + /// assert_eq!(decomp.lwe_size(), ksk.lwe_size()); + /// assert_eq!(decomp.count().0, 10); + /// } + /// assert_eq!(ksk.bit_decomp_iter().count(), 15); + /// ``` + pub(crate) fn bit_decomp_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0, self.decomp_level_count.0); + let size = self.decomp_level_count.0 * self.lwe_size.0; + let lwe_size = self.lwe_size; + self.as_tensor() + .subtensor_iter(size) + .map(move |sub| LweKeyBitDecomposition::from_container(sub.into_container(), lwe_size)) + } + + /// Iterates over mutably borrowed `LweKeyBitDecomposition` elements. + /// + /// One `LweKeyBitDecomposition` being a set of lwe ciphertext, encrypting under the output + /// key, the $l$ levels of the signed decomposition of a single bit of the input key. + /// + /// # Example + /// + /// ```ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyswitchKey}; + /// use tfhe::core_crypto::backends::default::private::math::tensor::{AsRefTensor, AsMutTensor}; + /// use tfhe::core_crypto::backends::default::private::math::decomposition::{DecompositionLevelCount, DecompositionBaseLog}; + /// let mut ksk = LweKeyswitchKey::allocate( + /// 0 as u8, + /// DecompositionLevelCount(10), + /// DecompositionBaseLog(16), + /// LweDimension(15), + /// LweDimension(20) + /// ); + /// for mut decomp in ksk.bit_decomp_iter_mut() { + /// for mut ciphertext in decomp.ciphertext_iter_mut() { + /// ciphertext.as_mut_tensor().fill_with_element(0); + /// } + /// } + /// assert!(ksk.as_tensor().iter().all(|a| *a == 0)); + /// assert_eq!(ksk.bit_decomp_iter_mut().count(), 15); + /// ``` + pub(crate) fn bit_decomp_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0, self.decomp_level_count.0); + let chunks_size = self.decomp_level_count.0 * self.lwe_size.0; + let lwe_size = self.lwe_size; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| LweKeyBitDecomposition::from_container(sub.into_container(), lwe_size)) + } + + /// Switches the key of a signel Lwe ciphertext. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::LweSecretKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LogStandardDev, LweDimension, LweSize, + /// }; + /// + /// let input_size = LweDimension(1024); + /// let output_size = LweDimension(1024); + /// let decomp_log_base = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(8); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = LweSecretKey::generate_binary(output_size, &mut secret_generator); + /// + /// let mut ksk = LweKeyswitchKey::allocate( + /// 0 as u64, + /// decomp_level_count, + /// decomp_log_base, + /// input_size, + /// output_size, + /// ); + /// ksk.fill_with_keyswitch_key(&input_key, &output_key, noise, &mut encryption_generator); + /// + /// let plaintext: Plaintext = Plaintext(1432154329994324); + /// let mut ciphertext = LweCiphertext::allocate(0. as u64, LweSize(1025)); + /// let mut switched_ciphertext = LweCiphertext::allocate(0. as u64, LweSize(1025)); + /// input_key.encrypt_lwe( + /// &mut ciphertext, + /// &plaintext, + /// noise, + /// &mut encryption_generator, + /// ); + /// + /// ksk.keyswitch_ciphertext(&mut switched_ciphertext, &ciphertext); + /// + /// let mut decrypted = Plaintext(0 as u64); + /// output_key.decrypt_lwe(&mut decrypted, &switched_ciphertext); + /// ``` + pub fn keyswitch_ciphertext( + &self, + after: &mut LweCiphertext, + before: &LweCiphertext, + ) where + Self: AsRefTensor, + LweCiphertext: AsMutTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedTorus, + { + ck_dim_eq!(self.before_key_size().0 => before.get_mask().mask_size().0); + ck_dim_eq!(self.after_key_size().0 => after.get_mask().mask_size().0); + + // We reset the output + after.as_mut_tensor().fill_with(|| Scalar::ZERO); + + // We copy the body + *after.get_mut_body() = *before.get_body(); + // We instantiate a decomposer + let decomposer = SignedDecomposer::new(self.decomp_base_log, self.decomp_level_count); + + for (block, before_mask) in self + .bit_decomp_iter() + .zip(before.get_mask().mask_element_iter()) + { + let mask_rounded = decomposer.closest_representable(*before_mask); + let decomp = decomposer.decompose(mask_rounded); + // loop over the number of levels + for (level_key_cipher, decomposed) in block + .as_tensor() + .subtensor_iter(self.after_key_size().0 + 1) + .rev() + .zip(decomp) + { + after + .as_mut_tensor() + .update_with_wrapping_sub_element_mul(&level_key_cipher, decomposed.value()); + } + } + } + + pub fn keyswitch_list( + &self, + output: &mut LweList, + input: &LweList, + ) where + Self: AsRefTensor, + LweList: AsRefTensor, + LweList: AsMutTensor, + Scalar: UnsignedTorus, + { + ck_dim_eq!(input.count().0 => output.count().0); + // for each ciphertext, call mono_key_switch + for (input_cipher, mut output_cipher) in + input.ciphertext_iter().zip(output.ciphertext_iter_mut()) + { + self.keyswitch_ciphertext(&mut output_cipher, &input_cipher); + } + } +} + +/// The encryption of a single bit of the output key. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] +pub(crate) struct LweKeyBitDecomposition { + pub(crate) tensor: Tensor, + pub(crate) lwe_size: LweSize, +} + +tensor_traits!(LweKeyBitDecomposition); + +impl LweKeyBitDecomposition { + /// Creates a key bit decomposition from a container. + /// + /// # Notes + /// + /// This method does not decompose a key bit in a basis, but merely wraps a container in the + /// right structure. See [`LweKeyswitchKey::bit_decomp_iter`] for an iterator that returns key + /// bit decompositions. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// assert_eq!(kbd.count(), CiphertextCount(15)); + /// assert_eq!(kbd.lwe_size(), LweSize(10)); + /// ``` + pub fn from_container(cont: Cont, lwe_size: LweSize) -> Self { + LweKeyBitDecomposition { + tensor: Tensor::from_container(cont), + lwe_size, + } + } + + /// Returns the size of the lwe ciphertexts encoding each level of the key bit decomposition. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// assert_eq!(kbd.lwe_size(), LweSize(10)); + /// ``` + #[allow(dead_code)] + pub fn lwe_size(&self) -> LweSize { + self.lwe_size + } + + /// Returns the number of ciphertexts in the decomposition. + /// + /// Note that this is actually equals to the number of levels in the decomposition. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// assert_eq!(kbd.count(), CiphertextCount(15)); + /// ``` + #[allow(dead_code)] + pub fn count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0); + CiphertextCount(self.as_tensor().len() / self.lwe_size.0) + } + + /// Returns an iterator over borrowed `LweCiphertext`. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// for ciphertext in kbd.ciphertext_iter(){ + /// assert_eq!(ciphertext.lwe_size(), LweSize(10)); + /// } + /// assert_eq!(kbd.ciphertext_iter().count(), 15); + /// ``` + #[allow(dead_code)] + pub fn ciphertext_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(self.lwe_size.0) + .map(|sub| LweCiphertext::from_container(sub.into_container())) + } + + /// Returns an iterator over mutably borrowed `LweCiphertext`. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyBitDecomposition}; + /// use tfhe::core_crypto::backends::default::private::math::tensor::{AsRefTensor, AsMutTensor}; + /// let mut kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// for mut ciphertext in kbd.ciphertext_iter_mut(){ + /// ciphertext.as_mut_tensor().fill_with_element(9); + /// } + /// assert!(kbd.as_tensor().iter().all(|a| *a == 9)); + /// assert_eq!(kbd.ciphertext_iter().count(), 15); + /// ``` + #[allow(dead_code)] + pub fn ciphertext_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.lwe_size.0; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(|sub| LweCiphertext::from_container(sub.into_container())) + } + + /// Consumes the current key bit decomposition and returns an lwe list. + /// + /// Note that this operation is super cheap, as it merely rewraps the current container in an + /// lwe list structure. + /// + /// # Example + /// + /// ```rust,ignore + /// use tfhe::core_crypto::backends::default::private::crypto::{*, lwe::LweKeyBitDecomposition}; + /// let kbd = LweKeyBitDecomposition::from_container(vec![0 as u8; 150], LweSize(10)); + /// let list = kbd.into_lwe_list(); + /// assert_eq!(list.count(), CiphertextCount(15)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// ``` + pub fn into_lwe_list(self) -> LweList { + LweList { + tensor: self.tensor, + lwe_size: self.lwe_size, + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/list.rs b/tfhe/src/core_crypto/commons/crypto/lwe/list.rs new file mode 100644 index 000000000..dc94c7591 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/list.rs @@ -0,0 +1,398 @@ +use super::LweCiphertext; +use crate::core_crypto::commons::crypto::encoding::{CleartextList, PlaintextList}; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Container, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::utils::{zip, zip_args}; +use crate::core_crypto::prelude::{CiphertextCount, CleartextCount, LweDimension, LweSize}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A list of ciphertext encoded with the LWE scheme. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweList { + pub(crate) tensor: Tensor, + pub(crate) lwe_size: LweSize, +} + +tensor_traits!(LweList); + +impl LweList> +where + Scalar: Copy, +{ + /// Allocates a list of lwe ciphertext whose all masks and bodies have the value `value`. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweSize}; + /// let list = LweList::allocate(0 as u8, LweSize(10), CiphertextCount(20)); + /// assert_eq!(list.count(), CiphertextCount(20)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// ``` + pub fn allocate(value: Scalar, lwe_size: LweSize, lwe_count: CiphertextCount) -> Self { + LweList { + tensor: Tensor::from_container(vec![value; lwe_size.0 * lwe_count.0]), + lwe_size, + } + } +} + +impl LweList> +where + Scalar: UnsignedTorus, +{ + /// Creates a new ciphertext containing the trivial encryption of the plain text + /// + /// `Trivial` means that the LWE masks consist of zeros only and can therefore be decrypted with + /// any key. + pub fn new_trivial_encryption( + lwe_size: LweSize, + plaintexts: &PlaintextList, + ) -> Self + where + PlaintextList: AsRefTensor, + { + let mut ciphertexts = Self::allocate( + Scalar::ZERO, + lwe_size, + CiphertextCount(plaintexts.count().0), + ); + ciphertexts.fill_with_trivial_encryption(plaintexts); + ciphertexts + } +} + +impl LweList { + /// Creates a list from a container and a lwe size. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweSize}; + /// let list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// assert_eq!(list.count(), CiphertextCount(20)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// ``` + pub fn from_container(cont: Cont, lwe_size: LweSize) -> Self + where + Cont: AsRefSlice, + { + ck_dim_div!(cont.as_slice().len() => lwe_size.0); + let tensor = Tensor::from_container(cont); + LweList { tensor, lwe_size } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> LweList<&'_ [Cont::Element]> + where + Cont: Container, + { + LweList { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + lwe_size: self.lwe_size, + } + } + + pub fn as_mut_view(&mut self) -> LweList<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + LweList { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + lwe_size: self.lwe_size, + } + } + + /// Returns the number of ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweSize}; + /// let list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// assert_eq!(list.count(), CiphertextCount(20)); + /// ``` + pub fn count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0); + CiphertextCount(self.as_tensor().len() / self.lwe_size.0) + } + + /// Returns the size of the ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweSize; + /// let list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// ``` + pub fn lwe_size(&self) -> LweSize + where + Self: AsRefTensor, + { + self.lwe_size + } + + /// Returns the number of masks of the ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweList; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// let list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// assert_eq!(list.mask_size(), LweDimension(9)); + /// ``` + pub fn mask_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension(self.lwe_size.0 - 1) + } + + /// Returns an iterator over ciphertexts borrowed from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweSize; + /// let list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// for ciphertext in list.ciphertext_iter() { + /// let (body, masks) = ciphertext.get_body_and_mask(); + /// assert_eq!(body, &LweBody(0)); + /// assert_eq!( + /// masks, + /// LweMask::from_container(&[0 as u8, 0, 0, 0, 0, 0, 0, 0, 0][..]) + /// ); + /// } + /// assert_eq!(list.ciphertext_iter().count(), 20); + /// ``` + pub fn ciphertext_iter( + &self, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0); + self.as_tensor() + .subtensor_iter(self.lwe_size.0) + .map(|sub| LweCiphertext::from_container(sub.into_container())) + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_ciphertext_iter( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsRefTensor, + ::Element: Sync, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0); + let lwe_size = self.lwe_size.0; + self.as_tensor() + .par_subtensor_iter(lwe_size) + .map(|sub| LweCiphertext::from_container(sub.into_container())) + } + + /// Returns an iterator over ciphers mutably borrowed from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweSize; + /// let mut list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// for mut ciphertext in list.ciphertext_iter_mut() { + /// let body = ciphertext.get_mut_body(); + /// *body = LweBody(2); + /// } + /// for ciphertext in list.ciphertext_iter() { + /// let body = ciphertext.get_body(); + /// assert_eq!(body, &LweBody(2)); + /// } + /// assert_eq!(list.ciphertext_iter_mut().count(), 20); + /// ``` + pub fn ciphertext_iter_mut( + &mut self, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0); + let lwe_size = self.lwe_size.0; + self.as_mut_tensor() + .subtensor_iter_mut(lwe_size) + .map(|sub| LweCiphertext::from_container(sub.into_container())) + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_ciphertext_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutTensor, + ::Element: Sync + Send, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0); + let lwe_size = self.lwe_size.0; + self.as_mut_tensor() + .par_subtensor_iter_mut(lwe_size) + .map(|sub| LweCiphertext::from_container(sub.into_container())) + } + + /// Returns an iterator over sub lists borrowed from the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweSize}; + /// let list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// for sublist in list.sublist_iter(CiphertextCount(5)) { + /// assert_eq!(sublist.count(), CiphertextCount(5)); + /// for ciphertext in sublist.ciphertext_iter() { + /// let (body, masks) = ciphertext.get_body_and_mask(); + /// assert_eq!(body, &LweBody(0)); + /// assert_eq!( + /// masks, + /// LweMask::from_container(&[0 as u8, 0, 0, 0, 0, 0, 0, 0, 0][..]) + /// ); + /// } + /// } + /// assert_eq!(list.sublist_iter(CiphertextCount(5)).count(), 4); + /// ``` + pub fn sublist_iter( + &self, + sub_len: CiphertextCount, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0, sub_len.0); + let lwe_size = self.lwe_size; + self.as_tensor() + .subtensor_iter(self.lwe_size.0 * sub_len.0) + .map(move |sub| LweList::from_container(sub.into_container(), lwe_size)) + } + + /// Returns an iterator over sub lists borrowed from the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweSize}; + /// let mut list = LweList::from_container(vec![0 as u8; 200], LweSize(10)); + /// for mut sublist in list.sublist_iter_mut(CiphertextCount(5)) { + /// assert_eq!(sublist.count(), CiphertextCount(5)); + /// for mut ciphertext in sublist.ciphertext_iter_mut() { + /// let (body, mut masks) = ciphertext.get_mut_body_and_mask(); + /// *body = LweBody(9); + /// for mut mask in masks.mask_element_iter_mut() { + /// *mask = 8; + /// } + /// } + /// } + /// for ciphertext in list.ciphertext_iter() { + /// let (body, masks) = ciphertext.get_body_and_mask(); + /// assert_eq!(body, &LweBody(9)); + /// assert_eq!( + /// masks, + /// LweMask::from_container(&[8 as u8, 8, 8, 8, 8, 8, 8, 8, 8][..]) + /// ); + /// } + /// assert_eq!(list.sublist_iter_mut(CiphertextCount(5)).count(), 4); + /// ``` + pub fn sublist_iter_mut( + &mut self, + sub_len: CiphertextCount, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.lwe_size.0, sub_len.0); + let chunks_size = self.lwe_size.0 * sub_len.0; + let size = self.lwe_size; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(move |sub| LweList::from_container(sub.into_container(), size)) + } + + /// Fills each ciphertexts of the list with the result of the multisum of a subpart of the + /// `input_list` ciphers, with a subset of the `weights_list` values, and one value of + /// `biases_list`. + /// + /// Said differently, this function fills `self` with: + /// $$ + /// bias\[i\] + \sum\_j input\_list\[i\]\[j\] * weights\[i\]\[j\] + /// $$ + pub fn fill_with_multisums_with_biases( + &mut self, + input_list: &LweList, + weights_list: &CleartextList, + biases_list: &PlaintextList, + ) where + Self: AsMutTensor, + LweList: AsRefTensor, + CleartextList: AsRefTensor, + PlaintextList: AsRefTensor, + for<'a> CleartextList<&'a [Scalar]>: AsRefTensor, + Scalar: UnsignedTorus, + { + ck_dim_div!(input_list.count().0 => weights_list.count().0, biases_list.count().0); + ck_dim_div!(input_list.count().0 => self.count().0); + let count = input_list.count().0 / self.count().0; + for zip_args!(mut output, input, weights, bias) in zip!( + self.ciphertext_iter_mut(), + input_list.sublist_iter(CiphertextCount(count)), + weights_list.sublist_iter(CleartextCount(count)), + biases_list.plaintext_iter() + ) { + output.fill_with_multisum_with_bias(&input, &weights, bias); + } + } + + pub fn fill_with_trivial_encryption( + &mut self, + encoded: &PlaintextList, + ) where + Self: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + { + debug_assert!( + self.count().0 == encoded.count().0, + "Lwe cipher list size and encoded list size are not compatible" + ); + for (mut cipher, plaintext) in self.ciphertext_iter_mut().zip(encoded.plaintext_iter()) { + cipher.fill_with_trivial_encryption(plaintext); + } + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/mod.rs b/tfhe/src/core_crypto/commons/crypto/lwe/mod.rs new file mode 100644 index 000000000..7e31babe6 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/mod.rs @@ -0,0 +1,106 @@ +//! LWE encryption scheme. +mod ciphertext; +mod keyswitch; +mod list; +mod seeded_ciphertext; +mod seeded_keyswitch; +mod seeded_list; + +pub use ciphertext::*; +pub use keyswitch::*; +pub use list::*; +pub use seeded_ciphertext::*; +pub use seeded_keyswitch::*; +pub use seeded_list::*; + +#[cfg(test)] +mod test { + use crate::core_crypto::commons::crypto::lwe::{LweKeyswitchKey, LweSeededKeyswitchKey}; + use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, + }; + use crate::core_crypto::commons::crypto::secret::LweSecretKey; + use crate::core_crypto::commons::math::random::CompressionSeed; + use crate::core_crypto::commons::math::torus::UnsignedTorus; + use crate::core_crypto::commons::test_tools::new_secret_random_generator; + use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, LweDimension, StandardDev, + }; + use concrete_csprng::generators::SoftwareRandomGenerator; + use concrete_csprng::seeders::Seed; + + fn test_ksk_seeded_gen_equivalence() { + for _ in 0..10 { + let input_lwe_dim = + LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let output_lwe_dim = + LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let level = DecompositionLevelCount( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let base_log = DecompositionBaseLog( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + let deterministic_seeder_seed = + Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + + let compression_seed = CompressionSeed { seed: mask_seed }; + + let mut secret_generator = new_secret_random_generator(); + + let input_key = LweSecretKey::generate_binary(input_lwe_dim, &mut secret_generator); + let output_key = LweSecretKey::generate_binary(output_lwe_dim, &mut secret_generator); + + let mut ksk = + LweKeyswitchKey::allocate(T::ZERO, level, base_log, input_lwe_dim, output_lwe_dim); + + let mut encryption_generator = + EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new( + deterministic_seeder_seed, + ), + ); + + ksk.fill_with_keyswitch_key( + &input_key, + &output_key, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + let mut seeded_ksk = LweSeededKeyswitchKey::allocate( + level, + base_log, + input_lwe_dim, + output_lwe_dim, + compression_seed, + ); + + seeded_ksk.fill_with_seeded_keyswitch_key::<_, _, _, _, _, SoftwareRandomGenerator>( + &input_key, + &output_key, + StandardDev::from_standard_dev(10.), + &mut DeterministicSeeder::::new(deterministic_seeder_seed), + ); + + let mut expanded_ksk = + LweKeyswitchKey::allocate(T::ZERO, level, base_log, input_lwe_dim, output_lwe_dim); + + seeded_ksk.expand_into::<_, _, SoftwareRandomGenerator>(&mut expanded_ksk); + + assert_eq!(ksk, expanded_ksk); + } + } + + #[test] + fn test_ksk_seeded_gen_equivalence_u32() { + test_ksk_seeded_gen_equivalence::() + } + + #[test] + fn test_ksk_seeded_gen_equivalence_u64() { + test_ksk_seeded_gen_equivalence::() + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/seeded_ciphertext.rs b/tfhe/src/core_crypto/commons/crypto/lwe/seeded_ciphertext.rs new file mode 100644 index 000000000..03c5858c3 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/seeded_ciphertext.rs @@ -0,0 +1,179 @@ +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{LweDimension, LweSize}; + +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Uniform, +}; +use crate::core_crypto::commons::math::tensor::AsMutTensor; + +use super::{LweBody, LweCiphertext}; + +/// A seeded ciphertext encrypted using the LWE scheme. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweSeededCiphertext { + pub(crate) body: LweBody, + pub(crate) lwe_dimension: LweDimension, + pub(crate) compression_seed: CompressionSeed, +} + +impl LweSeededCiphertext { + /// Allocates a seeded ciphertext whose body is 0. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ciphertext = LweSeededCiphertext::allocate(LweDimension(3), compression_seed); + /// assert_eq!(*ciphertext.get_body(), LweBody(0_u8)); + /// assert_eq!(ciphertext.lwe_size(), LweSize(4)); + /// assert_eq!(ciphertext.compression_seed(), compression_seed); + /// ``` + pub fn allocate(lwe_dimension: LweDimension, seed: CompressionSeed) -> Self { + Self::from_scalar(Scalar::ZERO, lwe_dimension, seed) + } + + /// Allocates a new seeded ciphertext from elementary components. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ciphertext = LweSeededCiphertext::from_scalar(0_u8, LweDimension(3), compression_seed); + /// assert_eq!(*ciphertext.get_body(), LweBody(0_u8)); + /// assert_eq!(ciphertext.lwe_size(), LweSize(4)); + /// assert_eq!(ciphertext.compression_seed(), compression_seed); + /// ``` + pub fn from_scalar(value: Scalar, lwe_dimension: LweDimension, seed: CompressionSeed) -> Self { + Self { + body: LweBody(value), + lwe_dimension, + compression_seed: seed, + } + } + + /// Returns the size of the ciphertext, e.g. the size of the mask + 1 for the body. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededCiphertext; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ciphertext = LweSeededCiphertext::from_scalar(0_u8, LweDimension(3), compression_seed); + /// assert_eq!(ciphertext.lwe_size(), LweSize(4)); + /// ``` + pub fn lwe_size(&self) -> LweSize { + self.lwe_dimension.to_lwe_size() + } + + /// Returns the body of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::LweDimension; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ciphertext = LweSeededCiphertext::from_scalar(0_u8, LweDimension(3), compression_seed); + /// let body = ciphertext.get_body(); + /// assert_eq!(*body, LweBody(0_u8)); + /// ``` + pub fn get_body(&self) -> &LweBody { + &self.body + } + + /// Returns the mutable body of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::LweDimension; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut ciphertext = LweSeededCiphertext::from_scalar(0_u8, LweDimension(3), compression_seed); + /// let mut body = ciphertext.get_mut_body(); + /// assert_eq!(*body, LweBody(0_u8)); + /// *body = LweBody(8); + /// let body = ciphertext.get_body(); + /// assert_eq!(body, &LweBody(8_u8)); + /// ``` + pub fn get_mut_body(&mut self) -> &mut LweBody { + &mut self.body + } + + /// Returns the seed of the ciphertext. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::LweDimension; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ciphertext = LweSeededCiphertext::from_scalar(0_u8, LweDimension(3), compression_seed); + /// assert_eq!(ciphertext.compression_seed(), compression_seed); + /// ``` + pub fn compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// Returns the ciphertext as a fully fledged LweCiphertext + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweCiphertext, LweSeededCiphertext}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_ciphertext: LweSeededCiphertext = + /// LweSeededCiphertext::allocate(LweDimension(9), compression_seed); + /// let mut ciphertext = LweCiphertext::allocate(0_u8, LweSize(10)); + /// seeded_ciphertext.expand_into::<_, SoftwareRandomGenerator>(&mut ciphertext); + /// let (body, mask) = ciphertext.get_mut_body_and_mask(); + /// assert_eq!(body, &mut LweBody(0)); + /// assert_eq!(mask.mask_size(), LweDimension(9)); + /// ``` + pub fn expand_into(self, output: &mut LweCiphertext) + where + LweCiphertext: AsMutTensor, + Scalar: Copy + RandomGenerable + Numeric, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed.seed); + let (output_body, mut output_mask) = output.get_mut_body_and_mask(); + + // generate a uniformly random mask + generator.fill_tensor_with_random_uniform(output_mask.as_mut_tensor()); + + output_body.0 = self.body.0; + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/seeded_keyswitch.rs b/tfhe/src/core_crypto/commons/crypto/lwe/seeded_keyswitch.rs new file mode 100644 index 000000000..376d44718 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/seeded_keyswitch.rs @@ -0,0 +1,541 @@ +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +use crate::core_crypto::commons::numeric::Numeric; + +use crate::core_crypto::prelude::{ + BinaryKeyKind, CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, + DispersionParameter, LweDimension, LweSize, +}; + +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::crypto::secret::LweSecretKey; +use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, DecompositionTerm}; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Seeder, Uniform, +}; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; + +use super::{LweKeyswitchKey, LweList, LweSeededList}; + +/// A seeded Lwe Keyswithing key. +/// +/// See [`LweKeyswitchKey`] for more details on keyswitching keys. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweSeededKeyswitchKey { + tensor: Tensor, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + // Output LweSize + lwe_size: LweSize, + compression_seed: CompressionSeed, +} + +tensor_traits!(LweSeededKeyswitchKey); + +impl LweSeededKeyswitchKey> +where + Scalar: Copy + Numeric, +{ + /// Allocates a seeded keyswitching key, the underlying container has a size of + /// `level_decomp * input_dimension`. This seeded version of the keyswitch key stores the + /// bodies of ciphertexts encrypting each bit of the input LWE secret key level_decomp times. + /// + /// # Note + /// + /// This function does *not* generate a seeded keyswitch key, but merely allocates a container + /// of the right size. See [`LweSeededKeyswitchKey::fill_with_seeded_keyswitch_key`] to fill + /// the container with a proper seeded keyswitching key. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(10); + /// let base_log = DecompositionBaseLog(16); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(ksk.decomposition_level_count(), levels); + /// assert_eq!(ksk.decomposition_base_log(), base_log); + /// assert_eq!(ksk.input_lwe_dimension(), input_dimension); + /// assert_eq!(ksk.output_lwe_dimension(), output_dimension); + /// assert_eq!(ksk.compression_seed(), compression_seed); + /// ``` + pub fn allocate( + decomp_level_count: DecompositionLevelCount, + decomp_base_log: DecompositionBaseLog, + input_dimension: LweDimension, + output_dimension: LweDimension, + compression_seed: CompressionSeed, + ) -> Self { + Self { + tensor: Tensor::from_container(vec![ + Scalar::ZERO; + decomp_level_count.0 * input_dimension.0 + ]), + decomp_base_log, + decomp_level_count, + lwe_size: output_dimension.to_lwe_size(), + compression_seed, + } + } +} + +impl LweSeededKeyswitchKey { + /// Return the LWE dimension of the output key. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(10); + /// let base_log = DecompositionBaseLog(16); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(ksk.output_lwe_dimension(), output_dimension); + /// ``` + pub fn output_lwe_dimension(&self) -> LweDimension + where + Self: AsRefTensor, + { + self.lwe_size.to_lwe_dimension() + } + + /// Returns the LWE dimension of the input key. This is also the LWE dimension of the + /// ciphertexts encoding each level of decomposition of the input key bits. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(10); + /// let base_log = DecompositionBaseLog(16); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(ksk.input_lwe_dimension(), input_dimension); + /// ``` + pub fn input_lwe_dimension(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension(self.as_tensor().len() / self.decomp_level_count.0) + } + + /// Returns the number of levels used for the decomposition of the input key bits. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(10); + /// let base_log = DecompositionBaseLog(16); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(ksk.decomposition_level_count(), levels); + /// ``` + pub fn decomposition_level_count(&self) -> DecompositionLevelCount + where + Self: AsRefTensor, + { + self.decomp_level_count + } + + /// Returns the logarithm of the base used for the decomposition of the input key bits. + /// + /// Indeed, the basis used is always of the form $2^N$. This function returns $N$. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(10); + /// let base_log = DecompositionBaseLog(16); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(ksk.decomposition_base_log(), base_log); + /// ``` + pub fn decomposition_base_log(&self) -> DecompositionBaseLog + where + Self: AsRefTensor, + { + self.decomp_base_log + } + + /// Fills the current seeded keyswitch key container with an actual seeded keyswitching key + /// constructed from an input and an output key. + /// + /// # Example + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::UnixSeeder; + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::LweSecretKey; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::commons::math::tensor::AsRefTensor; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, LogStandardDev, LweDimension, LweSize, + /// }; + /// + /// let input_size = LweDimension(10); + /// let output_size = LweDimension(20); + /// let decomp_log_base = DecompositionBaseLog(3); + /// let decomp_level_count = DecompositionLevelCount(5); + /// let cipher_size = LweSize(55); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut seeder = UnixSeeder::new(0); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut seeder); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let input_key = LweSecretKey::generate_binary(input_size, &mut secret_generator); + /// let output_key = LweSecretKey::generate_binary(output_size, &mut secret_generator); + /// + /// let mut ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// decomp_level_count, + /// decomp_log_base, + /// input_size, + /// output_size, + /// compression_seed, + /// ); + /// + /// ksk.fill_with_seeded_keyswitch_key::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &input_key, + /// &output_key, + /// noise, + /// &mut seeder, + /// ); + /// + /// assert!(!ksk.as_tensor().iter().all(|a| *a == 0)); + /// ``` + pub fn fill_with_seeded_keyswitch_key< + InKeyCont, + OutKeyCont, + Scalar, + NoiseParameter, + NoiseSeeder, + Gen, + >( + &mut self, + before_key: &LweSecretKey, + after_key: &LweSecretKey, + noise_parameters: NoiseParameter, + noise_seeder: &mut NoiseSeeder, + ) where + Self: AsMutTensor, + LweSecretKey: AsRefTensor, + LweSecretKey: AsRefTensor, + Scalar: UnsignedTorus, + NoiseParameter: DispersionParameter, + NoiseSeeder: Seeder, + Gen: ByteRandomGenerator, + { + // We instantiate a buffer + let mut messages = PlaintextList::from_container(vec![ + ::Element::ZERO; + self.decomp_level_count.0 + ]); + + // We retrieve decomposition arguments + let decomp_level_count = self.decomp_level_count; + let decomp_base_log = self.decomp_base_log; + + let mut generator = + EncryptionRandomGenerator::::new(self.compression_seed.seed, noise_seeder); + + // loop over the before key blocks + for (input_key_bit, keyswitch_key_block) in before_key + .as_tensor() + .iter() + .zip(self.bit_decomp_iter_mut()) + { + // We reset the buffer + messages + .as_mut_tensor() + .fill_with_element(::Element::ZERO); + + // We fill the buffer with the powers of the key bits + for (level, message) in (1..=decomp_level_count.0) + .map(DecompositionLevel) + .zip(messages.plaintext_iter_mut()) + { + *message = Plaintext( + DecompositionTerm::new(level, decomp_base_log, *input_key_bit) + .to_recomposition_summand(), + ); + } + + // We encrypt the buffer + after_key.encrypt_seeded_lwe_list_with_existing_generator::<_, _, _, _, Gen>( + &mut keyswitch_key_block.into_seeded_lwe_list(), + &messages, + noise_parameters, + &mut generator, + ); + } + } + + /// Iterates over borrowed `SeededLweKeyBitDecomposition` elements. + /// + /// One `SeededLweKeyBitDecomposition` being a set of seeded lwe ciphertexts, encrypting under + /// the output key, the $l$ levels of the signed decomposition of a single bit of the input + /// key. + pub(crate) fn bit_decomp_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.as_tensor().len() => self.decomp_level_count.0); + let level_count = self.decomp_level_count.0; + let lwe_size = self.lwe_size; + let compression_seed = self.compression_seed(); + self.as_tensor() + .subtensor_iter(level_count) + .map(move |sub| { + SeededLweKeyBitDecomposition::from_container( + sub.into_container(), + lwe_size, + compression_seed, + ) + }) + } + + /// Iterates over mutably borrowed `SeededLweKeyBitDecomposition` elements. + /// + /// One `SeededLweKeyBitDecomposition` being a set of seeded lwe ciphertexts, encrypting under + /// the output key, the $l$ levels of the signed decomposition of a single bit of the input + /// key. + pub(crate) fn bit_decomp_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.as_tensor().len() => self.decomp_level_count.0); + let level_count = self.decomp_level_count.0; + let lwe_size = self.lwe_size; + let compression_seed = self.compression_seed(); + self.as_mut_tensor() + .subtensor_iter_mut(level_count) + .map(move |sub| { + SeededLweKeyBitDecomposition::from_container( + sub.into_container(), + lwe_size, + compression_seed, + ) + }) + } + + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededKeyswitchKey; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(10); + /// let base_log = DecompositionBaseLog(16); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// assert_eq!(ksk.compression_seed(), compression_seed); + /// ``` + pub fn compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// # Example + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::lwe::{LweKeyswitchKey, LweSeededKeyswitchKey}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; + /// + /// let levels = DecompositionLevelCount(3); + /// let base_log = DecompositionBaseLog(5); + /// let input_dimension = LweDimension(15); + /// let output_dimension = LweDimension(20); + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let ksk: LweSeededKeyswitchKey> = LweSeededKeyswitchKey::allocate( + /// levels, + /// base_log, + /// input_dimension, + /// output_dimension, + /// compression_seed, + /// ); + /// + /// let mut output_ksk = LweKeyswitchKey::allocate( + /// 0, + /// ksk.decomposition_level_count(), + /// ksk.decomposition_base_log(), + /// ksk.input_lwe_dimension(), + /// ksk.output_lwe_dimension(), + /// ); + /// + /// ksk.expand_into::<_, _, SoftwareRandomGenerator>(&mut output_ksk); + /// ``` + pub fn expand_into(self, output: &mut LweKeyswitchKey) + where + LweKeyswitchKey: AsMutTensor, + Self: AsRefTensor, + Scalar: Copy + RandomGenerable + Numeric, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed.seed); + + for (mut output_tensor, keyswitch_key_block) in output + .as_mut_tensor() + // We need enough space for decomp_level_count ciphertexts of size lwe_size + .subtensor_iter_mut(self.decomp_level_count.0 * self.lwe_size.0) + .zip(self.bit_decomp_iter()) + { + let mut lwe_list = LweList::from_container(output_tensor.as_mut_slice(), self.lwe_size); + keyswitch_key_block + .into_seeded_lwe_list() + .expand_into_with_existing_generator::<_, _, Gen>(&mut lwe_list, &mut generator); + } + } +} + +/// The encryption of a single bit of the output key. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq)] +pub(crate) struct SeededLweKeyBitDecomposition { + pub(super) tensor: Tensor, + pub(super) lwe_size: LweSize, + pub(super) compression_seed: CompressionSeed, +} + +tensor_traits!(SeededLweKeyBitDecomposition); + +impl SeededLweKeyBitDecomposition { + /// Creates a key bit decomposition from a container. + /// + /// # Notes + /// + /// This method does not decompose a key bit in a basis, but merely wraps a container in the + /// right structure. See [`LweSeededKeyswitchKey::bit_decomp_iter`] for an iterator that returns + /// key bit decompositions. + pub fn from_container( + cont: Cont, + lwe_size: LweSize, + compression_seed: CompressionSeed, + ) -> Self { + SeededLweKeyBitDecomposition { + tensor: Tensor::from_container(cont), + lwe_size, + compression_seed, + } + } + + /// Returns the size of the lwe ciphertexts encoding each level of the key bit decomposition. + #[allow(dead_code)] + pub fn lwe_size(&self) -> LweSize { + self.lwe_size + } + + /// Returns the number of ciphertexts in the decomposition. + /// + /// Note that this is actually equals to the number of levels in the decomposition. + #[allow(dead_code)] + pub fn count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + CiphertextCount(self.as_tensor().len()) + } + + /// Consumes the current key bit decomposition and returns a seeded lwe list. + /// + /// Note that this operation is super cheap, as it merely rewraps the current container in a + /// seeded lwe list structure. + pub fn into_seeded_lwe_list(self) -> LweSeededList + where + Cont: AsRefSlice, + { + LweSeededList::from_container( + self.tensor.into_container(), + self.lwe_size.to_lwe_dimension(), + self.compression_seed, + ) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/lwe/seeded_list.rs b/tfhe/src/core_crypto/commons/crypto/lwe/seeded_list.rs new file mode 100644 index 000000000..91efb9337 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/lwe/seeded_list.rs @@ -0,0 +1,294 @@ +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{CiphertextCount, LweDimension, LweSize}; + +use crate::core_crypto::commons::crypto::lwe::LweList; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, CompressionSeed, RandomGenerable, RandomGenerator, Uniform, +}; +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Tensor, +}; + +use super::LweBody; + +/// A seeded list of ciphertexts encrypted using the LWE scheme. +/// +/// Note: all ciphertexts in an [`LweSeededList`] share the same seed for mask generation and have +/// the same [`LweDimension`]. If you need mixed seeds or dimensions you can use a container storing +/// seeded ciphertexts directly. The bytes used to generate their masks however are not the same. +/// The bytes index to use for each mask is dependant on the ciphertext position in the list. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweSeededList { + tensor: Tensor, + lwe_dimension: LweDimension, + pub(crate) compression_seed: CompressionSeed, +} + +tensor_traits!(LweSeededList); + +impl LweSeededList> +where + Scalar: Numeric, +{ + /// Allocates a list of seeded LWE ciphertexts whose bodies are 0. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// assert_eq!(list.count(), CiphertextCount(20)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// assert_eq!(list.get_compression_seed(), compression_seed); + /// ``` + pub fn allocate( + lwe_dimension: LweDimension, + lwe_count: CiphertextCount, + compression_seed: CompressionSeed, + ) -> Self { + LweSeededList { + tensor: Tensor::from_container(vec![Scalar::ZERO; lwe_count.0]), + lwe_dimension, + compression_seed, + } + } +} + +impl LweSeededList { + /// Creates a list from a container, an [`LweDimension`] a [`Seed`] and a usize representing the + /// index of the first byte to generate for the masks of the list. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// assert_eq!(list.count(), CiphertextCount(20)); + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// assert_eq!(list.get_compression_seed(), compression_seed); + /// ``` + pub fn from_container(cont: Cont, lwe_dimension: LweDimension, seed: CompressionSeed) -> Self + where + Cont: AsRefSlice, + { + let tensor = Tensor::from_container(cont); + LweSeededList { + tensor, + lwe_dimension, + compression_seed: seed, + } + } + + /// Returns the number of ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// assert_eq!(list.count(), CiphertextCount(20)); + /// ``` + pub fn count(&self) -> CiphertextCount + where + Self: AsRefTensor, + { + CiphertextCount(self.as_tensor().len()) + } + + /// Returns the size of the ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// assert_eq!(list.lwe_size(), LweSize(10)); + /// ``` + pub fn lwe_size(&self) -> LweSize { + self.lwe_dimension.to_lwe_size() + } + + /// Returns the number of masks of the ciphertexts in the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// assert_eq!(list.mask_size(), LweDimension(9)); + /// ``` + pub fn mask_size(&self) -> LweDimension { + self.lwe_dimension + } + + /// Returns the seed of the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::LweSeededList; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// assert_eq!(list.get_compression_seed(), compression_seed); + /// ``` + pub fn get_compression_seed(&self) -> CompressionSeed { + self.compression_seed + } + + /// Returns an iterator over seeded ciphertexts from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededList}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// for body in list.body_iter() { + /// assert_eq!(body, &LweBody(0)); + /// } + /// assert_eq!(list.body_iter().count(), 20); + /// ``` + pub fn body_iter(&self) -> impl Iterator::Element>> + where + Self: AsRefTensor, + { + self.as_tensor() + .iter() + .map(|scalar| unsafe { std::mem::transmute(scalar) }) + } + + /// Returns an iterator over seeded ciphertexts from the list. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::lwe::{LweBody, LweSeededList}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let mut list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// for mut body in list.body_iter_mut() { + /// assert_eq!(body, &LweBody(0)); + /// body.0 = 1; + /// } + /// for mut body in list.body_iter() { + /// assert_eq!(body, &LweBody(1)); + /// } + /// assert_eq!(list.body_iter().count(), 20); + /// ``` + pub fn body_iter_mut( + &mut self, + ) -> impl Iterator::Element>> + where + Self: AsMutTensor, + { + self.as_mut_tensor() + .iter_mut() + .map(|scalar| unsafe { std::mem::transmute(scalar) }) + } + + pub fn expand_into_with_existing_generator( + self, + output: &mut LweList, + generator: &mut RandomGenerator, + ) where + LweList: AsMutTensor, + Self: AsRefTensor, + Scalar: RandomGenerable + Numeric, + Gen: ByteRandomGenerator, + { + for (mut lwe_out, body_in) in output.ciphertext_iter_mut().zip(self.body_iter()) { + let (output_body, mut output_mask) = lwe_out.get_mut_body_and_mask(); + + // generate a uniformly random mask + generator.fill_tensor_with_random_uniform(output_mask.as_mut_tensor()); + output_body.0 = body_in.0; + } + } + + /// Returns the ciphertext list as a full fledged LweList + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::lwe::{LweList, LweSeededList}; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// use tfhe::core_crypto::prelude::{CiphertextCount, LweDimension, LweSize}; + /// + /// let compression_seed = CompressionSeed { seed: Seed(42) }; + /// + /// let seeded_list = + /// LweSeededList::>::allocate(LweDimension(9), CiphertextCount(20), compression_seed); + /// + /// let mut list = LweList::allocate(0u8, seeded_list.lwe_size(), seeded_list.count()); + /// seeded_list.expand_into::<_, _, SoftwareRandomGenerator>(&mut list); + /// assert_eq!(list.mask_size(), LweDimension(9)); + /// ``` + pub fn expand_into(self, output: &mut LweList) + where + LweList: AsMutTensor, + Self: AsRefTensor, + Scalar: RandomGenerable + Numeric, + Gen: ByteRandomGenerator, + { + let mut generator = RandomGenerator::::new(self.compression_seed.seed); + + self.expand_into_with_existing_generator(output, &mut generator); + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/mod.rs b/tfhe/src/core_crypto/commons/crypto/mod.rs new file mode 100644 index 000000000..5aba90ae0 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/mod.rs @@ -0,0 +1,10 @@ +//! Low-overhead homomorphic primitives. +//! +//! This module implements low-overhead fully homomorphic operations. + +pub mod bootstrap; +pub mod encoding; +pub mod ggsw; +pub mod glwe; +pub mod lwe; +pub mod secret; diff --git a/tfhe/src/core_crypto/commons/crypto/secret/generators/encryption.rs b/tfhe/src/core_crypto/commons/crypto/secret/generators/encryption.rs new file mode 100644 index 000000000..22ba9b6f7 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/generators/encryption.rs @@ -0,0 +1,465 @@ +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, Gaussian, RandomGenerable, RandomGenerator, Seed, Seeder, Uniform, +}; +use crate::core_crypto::commons::math::tensor::AsMutTensor; + +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{ + DecompositionLevelCount, DispersionParameter, FunctionalPackingKeyswitchKeyCount, + GlweDimension, GlweSize, LweCiphertextCount, LweDimension, LweSize, PolynomialSize, +}; +use concrete_csprng::generators::ForkError; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; + +/// A random number generator which can be used to encrypt messages. +pub struct EncryptionRandomGenerator { + // A separate mask generator, only used to generate the mask elements. + mask: RandomGenerator, + // A separate noise generator, only used to generate the noise elements. + noise: RandomGenerator, +} + +impl EncryptionRandomGenerator { + /// Creates a new encryption, optionally seeding it with the given value. + // S is ?Sized to allow Box to be passed. + pub fn new(seed: Seed, seeder: &mut S) -> EncryptionRandomGenerator { + EncryptionRandomGenerator { + mask: RandomGenerator::new(seed), + noise: RandomGenerator::new(seeder.seed()), + } + } + + // Allows to seed the noise generator. For testing purpose only. + #[cfg(test)] + pub(crate) fn seed_noise_generator(&mut self, seed: Seed) { + println!("WARNING: The noise generator of the encryption random generator was seeded."); + self.noise = RandomGenerator::new(seed); + } + + /// Returns the number of remaining bytes for the mask generator, if the generator is bounded. + pub fn remaining_bytes(&self) -> Option { + self.mask.remaining_bytes() + } + + // Forks the generator, when splitting a bootstrap key into ggsw ct. + #[allow(dead_code)] + pub(crate) fn fork_bsk_to_ggsw( + &mut self, + lwe_dimension: LweDimension, + level: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_ggsw::(level, glwe_size, polynomial_size); + let noise_bytes = noise_bytes_per_ggsw(level, glwe_size, polynomial_size); + self.try_fork(lwe_dimension.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a ggsw into level matrices. + pub(crate) fn fork_ggsw_to_ggsw_levels( + &mut self, + level: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_ggsw_level::(glwe_size, polynomial_size); + let noise_bytes = noise_bytes_per_ggsw_level(glwe_size, polynomial_size); + self.try_fork(level.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a ggsw level matrix to glwe. + pub(crate) fn fork_ggsw_level_to_glwe( + &mut self, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_glwe::(glwe_size.to_glwe_dimension(), polynomial_size); + let noise_bytes = noise_bytes_per_glwe(polynomial_size); + self.try_fork(glwe_size.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a ggsw into level matrices. + pub(crate) fn fork_gsw_to_gsw_levels( + &mut self, + level: DecompositionLevelCount, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_gsw_level::(lwe_size); + let noise_bytes = noise_bytes_per_gsw_level(lwe_size); + self.try_fork(level.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a ggsw level matrix to glwe. + pub(crate) fn fork_gsw_level_to_lwe( + &mut self, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_lwe::(lwe_size.to_lwe_dimension()); + let noise_bytes = noise_bytes_per_lwe(); + self.try_fork(lwe_size.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting an lwe ciphertext list into ciphertexts. + pub(crate) fn fork_lwe_list_to_lwe( + &mut self, + lwe_count: LweCiphertextCount, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_lwe::(lwe_size.to_lwe_dimension()); + let noise_bytes = noise_bytes_per_lwe(); + self.try_fork(lwe_count.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a collection of pfpksk for cbs + pub(crate) fn fork_cbs_pfpksk_to_pfpksk( + &mut self, + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, + lwe_size: LweSize, + pfpksk_count: FunctionalPackingKeyswitchKeyCount, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_pfpksk::(level, glwe_size, poly_size, lwe_size); + let noise_bytes = noise_bytes_per_pfpksk(level, poly_size, lwe_size); + self.try_fork(pfpksk_count.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a pfpksk into chunks + pub(crate) fn fork_pfpksk_to_pfpksk_chunks( + &mut self, + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_pfpksk_chunk::(level, glwe_size, poly_size); + let noise_bytes = noise_bytes_per_pfpksk_chunk(level, poly_size); + self.try_fork(lwe_size.0, mask_bytes, noise_bytes) + } + + // Forks both generators into an iterator + fn try_fork( + &mut self, + n_child: usize, + mask_bytes: usize, + noise_bytes: usize, + ) -> Result>, ForkError> { + // We try to fork the generators + let mask_iter = self.mask.try_fork(n_child, mask_bytes)?; + let noise_iter = self.noise.try_fork(n_child, noise_bytes)?; + + // We return a proper iterator. + Ok(mask_iter + .zip(noise_iter) + .map(|(mask, noise)| EncryptionRandomGenerator { mask, noise })) + } + + // Fills the tensor with random uniform values, using the mask generator. + pub(crate) fn fill_tensor_with_random_mask( + &mut self, + output: &mut Tensorable, + ) where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + self.mask.fill_tensor_with_random_uniform(output) + } + + // Sample a noise value, using the noise generator. + pub(crate) fn random_noise(&mut self, std: impl DispersionParameter) -> Scalar + where + Scalar: RandomGenerable>, + { + ::generate_one( + &mut self.noise, + Gaussian { + std: std.get_standard_dev(), + mean: 0., + }, + ) + } + + // Fills the input tensor with random noise, using the noise generator. + pub(crate) fn fill_tensor_with_random_noise( + &mut self, + output: &mut Tensorable, + std: impl DispersionParameter, + ) where + (Scalar, Scalar): RandomGenerable>, + Tensorable: AsMutTensor, + { + self.noise + .fill_tensor_with_random_gaussian(output, 0., std.get_standard_dev()); + } +} + +#[cfg(feature = "__commons_parallel")] +impl EncryptionRandomGenerator { + // Forks the generator into a parallel iterator, when splitting a bootstrap key into ggsw ct. + pub(crate) fn par_fork_bsk_to_ggsw( + &mut self, + lwe_dimension: LweDimension, + level: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_ggsw::(level, glwe_size, polynomial_size); + let noise_bytes = noise_bytes_per_ggsw(level, glwe_size, polynomial_size); + // panic!("{:?} {:?} {:?}", lwe_dimension.0, mask_bytes, noise_bytes); + self.par_try_fork(lwe_dimension.0, mask_bytes, noise_bytes) + } + + // Forks the generator into a parallel iterator, when splitting a ggsw into level matrices. + pub(crate) fn par_fork_ggsw_to_ggsw_levels( + &mut self, + level: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_ggsw_level::(glwe_size, polynomial_size); + let noise_bytes = noise_bytes_per_ggsw_level(glwe_size, polynomial_size); + self.par_try_fork(level.0, mask_bytes, noise_bytes) + } + + // Forks the generator into a parallel iterator, when splitting a ggsw level matrix to glwe. + pub(crate) fn par_fork_ggsw_level_to_glwe( + &mut self, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_glwe::(glwe_size.to_glwe_dimension(), polynomial_size); + let noise_bytes = noise_bytes_per_glwe(polynomial_size); + self.par_try_fork(glwe_size.0, mask_bytes, noise_bytes) + } + + // Forks the generator into a parallel iterator, when splitting a ggsw into level matrices. + pub(crate) fn par_fork_gsw_to_gsw_levels( + &mut self, + level: DecompositionLevelCount, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_gsw_level::(lwe_size); + let noise_bytes = noise_bytes_per_gsw_level(lwe_size); + self.par_try_fork(level.0, mask_bytes, noise_bytes) + } + + // Forks the generator into a parallel iterator, when splitting a ggsw level matrix to glwe. + pub(crate) fn par_fork_gsw_level_to_lwe( + &mut self, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_lwe::(lwe_size.to_lwe_dimension()); + let noise_bytes = noise_bytes_per_lwe(); + self.par_try_fork(lwe_size.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting an lwe ciphertext list into ciphertexts. + pub(crate) fn par_fork_lwe_list_to_lwe( + &mut self, + lwe_count: LweCiphertextCount, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_lwe::(lwe_size.to_lwe_dimension()); + let noise_bytes = noise_bytes_per_lwe(); + self.par_try_fork(lwe_count.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a collection of pfpksk for cbs + pub(crate) fn par_fork_cbs_pfpksk_to_pfpksk( + &mut self, + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, + lwe_size: LweSize, + pfpksk_count: FunctionalPackingKeyswitchKeyCount, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_pfpksk::(level, glwe_size, poly_size, lwe_size); + let noise_bytes = noise_bytes_per_pfpksk(level, poly_size, lwe_size); + self.par_try_fork(pfpksk_count.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a pfpksk into chunks + pub(crate) fn par_fork_pfpksk_to_pfpksk_chunks( + &mut self, + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, + lwe_size: LweSize, + ) -> Result>, ForkError> { + let mask_bytes = mask_bytes_per_pfpksk_chunk::(level, glwe_size, poly_size); + let noise_bytes = noise_bytes_per_pfpksk_chunk(level, poly_size); + self.par_try_fork(lwe_size.0, mask_bytes, noise_bytes) + } + + // Forks both generators into a parallel iterator. + fn par_try_fork( + &mut self, + n_child: usize, + mask_bytes: usize, + noise_bytes: usize, + ) -> Result>, ForkError> { + // We try to fork the generators + let mask_iter = self.mask.par_try_fork(n_child, mask_bytes)?; + let noise_iter = self.noise.par_try_fork(n_child, noise_bytes)?; + + // We return a proper iterator. + Ok(mask_iter + .zip(noise_iter) + .map(|(mask, noise)| EncryptionRandomGenerator { mask, noise })) + } +} + +fn mask_bytes_per_coef() -> usize { + T::BITS / 8 +} + +fn mask_bytes_per_polynomial(poly_size: PolynomialSize) -> usize { + poly_size.0 * mask_bytes_per_coef::() +} + +fn mask_bytes_per_glwe( + glwe_dimension: GlweDimension, + poly_size: PolynomialSize, +) -> usize { + glwe_dimension.0 * mask_bytes_per_polynomial::(poly_size) +} + +fn mask_bytes_per_ggsw_level( + glwe_size: GlweSize, + poly_size: PolynomialSize, +) -> usize { + glwe_size.0 * mask_bytes_per_glwe::(glwe_size.to_glwe_dimension(), poly_size) +} + +fn mask_bytes_per_lwe(lwe_dimension: LweDimension) -> usize { + lwe_dimension.0 * mask_bytes_per_coef::() +} + +fn mask_bytes_per_gsw_level(lwe_size: LweSize) -> usize { + lwe_size.0 * mask_bytes_per_lwe::(lwe_size.to_lwe_dimension()) +} + +fn mask_bytes_per_ggsw( + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, +) -> usize { + level.0 * mask_bytes_per_ggsw_level::(glwe_size, poly_size) +} + +fn mask_bytes_per_pfpksk_chunk( + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, +) -> usize { + level.0 * mask_bytes_per_glwe::(glwe_size.to_glwe_dimension(), poly_size) +} + +fn mask_bytes_per_pfpksk( + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, + lwe_size: LweSize, +) -> usize { + lwe_size.0 * mask_bytes_per_pfpksk_chunk::(level, glwe_size, poly_size) +} + +fn noise_bytes_per_coef() -> usize { + // We use f64 to sample the noise for every precision, and we need 4/pi inputs to generate + // such an output (here we take 32 to keep a safety margin). + 8 * 32 +} +fn noise_bytes_per_polynomial(poly_size: PolynomialSize) -> usize { + poly_size.0 * noise_bytes_per_coef() +} + +fn noise_bytes_per_glwe(poly_size: PolynomialSize) -> usize { + noise_bytes_per_polynomial(poly_size) +} + +fn noise_bytes_per_ggsw_level(glwe_size: GlweSize, poly_size: PolynomialSize) -> usize { + glwe_size.0 * noise_bytes_per_glwe(poly_size) +} + +fn noise_bytes_per_lwe() -> usize { + // Here we take 3 to keep a safety margin + noise_bytes_per_coef() * 3 +} + +fn noise_bytes_per_gsw_level(lwe_size: LweSize) -> usize { + lwe_size.0 * noise_bytes_per_lwe() +} + +fn noise_bytes_per_ggsw( + level: DecompositionLevelCount, + glwe_size: GlweSize, + poly_size: PolynomialSize, +) -> usize { + level.0 * noise_bytes_per_ggsw_level(glwe_size, poly_size) +} + +fn noise_bytes_per_pfpksk_chunk( + level: DecompositionLevelCount, + poly_size: PolynomialSize, +) -> usize { + level.0 * noise_bytes_per_glwe(poly_size) +} + +fn noise_bytes_per_pfpksk( + level: DecompositionLevelCount, + poly_size: PolynomialSize, + lwe_size: LweSize, +) -> usize { + lwe_size.0 * noise_bytes_per_pfpksk_chunk(level, poly_size) +} + +#[cfg(all(test, feature = "__commons_parallel"))] +mod test { + use crate::core_crypto::commons::crypto::bootstrap::StandardBootstrapKey; + use crate::core_crypto::commons::crypto::secret::{GlweSecretKey, LweSecretKey}; + use crate::core_crypto::commons::test_tools::{ + new_encryption_random_generator, new_secret_random_generator, + }; + use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSize, LweDimension, PolynomialSize, + Variance, + }; + + #[test] + fn test_gaussian_sampling_margin_factor_does_not_panic() { + struct Params { + glwe_size: GlweSize, + poly_size: PolynomialSize, + dec_level_count: DecompositionLevelCount, + dec_base_log: DecompositionBaseLog, + lwe_dim: LweDimension, + } + let params = Params { + glwe_size: GlweSize(2), + poly_size: PolynomialSize(1), + dec_level_count: DecompositionLevelCount(1), + dec_base_log: DecompositionBaseLog(4), + lwe_dim: LweDimension(17000), + }; + let mut enc_generator = new_encryption_random_generator(); + let mut sec_generator = new_secret_random_generator(); + let mut bsk = StandardBootstrapKey::allocate( + 0u32, + params.glwe_size, + params.poly_size, + params.dec_level_count, + params.dec_base_log, + params.lwe_dim, + ); + let lwe_sk = LweSecretKey::generate_binary(params.lwe_dim, &mut sec_generator); + let glwe_sk = GlweSecretKey::generate_binary( + params.glwe_size.to_glwe_dimension(), + params.poly_size, + &mut sec_generator, + ); + bsk.par_fill_with_new_key(&lwe_sk, &glwe_sk, Variance(0.), &mut enc_generator); + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/secret/generators/mod.rs b/tfhe/src/core_crypto/commons/crypto/secret/generators/mod.rs new file mode 100644 index 000000000..6e92ced03 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/generators/mod.rs @@ -0,0 +1,8 @@ +mod encryption; +pub use encryption::EncryptionRandomGenerator; + +mod secret; +pub use secret::SecretRandomGenerator; + +mod seeder; +pub use seeder::DeterministicSeeder; diff --git a/tfhe/src/core_crypto/commons/crypto/secret/generators/secret.rs b/tfhe/src/core_crypto/commons/crypto/secret/generators/secret.rs new file mode 100644 index 000000000..dbdc35427 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/generators/secret.rs @@ -0,0 +1,55 @@ +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, Gaussian, RandomGenerable, RandomGenerator, Seed, +}; +use crate::core_crypto::commons::math::tensor::Tensor; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::prelude::DispersionParameter; + +/// A random number generator which can be used to generate secret keys. +pub struct SecretRandomGenerator(RandomGenerator); + +impl SecretRandomGenerator { + /// Creates a new generator, optionally seeding it with the given value. + pub fn new(seed: Seed) -> SecretRandomGenerator { + SecretRandomGenerator(RandomGenerator::new(seed)) + } + + /// Returns the number of remaining bytes, if the generator is bounded. + pub fn remaining_bytes(&self) -> Option { + self.0.remaining_bytes() + } + + // Returns a tensor with random uniform binary values. + pub(crate) fn random_binary_tensor(&mut self, length: usize) -> Tensor> + where + Scalar: UnsignedTorus, + { + self.0.random_uniform_binary_tensor(length) + } + + // Returns a tensor with random uniform ternary values. + pub(crate) fn random_ternary_tensor(&mut self, length: usize) -> Tensor> + where + Scalar: UnsignedTorus, + { + self.0.random_uniform_ternary_tensor(length) + } + + // Returns a tensor with random uniform values. + pub(crate) fn random_uniform_tensor(&mut self, length: usize) -> Tensor> + where + Scalar: UnsignedTorus, + { + self.0.random_uniform_tensor(length) + } + + // Returns a tensor with random gaussian values. + pub(crate) fn random_gaussian_tensor(&mut self, length: usize) -> Tensor> + where + (Scalar, Scalar): RandomGenerable>, + Scalar: UnsignedTorus, + { + self.0 + .random_gaussian_tensor(length, 0.0, Scalar::GAUSSIAN_KEY_LOG_STD.get_standard_dev()) + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/secret/generators/seeder.rs b/tfhe/src/core_crypto/commons/crypto/secret/generators/seeder.rs new file mode 100644 index 000000000..d5c3e207b --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/generators/seeder.rs @@ -0,0 +1,56 @@ +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, RandomGenerable, RandomGenerator, Seed, Seeder, Uniform, +}; + +/// Seeder backed by a CSPRNG +/// +/// ------------ +/// ## Why this Seeder implementation? +/// +/// [`Seeder`] is a trait available to the external user, and we expect some of them to implement +/// their own seeding strategy. Since this trait is public, it means that the implementor can be +/// arbitrarily slow. For this reason, it is better to only use it once when we initialize the +/// engine, and use the CSPRNG to generate other seeds when needed, because that gives us the +/// control on the performances. +/// +/// ## Is it safe? +/// +/// The answer to this question is the following: as long as the the CSPRNG used in this [`Seeder`] +/// is seeded with a [`Seed`] coming from an entropy source then yes, seeding other CSPRNGs using +/// this CSPRNG is safe. +/// +/// ## Why is it deterministic? +/// +/// A CSPRNG is a Cryptograhically Secure Pseudo Random Number Generator. +/// +/// Cryptographically Secure means that if one looks at the numbers it outputs, it looks exactly +/// like numbers drawn from a random distribution, this property is also known as "indistinguishable +/// from random". Here our CSPRNG outputs numbers uniformly so each value for a byte should appear +/// with the same probability. +/// +/// Pseudo Random indicates that for the same initial state (here Seed) it will generate the same +/// exact set of numbers in the same order, making it deterministic. +pub struct DeterministicSeeder { + generator: RandomGenerator, +} + +impl DeterministicSeeder { + pub fn new(seed: Seed) -> Self { + DeterministicSeeder { + generator: RandomGenerator::new(seed), + } + } +} + +impl Seeder for DeterministicSeeder { + fn seed(&mut self) -> Seed { + Seed(u128::generate_one(&mut self.generator, Uniform)) + } + + fn is_available() -> bool + where + Self: Sized, + { + true + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/secret/glwe.rs b/tfhe/src/core_crypto/commons/crypto/secret/glwe.rs new file mode 100644 index 000000000..7664ec072 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/glwe.rs @@ -0,0 +1,1684 @@ +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::ggsw::{ + StandardGgswCiphertext, StandardGgswSeededCiphertext, +}; +use crate::core_crypto::commons::crypto::glwe::{ + GlweBody, GlweCiphertext, GlweList, GlweMask, GlweSeededCiphertext, GlweSeededList, +}; +use crate::core_crypto::commons::crypto::secret::generators::{ + EncryptionRandomGenerator, SecretRandomGenerator, +}; +use crate::core_crypto::commons::crypto::secret::LweSecretKey; +use crate::core_crypto::commons::math::polynomial::PolynomialList; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, Gaussian, RandomGenerable, Seeder, +}; +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, ck_dim_eq, AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, IntoTensor, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::Numeric; +use crate::core_crypto::prelude::{ + BinaryKeyKind, DispersionParameter, GaussianKeyKind, GlweDimension, KeyKind, PlaintextCount, + PolynomialSize, TernaryKeyKind, UniformKeyKind, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; +use std::ops::Add; + +/// A GLWE secret key +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct GlweSecretKey +where + Kind: KeyKind, +{ + tensor: Tensor, + poly_size: PolynomialSize, + kind: PhantomData, +} + +impl GlweSecretKey> +where + Scalar: UnsignedTorus, +{ + /// Allocates a container for a new key, and fills it with random binary values. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{BinaryKeyKind, GlweDimension, PolynomialSize}; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey> = + /// GlweSecretKey::generate_binary(GlweDimension(256), PolynomialSize(10), &mut generator); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn generate_binary( + dimension: GlweDimension, + poly_size: PolynomialSize, + generator: &mut SecretRandomGenerator, + ) -> Self { + GlweSecretKey { + tensor: generator.random_binary_tensor(poly_size.0 * dimension.0), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey> +where + Scalar: UnsignedTorus, +{ + /// Allocates a container for a new key, and fill it with random ternary values. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, TernaryKeyKind}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey<_, Vec> = GlweSecretKey::generate_ternary( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn generate_ternary( + dimension: GlweDimension, + poly_size: PolynomialSize, + generator: &mut SecretRandomGenerator, + ) -> Self { + GlweSecretKey { + tensor: generator.random_ternary_tensor(poly_size.0 * dimension.0), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey> +where + (Scalar, Scalar): RandomGenerable>, + Scalar: UnsignedTorus, +{ + /// Allocates a container for a new key, and fill it with random gaussian values. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{ + /// GaussianKeyKind, GlweDimension, LweDimension, PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey> = GlweSecretKey::generate_gaussian( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn generate_gaussian( + dimension: GlweDimension, + poly_size: PolynomialSize, + generator: &mut SecretRandomGenerator, + ) -> Self { + GlweSecretKey { + tensor: generator.random_gaussian_tensor(poly_size.0 * dimension.0), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey> +where + Scalar: UnsignedTorus, +{ + /// Allocates a container for a new key, and fill it with random uniform values. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize, UniformKeyKind}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey> = GlweSecretKey::generate_uniform( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn generate_uniform( + dimension: GlweDimension, + poly_size: PolynomialSize, + generator: &mut SecretRandomGenerator, + ) -> Self { + GlweSecretKey { + tensor: generator.random_uniform_tensor(poly_size.0 * dimension.0), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey { + /// Creates a binary key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random data. It merely wraps the container in + /// the appropriate type. For a method that generate a new random key see + /// [`GlweSecretKey::generate_binary`]. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let secret_key = + /// GlweSecretKey::binary_from_container(vec![0 as u8; 11 * 256], PolynomialSize(11)); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(11)); + /// ``` + pub fn binary_from_container(cont: Cont, poly_size: PolynomialSize) -> Self + where + Cont: AsRefSlice, + { + ck_dim_div!(cont.as_slice().len() => poly_size.0); + GlweSecretKey { + tensor: Tensor::from_container(cont), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey { + /// Creates a ternary key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random data. It merely wraps the container in + /// the appropriate type. For a method that generate a new random key see + /// [`GlweSecretKey::generate_ternary`]. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let secret_key = + /// GlweSecretKey::ternary_from_container(vec![0 as u8; 11 * 256], PolynomialSize(11)); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(11)); + /// ``` + pub fn ternary_from_container(cont: Cont, poly_size: PolynomialSize) -> Self + where + Cont: AsRefSlice, + { + ck_dim_div!(cont.as_slice().len() => poly_size.0); + GlweSecretKey { + tensor: Tensor::from_container(cont), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey { + /// Creates a gaussian key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random data. It merely wraps the container in + /// the appropriate type. For a method that generate a new random key see + /// [`GlweSecretKey::generate_gaussian`]. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let secret_key = + /// GlweSecretKey::binary_from_container(vec![0 as u8; 11 * 256], PolynomialSize(11)); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(11)); + /// ``` + pub fn gaussian_from_container(cont: Cont, poly_size: PolynomialSize) -> Self + where + Cont: AsRefSlice, + { + ck_dim_div!(cont.as_slice().len() => poly_size.0); + GlweSecretKey { + tensor: Tensor::from_container(cont), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey { + /// Creates a uniform key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random data. It merely wraps the container in + /// the appropriate type. For a method that generate a new random key see + /// [`GlweSecretKey::generate_uniform`]. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let secret_key = + /// GlweSecretKey::binary_from_container(vec![0 as u8; 11 * 256], PolynomialSize(11)); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(11)); + /// ``` + pub fn uniform_from_container(cont: Cont, poly_size: PolynomialSize) -> Self + where + Cont: AsRefSlice, + { + ck_dim_div!(cont.as_slice().len() => poly_size.0); + GlweSecretKey { + tensor: Tensor::from_container(cont), + poly_size, + kind: PhantomData, + } + } +} + +impl GlweSecretKey> +where + Kind: KeyKind, +{ + /// Consumes the current GLWE secret key and turns it into an LWE secret key. + /// + /// # Examples + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::prelude::{GlweDimension, LweDimension, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let glwe_secret_key: GlweSecretKey<_, Vec> = + /// GlweSecretKey::generate_binary(GlweDimension(2), PolynomialSize(10), &mut secret_generator); + /// let lwe_secret_key = glwe_secret_key.into_lwe_secret_key(); + /// assert_eq!(lwe_secret_key.key_size(), LweDimension(20)) + /// ``` + pub fn into_lwe_secret_key(self) -> LweSecretKey> { + LweSecretKey { + tensor: self.tensor, + kind: PhantomData, + } + } +} + +impl GlweSecretKey +where + Kind: KeyKind, +{ + /// Returns the size of the secret key. + /// + /// This is equivalent to the number of masks in the [`GlweCiphertext`]. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey<_, Vec> = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// assert_eq!(secret_key.key_size(), GlweDimension(256)); + /// ``` + pub fn key_size(&self) -> GlweDimension + where + Self: AsRefTensor, + { + GlweDimension(self.as_tensor().len() / self.poly_size.0) + } + + /// Returns the size of the secret key polynomials. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey<_, Vec> = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// assert_eq!(secret_key.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns a borrowed polynomial list from the current key. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialCount, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey<_, Vec> = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// let poly = secret_key.as_polynomial_list(); + /// assert_eq!(poly.polynomial_count(), PolynomialCount(256)); + /// assert_eq!(poly.polynomial_size(), PolynomialSize(10)); + /// ``` + pub fn as_polynomial_list(&self) -> PolynomialList<&[::Element]> + where + Self: AsRefTensor, + { + PolynomialList::from_container(self.as_tensor().as_slice(), self.poly_size) + } + + /// Returns a mutably borrowed polynomial list from the current key. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweDimension, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let mut secret_key: GlweSecretKey<_, Vec> = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(10), + /// &mut secret_generator, + /// ); + /// let mut poly = secret_key.as_mut_polynomial_list(); + /// poly.as_mut_tensor().fill_with_element(1); + /// assert!(secret_key.as_tensor().iter().all(|a| *a == 1)); + /// ``` + pub fn as_mut_polynomial_list( + &mut self, + ) -> PolynomialList<&mut [::Element]> + where + Self: AsMutTensor, + { + let poly_size = self.poly_size; + PolynomialList::from_container(self.as_mut_tensor().as_mut_slice(), poly_size) + } + + fn fill_glwe_mask_and_body_for_encryption( + &self, + mut output_body: GlweBody, + mut output_mask: GlweMask, + encoded: &PlaintextList, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + GlweBody: AsMutTensor, + GlweMask: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + generator.fill_tensor_with_random_noise(&mut output_body, noise_parameters); + + generator.fill_tensor_with_random_mask(&mut output_mask); + + output_body + .as_mut_polynomial() + .update_with_wrapping_add_multisum( + &output_mask.as_polynomial_list(), + &self.as_polynomial_list(), + ); + output_body + .as_mut_polynomial() + .update_with_wrapping_add(&encoded.as_polynomial()); + } + + /// Encrypts a single GLWE ciphertext. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, LogStandardDev, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(5), + /// &mut secret_generator, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-50.); + /// let plaintexts = + /// PlaintextList::from_container(vec![100000 as u32, 200000, 300000, 400000, 500000]); + /// let mut ciphertext = GlweCiphertext::allocate(0 as u32, PolynomialSize(5), GlweSize(257)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.encrypt_glwe( + /// &mut ciphertext, + /// &plaintexts, + /// noise, + /// &mut encryption_generator, + /// ); + /// let mut decrypted = PlaintextList::from_container(vec![0 as u32, 0, 0, 0, 0]); + /// secret_key.decrypt_glwe(&mut decrypted, &ciphertext); + /// for (dec, plain) in decrypted.plaintext_iter().zip(plaintexts.plaintext_iter()) { + /// let d0 = dec.0.wrapping_sub(plain.0); + /// let d1 = plain.0.wrapping_sub(dec.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400, "dist: {:?}", dist); + /// } + /// ``` + pub fn encrypt_glwe( + &self, + encrypted: &mut GlweCiphertext, + encoded: &PlaintextList, + noise_parameter: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + GlweCiphertext: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(encoded.count().0 => encrypted.polynomial_size().0); + ck_dim_eq!(encrypted.mask_size().0 => self.key_size().0); + + let (body, masks) = encrypted.get_mut_body_and_mask(); + + self.fill_glwe_mask_and_body_for_encryption( + body, + masks, + encoded, + noise_parameter, + generator, + ); + } + + pub fn encrypt_seeded_glwe_with_existing_generator( + &self, + encrypted: &mut GlweSeededCiphertext, + encoded: &PlaintextList, + noise_parameter: NoiseParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Scalar: UnsignedTorus, + Self: AsRefTensor, + GlweSeededCiphertext: AsMutTensor, + PlaintextList: AsRefTensor, + NoiseParameter: DispersionParameter, + Gen: ByteRandomGenerator, + { + let masks = GlweMask { + tensor: Tensor::allocate(Scalar::ZERO, self.polynomial_size().0 * self.key_size().0), + poly_size: encrypted.polynomial_size(), + }; + let body = encrypted.get_mut_body(); + + self.fill_glwe_mask_and_body_for_encryption( + body, + masks, + encoded, + noise_parameter, + generator, + ); + } + + /// Encrypts a single seeded GLWE ciphertext. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweCiphertext, GlweSeededCiphertext}; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, LogStandardDev, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(5), + /// &mut secret_generator, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-50.); + /// let plaintexts = + /// PlaintextList::from_container(vec![100000 as u32, 200000, 300000, 400000, 500000]); + /// let mut seeded_ciphertext = GlweSeededCiphertext::allocate( + /// PolynomialSize(5), + /// GlweDimension(256), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let mut seeder = UnixSeeder::new(0); + /// secret_key.encrypt_seeded_glwe::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &mut seeded_ciphertext, + /// &plaintexts, + /// noise, + /// &mut seeder, + /// ); + /// + /// let mut ciphertext = GlweCiphertext::allocate( + /// 0 as u32, + /// seeded_ciphertext.polynomial_size(), + /// seeded_ciphertext.size(), + /// ); + /// + /// seeded_ciphertext.expand_into::<_, _, SoftwareRandomGenerator>(&mut ciphertext); + /// + /// let mut decrypted = PlaintextList::from_container(vec![0 as u32, 0, 0, 0, 0]); + /// secret_key.decrypt_glwe(&mut decrypted, &ciphertext); + /// for (dec, plain) in decrypted.plaintext_iter().zip(plaintexts.plaintext_iter()) { + /// let d0 = dec.0.wrapping_sub(plain.0); + /// let d1 = plain.0.wrapping_sub(dec.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400, "dist: {:?}", dist); + /// } + /// ``` + pub fn encrypt_seeded_glwe( + &self, + encrypted: &mut GlweSeededCiphertext, + encoded: &PlaintextList, + noise_parameter: NoiseParameter, + seeder: &mut NoiseSeeder, + ) where + Self: AsRefTensor, + GlweSeededCiphertext: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + NoiseParameter: DispersionParameter, + NoiseSeeder: Seeder, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(encrypted.mask_size().0 => self.key_size().0); + + let mut generator = + EncryptionRandomGenerator::::new(encrypted.compression_seed().seed, seeder); + + self.encrypt_seeded_glwe_with_existing_generator( + encrypted, + encoded, + noise_parameter, + &mut generator, + ); + } + + /// Encrypts a zero plaintext into a GLWE ciphertext. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::GlweCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{GlweDimension, GlweSize, LogStandardDev, PolynomialSize}; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(5), + /// &mut secret_generator, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-50.); + /// let mut ciphertext = GlweCiphertext::allocate(0 as u32, PolynomialSize(5), GlweSize(257)); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.encrypt_zero_glwe(&mut ciphertext, noise, &mut encryption_generator); + /// let mut decrypted = PlaintextList::from_container(vec![0 as u32, 0, 0, 0, 0]); + /// secret_key.decrypt_glwe(&mut decrypted, &ciphertext); + /// for dec in decrypted.plaintext_iter() { + /// let d0 = dec.0.wrapping_sub(0u32); + /// let d1 = 0u32.wrapping_sub(dec.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 500, "dist: {:?}", dist); + /// } + /// ``` + pub fn encrypt_zero_glwe( + &self, + encrypted: &mut GlweCiphertext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + GlweCiphertext: AsMutTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(encrypted.mask_size().0 => self.key_size().0); + let (mut body, mut masks) = encrypted.get_mut_body_and_mask(); + generator.fill_tensor_with_random_noise(&mut body, noise_parameters); + generator.fill_tensor_with_random_mask(&mut masks); + body.as_mut_polynomial().update_with_wrapping_add_multisum( + &masks.as_mut_polynomial_list(), + &self.as_polynomial_list(), + ); + } + + /// Encrypts a list of GLWE ciphertexts. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweCiphertext, GlweList}; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, GlweDimension, LogStandardDev, PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(2), + /// &mut secret_generator, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-60.); + /// let plaintexts = PlaintextList::from_container(vec![1000 as u32, 2000, 3000, 4000]); + /// let mut ciphertexts = GlweList::allocate( + /// 0 as u32, + /// PolynomialSize(2), + /// GlweDimension(256), + /// CiphertextCount(2), + /// ); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.encrypt_glwe_list( + /// &mut ciphertexts, + /// &plaintexts, + /// noise, + /// &mut encryption_generator, + /// ); + /// let mut decrypted = PlaintextList::from_container(vec![0 as u32, 0, 0, 0]); + /// secret_key.decrypt_glwe_list(&mut decrypted, &ciphertexts); + /// for (dec, plain) in decrypted.plaintext_iter().zip(plaintexts.plaintext_iter()) { + /// let d0 = dec.0.wrapping_sub(plain.0); + /// let d1 = plain.0.wrapping_sub(dec.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400, "dist: {:?}", dist); + /// } + /// ``` + pub fn encrypt_glwe_list( + &self, + encrypt: &mut GlweList, + encoded: &PlaintextList, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + GlweList: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + for<'a> PlaintextList<&'a [Scalar]>: AsRefTensor, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(encrypt.ciphertext_count().0 * encrypt.polynomial_size().0 => encoded.count().0); + ck_dim_eq!(encrypt.glwe_dimension().0 => self.key_size().0); + + let count = PlaintextCount(encrypt.polynomial_size().0); + for (mut ciphertext, encoded) in encrypt + .ciphertext_iter_mut() + .zip(encoded.sublist_iter(count)) + { + self.encrypt_glwe(&mut ciphertext, &encoded, noise_parameters, generator); + } + } + + /// Encrypts a list of seeded GLWE ciphertexts. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweCiphertext, GlweList, GlweSeededList}; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, GlweDimension, LogStandardDev, PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(2), + /// &mut secret_generator, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-60.); + /// let plaintexts = PlaintextList::from_container(vec![1000 as u32, 2000, 3000, 4000]); + /// let mut seeded_ciphertexts = GlweSeededList::allocate( + /// PolynomialSize(2), + /// GlweDimension(256), + /// CiphertextCount(2), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let mut seeder = UnixSeeder::new(0); + /// secret_key.encrypt_seeded_glwe_list::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &mut seeded_ciphertexts, + /// &plaintexts, + /// noise, + /// &mut seeder, + /// ); + /// + /// let mut ciphertexts = GlweList::allocate( + /// 0 as u32, + /// seeded_ciphertexts.polynomial_size(), + /// seeded_ciphertexts.glwe_size().to_glwe_dimension(), + /// seeded_ciphertexts.ciphertext_count(), + /// ); + /// + /// seeded_ciphertexts.expand_into::<_, _, SoftwareRandomGenerator>(&mut ciphertexts); + /// + /// let mut decrypted = PlaintextList::from_container(vec![0 as u32, 0, 0, 0]); + /// secret_key.decrypt_glwe_list(&mut decrypted, &ciphertexts); + /// for (dec, plain) in decrypted.plaintext_iter().zip(plaintexts.plaintext_iter()) { + /// let d0 = dec.0.wrapping_sub(plain.0); + /// let d1 = plain.0.wrapping_sub(dec.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400, "dist: {:?}", dist); + /// } + /// ``` + pub fn encrypt_seeded_glwe_list( + &self, + encrypt: &mut GlweSeededList, + encoded: &PlaintextList, + noise_parameters: NoiseParameter, + seeder: &mut NoiseSeeder, + ) where + Self: AsRefTensor, + GlweSeededList: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + for<'a> PlaintextList<&'a [Scalar]>: AsRefTensor, + NoiseParameter: DispersionParameter, + NoiseSeeder: Seeder, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(encrypt.ciphertext_count().0 * encrypt.polynomial_size().0 => encoded.count().0); + ck_dim_eq!(encrypt.glwe_dimension().0 => self.key_size().0); + + let mut generator = + EncryptionRandomGenerator::::new(encrypt.compression_seed().seed, seeder); + + let count = PlaintextCount(encrypt.polynomial_size().0); + let polynomial_size = encrypt.polynomial_size(); + for (body, encoded) in encrypt.body_iter_mut().zip(encoded.sublist_iter(count)) { + let masks = GlweMask { + tensor: Tensor::allocate( + Scalar::ZERO, + self.polynomial_size().0 * self.key_size().0, + ), + poly_size: polynomial_size, + }; + + self.fill_glwe_mask_and_body_for_encryption( + body, + masks, + &encoded, + noise_parameters, + &mut generator, + ); + } + } + + /// Encrypts a list of GLWE ciphertexts, with a zero plaintext. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::PlaintextList; + /// use tfhe::core_crypto::commons::crypto::glwe::{GlweCiphertext, GlweList}; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor}; + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, GlweDimension, LogStandardDev, PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = GlweSecretKey::generate_binary( + /// GlweDimension(256), + /// PolynomialSize(2), + /// &mut secret_generator, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-60.); + /// let mut ciphertexts = GlweList::allocate( + /// 0 as u32, + /// PolynomialSize(2), + /// GlweDimension(256), + /// CiphertextCount(2), + /// ); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.encrypt_zero_glwe_list(&mut ciphertexts, noise, &mut encryption_generator); + /// let mut decrypted = PlaintextList::from_container(vec![0 as u32, 0, 0, 0]); + /// secret_key.decrypt_glwe_list(&mut decrypted, &ciphertexts); + /// for dec in decrypted.plaintext_iter() { + /// let d0 = dec.0.wrapping_sub(0u32); + /// let d1 = 0u32.wrapping_sub(dec.0); + /// let dist = std::cmp::min(d0, d1); + /// assert!(dist < 400, "dist: {:?}", dist); + /// } + /// ``` + pub fn encrypt_zero_glwe_list( + &self, + encrypted: &mut GlweList, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + GlweList: AsMutTensor, + Scalar: UnsignedTorus + Add, + Gen: ByteRandomGenerator, + { + for mut ciphertext in encrypted.ciphertext_iter_mut() { + self.encrypt_zero_glwe(&mut ciphertext, noise_parameters, generator); + } + } + + /// Decrypts a single GLWE ciphertext. + /// + /// See ['GlweSecretKey::encrypt_glwe`] for an example. + pub fn decrypt_glwe( + &self, + encoded: &mut PlaintextList, + encrypted: &GlweCiphertext, + ) where + Self: AsRefTensor, + PlaintextList: AsMutTensor, + GlweCiphertext: AsRefTensor, + Scalar: UnsignedTorus + Add, + { + ck_dim_eq!(encoded.count().0 => encrypted.polynomial_size().0); + let (body, masks) = encrypted.get_body_and_mask(); + encoded + .as_mut_tensor() + .fill_with_one(body.as_tensor(), |a| *a); + encoded + .as_mut_polynomial() + .update_with_wrapping_sub_multisum( + &masks.as_polynomial_list(), + &self.as_polynomial_list(), + ); + } + + /// Decrypts a list of GLWE ciphertexts. + /// + /// See ['GlweSecretKey::encrypt_glwe_list`] for an example. + pub fn decrypt_glwe_list( + &self, + encoded: &mut PlaintextList, + encrypted: &GlweList, + ) where + Self: AsRefTensor, + PlaintextList: AsMutTensor, + GlweList: AsRefTensor, + Scalar: UnsignedTorus + Add, + for<'a> PlaintextList<&'a mut [Scalar]>: AsMutTensor, + { + ck_dim_eq!(encrypted.ciphertext_count().0 * encrypted.polynomial_size().0 => encoded.count().0); + ck_dim_eq!(encrypted.glwe_dimension().0 => self.key_size().0); + for (ciphertext, mut encoded) in encrypted + .ciphertext_iter() + .zip(encoded.sublist_iter_mut(PlaintextCount(encrypted.polynomial_size().0))) + { + self.decrypt_glwe(&mut encoded, &ciphertext); + } + } + + fn encrypt_constant_ggsw_row( + &self, + (row_index, last_row_index): (usize, usize), + factor: &Scalar, + sk_poly_list: &PolynomialList, + row_as_glwe: &mut GlweCiphertext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Scalar: UnsignedTorus, + Self: AsRefTensor, + PolynomialList: AsRefTensor, + GlweCiphertext: AsMutTensor, + Gen: ByteRandomGenerator, + { + if row_index < last_row_index { + // Not the last row + let sk_poly = sk_poly_list.get_polynomial(row_index); + + // We need a copy of the polynomial to not modify the GLWE secret key + let mut sk_factored = + Tensor::from_container(sk_poly.as_tensor().as_container().to_vec()); + + sk_factored.update_with_wrapping_scalar_mul(factor); + + let encoded = PlaintextList::from_tensor(sk_factored); + + self.encrypt_glwe(row_as_glwe, &encoded, noise_parameters, generator) + } else { + // The last row needs a slightly different treatment + let mut encoded = + PlaintextList::allocate(Scalar::ZERO, PlaintextCount(self.poly_size.0)); + let first_coeff = encoded.as_mut_tensor().first_mut(); + *first_coeff = first_coeff.wrapping_add(factor.wrapping_neg()); + + self.encrypt_glwe(row_as_glwe, &encoded, noise_parameters, generator); + } + } + + /// This function encrypts a message as a GGSW ciphertext. + /// + /// # Examples + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::Plaintext; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// PolynomialSize, + /// }; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = + /// GlweSecretKey::generate_binary(GlweDimension(2), PolynomialSize(10), &mut generator); + /// let mut ciphertext = StandardGgswCiphertext::allocate( + /// 0 as u32, + /// PolynomialSize(10), + /// GlweSize(3), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(7), + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// let mut secret_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.encrypt_constant_ggsw( + /// &mut ciphertext, + /// &Plaintext(10), + /// noise, + /// &mut secret_generator, + /// ); + /// ``` + pub fn encrypt_constant_ggsw( + &self, + encrypted: &mut StandardGgswCiphertext, + encoded: &Plaintext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + StandardGgswCiphertext: AsMutTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(self.polynomial_size() => encrypted.polynomial_size()); + ck_dim_eq!(self.key_size() => encrypted.glwe_size().to_glwe_dimension()); + + let gen_iter = generator + .fork_ggsw_to_ggsw_levels::( + encrypted.decomposition_level_count(), + self.key_size().to_glwe_size(), + self.poly_size, + ) + .expect("Failed to split generator into ggsw levels"); + + let base_log = encrypted.decomposition_base_log(); + for (mut matrix, mut generator) in encrypted.level_matrix_iter_mut().zip(gen_iter) { + let factor = encoded.0.wrapping_neg().wrapping_mul( + Scalar::ONE << (Scalar::BITS - (base_log.0 * (matrix.decomposition_level().0))), + ); + + // We iterate over the rows of the level matrix, the last row needs special treatment + let gen_iter = generator + .fork_ggsw_level_to_glwe::(self.key_size().to_glwe_size(), self.poly_size) + .expect("Failed to split generator into rlwe"); + + let last_row_index = matrix.glwe_size().0 - 1; + let sk_poly_list = &self.as_polynomial_list(); + + for ((row_index, row), mut generator) in matrix.row_iter_mut().enumerate().zip(gen_iter) + { + self.encrypt_constant_ggsw_row( + (row_index, last_row_index), + &factor, + sk_poly_list, + &mut row.into_glwe(), + noise_parameters, + &mut generator, + ); + } + } + } + + fn encrypt_constant_seeded_ggsw_row( + &self, + (row_index, last_row_index): (usize, usize), + factor: &Scalar, + sk_poly_list: &PolynomialList, + row_as_seeded_glwe: &mut GlweSeededCiphertext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Scalar: UnsignedTorus, + Self: AsRefTensor, + PolynomialList: AsRefTensor, + GlweCiphertext: AsMutTensor, + OutputCont: AsMutSlice, + Gen: ByteRandomGenerator, + { + if row_index < last_row_index { + // Not the last row + let sk_poly = sk_poly_list.get_polynomial(row_index); + + // We need a copy of the polynomial to not modify the GLWE secret key + let mut sk_factored = + Tensor::from_container(sk_poly.as_tensor().as_container().to_vec()); + + sk_factored.update_with_wrapping_scalar_mul(factor); + + let encoded = PlaintextList::from_tensor(sk_factored); + + self.encrypt_seeded_glwe_with_existing_generator( + row_as_seeded_glwe, + &encoded, + noise_parameters, + generator, + ); + } else { + // The last row needs a slightly different treatment + let mut encoded = + PlaintextList::allocate(Scalar::ZERO, PlaintextCount(self.poly_size.0)); + let first_coeff = encoded.as_mut_tensor().first_mut(); + *first_coeff = first_coeff.wrapping_add(factor.wrapping_neg()); + + self.encrypt_seeded_glwe_with_existing_generator( + row_as_seeded_glwe, + &encoded, + noise_parameters, + generator, + ); + } + } + + /// Factorized function to be able to encrypt a GGSW with a generator in a particular state i.e. + /// not freshly instantiated. The caller is responsible for maintaining consistency. + pub fn encrypt_constant_seeded_ggsw_with_existing_generator< + OutputCont, + Scalar, + NoiseParameter, + Gen, + >( + &self, + encrypted: &mut StandardGgswSeededCiphertext, + encoded: &Plaintext, + noise_parameters: NoiseParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + StandardGgswCiphertext: AsMutTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus, + NoiseParameter: DispersionParameter, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(self.polynomial_size() => encrypted.polynomial_size()); + ck_dim_eq!(self.key_size() => encrypted.glwe_size().to_glwe_dimension()); + + let gen_iter = generator + .fork_ggsw_to_ggsw_levels::( + encrypted.decomposition_level_count(), + self.key_size().to_glwe_size(), + self.poly_size, + ) + .expect("Failed to split generator into ggsw levels"); + + let base_log = encrypted.decomposition_base_log(); + for (mut matrix, mut generator) in encrypted.level_matrix_iter_mut().zip(gen_iter) { + let factor = encoded.0.wrapping_neg().wrapping_mul( + Scalar::ONE << (Scalar::BITS - (base_log.0 * (matrix.decomposition_level().0))), + ); + + // We iterate over the rows of the level matrix, the last row needs special treatment + let gen_iter = generator + .fork_ggsw_level_to_glwe::(self.key_size().to_glwe_size(), self.poly_size) + .expect("Failed to split generator into glwe"); + + let last_row_index = matrix.glwe_size().0 - 1; + let sk_poly_list = &self.as_polynomial_list(); + + for ((row_index, row), mut generator) in matrix.row_iter_mut().enumerate().zip(gen_iter) + { + self.encrypt_constant_seeded_ggsw_row( + (row_index, last_row_index), + &factor, + sk_poly_list, + &mut row.into_seeded_glwe(), + noise_parameters, + &mut generator, + ); + } + } + } + + /// This function encrypts a message as a GGSW seeded ciphertext. + /// + /// # Examples + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::Plaintext; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// PolynomialSize, + /// }; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = + /// GlweSecretKey::generate_binary(GlweDimension(2), PolynomialSize(10), &mut generator); + /// let mut seeded_ciphertext = StandardGgswSeededCiphertext::>::allocate( + /// PolynomialSize(10), + /// GlweSize(3), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(7), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let mut seeder = UnixSeeder::new(0); + /// + /// secret_key.encrypt_constant_seeded_ggsw::<_, _, _, _, SoftwareRandomGenerator>( + /// &mut seeded_ciphertext, + /// &Plaintext(10), + /// noise, + /// &mut seeder, + /// ); + /// ``` + pub fn encrypt_constant_seeded_ggsw( + &self, + encrypted: &mut StandardGgswSeededCiphertext, + encoded: &Plaintext, + noise_parameters: NoiseParameter, + seeder: &mut NoiseSeeder, + ) where + Self: AsRefTensor, + StandardGgswCiphertext: AsMutTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus, + NoiseParameter: DispersionParameter, + NoiseSeeder: Seeder, + Gen: ByteRandomGenerator, + { + let mut generator = + EncryptionRandomGenerator::::new(encrypted.compression_seed().seed, seeder); + + self.encrypt_constant_seeded_ggsw_with_existing_generator( + encrypted, + encoded, + noise_parameters, + &mut generator, + ) + } + + /// This function encrypts a message as a GGSW ciphertext, using as many threads as possible. + /// + /// # Notes + /// This method is hidden behind the "__commons_parallel" feature gate. + /// + /// # Examples + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::Plaintext; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = + /// GlweSecretKey::generate_binary(GlweDimension(2), PolynomialSize(10), &mut secret_generator); + /// let mut ciphertext = StandardGgswCiphertext::allocate( + /// 0 as u32, + /// PolynomialSize(10), + /// GlweSize(3), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(7), + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.par_encrypt_constant_ggsw( + /// &mut ciphertext, + /// &Plaintext(10), + /// noise, + /// &mut encryption_generator, + /// ); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_encrypt_constant_ggsw( + &self, + encrypted: &mut StandardGgswCiphertext, + encoded: &Plaintext, + noise_parameters: impl DispersionParameter + Send + Sync, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + StandardGgswCiphertext: AsMutTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus + Send + Sync, + Cont: Sync, + Gen: ParallelByteRandomGenerator, + { + ck_dim_eq!(self.polynomial_size() => encrypted.polynomial_size()); + ck_dim_eq!(self.key_size() => encrypted.glwe_size().to_glwe_dimension()); + let generators = generator + .par_fork_ggsw_to_ggsw_levels::( + encrypted.decomposition_level_count(), + self.key_size().to_glwe_size(), + self.poly_size, + ) + .expect("Failed to split generator into ggsw levels"); + let base_log = encrypted.decomposition_base_log(); + encrypted + .par_level_matrix_iter_mut() + .zip(generators) + .for_each(move |(mut matrix, mut generator)| { + let factor = encoded.0.wrapping_neg().wrapping_mul( + Scalar::ONE << (Scalar::BITS - (base_log.0 * (matrix.decomposition_level().0))), + ); + + let gen_iter = generator + .par_fork_ggsw_level_to_glwe::( + self.key_size().to_glwe_size(), + self.poly_size, + ) + .expect("Failed to split generator into glwe"); + + let last_row_index = matrix.glwe_size().0 - 1; + + let sk_poly_list = &self.as_polynomial_list(); + + // We iterate over the rows of the level matrix + matrix + .par_row_iter_mut() + .enumerate() + .zip(gen_iter) + .for_each(|((row_index, row), mut generator)| { + self.encrypt_constant_ggsw_row( + (row_index, last_row_index), + &factor, + sk_poly_list, + &mut row.into_glwe(), + noise_parameters, + &mut generator, + ); + }) + }) + } + + /// Factorized function to be able to encrypt a GGSW with a generator in a particular state i.e. + /// not freshly instantiated. The caller is responsible for maintaining consistency. + #[cfg(feature = "__commons_parallel")] + pub fn par_encrypt_constant_seeded_ggsw_with_existing_generator< + OutputCont, + Scalar, + NoiseParameter, + Gen, + >( + &self, + encrypted: &mut StandardGgswSeededCiphertext, + encoded: &Plaintext, + noise_parameters: NoiseParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus + Send + Sync, + Cont: Send + Sync, + NoiseParameter: DispersionParameter + Sync + Send, + Gen: ParallelByteRandomGenerator, + { + ck_dim_eq!(self.polynomial_size() => encrypted.polynomial_size()); + ck_dim_eq!(self.key_size() => encrypted.glwe_size().to_glwe_dimension()); + + let gen_iter = generator + .par_fork_ggsw_to_ggsw_levels::( + encrypted.decomposition_level_count(), + self.key_size().to_glwe_size(), + self.poly_size, + ) + .expect("Failed to split generator into ggsw levels"); + + let base_log = encrypted.decomposition_base_log(); + encrypted + .par_level_matrix_iter_mut() + .zip(gen_iter) + .for_each(move |(mut matrix, mut generator)| { + let factor = encoded.0.wrapping_neg().wrapping_mul( + Scalar::ONE << (Scalar::BITS - (base_log.0 * (matrix.decomposition_level().0))), + ); + + // We iterate over the rows of the level matrix, the last row needs special + // treatment + let gen_iter = generator + .par_fork_ggsw_level_to_glwe::( + self.key_size().to_glwe_size(), + self.poly_size, + ) + .expect("Failed to split generator into rlwe"); + + let last_row_index = matrix.glwe_size().0 - 1; + let sk_poly_list = &self.as_polynomial_list(); + + // We iterate over the rows of the level matrix + matrix + .par_row_iter_mut() + .zip(gen_iter) + .enumerate() + .for_each(|(row_index, (row, mut generator))| { + self.encrypt_constant_seeded_ggsw_row( + (row_index, last_row_index), + &factor, + sk_poly_list, + &mut row.into_seeded_glwe(), + noise_parameters, + &mut generator, + ); + }); + }); + } + + /// This function encrypts a message as a GGSW seeded ciphertext. + /// + /// # Notes + /// This method is hidden behind the "multithread" feature gate. + /// + /// # Examples + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::Plaintext; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswSeededCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::commons::math::random::CompressionSeed; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// PolynomialSize, + /// }; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key = + /// GlweSecretKey::generate_binary(GlweDimension(2), PolynomialSize(10), &mut generator); + /// let mut seeded_ciphertext = StandardGgswSeededCiphertext::>::allocate( + /// PolynomialSize(10), + /// GlweSize(3), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(7), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let mut seeder = UnixSeeder::new(0); + /// + /// secret_key.par_encrypt_constant_seeded_ggsw::<_, _, _, _, SoftwareRandomGenerator>( + /// &mut seeded_ciphertext, + /// &Plaintext(10), + /// noise, + /// &mut seeder, + /// ); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_encrypt_constant_seeded_ggsw( + &self, + encrypted: &mut StandardGgswSeededCiphertext, + encoded: &Plaintext, + noise_parameters: NoiseParameter, + seeder: &mut NoiseSeeder, + ) where + Self: AsRefTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus + Send + Sync, + Cont: Send + Sync, + NoiseParameter: DispersionParameter + Sync + Send, + NoiseSeeder: Seeder + Send + Sync, + Gen: ParallelByteRandomGenerator, + { + let mut generator = + EncryptionRandomGenerator::::new(encrypted.compression_seed().seed, seeder); + + self.par_encrypt_constant_seeded_ggsw_with_existing_generator( + encrypted, + encoded, + noise_parameters, + &mut generator, + ); + } + + /// This function encrypts a message as a GGSW ciphertext whose rlwe masks are all zeros. + /// + /// # Examples + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seed, UnixSeeder}; + /// use tfhe::core_crypto::commons::crypto::encoding::Plaintext; + /// use tfhe::core_crypto::commons::crypto::ggsw::StandardGgswCiphertext; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::GlweSecretKey; + /// use tfhe::core_crypto::prelude::{ + /// DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LogStandardDev, + /// PolynomialSize, + /// }; + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: GlweSecretKey<_, Vec> = + /// GlweSecretKey::generate_binary(GlweDimension(2), PolynomialSize(10), &mut secret_generator); + /// let mut ciphertext = StandardGgswCiphertext::allocate( + /// 0 as u32, + /// PolynomialSize(10), + /// GlweSize(3), + /// DecompositionLevelCount(3), + /// DecompositionBaseLog(7), + /// ); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// let mut encryption_generator = + /// EncryptionRandomGenerator::::new(Seed(0), &mut UnixSeeder::new(0)); + /// secret_key.trivial_encrypt_constant_ggsw( + /// &mut ciphertext, + /// &Plaintext(10), + /// noise, + /// &mut encryption_generator, + /// ); + /// ``` + pub fn trivial_encrypt_constant_ggsw( + &self, + encrypted: &mut StandardGgswCiphertext, + encoded: &Plaintext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + StandardGgswCiphertext: AsMutTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + ck_dim_eq!(self.polynomial_size() => encrypted.polynomial_size()); + ck_dim_eq!(self.key_size() => encrypted.glwe_size().to_glwe_dimension()); + // We fill the ggsw with trivial glwe encryptions of zero: + for mut glwe in encrypted.as_mut_glwe_list().ciphertext_iter_mut() { + let (mut body, mut mask) = glwe.get_mut_body_and_mask(); + mask.as_mut_tensor().fill_with_element(Scalar::ZERO); + generator.fill_tensor_with_random_noise(&mut body, noise_parameters); + } + let base_log = encrypted.decomposition_base_log(); + for mut matrix in encrypted.level_matrix_iter_mut() { + let decomposition = encoded.0.wrapping_mul( + Scalar::ONE + << (::BITS + - (base_log.0 * (matrix.decomposition_level().0))), + ); + // We iterate over the rowe of the level matrix + for (index, row) in matrix.row_iter_mut().enumerate() { + let rlwe_ct = row.into_glwe(); + // We retrieve the row as a polynomial list + let mut polynomial_list = rlwe_ct.into_polynomial_list(); + // We retrieve the polynomial in the diagonal + let mut level_polynomial = polynomial_list.get_mut_polynomial(index); + // We get the first coefficient + let first_coef = level_polynomial.as_mut_tensor().first_mut(); + // We update the first coefficient + *first_coef = first_coef.wrapping_add(decomposition); + } + } + } +} + +impl AsRefTensor for GlweSecretKey +where + Kind: KeyKind, + Cont: AsRefSlice, +{ + type Element = Element; + type Container = Cont; + fn as_tensor(&self) -> &Tensor { + &self.tensor + } +} + +impl AsMutTensor for GlweSecretKey +where + Kind: KeyKind, + Cont: AsMutSlice, +{ + type Element = Element; + type Container = Cont; + fn as_mut_tensor(&mut self) -> &mut Tensor<::Container> { + &mut self.tensor + } +} + +impl IntoTensor for GlweSecretKey +where + Kind: KeyKind, + Cont: AsRefSlice, +{ + type Element = ::Element; + type Container = Cont; + fn into_tensor(self) -> Tensor { + self.tensor + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/secret/lwe.rs b/tfhe/src/core_crypto/commons/crypto/secret/lwe.rs new file mode 100644 index 000000000..658448b1c --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/lwe.rs @@ -0,0 +1,647 @@ +use crate::core_crypto::commons::crypto::encoding::{Plaintext, PlaintextList}; +use crate::core_crypto::commons::crypto::lwe::{ + LweBody, LweCiphertext, LweList, LweMask, LweSeededCiphertext, LweSeededList, +}; +use crate::core_crypto::commons::crypto::secret::generators::{ + EncryptionRandomGenerator, SecretRandomGenerator, +}; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; +use crate::core_crypto::commons::math::random::{ + ByteRandomGenerator, Gaussian, RandomGenerable, Seeder, +}; +use crate::core_crypto::commons::math::tensor::{ + AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor, IntoTensor, Tensor, +}; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::prelude::LweCiphertextCount; +use crate::core_crypto::prelude::{ + BinaryKeyKind, DispersionParameter, GaussianKeyKind, KeyKind, LweDimension, TernaryKeyKind, + UniformKeyKind, +}; +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +/// A LWE secret key. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LweSecretKey +where + Kind: KeyKind, +{ + pub(crate) tensor: Tensor, + pub(crate) kind: PhantomData, +} + +impl LweSecretKey> +where + Scalar: UnsignedTorus, +{ + /// Generates a new binary secret key; e.g. allocates a storage and samples random values for + /// the key. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: LweSecretKey<_, Vec> = + /// LweSecretKey::generate_binary(LweDimension(256), &mut generator); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn generate_binary( + size: LweDimension, + generator: &mut SecretRandomGenerator, + ) -> Self { + LweSecretKey { + tensor: generator.random_binary_tensor(size.0), + kind: PhantomData, + } + } +} + +impl LweSecretKey> +where + Scalar: UnsignedTorus, +{ + /// Generates a new ternary secret key; e.g. allocates a storage and samples random values for + /// the key. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: LweSecretKey<_, Vec> = + /// LweSecretKey::generate_ternary(LweDimension(256), &mut generator); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn generate_ternary( + size: LweDimension, + generator: &mut SecretRandomGenerator, + ) -> Self { + LweSecretKey { + tensor: generator.random_ternary_tensor(size.0), + kind: PhantomData, + } + } +} + +impl LweSecretKey> +where + (Scalar, Scalar): RandomGenerable>, + Scalar: UnsignedTorus, +{ + /// Generates a new gaussian secret key; e.g. allocates a storage and samples random values for + /// the key. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: LweSecretKey<_, Vec> = + /// LweSecretKey::generate_gaussian(LweDimension(256), &mut generator); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn generate_gaussian( + size: LweDimension, + generator: &mut SecretRandomGenerator, + ) -> Self { + LweSecretKey { + tensor: generator.random_gaussian_tensor(size.0), + kind: PhantomData, + } + } +} + +impl LweSecretKey> +where + Scalar: UnsignedTorus, +{ + /// Generates a new gaussian secret key; e.g. allocates a storage and samples random values for + /// the key. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::crypto::secret::generators::SecretRandomGenerator; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let mut generator = SecretRandomGenerator::::new(Seed(0)); + /// let secret_key: LweSecretKey<_, Vec> = + /// LweSecretKey::generate_uniform(LweDimension(256), &mut generator); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn generate_uniform( + size: LweDimension, + generator: &mut SecretRandomGenerator, + ) -> Self { + LweSecretKey { + tensor: generator.random_uniform_tensor(size.0), + kind: PhantomData, + } + } +} + +impl LweSecretKey { + /// Creates a binary lwe secret key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random values to create a new key. It merely + /// wraps a container into the appropriate type. See [`LweSecretKey::generate_binary`] for a + /// generation method. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let secret_key = LweSecretKey::binary_from_container(vec![true; 256]); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn binary_from_container(cont: Cont) -> Self + where + Cont: AsRefSlice, + { + LweSecretKey { + tensor: Tensor::from_container(cont), + kind: PhantomData, + } + } +} + +impl LweSecretKey { + /// Creates a ternary lwe secret key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random values to create a new key. It merely + /// wraps a container into the appropriate type. See [`LweSecretKey::generate_ternary`] for a + /// generation method. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let secret_key = LweSecretKey::ternary_from_container(vec![true; 256]); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn ternary_from_container(cont: Cont) -> Self + where + Cont: AsRefSlice, + { + LweSecretKey { + tensor: Tensor::from_container(cont), + kind: PhantomData, + } + } +} + +impl LweSecretKey { + /// Creates a gaussian lwe secret key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random values to create a new key. It merely + /// wraps a container into the appropriate type. See [`LweSecretKey::generate_gaussian`] for a + /// generation method. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let secret_key = LweSecretKey::gaussian_from_container(vec![true; 256]); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn gaussian_from_container(cont: Cont) -> Self + where + Cont: AsRefSlice, + { + LweSecretKey { + tensor: Tensor::from_container(cont), + kind: PhantomData, + } + } +} + +impl LweSecretKey { + /// Creates a uniform lwe secret key from a container. + /// + /// # Notes + /// + /// This method does not fill the container with random values to create a new key. It merely + /// wraps a container into the appropriate type. See [`LweSecretKey::generate_uniform`] for a + /// generation method. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let secret_key = LweSecretKey::uniform_from_container(vec![true; 256]); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn uniform_from_container(cont: Cont) -> Self + where + Cont: AsRefSlice, + { + LweSecretKey { + tensor: Tensor::from_container(cont), + kind: PhantomData, + } + } +} + +impl LweSecretKey +where + Kind: KeyKind, +{ + /// Returns the size of the secret key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::prelude::LweDimension; + /// let secret_key = LweSecretKey::binary_from_container(vec![true; 256]); + /// assert_eq!(secret_key.key_size(), LweDimension(256)); + /// ``` + pub fn key_size(&self) -> LweDimension + where + Self: AsRefTensor, + { + LweDimension(self.as_tensor().len()) + } + + fn fill_lwe_mask_and_body_for_encryption( + &self, + output_body: &mut LweBody, + output_mask: &mut LweMask, + encoded: &Plaintext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + OutputCont: AsMutSlice, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + // generate a uniformly random mask + generator.fill_tensor_with_random_mask(output_mask); + + // generate an error from the normal distribution described by std_dev + output_body.0 = generator.random_noise(noise_parameters); + + // compute the multisum between the secret key and the mask + output_body.0 = output_body + .0 + .wrapping_add(output_mask.compute_multisum(self)); + + // add the encoded message + output_body.0 = output_body.0.wrapping_add(encoded.0); + } + + /// Encrypts a single ciphertext. + pub fn encrypt_lwe( + &self, + output: &mut LweCiphertext, + encoded: &Plaintext, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + LweCiphertext: AsMutTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + let (output_body, mut output_mask) = output.get_mut_body_and_mask(); + + self.fill_lwe_mask_and_body_for_encryption( + output_body, + &mut output_mask, + encoded, + noise_parameters, + generator, + ); + } + + /// Encrypts a single seeded ciphertext. + pub fn encrypt_seeded_lwe( + &self, + output: &mut LweSeededCiphertext, + encoded: &Plaintext, + noise_parameters: NoiseParameter, + seeder: &mut NoiseSeeder, + ) where + Self: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + // This will be removable when https://github.com/rust-lang/rust/issues/83701 is stabilized + // We currently need to be able to specify concrete types for the generic type parameters + // which cannot be done when some arguments use the `impl Trait` pattern + NoiseParameter: DispersionParameter, + NoiseSeeder: Seeder + ?Sized, + { + debug_assert!( + output.lwe_size().to_lwe_dimension() == self.key_size(), + "Output LweSeededCiphertext dimension is not compatible with LweSecretKey dimension" + ); + + // Create the generator for the encryption, seed it with the output seed, pass a seeder so + // that the noise generator is seeded with a private seed + let mut generator = + EncryptionRandomGenerator::::new(output.compression_seed().seed, seeder); + + let mut output_mask = LweMask::from_container(vec![Scalar::ZERO; self.key_size().0]); + let output_body = output.get_mut_body(); + + self.fill_lwe_mask_and_body_for_encryption( + output_body, + &mut output_mask, + encoded, + noise_parameters, + &mut generator, + ); + } + + /// Encrypts a list of ciphertexts. + pub fn encrypt_lwe_list( + &self, + output: &mut LweList, + encoded: &PlaintextList, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + LweList: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + { + debug_assert!( + output.count().0 == encoded.count().0, + "Lwe cipher list size and encoded list size are not compatible" + ); + for (mut cipher, message) in output.ciphertext_iter_mut().zip(encoded.plaintext_iter()) { + self.encrypt_lwe(&mut cipher, message, noise_parameters, generator); + } + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_encrypt_lwe_list( + &self, + output: &mut LweList, + encoded: &PlaintextList, + noise_parameters: impl DispersionParameter + Sync, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + LweList: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus + Send + Sync, + Gen: ByteRandomGenerator + ParallelByteRandomGenerator, + Cont: Sync, + { + debug_assert!( + output.count().0 == encoded.count().0, + "Lwe cipher list size and encoded list size are not compatible" + ); + let ct_count = LweCiphertextCount(output.count().0); + let ct_size = output.lwe_size; + output + .par_ciphertext_iter_mut() + .zip(encoded.par_plaintext_iter()) + .zip( + generator + .par_fork_lwe_list_to_lwe::(ct_count, ct_size) + .unwrap(), + ) + .for_each(|((mut cipher, message), mut generator)| { + self.encrypt_lwe(&mut cipher, message, noise_parameters, &mut generator); + }) + } + + pub fn encrypt_seeded_lwe_list_with_existing_generator< + OutputCont, + InputCont, + Scalar, + NoiseParameter, + Gen, + >( + &self, + output: &mut LweSeededList, + encoded: &PlaintextList, + noise_parameters: NoiseParameter, + generator: &mut EncryptionRandomGenerator, + ) where + Self: AsRefTensor, + LweSeededList: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + // This will be removable when https://github.com/rust-lang/rust/issues/83701 is stabilized + // We currently need to be able to specify concrete types for the generic type parameters + // which cannot be done when some arguments use the `impl Trait` pattern + NoiseParameter: DispersionParameter, + { + let mut mask_tensor = vec![Scalar::ZERO; self.key_size().0]; + let mut output_mask = LweMask::from_container(mask_tensor.as_mut_slice()); + + for (output_body, encoded_message) in output.body_iter_mut().zip(encoded.plaintext_iter()) { + self.fill_lwe_mask_and_body_for_encryption( + output_body, + &mut output_mask, + encoded_message, + noise_parameters, + generator, + ); + } + } + + /// Encrypts a list of seeded ciphertexts. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::prelude::{ + /// CiphertextCount, LogStandardDev, LweDimension, PlaintextCount, + /// }; + /// + /// use tfhe::core_crypto::commons::crypto::encoding::*; + /// use tfhe::core_crypto::commons::crypto::lwe::*; + /// use tfhe::core_crypto::commons::crypto::secret::generators::{ + /// EncryptionRandomGenerator, SecretRandomGenerator, + /// }; + /// use tfhe::core_crypto::commons::crypto::secret::*; + /// use tfhe::core_crypto::commons::crypto::*; + /// use tfhe::core_crypto::commons::math::random::{CompressionSeed, Seed}; + /// + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::{Seeder, UnixSeeder}; + /// + /// let mut secret_generator = SecretRandomGenerator::::new(Seed(0)); + /// + /// let mut seeder = UnixSeeder::new(0); + /// + /// let secret_key = LweSecretKey::generate_binary(LweDimension(256), &mut secret_generator); + /// let noise = LogStandardDev::from_log_standard_dev(-15.); + /// + /// let mut plain_values = PlaintextList::allocate(3u32, PlaintextCount(100)); + /// let mut encrypted_values = LweSeededList::allocate( + /// LweDimension(256), + /// CiphertextCount(100), + /// CompressionSeed { seed: Seed(42) }, + /// ); + /// secret_key.encrypt_seeded_lwe_list::<_, _, _, _, _, SoftwareRandomGenerator>( + /// &mut encrypted_values, + /// &plain_values, + /// noise, + /// &mut seeder, + /// ); + /// ``` + pub fn encrypt_seeded_lwe_list< + OutputCont, + InputCont, + Scalar, + NoiseParameter, + NoiseSeeder, + Gen, + >( + &self, + output: &mut LweSeededList, + encoded: &PlaintextList, + noise_parameters: NoiseParameter, + seeder: &mut NoiseSeeder, + ) where + Self: AsRefTensor, + LweSeededList: AsMutTensor, + PlaintextList: AsRefTensor, + Scalar: UnsignedTorus, + Gen: ByteRandomGenerator, + // This will be removable when https://github.com/rust-lang/rust/issues/83701 is stabilized + // We currently need to be able to specify concrete types for the generic type parameters + // which cannot be done when some arguments use the `impl Trait` pattern + NoiseParameter: DispersionParameter, + NoiseSeeder: Seeder, + { + let mut generator = + EncryptionRandomGenerator::::new(output.get_compression_seed().seed, seeder); + + self.encrypt_seeded_lwe_list_with_existing_generator( + output, + encoded, + noise_parameters, + &mut generator, + ); + } + + /// Decrypts a single ciphertext. + /// + /// See ['encrypt_lwe'] for an example. + pub fn decrypt_lwe( + &self, + output: &mut Plaintext, + cipher: &LweCiphertext, + ) where + Self: AsRefTensor, + LweCiphertext: AsRefTensor, + Scalar: UnsignedTorus, + { + let (body, masks) = cipher.get_body_and_mask(); + // put body inside result + output.0 = body.0; + // subtract the multisum between the key and the mask + output.0 = output.0.wrapping_sub(masks.compute_multisum(self)); + } + + /// Decrypts a list of ciphertexts. + /// + /// See ['encrypt_lwe_list'] for an example. + pub fn decrypt_lwe_list( + &self, + output: &mut PlaintextList, + cipher: &LweList, + ) where + Self: AsRefTensor, + PlaintextList: AsMutTensor, + LweList: AsRefTensor, + Scalar: UnsignedTorus, + { + debug_assert!( + output.count().0 == cipher.count().0, + "Tried to decrypt a list into one with incompatible size.Expected {} found {}", + output.count().0, + cipher.count().0 + ); + for (cipher, output) in cipher.ciphertext_iter().zip(output.plaintext_iter_mut()) { + self.decrypt_lwe(output, &cipher); + } + } +} + +impl AsRefTensor for LweSecretKey +where + Kind: KeyKind, + Cont: AsRefSlice, +{ + type Element = Element; + type Container = Cont; + fn as_tensor(&self) -> &Tensor { + &self.tensor + } +} + +impl AsMutTensor for LweSecretKey +where + Kind: KeyKind, + Cont: AsMutSlice, +{ + type Element = Element; + type Container = Cont; + fn as_mut_tensor(&mut self) -> &mut Tensor<::Container> { + &mut self.tensor + } +} + +impl IntoTensor for LweSecretKey +where + Kind: KeyKind, + Cont: AsRefSlice, +{ + type Element = ::Element; + type Container = Cont; + fn into_tensor(self) -> Tensor { + self.tensor + } +} diff --git a/tfhe/src/core_crypto/commons/crypto/secret/mod.rs b/tfhe/src/core_crypto/commons/crypto/secret/mod.rs new file mode 100644 index 000000000..ab28c9c00 --- /dev/null +++ b/tfhe/src/core_crypto/commons/crypto/secret/mod.rs @@ -0,0 +1,8 @@ +//! Secret keys module. +pub use glwe::*; +pub use lwe::*; + +pub mod generators; + +mod glwe; +mod lwe; diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs new file mode 100644 index 000000000..370083fb3 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -0,0 +1,297 @@ +use crate::core_crypto::commons::math::decomposition::{ + SignedDecompositionIter, TensorSignedDecompositionIter, +}; +use crate::core_crypto::commons::math::tensor::{AsMutTensor, AsRefTensor, Tensor}; +use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger}; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; +use std::marker::PhantomData; + +/// A structure which allows to decompose unsigned integers into a set of smaller terms. +/// +/// See the [module level](super) documentation for a description of the signed decomposition. +#[derive(Debug)] +pub struct SignedDecomposer +where + Scalar: UnsignedInteger, +{ + pub(crate) base_log: usize, + pub(crate) level_count: usize, + integer_type: PhantomData, +} + +impl SignedDecomposer +where + Scalar: UnsignedInteger, +{ + /// Creates a new decomposer. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3)); + /// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn new( + base_log: DecompositionBaseLog, + level_count: DecompositionLevelCount, + ) -> SignedDecomposer { + debug_assert!( + Scalar::BITS > base_log.0 * level_count.0, + "Decomposed bits exceeds the size of the integer to be decomposed" + ); + SignedDecomposer { + base_log: base_log.0, + level_count: level_count.0, + integer_type: PhantomData, + } + } + + /// Returns the logarithm in base two of the base of this decomposer. + /// + /// If the decomposer uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposer. + /// + /// If the decomposer uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } + + /// Returns the closet value representable by the decomposition. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let closest = decomposer.closest_representable(1_340_987_234_u32); + /// assert_eq!(closest, 1_341_128_704_u32); + /// ``` + #[inline] + pub fn closest_representable(&self, input: Scalar) -> Scalar { + // The closest number representable by the decomposition can be computed by performing + // the rounding at the appropriate bit. + + // We compute the number of least significant bits which can not be represented by the + // decomposition + let non_rep_bit_count: usize = ::BITS - self.level_count * self.base_log; + // We generate a mask which captures the non representable bits + let non_rep_mask = Scalar::ONE << (non_rep_bit_count - 1); + // We retrieve the non representable bits + let non_rep_bits = input & non_rep_mask; + // We extract the msb of the non representable bits to perform the rounding + let non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1); + // We remove the non-representable bits and perform the rounding + let res = input >> non_rep_bit_count; + let res = res + non_rep_msb; + res << non_rep_bit_count + } + + /// Fills a mutable tensor-like objects with the closest representable values from another + /// tensor-like object. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// + /// let input = Tensor::allocate(1_340_987_234_u32, 1); + /// let mut closest = Tensor::allocate(0u32, 1); + /// decomposer.fill_tensor_with_closest_representable(&mut closest, &input); + /// assert_eq!(*closest.get_element(0), 1_341_128_704_u32); + /// ``` + pub fn fill_tensor_with_closest_representable(&self, output: &mut O, input: &I) + where + I: AsRefTensor, + O: AsMutTensor, + { + output + .as_mut_tensor() + .fill_with_one(input.as_tensor(), |elmt| self.closest_representable(*elmt)) + } + + /// Generates an iterator over the terms of the decomposition of the input. + /// + /// # Warning + /// + /// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// for term in decomposer.decompose(1_340_987_234_u32) { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// let signed_term = term.value().into_signed(); + /// let half_basis = 2i32.pow(4) / 2i32; + /// assert!(-half_basis <= signed_term); + /// assert!(signed_term < half_basis); + /// } + /// assert_eq!(decomposer.decompose(1).count(), 3); + /// ``` + pub fn decompose(&self, input: Scalar) -> SignedDecompositionIter { + // Note that there would be no sense of making the decomposition on an input which was + // not rounded to the closest representable first. We then perform it before decomposing. + SignedDecompositionIter::new( + self.closest_representable(input), + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + ) + } + + /// Recomposes a decomposed value by summing all the terms. + /// + /// If the input iterator yields $\tilde{\theta}\_i$, this returns + /// $\sum\_{i=1}^l\tilde{\theta}\_i\frac{q}{B^i}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let val = 1_340_987_234_u32; + /// let dec = decomposer.decompose(val); + /// let rec = decomposer.recompose(dec); + /// assert_eq!(decomposer.closest_representable(val), rec.unwrap()); + /// ``` + pub fn recompose(&self, decomp: SignedDecompositionIter) -> Option { + if decomp.is_fresh() { + Some(decomp.fold(Scalar::ZERO, |acc, term| { + acc.wrapping_add(term.to_recomposition_summand()) + })) + } else { + None + } + } + + /// Generates an iterator-like object over tensors of terms of the decomposition of the input + /// tensor. + /// + /// # Warning + /// + /// The returned iterator yields the terms $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ in + /// order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = Tensor::from_container(vec![1_340_987_234_u32, 1_340_987_234_u32]); + /// let mut decomp = decomposer.decompose_tensor(&decomposable); + /// /// + /// let mut count = 0; + /// while let Some(term) = decomp.next_term() { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// for elmt in term.as_tensor().iter() { + /// let signed_term = elmt.into_signed(); + /// let half_basis = 2i32.pow(4) / 2i32; + /// assert!(-half_basis <= signed_term); + /// assert!(signed_term < half_basis); + /// } + /// count += 1; + /// } + /// assert_eq!(count, 3); + /// ``` + pub fn decompose_tensor(&self, input: &I) -> TensorSignedDecompositionIter + where + I: AsRefTensor, + { + // Note that there would be no sense of making the decomposition on an input which was + // not rounded to the closest representable first. We then perform it before decomposing. + let mut rounded = Tensor::allocate(Scalar::ZERO, input.as_tensor().len()); + self.fill_tensor_with_closest_representable(&mut rounded, input); + TensorSignedDecompositionIter::new( + rounded, + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + ) + } + + /// Fills the output tensor with the recomposition of an other tensor. + /// + /// Returns `Some(())` if the decomposition was fresh, and the output was filled with a + /// recomposition, and `None`, if not. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = Tensor::allocate(1_340_987_234_u32, 1); + /// let mut rounded = Tensor::allocate(0u32, 1); + /// decomposer.fill_tensor_with_closest_representable(&mut rounded, &decomposable); + /// let mut decomp = decomposer.decompose_tensor(&rounded); + /// let mut recomposition = Tensor::allocate(0u32, 1); + /// decomposer + /// .fill_tensor_with_recompose(decomp, &mut recomposition) + /// .unwrap(); + /// assert_eq!(recomposition, rounded); + /// ``` + pub fn fill_tensor_with_recompose( + &self, + decomp: TensorSignedDecompositionIter, + output: &mut TLike, + ) -> Option<()> + where + TLike: AsMutTensor, + { + let mut decomp = decomp; + if decomp.is_fresh() { + while let Some(term) = decomp.next_term() { + term.update_tensor_with_recomposition_summand_wrapping_addition(output); + } + Some(()) + } else { + None + } + } +} diff --git a/tfhe/src/core_crypto/commons/math/decomposition/iter.rs b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs new file mode 100644 index 000000000..b7d953927 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/decomposition/iter.rs @@ -0,0 +1,283 @@ +use crate::core_crypto::commons::math::decomposition::{ + DecompositionLevel, DecompositionTerm, DecompositionTermTensor, +}; +use crate::core_crypto::commons::math::tensor::Tensor; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::commons::utils::{zip, zip_args}; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + +/// An iterator-like object that yields the terms of the signed decomposition of a tensor of values. +/// +/// # Note +/// +/// On each call to [`TensorSignedDecompositionIter::next_term`], this structure yields a new +/// [`DecompositionTermTensor`], backed by a `Vec` owned by the structure. This vec is mutated at +/// each call of the `next_term` method, and as such the term must be dropped before `next_term` is +/// called again. +/// +/// Such a pattern can not be implemented with iterators yet (without GATs), which is why this +/// iterator must be explicitly called. +/// +/// # Warning +/// +/// This iterator yields the decomposition in reverse order. That means that the highest level +/// will be yielded first. +pub struct TensorSignedDecompositionIter +where + Scalar: UnsignedInteger, +{ + // The base log of the decomposition + base_log: usize, + // The number of levels of the decomposition + level_count: usize, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: Scalar, + // The values being decomposed + inputs: Vec, + // The internal states of each decomposition + states: Vec, + // In order to avoid allocating a new Vec every time we yield a decomposition term, we store + // a Vec inside the structure and yield slices pointing to it. + outputs: Vec, + // A flag which stores whether the iterator is a fresh one (for the recompose method). + fresh: bool, +} + +impl TensorSignedDecompositionIter +where + Scalar: UnsignedInteger, +{ + // Creates a new tensor decomposition iterator. + pub(crate) fn new( + input: Tensor>, + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + ) -> TensorSignedDecompositionIter { + let len = input.len(); + TensorSignedDecompositionIter { + base_log: base_log.0, + level_count: level.0, + current_level: level.0, + mod_b_mask: (Scalar::ONE << base_log.0) - Scalar::ONE, + inputs: input.clone().into_container(), + outputs: vec![Scalar::ZERO; len], + states: input + .iter() + .map(|i| *i >> (Scalar::BITS - base_log.0 * level.0)) + .collect(), + fresh: true, + } + } + + pub(crate) fn is_fresh(&self) -> bool { + self.fresh + } + + /// Returns the logarithm in base two of the base of this decomposition. + /// + /// If the decomposition uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = Tensor::allocate(1_340_987_234_u32, 2); + /// let decomp = decomposer.decompose_tensor(&decomposable); + /// assert_eq!(decomp.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposition. + /// + /// If the decomposition uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = Tensor::allocate(1_340_987_234_u32, 2); + /// let decomp = decomposer.decompose_tensor(&decomposable); + /// assert_eq!(decomp.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } + + /// Yield the next term of the decomposition, if any. + /// + /// # Note + /// + /// Because this function returns a borrowed tensor, owned by the iterator, the term must be + /// dropped before `next_term` is called again. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let decomposable = Tensor::allocate(1_340_987_234_u32, 1); + /// let mut decomp = decomposer.decompose_tensor(&decomposable); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// assert_eq!(*term.as_tensor().get_element(0), 4294967295); + /// ``` + pub fn next_term(&mut self) -> Option> { + // The iterator is not fresh anymore. + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + // We iterate over the elements of the outputs and decompose + for zip_args!(output_i, state_i) in zip!(self.outputs.iter_mut(), self.states.iter_mut()) { + *output_i = decompose_one_level(self.base_log, state_i, self.mod_b_mask); + } + self.current_level -= 1; + // We return the term tensor. + Some(DecompositionTermTensor::new( + DecompositionLevel(self.current_level + 1), + DecompositionBaseLog(self.base_log), + Tensor::from_container(self.outputs.as_slice()), + )) + } +} + +/// An iterator that yields the terms of the signed decomposition of an integer. +/// +/// # Warning +/// +/// This iterator yields the decomposition in reverse order. That means that the highest level +/// will be yielded first. +pub struct SignedDecompositionIter +where + T: UnsignedInteger, +{ + // The value being decomposed + input: T, + // The base log of the decomposition + base_log: usize, + // The number of levels of the decomposition + level_count: usize, + // The internal state of the decomposition + state: T, + // The current level + current_level: usize, + // A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form: + // ...0001111 + mod_b_mask: T, + // A flag which store whether the iterator is a fresh one (for the recompose method) + fresh: bool, +} + +impl SignedDecompositionIter +where + T: UnsignedInteger, +{ + pub(crate) fn new( + input: T, + base_log: DecompositionBaseLog, + level: DecompositionLevelCount, + ) -> SignedDecompositionIter { + SignedDecompositionIter { + input, + base_log: base_log.0, + level_count: level.0, + state: input >> (T::BITS - base_log.0 * level.0), + current_level: level.0, + mod_b_mask: (T::ONE << base_log.0) - T::ONE, + fresh: true, + } + } + + pub(crate) fn is_fresh(&self) -> bool { + self.fresh + } + + /// Returns the logarithm in base two of the base of this decomposition. + /// + /// If the decomposition uses a base $B=2^b$, this returns $b$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let val = 1_340_987_234_u32; + /// let decomp = decomposer.decompose(val); + /// assert_eq!(decomp.base_log(), DecompositionBaseLog(4)); + /// ``` + pub fn base_log(&self) -> DecompositionBaseLog { + DecompositionBaseLog(self.base_log) + } + + /// Returns the number of levels of this decomposition. + /// + /// If the decomposition uses $l$ levels, this returns $l$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let val = 1_340_987_234_u32; + /// let decomp = decomposer.decompose(val); + /// assert_eq!(decomp.level_count(), DecompositionLevelCount(3)); + /// ``` + pub fn level_count(&self) -> DecompositionLevelCount { + DecompositionLevelCount(self.level_count) + } +} + +impl Iterator for SignedDecompositionIter +where + T: UnsignedInteger, +{ + type Item = DecompositionTerm; + + fn next(&mut self) -> Option { + // The iterator is not fresh anymore + self.fresh = false; + // We check if the decomposition is over + if self.current_level == 0 { + return None; + } + // We decompose the current level + let output = decompose_one_level(self.base_log, &mut self.state, self.mod_b_mask); + self.current_level -= 1; + // We return the output for this level + Some(DecompositionTerm::new( + DecompositionLevel(self.current_level + 1), + DecompositionBaseLog(self.base_log), + output, + )) + } +} + +fn decompose_one_level(base_log: usize, state: &mut S, mod_b_mask: S) -> S { + let res = *state & mod_b_mask; + *state >>= base_log; + let mut carry = (res.wrapping_sub(S::ONE) | *state) & res; + carry >>= base_log - 1; + *state += carry; + res.wrapping_sub(carry << base_log) +} diff --git a/tfhe/src/core_crypto/commons/math/decomposition/mod.rs b/tfhe/src/core_crypto/commons/math/decomposition/mod.rs new file mode 100644 index 000000000..7a3292219 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/decomposition/mod.rs @@ -0,0 +1,41 @@ +//! Signed decomposition of unsigned integers. +//! +//! Multiple homomorphic operations used in Zama's variant of the TFHE scheme use a signed +//! decomposition to reduce the amount of noise. This module contains a [`SignedDecomposer`] which +//! offer a clean api for this decomposition. +//! +//! # Description +//! +//! We assume a number $\theta$ lives in $\mathbb{Z}/q\mathbb{Z}$, with $q$ a power of two. Such +//! a number can also be seen as a signed integer in $[ -\frac{q}{2}; \frac{q}{2}-1]$. Assuming a +//! given base $B=2^{b}$ and a number of levels $l$ such that $B^l\leq q$, such a $\theta$ can be +//! approximately decomposed as: +//! $$ +//! \theta \approx \sum\_{i=1}^l\tilde{\theta}\_i\frac{q}{B^i} +//! $$ +//! With the $\tilde{\theta}\_i\in[-\frac{B}{2}, \frac{B}{2}-1]$. When $B^l = q$, the decomposition +//! is no longer an approximation, and becomes exact. The rationale behind using an approximate +//! decomposition like that, is that when using this decomposition the approximation error will be +//! located in the least significant bits, which are already erroneous. +use std::fmt::Debug; + +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +pub use decomposer::*; +pub use iter::*; +pub use term::*; + +mod decomposer; +mod iter; +mod term; +#[cfg(test)] +mod tests; + +/// The level of a given term of a decomposition. +/// +/// When decomposing an integer over the $l$ levels, this type represent the level (in $[0,l)$) +/// currently manipulated. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub struct DecompositionLevel(pub usize); diff --git a/tfhe/src/core_crypto/commons/math/decomposition/term.rs b/tfhe/src/core_crypto/commons/math/decomposition/term.rs new file mode 100644 index 000000000..f20c46f8b --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/decomposition/term.rs @@ -0,0 +1,212 @@ +use crate::core_crypto::commons::math::decomposition::DecompositionLevel; +use crate::core_crypto::commons::math::tensor::{AsMutTensor, Tensor}; +use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger}; +use crate::core_crypto::prelude::DecompositionBaseLog; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +/// A member of the decomposition. +/// +/// If we decompose a value $\theta$ as a sum $\sum\_{i=1}^l\tilde{\theta}\_i\frac{q}{B^i}$, this +/// represents a $\tilde{\theta}\_i$. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DecompositionTerm +where + T: UnsignedInteger, +{ + level: usize, + base_log: usize, + value: T, +} + +impl DecompositionTerm +where + T: UnsignedInteger, +{ + // Creates a new decomposition term. + pub(crate) fn new( + level: DecompositionLevel, + base_log: DecompositionBaseLog, + value: T, + ) -> DecompositionTerm { + DecompositionTerm { + level: level.0, + base_log: base_log.0, + value, + } + } + + /// Turns this term into a summand. + /// + /// If our member represents one $\tilde{\theta}\_i$ of the decomposition, this method returns + /// $\tilde{\theta}\_i\frac{q}{B^i}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let output = decomposer.decompose(2u32.pow(19)).next().unwrap(); + /// assert_eq!(output.to_recomposition_summand(), 1048576); + /// ``` + pub fn to_recomposition_summand(&self) -> T { + let shift: usize = ::BITS - self.base_log * self.level; + self.value << shift + } + + /// Returns the value of the term. + /// + /// If our member represents one $\tilde{\theta}\_i$, this returns its actual value. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let output = decomposer.decompose(2u32.pow(19)).next().unwrap(); + /// assert_eq!(output.value(), 1); + /// ``` + pub fn value(&self) -> T { + self.value + } + + /// Returns the level of the term. + /// + /// If our member represents one $\tilde{\theta}\_i$, this returns the value of $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let output = decomposer.decompose(2u32.pow(19)).next().unwrap(); + /// assert_eq!(output.level(), DecompositionLevel(3)); + /// ``` + pub fn level(&self) -> DecompositionLevel { + DecompositionLevel(self.level) + } +} + +/// A tensor whose elements are the terms of the decomposition of another tensor. +/// +/// If we decompose each elements of a set of values $(\theta^{(a)})\_{a\in\mathbb{N}}$ as a set of +/// sums $(\sum\_{i=1}^l\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$, this represents a +/// set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$. +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct DecompositionTermTensor<'a, Scalar> +where + Scalar: UnsignedInteger, +{ + level: usize, + base_log: usize, + tensor: Tensor<&'a [Scalar]>, +} + +impl<'a, Scalar> DecompositionTermTensor<'a, Scalar> +where + Scalar: UnsignedInteger, +{ + // Creates a new tensor decomposition term. + pub(crate) fn new( + level: DecompositionLevel, + base_log: DecompositionBaseLog, + tensor: Tensor<&'a [Scalar]>, + ) -> DecompositionTermTensor { + DecompositionTermTensor { + level: level.0, + base_log: base_log.0, + tensor, + } + } + + /// Fills the output tensor with the terms turned to summands. + /// + /// If our term tensor represents a set of $(\tilde{\theta}^{(a)}\_i)\_{a\in\mathbb{N}}$ of the + /// decomposition, this method fills the output tensor with a set of + /// $(\tilde{\theta}^{(a)}\_i\frac{q}{B^i})\_{a\in\mathbb{N}}$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = Tensor::allocate(2u32.pow(19), 1); + /// let mut decomp = decomposer.decompose_tensor(&input); + /// let term = decomp.next_term().unwrap(); + /// let mut output = Tensor::allocate(0, 1); + /// term.fill_tensor_with_recomposition_summand(&mut output); + /// assert_eq!(*output.get_element(0), 1048576); + /// ``` + pub fn fill_tensor_with_recomposition_summand(&self, output: &mut TLike) + where + TLike: AsMutTensor, + { + output.as_mut_tensor().fill_with_one(&self.tensor, |value| { + let shift: usize = ::BITS - self.base_log * self.level; + *value << shift + }); + } + + pub(crate) fn update_tensor_with_recomposition_summand_wrapping_addition( + &self, + output: &mut TLike, + ) where + TLike: AsMutTensor, + { + output + .as_mut_tensor() + .update_with_one(&self.tensor, |out, value| { + let shift: usize = ::BITS - self.base_log * self.level; + *out = out.wrapping_add(*value << shift); + }); + } + + /// Returns a tensor with the values of term. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = Tensor::allocate(2u32.pow(19), 1); + /// let mut decomp = decomposer.decompose_tensor(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(*term.as_tensor().get_element(0), 1); + /// ``` + pub fn as_tensor(&self) -> &Tensor<&'a [Scalar]> { + &self.tensor + } + + /// Returns the level of this decomposition term tensor. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer}; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// use tfhe::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// let input = Tensor::allocate(2u32.pow(19), 1); + /// let mut decomp = decomposer.decompose_tensor(&input); + /// let term = decomp.next_term().unwrap(); + /// assert_eq!(term.level(), DecompositionLevel(3)); + /// ``` + pub fn level(&self) -> DecompositionLevel { + DecompositionLevel(self.level) + } +} diff --git a/tfhe/src/core_crypto/commons/math/decomposition/tests.rs b/tfhe/src/core_crypto/commons/math/decomposition/tests.rs new file mode 100644 index 000000000..980da527a --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/decomposition/tests.rs @@ -0,0 +1,180 @@ +use crate::core_crypto::commons::math::decomposition::SignedDecomposer; +use crate::core_crypto::commons::math::random::{RandomGenerable, Uniform}; +use crate::core_crypto::commons::math::tensor::Tensor; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::numeric::{Numeric, SignedInteger, UnsignedInteger}; +use crate::core_crypto::commons::test_tools::{any_uint, any_usize, random_usize_between}; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount}; +use std::fmt::Debug; + +// Returns a random decomposition valid for the size of the T type. +fn random_decomp() -> SignedDecomposer { + let mut base_log; + let mut level_count; + loop { + base_log = random_usize_between(2..T::BITS); + level_count = random_usize_between(2..T::BITS); + if base_log * level_count < T::BITS { + break; + } + } + SignedDecomposer::new( + DecompositionBaseLog(base_log), + DecompositionLevelCount(level_count), + ) +} + +fn test_decompose_recompose>() +where + ::Signed: Debug + SignedInteger, +{ + // Checks that the decomposing and recomposing a value brings the closest representable + for _ in 0..100_000 { + let decomposer = random_decomp::(); + let input = any_uint::(); + for term in decomposer.decompose(input) { + assert!(1 <= term.level().0); + assert!(term.level().0 <= decomposer.level_count); + let signed_term = term.value().into_signed(); + let half_basis = (T::Signed::ONE << decomposer.base_log) / T::TWO.into_signed(); + assert!(-half_basis <= signed_term); + assert!(signed_term <= half_basis); + } + let closest = decomposer.closest_representable(input); + assert_eq!( + closest, + decomposer.recompose(decomposer.decompose(closest)).unwrap() + ); + } +} + +#[test] +fn test_decompose_recompose_u32() { + test_decompose_recompose::() +} + +#[test] +fn test_decompose_recompose_u64() { + test_decompose_recompose::() +} + +fn test_decompose_recompose_tensor>() +where + ::Signed: Debug + SignedInteger, +{ + // Checks that the decomposing and recomposing a value brings the closest representable + for _ in 0..100_000 { + let decomposer = random_decomp::(); + let input = Tensor::allocate(any_uint::(), 1); + let mut decomp = decomposer.decompose_tensor(&input); + while let Some(term) = decomp.next_term() { + assert!(1 <= term.level().0); + assert!(term.level().0 <= decomposer.level_count); + let signed_term = term.as_tensor().get_element(0).into_signed(); + let half_basis = (T::Signed::ONE << decomposer.base_log) / T::TWO.into_signed(); + assert!(-half_basis <= signed_term); + assert!(signed_term <= half_basis); + } + let mut rounded = Tensor::allocate(T::ZERO, 1); + decomposer.fill_tensor_with_closest_representable(&mut rounded, &input); + let mut recomposition = Tensor::allocate(T::ZERO, 1); + let decomp_iter = decomposer.decompose_tensor(&rounded); + decomposer.fill_tensor_with_recompose(decomp_iter, &mut recomposition); + assert_eq!(rounded, recomposition); + } +} + +#[test] +fn test_decompose_recompose_tensor_u32() { + test_decompose_recompose_tensor::() +} + +#[test] +fn test_decompose_recompose_tensor_u64() { + test_decompose_recompose_tensor::() +} + +fn test_round_to_closest_representable() { + for _ in 0..1000 { + let log_b = any_usize(); + let level_max = any_usize(); + let val = any_uint::(); + let delta = any_uint::(); + let bits = T::BITS; + let log_b = (log_b % ((bits / 4) - 1)) + 1; + let level_max = (level_max % 4) + 1; + let bit: usize = log_b * level_max; + + let val = val << (bits - bit); + let delta = delta >> (bits - (bits - bit - 1)); + + let decomposer = SignedDecomposer::new( + DecompositionBaseLog(log_b), + DecompositionLevelCount(level_max), + ); + + assert_eq!( + val, + decomposer.closest_representable(val.wrapping_add(delta)) + ); + assert_eq!( + val, + decomposer.closest_representable(val.wrapping_sub(delta)) + ); + } +} + +#[test] +fn test_round_to_closest_representable_u32() { + test_round_to_closest_representable::(); +} + +#[test] +fn test_round_to_closest_representable_u64() { + test_round_to_closest_representable::(); +} + +fn test_round_to_closest_twice() { + for _ in 0..1000 { + let decomp = random_decomp(); + let input: T = any_uint(); + + let rounded_once = decomp.closest_representable(input); + let rounded_twice = decomp.closest_representable(rounded_once); + assert_eq!(rounded_once, rounded_twice); + } +} + +#[test] +fn test_round_to_closest_twice_u32() { + test_round_to_closest_twice::(); +} + +#[test] +fn test_round_to_closest_twice_u64() { + test_round_to_closest_twice::(); +} + +fn test_round_tensor_to_closest_twice() { + for _ in 0..1000 { + let decomp = random_decomp(); + let input: T = any_uint(); + let input = Tensor::from_container(vec![input]); + let mut rounded_once = Tensor::from_container(vec![T::ZERO]); + let mut rounded_twice = Tensor::from_container(vec![T::ZERO]); + + decomp.fill_tensor_with_closest_representable(&mut rounded_once, &input); + decomp.fill_tensor_with_closest_representable(&mut rounded_twice, &rounded_once); + assert_eq!(rounded_once, rounded_twice); + } +} + +#[test] +fn test_round_tensor_to_closest_twice_u32() { + test_round_tensor_to_closest_twice::(); +} + +#[test] +fn test_round_tensor_to_closest_twice_u64() { + test_round_tensor_to_closest_twice::(); +} diff --git a/tfhe/src/core_crypto/commons/math/mod.rs b/tfhe/src/core_crypto/commons/math/mod.rs new file mode 100644 index 000000000..417f61ba5 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/mod.rs @@ -0,0 +1,7 @@ +//! A module containing general mathematical tools. + +pub mod decomposition; +pub mod polynomial; +pub mod random; +pub mod tensor; +pub mod torus; diff --git a/tfhe/src/core_crypto/commons/math/polynomial/list.rs b/tfhe/src/core_crypto/commons/math/polynomial/list.rs new file mode 100644 index 000000000..9ceb58959 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/polynomial/list.rs @@ -0,0 +1,360 @@ +use crate::core_crypto::commons::math::tensor::Container; +use std::iter::Iterator; + +use crate::core_crypto::commons::math::tensor::{ + ck_dim_div, tensor_traits, AsMutTensor, AsRefSlice, AsRefTensor, Split, Tensor, +}; + +use super::*; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{MonomialDegree, PolynomialCount, PolynomialSize}; + +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; + +/// A generic polynomial list type. +/// +/// This type represents a set of polynomial of homogeneous degree. +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::commons::math::polynomial::PolynomialList; +/// use tfhe::core_crypto::prelude::{PolynomialCount, PolynomialSize}; +/// let list = PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6, 7, 8], PolynomialSize(2)); +/// assert_eq!(list.polynomial_count(), PolynomialCount(4)); +/// assert_eq!(list.polynomial_size(), PolynomialSize(2)); +/// ``` +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct PolynomialList { + pub(crate) tensor: Tensor, + pub(crate) poly_size: PolynomialSize, +} + +tensor_traits!(PolynomialList); + +impl PolynomialList> +where + Coef: Copy, +{ + /// Allocates a new polynomial list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::PolynomialList; + /// use tfhe::core_crypto::prelude::{PolynomialCount, PolynomialSize}; + /// let list = PolynomialList::allocate(1u8, PolynomialCount(10), PolynomialSize(2)); + /// assert_eq!(list.polynomial_count(), PolynomialCount(10)); + /// assert_eq!(list.polynomial_size(), PolynomialSize(2)); + /// ``` + pub fn allocate(value: Coef, number: PolynomialCount, size: PolynomialSize) -> Self { + PolynomialList { + tensor: Tensor::from_container(vec![value; number.0 * size.0]), + poly_size: size, + } + } +} + +impl PolynomialList { + /// Creates a polynomial list from a list of values. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::PolynomialList; + /// use tfhe::core_crypto::prelude::{PolynomialCount, PolynomialSize}; + /// let list = PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6, 7, 8], PolynomialSize(2)); + /// assert_eq!(list.polynomial_count(), PolynomialCount(4)); + /// assert_eq!(list.polynomial_size(), PolynomialSize(2)); + /// ``` + pub fn from_container(cont: Cont, poly_size: PolynomialSize) -> PolynomialList + where + Cont: AsRefSlice, + { + ck_dim_div!(cont.as_slice().len() => poly_size.0); + PolynomialList { + tensor: Tensor::from_container(cont), + poly_size, + } + } + + pub fn into_container(self) -> Cont { + self.tensor.into_container() + } + + pub fn as_view(&self) -> PolynomialList<&'_ [Cont::Element]> + where + Cont: Container, + { + PolynomialList { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + poly_size: self.poly_size, + } + } + + pub fn as_mut_view(&mut self) -> PolynomialList<&'_ mut [Cont::Element]> + where + Cont: Container + AsMut<[Cont::Element]>, + { + PolynomialList { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + poly_size: self.poly_size, + } + } + + /// Returns the number of polynomials in the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::PolynomialList; + /// use tfhe::core_crypto::prelude::{PolynomialCount, PolynomialSize}; + /// let list = PolynomialList::allocate(1u8, PolynomialCount(10), PolynomialSize(2)); + /// assert_eq!(list.polynomial_count(), PolynomialCount(10)); + /// ``` + pub fn polynomial_count(&self) -> PolynomialCount + where + Self: AsRefTensor, + { + PolynomialCount(self.as_tensor().len() / self.poly_size.0) + } + + /// Returns the size of the polynomials in the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::PolynomialList; + /// use tfhe::core_crypto::prelude::{PolynomialCount, PolynomialSize}; + /// let list = PolynomialList::allocate(1u8, PolynomialCount(10), PolynomialSize(2)); + /// assert_eq!(list.polynomial_size(), PolynomialSize(2)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize { + self.poly_size + } + + /// Returns a reference to the n-th polynomial of the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, PolynomialList}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let list = PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6, 7, 8], PolynomialSize(2)); + /// let poly = list.get_polynomial(2); + /// assert_eq!(*poly.get_monomial(MonomialDegree(0)).get_coefficient(), 5u8); + /// assert_eq!(*poly.get_monomial(MonomialDegree(1)).get_coefficient(), 6u8); + /// ``` + pub fn get_polynomial(&self, n: usize) -> Polynomial<&[::Element]> + where + Self: AsRefTensor, + { + Polynomial { + tensor: self + .as_tensor() + .get_sub((n * self.poly_size.0)..(n + 1) * self.poly_size.0), + } + } + + /// Returns a mutable reference to the n-th polynomial of the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, PolynomialList}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut list = + /// PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6, 7, 8], PolynomialSize(2)); + /// let mut poly = list.get_mut_polynomial(2); + /// poly.get_mut_monomial(MonomialDegree(0)) + /// .set_coefficient(10u8); + /// poly.get_mut_monomial(MonomialDegree(1)) + /// .set_coefficient(11u8); + /// let poly = list.get_polynomial(2); + /// assert_eq!( + /// *poly.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 10u8 + /// ); + /// assert_eq!( + /// *poly.get_monomial(MonomialDegree(1)).get_coefficient(), + /// 11u8 + /// ); + /// ``` + pub fn get_mut_polynomial( + &mut self, + n: usize, + ) -> Polynomial<&mut [::Element]> + where + Self: AsMutTensor, + { + let index = (n * self.poly_size.0)..((n + 1) * self.poly_size.0); + Polynomial { + tensor: self.as_mut_tensor().get_sub_mut(index), + } + } + + pub fn into_polynomial_iter(self) -> impl DoubleEndedIterator> + where + Cont: Split, + { + let poly_size = self.polynomial_size(); + self.tensor + .into_container() + .into_chunks(poly_size.0) + .map(Polynomial::from_container) + } + + /// Returns an iterator over references to the polynomials contained in the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::PolynomialList; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut list = + /// PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6, 7, 8], PolynomialSize(2)); + /// for polynomial in list.polynomial_iter() { + /// assert_eq!(polynomial.polynomial_size(), PolynomialSize(2)); + /// } + /// assert_eq!(list.polynomial_iter().count(), 4); + /// ``` + pub fn polynomial_iter( + &self, + ) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(self.poly_size.0) + .map(|sub| Polynomial::from_container(sub.into_container())) + } + + #[cfg(feature = "__commons_parallel")] + pub fn par_polynomial_iter( + &self, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsRefTensor, + ::Element: Sync + Send, + { + self.as_tensor() + .par_subtensor_iter(self.poly_size.0) + .map(|sub| Polynomial::from_container(sub.into_container())) + } + + /// Returns an iterator over mutable references to the polynomials contained in the list. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, PolynomialList}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut list = + /// PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6, 7, 8], PolynomialSize(2)); + /// for mut polynomial in list.polynomial_iter_mut() { + /// polynomial + /// .get_mut_monomial(MonomialDegree(0)) + /// .set_coefficient(10u8); + /// assert_eq!(polynomial.polynomial_size(), PolynomialSize(2)); + /// } + /// for polynomial in list.polynomial_iter() { + /// assert_eq!( + /// *polynomial.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 10u8 + /// ); + /// } + /// assert_eq!(list.polynomial_iter_mut().count(), 4); + /// ``` + pub fn polynomial_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + let chunks_size = self.poly_size.0; + self.as_mut_tensor() + .subtensor_iter_mut(chunks_size) + .map(|sub| Polynomial::from_container(sub.into_container())) + } + + /// Creates an iterator over borrowed sub-lists. + pub fn sublist_iter( + &self, + count: PolynomialCount, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsRefTensor, + { + ck_dim_div!(self.polynomial_count().0 => count.0); + let polynomial_size = self.polynomial_size(); + self.as_tensor() + .subtensor_iter(count.0 * polynomial_size.0) + .map(move |sub| PolynomialList::from_container(sub.into_container(), polynomial_size)) + } + + /// Creates an iterator over mutably borrowed sub-lists. + pub fn sublist_iter_mut( + &mut self, + count: PolynomialCount, + ) -> impl DoubleEndedIterator::Element]>> + where + Self: AsMutTensor, + { + ck_dim_div!(self.polynomial_count().0 => count.0); + let polynomial_size = self.polynomial_size(); + self.as_mut_tensor() + .subtensor_iter_mut(count.0 * polynomial_size.0) + .map(move |sub| PolynomialList::from_container(sub.into_container(), polynomial_size)) + } + + /// Multiplies (mod $(X^N+1)$), all the polynomials of the list with a unit monomial of a + /// given degree. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, PolynomialList}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut list = PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6], PolynomialSize(3)); + /// list.update_with_wrapping_monic_monomial_mul(MonomialDegree(2)); + /// let poly = list.get_polynomial(0); + /// assert_eq!(*poly.get_monomial(MonomialDegree(0)).get_coefficient(), 254); + /// assert_eq!(*poly.get_monomial(MonomialDegree(1)).get_coefficient(), 253); + /// assert_eq!(*poly.get_monomial(MonomialDegree(2)).get_coefficient(), 1); + /// ``` + pub fn update_with_wrapping_monic_monomial_mul(&mut self, monomial_degree: MonomialDegree) + where + Self: AsMutTensor, + Coef: UnsignedInteger, + { + for mut poly in self.polynomial_iter_mut() { + poly.update_with_wrapping_monic_monomial_mul(monomial_degree); + } + } + + /// Divides (mod $(X^N+1)$), all the polynomials of the list with a unit monomial of a + /// given degree. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, PolynomialList}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut list = PolynomialList::from_container(vec![1u8, 2, 3, 4, 5, 6], PolynomialSize(3)); + /// list.update_with_wrapping_monic_monomial_div(MonomialDegree(2)); + /// let poly = list.get_polynomial(0); + /// assert_eq!(*poly.get_monomial(MonomialDegree(0)).get_coefficient(), 3); + /// assert_eq!(*poly.get_monomial(MonomialDegree(1)).get_coefficient(), 255); + /// assert_eq!(*poly.get_monomial(MonomialDegree(2)).get_coefficient(), 254); + /// ``` + pub fn update_with_wrapping_monic_monomial_div(&mut self, monomial_degree: MonomialDegree) + where + Self: AsMutTensor, + Coef: UnsignedInteger, + { + for mut poly in self.polynomial_iter_mut() { + poly.update_with_wrapping_unit_monomial_div(monomial_degree); + } + } +} diff --git a/tfhe/src/core_crypto/commons/math/polynomial/mod.rs b/tfhe/src/core_crypto/commons/math/polynomial/mod.rs new file mode 100644 index 000000000..1d0997a23 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/polynomial/mod.rs @@ -0,0 +1,23 @@ +//! A module to manipulate polynomials. +//! +//! This module allows to manipulate modular polynomials In particular, we provide three generic +//! types to manipulate such objects: +//! +//! + [`Monomial`], which represents a free monomial term (not bound to a given modular degree) +//! + [`Polynomial`], which represents a dense polynomial of a given degree. +//! + [`PolynomialList`], which represent a set of polynomials with the same degree, on which +//! operations can be performed. + +pub use list::*; +pub use monomial::*; +pub use polynomial::*; + +#[cfg(test)] +mod tests; + +mod list; +mod monomial; +#[allow(clippy::module_inception)] +mod polynomial; + +pub use crate::core_crypto::prelude::MonomialDegree; diff --git a/tfhe/src/core_crypto/commons/math/polynomial/monomial.rs b/tfhe/src/core_crypto/commons/math/polynomial/monomial.rs new file mode 100644 index 000000000..92f74a2a2 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/polynomial/monomial.rs @@ -0,0 +1,154 @@ +use crate::core_crypto::commons::math::tensor::{ + tensor_traits, AsMutElement, AsMutTensor, AsRefElement, AsRefSlice, AsRefTensor, Tensor, +}; +use crate::core_crypto::prelude::MonomialDegree; + +/// A monomial term. +/// +/// This type represents a free monomial term of a given degree. +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::commons::math::polynomial::Monomial; +/// use tfhe::core_crypto::prelude::MonomialDegree; +/// let mono = Monomial::allocate(1u8, MonomialDegree(5)); +/// assert_eq!(*mono.get_coefficient(), 1u8); +/// assert_eq!(mono.degree(), MonomialDegree(5)); +/// ``` +#[derive(PartialEq, Eq)] +pub struct Monomial { + tensor: Tensor, + degree: MonomialDegree, +} + +tensor_traits!(Monomial); + +impl AsRefElement for Monomial +where + Monomial: AsRefTensor, +{ + type Element = as AsRefTensor>::Element; + fn as_element(&self) -> &Self::Element { + self.as_tensor().first() + } +} + +impl AsMutElement for Monomial +where + Monomial: AsMutTensor, +{ + type Element = as AsRefTensor>::Element; + fn as_mut_element(&mut self) -> &mut ::Element { + self.as_mut_tensor().first_mut() + } +} + +impl Monomial> { + /// Allocates a new monomial. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{Monomial, MonomialDegree}; + /// let mono = Monomial::allocate(1u8, MonomialDegree(5)); + /// assert_eq!(*mono.get_coefficient(), 1u8); + /// assert_eq!(mono.degree(), MonomialDegree(5)); + /// ``` + pub fn allocate(value: Coef, degree: MonomialDegree) -> Monomial> { + Monomial { + tensor: Tensor::from_container(vec![value]), + degree, + } + } +} + +impl Monomial { + /// Creates a new monomial from a value container and a degree. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{Monomial, MonomialDegree}; + /// let vector = vec![1u8]; + /// let mono = Monomial::from_container(vector.as_slice(), MonomialDegree(5)); + /// assert_eq!(*mono.get_coefficient(), 1u8); + /// assert_eq!(mono.degree(), MonomialDegree(5)); + /// ``` + pub fn from_container(cont: Cont, degree: MonomialDegree) -> Monomial + where + Cont: AsRefSlice, + { + debug_assert!( + cont.as_slice().len() == 1, + "Tried to create a monomial with a container of size different than one" + ); + Monomial { + tensor: Tensor::from_container(cont), + degree, + } + } + + /// Returns a reference to the monomial coefficient. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{Monomial, MonomialDegree}; + /// let mono = Monomial::allocate(1u8, MonomialDegree(5)); + /// assert_eq!(*mono.get_coefficient(), 1u8); + /// ``` + pub fn get_coefficient(&self) -> &::Element + where + Self: AsRefElement, + { + self.as_element() + } + + /// Sets the monomial coefficient. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{Monomial, MonomialDegree}; + /// let mut mono = Monomial::allocate(1u8, MonomialDegree(5)); + /// mono.set_coefficient(5u8); + /// assert_eq!(*mono.get_coefficient(), 5u8); + /// ``` + pub fn set_coefficient(&mut self, coefficient: Coef) + where + Self: AsMutElement, + { + *(self.as_mut_element()) = coefficient; + } + + /// Returns a mutable reference to the coefficient. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{Monomial, MonomialDegree}; + /// let mut mono = Monomial::allocate(1u8, MonomialDegree(5)); + /// *mono.get_mut_coefficient() += 1u8; + /// assert_eq!(*mono.get_coefficient(), 2u8); + /// ``` + pub fn get_mut_coefficient(&mut self) -> &mut ::Element + where + Self: AsMutElement, + { + self.as_mut_element() + } + + /// Returns the degree of the monomial. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{Monomial, MonomialDegree}; + /// let mono = Monomial::allocate(1u8, MonomialDegree(5)); + /// assert_eq!(mono.degree(), MonomialDegree(5)); + /// ``` + pub fn degree(&self) -> MonomialDegree { + self.degree + } +} diff --git a/tfhe/src/core_crypto/commons/math/polynomial/polynomial.rs b/tfhe/src/core_crypto/commons/math/polynomial/polynomial.rs new file mode 100644 index 000000000..4e470dcdf --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/polynomial/polynomial.rs @@ -0,0 +1,900 @@ +use crate::core_crypto::commons::math::tensor::Container; +use std::fmt::Debug; +use std::iter::Iterator; + +use crate::core_crypto::commons::math::tensor::{ + ck_dim_eq, tensor_traits, AsMutSlice, AsMutTensor, AsRefTensor, Tensor, +}; + +use super::*; +use crate::core_crypto::commons::numeric::UnsignedInteger; +use crate::core_crypto::prelude::{MonomialDegree, PolynomialSize}; + +// stop the induction when polynomials have KARATUSBA_STOP elements +const KARATUSBA_STOP: usize = 32; + +/// A dense polynomial. +/// +/// This type represent a dense polynomial in $\mathbb{Z}\_{2^q}\[X\] / $, composed of $N$ +/// integer coefficients encoded on $q$ bits. +/// +/// # Example: +/// +/// ``` +/// use tfhe::core_crypto::commons::math::polynomial::Polynomial; +/// use tfhe::core_crypto::prelude::PolynomialSize; +/// let poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); +/// assert_eq!(poly.polynomial_size(), PolynomialSize(100)); +/// ``` +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct Polynomial { + pub(crate) tensor: Tensor, +} + +tensor_traits!(Polynomial); + +impl Polynomial> +where + Scalar: Copy, +{ + /// Allocates a new polynomial. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); + /// assert_eq!(poly.polynomial_size(), PolynomialSize(100)); + /// ``` + pub fn allocate(value: Scalar, coef_count: PolynomialSize) -> Polynomial> { + Polynomial::from_container(vec![value; coef_count.0]) + } +} + +impl Polynomial { + /// Creates a polynomial from a container of values. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let vec = vec![0 as u32; 100]; + /// let poly = Polynomial::from_container(vec.as_slice()); + /// assert_eq!(poly.polynomial_size(), PolynomialSize(100)); + /// ``` + pub fn from_container(cont: Cont) -> Self { + Polynomial { + tensor: Tensor::from_container(cont), + } + } + + pub(crate) fn from_tensor(tensor: Tensor) -> Self { + Polynomial { tensor } + } + + pub fn as_view(&self) -> Polynomial<&'_ [Cont::Element]> + where + Cont: Container, + { + Polynomial { + tensor: Tensor::from_container(self.tensor.as_container().as_ref()), + } + } + + pub fn as_mut_view(&mut self) -> Polynomial<&'_ mut [Cont::Element]> + where + Cont: Container, + Cont: AsMut<[Cont::Element]>, + { + Polynomial { + tensor: Tensor::from_container(self.tensor.as_mut_container().as_mut()), + } + } + + /// Returns the number of coefficients in the polynomial. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); + /// assert_eq!(poly.polynomial_size(), PolynomialSize(100)); + /// ``` + pub fn polynomial_size(&self) -> PolynomialSize + where + Self: AsRefTensor, + { + PolynomialSize(self.as_tensor().len()) + } + + /// Builds an iterator over `Monomial<&Coef>` elements. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); + /// for monomial in poly.monomial_iter() { + /// assert!(monomial.degree().0 <= 99) + /// } + /// assert_eq!(poly.monomial_iter().count(), 100); + /// ``` + pub fn monomial_iter(&self) -> impl Iterator::Element]>> + where + Self: AsRefTensor, + { + self.as_tensor() + .subtensor_iter(1) + .enumerate() + .map(|(i, coef)| Monomial::from_container(coef.into_container(), MonomialDegree(i))) + } + + /// Builds an iterator over `&Coef` elements, in order of increasing degree. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); + /// for coef in poly.coefficient_iter() { + /// assert_eq!(*coef, 0); + /// } + /// assert_eq!(poly.coefficient_iter().count(), 100); + /// ``` + pub fn coefficient_iter( + &self, + ) -> impl DoubleEndedIterator::Element> + where + Self: AsRefTensor, + { + self.as_tensor().iter() + } + + /// Returns the monomial of a given degree. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let poly = Polynomial::from_container(vec![16_u32, 8, 19, 12, 3]); + /// let mono = poly.get_monomial(MonomialDegree(0)); + /// assert_eq!(*mono.get_coefficient(), 16_u32); + /// let mono = poly.get_monomial(MonomialDegree(2)); + /// assert_eq!(*mono.get_coefficient(), 19_u32); + /// ``` + pub fn get_monomial( + &self, + degree: MonomialDegree, + ) -> Monomial<&[::Element]> + where + Self: AsRefTensor, + { + Monomial::from_container( + self.as_tensor() + .get_sub(degree.0..=degree.0) + .into_container(), + degree, + ) + } + + /// Builds an iterator over `Monomial<&mut Coef>` elements. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); + /// for mut monomial in poly.monomial_iter_mut() { + /// monomial.set_coefficient(monomial.degree().0 as u32); + /// } + /// for (i, monomial) in poly.monomial_iter().enumerate() { + /// assert_eq!(*monomial.get_coefficient(), i as u32); + /// } + /// assert_eq!(poly.monomial_iter_mut().count(), 100); + /// ``` + pub fn monomial_iter_mut( + &mut self, + ) -> impl Iterator::Element]>> + where + Self: AsMutTensor, + { + self.as_mut_tensor() + .subtensor_iter_mut(1) + .enumerate() + .map(|(i, coef)| Monomial::from_container(coef.into_container(), MonomialDegree(i))) + } + + /// Builds an iterator over `&mut Coef` elements, in order of increasing + /// degree. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::Polynomial; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut poly = Polynomial::allocate(0 as u32, PolynomialSize(100)); + /// for mut coef in poly.coefficient_iter_mut() { + /// *coef = 1; + /// } + /// for coef in poly.coefficient_iter() { + /// assert_eq!(*coef, 1); + /// } + /// assert_eq!(poly.coefficient_iter_mut().count(), 100); + /// ``` + pub fn coefficient_iter_mut( + &mut self, + ) -> impl DoubleEndedIterator::Element> + where + Self: AsMutTensor, + { + self.as_mut_tensor().iter_mut() + } + + /// Returns the mutable monomial of a given degree. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let mut poly = Polynomial::from_container(vec![16_u32, 8, 19, 12, 3]); + /// let mut mono = poly.get_mut_monomial(MonomialDegree(0)); + /// mono.set_coefficient(18); + /// let mono = poly.get_monomial(MonomialDegree(0)); + /// assert_eq!(*mono.get_coefficient(), 18); + /// ``` + pub fn get_mut_monomial( + &mut self, + degree: MonomialDegree, + ) -> Monomial<&mut [::Element]> + where + Self: AsMutTensor, + { + Monomial::from_container( + self.as_mut_tensor() + .get_sub_mut(degree.0..=degree.0) + .into_container(), + degree, + ) + } + + /// Fills the current polynomial, with the result of the (slow) product of + /// two polynomials, reduced modulo $(X^N + 1)$. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let lhs = Polynomial::from_container(vec![4_u8, 5, 0]); + /// let rhs = Polynomial::from_container(vec![7_u8, 9, 0]); + /// let mut res = Polynomial::allocate(0 as u8, PolynomialSize(3)); + /// res.fill_with_wrapping_mul(&lhs, &rhs); + /// assert_eq!( + /// *res.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 28 as u8 + /// ); + /// assert_eq!( + /// *res.get_monomial(MonomialDegree(1)).get_coefficient(), + /// 71 as u8 + /// ); + /// assert_eq!( + /// *res.get_monomial(MonomialDegree(2)).get_coefficient(), + /// 45 as u8 + /// ); + /// ``` + pub fn fill_with_wrapping_mul( + &mut self, + lhs: &Polynomial, + rhs: &Polynomial, + ) where + Self: AsMutTensor, + Polynomial: AsRefTensor, + Polynomial: AsRefTensor, + Coef: UnsignedInteger, + { + ck_dim_eq!(self.polynomial_size() => lhs.polynomial_size(), rhs.polynomial_size()); + self.coefficient_iter_mut().for_each(|a| *a = Coef::ZERO); + let degree = lhs.polynomial_size().0 - 1; + for lhsi in lhs.monomial_iter() { + for rhsi in rhs.monomial_iter() { + let target_degree = lhsi.degree().0 + rhsi.degree().0; + if target_degree <= degree { + let element = self.as_mut_tensor().get_element_mut(target_degree); + let new = lhsi.get_coefficient().wrapping_mul(*rhsi.get_coefficient()); + *element = element.wrapping_add(new); + } else { + let element = self + .as_mut_tensor() + .get_element_mut(target_degree % (degree + 1)); + let new = lhsi.get_coefficient().wrapping_mul(*rhsi.get_coefficient()); + *element = element.wrapping_sub(new); + } + } + } + } + + /// Fills the current polynomial, with the result of the product of two + /// polynomials, reduced modulo $(X^N + 1)$ with the Karatsuba algorithm + /// Complexity: N^{1.58} + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let lhs = Polynomial::from_container(vec![1_u32; 128]); + /// let rhs = Polynomial::from_container(vec![2_u32; 128]); + /// let mut res_kara = Polynomial::allocate(0 as u32, PolynomialSize(128)); + /// let mut res_mul = Polynomial::allocate(0 as u32, PolynomialSize(128)); + /// res_kara.fill_with_karatsuba_mul(&lhs, &rhs); + /// res_mul.fill_with_wrapping_mul(&lhs, &rhs); + /// assert_eq!(res_kara, res_mul); + /// ``` + pub fn fill_with_karatsuba_mul( + &mut self, + p: &Polynomial, + q: &Polynomial, + ) where + Self: AsMutTensor, + Polynomial: AsRefTensor, + Polynomial: AsRefTensor, + Coef: UnsignedInteger, + { + // check same dimensions + ck_dim_eq!(self.polynomial_size() => p.polynomial_size(), q.polynomial_size()); + + // check dimensions are a power of 2 + debug_assert!( + f64::abs( + f64::floor(f64::log2(p.polynomial_size().0 as f64)) + - f64::log2(p.polynomial_size().0 as f64) + ) < f64::EPSILON + ); + + let poly_size = self.polynomial_size().0; + + // allocate slices for the rec + let mut a0 = Tensor::allocate(Coef::ZERO, poly_size); + let mut a1 = Tensor::allocate(Coef::ZERO, poly_size); + let mut a2 = Tensor::allocate(Coef::ZERO, poly_size); + let mut input_a2_p = Tensor::allocate(Coef::ZERO, poly_size / 2); + let mut input_a2_q = Tensor::allocate(Coef::ZERO, poly_size / 2); + + // prepare for splitting + let bottom = 0..(poly_size / 2); + let top = (poly_size / 2)..poly_size; + + // induction + induction_karatsuba( + &mut a0.get_sub_mut(..), + &p.as_tensor().get_sub(bottom.clone()), + &q.as_tensor().get_sub(bottom.clone()), + ); + induction_karatsuba( + &mut a1.get_sub_mut(..), + &p.as_tensor().get_sub(top.clone()), + &q.as_tensor().get_sub(top.clone()), + ); + input_a2_p.fill_with_wrapping_add( + &p.as_tensor().get_sub(bottom.clone()), + &p.as_tensor().get_sub(top.clone()), + ); + input_a2_q.fill_with_wrapping_add( + &q.as_tensor().get_sub(bottom.clone()), + &q.as_tensor().get_sub(top.clone()), + ); + induction_karatsuba( + &mut a2.get_sub_mut(..), + &input_a2_p.get_sub(..), + &input_a2_q.get_sub(..), + ); + + // rebuild the result + self.as_mut_tensor().fill_with_wrapping_sub(&a0, &a1); + self.as_mut_tensor() + .get_sub_mut(bottom.clone()) + .update_with_wrapping_sub(&a2.get_sub(top.clone())); + self.as_mut_tensor() + .get_sub_mut(bottom.clone()) + .update_with_wrapping_add(&a0.get_sub(top.clone())); + self.as_mut_tensor() + .get_sub_mut(bottom.clone()) + .update_with_wrapping_add(&a1.get_sub(top.clone())); + self.as_mut_tensor() + .get_sub_mut(top.clone()) + .update_with_wrapping_add(&a2.get_sub(bottom.clone())); + self.as_mut_tensor() + .get_sub_mut(top.clone()) + .update_with_wrapping_sub(&a0.get_sub(bottom.clone())); + self.as_mut_tensor() + .get_sub_mut(top) + .update_with_wrapping_sub(&a1.get_sub(bottom)); + } + + /// Adds the sum of the element-wise product between two lists of integer polynomial to the + /// current polynomial. + /// + /// I.e., if the current polynomial is $C(X)$, for a collection of polynomials $(P\_i(X)))\_i$ + /// and another collection of polynomials $(B\_i(X))\_i$ we perform the operation: + /// $$ + /// C(X) := C(X) + \sum\_i P\_i(X) \times B\_i(X) mod (X^N + 1) + /// $$ + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{ + /// MonomialDegree, Polynomial, PolynomialList, + /// }; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let poly_list = PolynomialList::from_container(vec![100_u8, 20, 3, 4, 5, 6], PolynomialSize(3)); + /// let bin_poly_list = PolynomialList::from_container(vec![0, 1, 1, 1, 0, 0], PolynomialSize(3)); + /// let mut output = Polynomial::allocate(250, PolynomialSize(3)); + /// output.update_with_wrapping_add_multisum(&poly_list, &bin_poly_list); + /// assert_eq!( + /// *output.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 231 + /// ); + /// assert_eq!( + /// *output.get_monomial(MonomialDegree(1)).get_coefficient(), + /// 96 + /// ); + /// assert_eq!( + /// *output.get_monomial(MonomialDegree(2)).get_coefficient(), + /// 120 + /// ); + /// ``` + pub fn update_with_wrapping_add_multisum( + &mut self, + coef_list: &PolynomialList, + bin_list: &PolynomialList, + ) where + Self: AsMutTensor, + PolynomialList: AsRefTensor, + PolynomialList: AsRefTensor, + for<'a> Polynomial<&'a [Coef]>: AsRefTensor, + for<'a> Polynomial<&'a [Coef]>: AsRefTensor, + Coef: UnsignedInteger, + { + for (poly, bin_poly) in coef_list.polynomial_iter().zip(bin_list.polynomial_iter()) { + self.update_with_wrapping_add_mul(&poly, &bin_poly); + } + } + + /// Subtracts the sum of the element-wise product between two lists of integer polynomials, + /// to the current polynomial. + /// + /// I.e., if the current polynomial is $C(X)$, for two lists of polynomials $(P\_i(X)))\_i$ and + /// $(B\_i(X))\_i$ we perform the operation: + /// $$ + /// C(X) := C(X) + \sum\_i P\_i(X) \times B\_i(X) mod (X^N + 1) + /// $$ + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{ + /// MonomialDegree, Polynomial, PolynomialList, + /// }; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let poly_list = + /// PolynomialList::from_container(vec![100 as u8, 20, 3, 4, 5, 6], PolynomialSize(3)); + /// let bin_poly_list = PolynomialList::from_container(vec![0, 1, 1, 1, 0, 0], PolynomialSize(3)); + /// let mut output = Polynomial::allocate(250 as u8, PolynomialSize(3)); + /// output.update_with_wrapping_sub_multisum(&poly_list, &bin_poly_list); + /// assert_eq!( + /// *output.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 13 + /// ); + /// assert_eq!( + /// *output.get_monomial(MonomialDegree(1)).get_coefficient(), + /// 148 + /// ); + /// assert_eq!( + /// *output.get_monomial(MonomialDegree(2)).get_coefficient(), + /// 124 + /// ); + /// ``` + pub fn update_with_wrapping_sub_multisum( + &mut self, + coef_list: &PolynomialList, + bin_list: &PolynomialList, + ) where + Self: AsMutTensor, + PolynomialList: AsRefTensor, + PolynomialList: AsRefTensor, + for<'a> Polynomial<&'a [Coef]>: AsRefTensor, + for<'a> Polynomial<&'a [Coef]>: AsRefTensor, + Coef: UnsignedInteger, + { + for (poly, bin_poly) in coef_list.polynomial_iter().zip(bin_list.polynomial_iter()) { + self.update_with_wrapping_sub_mul(&poly, &bin_poly); + } + } + + /// Adds the result of the product between two integer polynomials, reduced modulo $(X^N+1)$, + /// to the current polynomial. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let poly_1 = Polynomial::from_container(vec![1_u8, 2, 3]); + /// let poly_2 = Polynomial::from_container(vec![0, 1, 1]); + /// let mut res = Polynomial::from_container(vec![1, 0, 253]); + /// res.update_with_wrapping_add_mul(&poly_1, &poly_2); + /// assert_eq!(*res.get_monomial(MonomialDegree(0)).get_coefficient(), 252); + /// assert_eq!(*res.get_monomial(MonomialDegree(1)).get_coefficient(), 254); + /// assert_eq!(*res.get_monomial(MonomialDegree(2)).get_coefficient(), 0); + /// ``` + pub fn update_with_wrapping_add_mul( + &mut self, + polynomial: &Polynomial, + bin_polynomial: &Polynomial, + ) where + Self: AsMutTensor, + Polynomial: AsRefTensor, + Polynomial: AsRefTensor, + Coef: UnsignedInteger, + { + ck_dim_eq!( + self.polynomial_size() => + polynomial.polynomial_size(), + bin_polynomial.polynomial_size() + ); + let degree = polynomial.polynomial_size().0 - 1; + for lhsi in polynomial.monomial_iter() { + for rhsi in bin_polynomial.monomial_iter() { + let target_degree = lhsi.degree().0 + rhsi.degree().0; + if target_degree <= degree { + let update = self + .as_tensor() + .get_element(target_degree) + .wrapping_add(*lhsi.get_coefficient() * *rhsi.get_coefficient()); + *self.as_mut_tensor().get_element_mut(target_degree) = update; + } else { + let update = self + .as_tensor() + .get_element(target_degree % (degree + 1)) + .wrapping_sub(*lhsi.get_coefficient() * *rhsi.get_coefficient()); + *self + .as_mut_tensor() + .get_element_mut(target_degree % (degree + 1)) = update; + } + } + } + } + + /// Subtracts the result of the product between two integer polynomials, reduced + /// modulo $(X^N+1)$, to the current polynomial. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let poly = Polynomial::from_container(vec![1_u8, 2, 3]); + /// let bin_poly = Polynomial::from_container(vec![0, 1, 1]); + /// let mut res = Polynomial::from_container(vec![255, 255, 1]); + /// res.update_with_wrapping_sub_mul(&poly, &bin_poly); + /// assert_eq!(*res.get_monomial(MonomialDegree(0)).get_coefficient(), 4); + /// assert_eq!(*res.get_monomial(MonomialDegree(1)).get_coefficient(), 1); + /// assert_eq!(*res.get_monomial(MonomialDegree(2)).get_coefficient(), 254); + /// ``` + pub fn update_with_wrapping_sub_mul( + &mut self, + polynomial: &Polynomial, + bin_polynomial: &Polynomial, + ) where + Self: AsMutTensor, + Polynomial: AsRefTensor, + Polynomial: AsRefTensor, + Coef: UnsignedInteger, + { + ck_dim_eq!( + self.polynomial_size() => + polynomial.polynomial_size(), + bin_polynomial.polynomial_size() + ); + let degree = polynomial.polynomial_size().0 - 1; + for lhsi in polynomial.monomial_iter() { + for rhsi in bin_polynomial.monomial_iter() { + let target_degree = lhsi.degree().0 + rhsi.degree().0; + if target_degree <= degree { + let update = self + .as_tensor() + .get_element(target_degree) + .wrapping_sub(*lhsi.get_coefficient() * *rhsi.get_coefficient()); + *self.as_mut_tensor().get_element_mut(target_degree) = update; + } else { + let update = self + .as_tensor() + .get_element(target_degree % (degree + 1)) + .wrapping_add(*lhsi.get_coefficient() * *rhsi.get_coefficient()); + *self + .as_mut_tensor() + .as_mut_slice() + .get_mut(target_degree % (degree + 1)) + .unwrap() = update; + } + } + } + } + + /// Adds a integer polynomial to another one. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let mut first = Polynomial::from_container(vec![1u8, 2, 3]); + /// let second = Polynomial::from_container(vec![255u8, 255, 255]); + /// first.update_with_wrapping_add(&second); + /// assert_eq!(*first.get_monomial(MonomialDegree(0)).get_coefficient(), 0); + /// assert_eq!(*first.get_monomial(MonomialDegree(1)).get_coefficient(), 1); + /// assert_eq!(*first.get_monomial(MonomialDegree(2)).get_coefficient(), 2); + /// ``` + pub fn update_with_wrapping_add(&mut self, other: &Polynomial) + where + Self: AsMutTensor, + Polynomial: AsRefTensor, + Coef: UnsignedInteger, + { + ck_dim_eq!( + self.polynomial_size() => + other.polynomial_size() + ); + self.as_mut_tensor() + .update_with_wrapping_add(other.as_tensor()); + } + + /// Subtracts an integer polynomial to another one. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let mut first = Polynomial::from_container(vec![1u8, 2, 3]); + /// let second = Polynomial::from_container(vec![4u8, 5, 6]); + /// first.update_with_wrapping_sub(&second); + /// assert_eq!( + /// *first.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 253 + /// ); + /// assert_eq!( + /// *first.get_monomial(MonomialDegree(1)).get_coefficient(), + /// 253 + /// ); + /// assert_eq!( + /// *first.get_monomial(MonomialDegree(2)).get_coefficient(), + /// 253 + /// ); + /// ``` + pub fn update_with_wrapping_sub(&mut self, other: &Polynomial) + where + Self: AsMutTensor, + Polynomial: AsRefTensor, + Coef: UnsignedInteger, + { + ck_dim_eq!( + self.polynomial_size() => + other.polynomial_size() + ); + self.as_mut_tensor() + .update_with_wrapping_sub(other.as_tensor()); + } + + /// Multiplies (mod $(X^N+1)$), the current polynomial with a monomial of a given degree, and + /// a coefficient of one. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let mut poly = Polynomial::from_container(vec![1u8, 2, 3]); + /// poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(2)); + /// assert_eq!(*poly.get_monomial(MonomialDegree(0)).get_coefficient(), 254); + /// assert_eq!(*poly.get_monomial(MonomialDegree(1)).get_coefficient(), 253); + /// assert_eq!(*poly.get_monomial(MonomialDegree(2)).get_coefficient(), 1); + /// ``` + pub fn update_with_wrapping_monic_monomial_mul(&mut self, monomial_degree: MonomialDegree) + where + Self: AsMutTensor, + Coef: UnsignedInteger, + { + let full_cycles_count = monomial_degree.0 / self.as_tensor().len(); + if full_cycles_count % 2 != 0 { + self.as_mut_tensor() + .as_mut_slice() + .iter_mut() + .for_each(|a| *a = a.wrapping_neg()); + } + let remaining_degree = monomial_degree.0 % self.as_tensor().len(); + self.as_mut_tensor() + .as_mut_slice() + .rotate_right(remaining_degree); + self.as_mut_tensor() + .as_mut_slice() + .iter_mut() + .take(remaining_degree) + .for_each(|a| *a = a.wrapping_neg()); + } + + /// Divides (mod $(X^N+1)$), the current polynomial with a monomial of a given degree, and a + /// coefficient of one. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{MonomialDegree, Polynomial}; + /// let mut poly = Polynomial::from_container(vec![1u8, 2, 3]); + /// poly.update_with_wrapping_unit_monomial_div(MonomialDegree(2)); + /// assert_eq!(*poly.get_monomial(MonomialDegree(0)).get_coefficient(), 3); + /// assert_eq!(*poly.get_monomial(MonomialDegree(1)).get_coefficient(), 255); + /// assert_eq!(*poly.get_monomial(MonomialDegree(2)).get_coefficient(), 254); + /// ``` + pub fn update_with_wrapping_unit_monomial_div(&mut self, monomial_degree: MonomialDegree) + where + Self: AsMutTensor, + Coef: UnsignedInteger, + { + let full_cycles_count = monomial_degree.0 / self.as_tensor().len(); + if full_cycles_count % 2 != 0 { + self.as_mut_tensor() + .as_mut_slice() + .iter_mut() + .for_each(|a| *a = a.wrapping_neg()); + } + let remaining_degree = monomial_degree.0 % self.as_tensor().len(); + self.as_mut_tensor() + .as_mut_slice() + .rotate_left(remaining_degree); + self.as_mut_tensor() + .as_mut_slice() + .iter_mut() + .rev() + .take(remaining_degree) + .for_each(|a| *a = a.wrapping_neg()); + } + + /// Adds multiple integer polynomials to the current one. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{ + /// MonomialDegree, Polynomial, PolynomialList, + /// }; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut poly = Polynomial::from_container(vec![1u8, 2, 3]); + /// let poly_list = PolynomialList::from_container(vec![4u8, 5, 6, 7, 8, 9], PolynomialSize(3)); + /// poly.update_with_wrapping_add_several(&poly_list); + /// assert_eq!(*poly.get_monomial(MonomialDegree(0)).get_coefficient(), 12); + /// assert_eq!(*poly.get_monomial(MonomialDegree(1)).get_coefficient(), 15); + /// assert_eq!(*poly.get_monomial(MonomialDegree(2)).get_coefficient(), 18); + /// ``` + pub fn update_with_wrapping_add_several( + &mut self, + coef_list: &PolynomialList, + ) where + Self: AsMutTensor, + PolynomialList: AsRefTensor, + for<'a> Polynomial<&'a [Coef]>: AsRefTensor, + Coef: UnsignedInteger, + { + for poly in coef_list.polynomial_iter() { + self.update_with_wrapping_add(&poly); + } + } + + /// Subtracts multiple integer polynomials to the current one. + /// + /// # Examples + /// + /// ``` + /// use tfhe::core_crypto::commons::math::polynomial::{ + /// MonomialDegree, Polynomial, PolynomialList, + /// }; + /// use tfhe::core_crypto::prelude::PolynomialSize; + /// let mut poly = Polynomial::from_container(vec![1u32, 2, 3]); + /// let poly_list = PolynomialList::from_container(vec![4u32, 5, 6, 7, 8, 9], PolynomialSize(3)); + /// poly.update_with_wrapping_sub_several(&poly_list); + /// assert_eq!( + /// *poly.get_monomial(MonomialDegree(0)).get_coefficient(), + /// 4294967286 + /// ); + /// assert_eq!( + /// *poly.get_monomial(MonomialDegree(1)).get_coefficient(), + /// 4294967285 + /// ); + /// assert_eq!( + /// *poly.get_monomial(MonomialDegree(2)).get_coefficient(), + /// 4294967284 + /// ); + /// ``` + pub fn update_with_wrapping_sub_several( + &mut self, + coef_list: &PolynomialList, + ) where + Self: AsMutTensor, + PolynomialList: AsRefTensor, + for<'a> Polynomial<&'a [Coef]>: AsRefTensor, + Coef: UnsignedInteger, + { + for poly in coef_list.polynomial_iter() { + self.update_with_wrapping_sub(&poly); + } + } +} + +/// function used to compute the induction for the karatsuba algorithm +fn induction_karatsuba( + res: &mut Tensor<&mut [Coef]>, + p: &Tensor<&[Coef]>, + q: &Tensor<&[Coef]>, +) where + Coef: UnsignedInteger, +{ + if p.len() == KARATUSBA_STOP { + // schoolbook algorithm + for i in 0..p.len() { + for j in 0..q.len() { + *res.get_element_mut(i + j) = res + .get_element(i + j) + .wrapping_add(p.get_element(i).wrapping_mul(*q.get_element(j))) + } + } + } else { + let poly_size = res.len(); + + // allocate slices for the rec + let mut a0 = Tensor::allocate(Coef::ZERO, poly_size / 2); + let mut a1 = Tensor::allocate(Coef::ZERO, poly_size / 2); + let mut a2 = Tensor::allocate(Coef::ZERO, poly_size / 2); + let mut input_a2_p = Tensor::allocate(Coef::ZERO, poly_size / 4); + let mut input_a2_q = Tensor::allocate(Coef::ZERO, poly_size / 4); + + // prepare for splitting + let bottom = 0..(poly_size / 4); + let top = (poly_size / 4)..(poly_size / 2); + + // rec + induction_karatsuba( + &mut a0.get_sub_mut(..), + &p.get_sub(bottom.clone()), + &q.get_sub(bottom.clone()), + ); + induction_karatsuba( + &mut a1.get_sub_mut(..), + &p.get_sub(top.clone()), + &q.get_sub(top.clone()), + ); + input_a2_p + .as_mut_tensor() + .fill_with_wrapping_add(&p.get_sub(bottom.clone()), &p.get_sub(top.clone())); + input_a2_q + .as_mut_tensor() + .fill_with_wrapping_add(&q.get_sub(bottom), &q.get_sub(top)); + induction_karatsuba( + &mut a2.get_sub_mut(..), + &input_a2_p.get_sub(..), + &input_a2_q.get_sub(..), + ); + + // rebuild the result + res.get_sub_mut((poly_size / 4)..(3 * poly_size / 4)) + .as_mut_tensor() + .fill_with_wrapping_sub(&a2, &a0); + res.get_sub_mut((poly_size / 4)..(3 * poly_size / 4)) + .update_with_wrapping_sub(&a1); + res.get_sub_mut(0..(poly_size / 2)) + .update_with_wrapping_add(&a0); + res.get_sub_mut((poly_size / 2)..poly_size) + .update_with_wrapping_add(&a1); + } +} diff --git a/tfhe/src/core_crypto/commons/math/polynomial/tests.rs b/tfhe/src/core_crypto/commons/math/polynomial/tests.rs new file mode 100644 index 000000000..cd349d234 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/polynomial/tests.rs @@ -0,0 +1,115 @@ +use rand::Rng; + +use crate::core_crypto::prelude::{MonomialDegree, PolynomialSize}; + +use crate::core_crypto::commons::math::polynomial::Polynomial; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::test_tools::*; + +fn test_multiply_divide_unit_monomial() { + //! tests if multiply_by_monomial and divide_by_monomial cancel each other + let mut rng = rand::thread_rng(); + let mut generator = new_random_generator(); + + // settings + let polynomial_size = (rng.gen::() % 2048) + 1; + + // generates a random Torus polynomial + let mut poly = Polynomial::from_container( + generator + .random_uniform_tensor::(polynomial_size) + .into_container(), + ); + + // copy this polynomial + let ground_truth = poly.clone(); + + // generates a random r + let mut r: usize = rng.gen(); + r %= polynomial_size; + + // multiply by X^r and then divides by X^r + poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(r)); + poly.update_with_wrapping_unit_monomial_div(MonomialDegree(r)); + + // test + assert_eq!(&poly, &ground_truth); + + // generates a random r_big + let mut r_big: usize = rng.gen(); + r_big = r_big % polynomial_size + 2048; + + // multiply by X^r_big and then divides by X^r_big + poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(r_big)); + poly.update_with_wrapping_unit_monomial_div(MonomialDegree(r_big)); + + // test + assert_eq!(&poly, &ground_truth); + + // divides by X^r_big and then multiply by X^r_big + poly.update_with_wrapping_monic_monomial_mul(MonomialDegree(r_big)); + poly.update_with_wrapping_unit_monomial_div(MonomialDegree(r_big)); + + // test + assert_eq!(&poly, &ground_truth); +} + +/// test if we have the same result when using schoolbook or karatsuba +/// for random polynomial multiplication +fn test_multiply_karatsuba() { + // 50 times the test + for _i in 0..50 { + // random source + let mut rng = rand::thread_rng(); + + // random settings settings + let polynomial_log = (rng.gen::() % 7) + 6; + let polynomial_size = PolynomialSize(1 << polynomial_log); + let mut generator = new_random_generator(); + + // generates two random Torus polynomials + let poly_1 = Polynomial::from_container( + generator + .random_uniform_tensor::(polynomial_size.0) + .into_container(), + ); + let poly_2 = Polynomial::from_container( + generator + .random_uniform_tensor::(polynomial_size.0) + .into_container(), + ); + + // copy this polynomial + let mut sb_mul = Polynomial::allocate(T::ZERO, polynomial_size); + let mut ka_mul = Polynomial::allocate(T::ZERO, polynomial_size); + + // compute the schoolbook + sb_mul.fill_with_wrapping_mul(&poly_1, &poly_2); + + // compute the karatsuba + ka_mul.fill_with_karatsuba_mul(&poly_1, &poly_2); + + // test + assert_eq!(&sb_mul, &ka_mul); + } +} + +#[test] +pub fn test_multiply_divide_unit_monomial_u32() { + test_multiply_divide_unit_monomial::() +} + +#[test] +pub fn test_multiply_divide_unit_monomial_u64() { + test_multiply_divide_unit_monomial::() +} + +#[test] +pub fn test_multiply_karatsuba_u32() { + test_multiply_karatsuba::() +} + +#[test] +pub fn test_multiply_karatsuba_u64() { + test_multiply_karatsuba::() +} diff --git a/tfhe/src/core_crypto/commons/math/random/gaussian.rs b/tfhe/src/core_crypto/commons/math/random/gaussian.rs new file mode 100644 index 000000000..0c985a54e --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/gaussian.rs @@ -0,0 +1,85 @@ +use crate::core_crypto::commons::numeric::{CastInto, Numeric}; + +use crate::core_crypto::commons::math::torus::{FromTorus, UnsignedTorus}; + +use super::*; + +/// A distribution type representing random sampling of floating point numbers, following a +/// gaussian distribution. +#[derive(Clone, Copy)] +pub struct Gaussian { + /// The standard deviation of the distribution. + pub std: T, + /// The mean of the distribution. + pub mean: T, +} + +macro_rules! implement_gaussian { + ($T:ty, $S:ty) => { + impl RandomGenerable> for ($T, $T) { + fn generate_one( + generator: &mut RandomGenerator, + Gaussian { std, mean }: Gaussian<$T>, + ) -> Self { + let output: ($T, $T); + let mut uniform_rand = vec![0 as $S; 2]; + loop { + let n_bytes = (<$S as Numeric>::BITS * 2) / 8; + let uniform_rand_bytes = unsafe { + std::slice::from_raw_parts_mut( + uniform_rand.as_mut_ptr() as *mut u8, + n_bytes, + ) + }; + uniform_rand_bytes + .iter_mut() + .for_each(|a| *a = generator.generate_next()); + let size = <$T>::BITS as i32; + let mut u: $T = uniform_rand[0].cast_into(); + u *= <$T>::TWO.powi(-size + 1); + let mut v: $T = uniform_rand[1].cast_into(); + v *= <$T>::TWO.powi(-size + 1); + let s = u.powi(2) + v.powi(2); + if (s > <$T>::ZERO && s < <$T>::ONE) { + let cst = std * (-<$T>::TWO * s.ln() / s).sqrt(); + output = (u * cst + mean, v * cst + mean); + break; + } + } + output + } + } + }; +} + +implement_gaussian!(f32, i32); +implement_gaussian!(f64, i64); + +impl RandomGenerable> for (Torus, Torus) +where + Torus: UnsignedTorus, +{ + fn generate_one( + generator: &mut RandomGenerator, + distribution: Gaussian, + ) -> Self { + let (s1, s2) = <(f64, f64)>::generate_one(generator, distribution); + ( + >::from_torus(s1), + >::from_torus(s2), + ) + } +} + +impl RandomGenerable> for Torus +where + Torus: UnsignedTorus, +{ + fn generate_one( + generator: &mut RandomGenerator, + distribution: Gaussian, + ) -> Self { + let (s1, _) = <(f64, f64)>::generate_one(generator, distribution); + >::from_torus(s1) + } +} diff --git a/tfhe/src/core_crypto/commons/math/random/generator.rs b/tfhe/src/core_crypto/commons/math/random/generator.rs new file mode 100644 index 000000000..a7a27ed53 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/generator.rs @@ -0,0 +1,673 @@ +use crate::core_crypto::commons::math::random::{ + Gaussian, RandomGenerable, Uniform, UniformBinary, UniformLsb, UniformMsb, UniformTernary, + UniformWithZeros, +}; +use crate::core_crypto::commons::math::tensor::{AsMutSlice, AsMutTensor, Tensor}; +use crate::core_crypto::commons::numeric::{FloatingPoint, Numeric}; +use concrete_csprng::generators::{BytesPerChild, ChildrenCount, ForkError}; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; +use std::convert::TryInto; + +#[cfg(feature = "__commons_parallel")] +pub use concrete_csprng::generators::ParallelRandomGenerator as ParallelByteRandomGenerator; +pub use concrete_csprng::generators::RandomGenerator as ByteRandomGenerator; +pub use concrete_csprng::seeders::{Seed, Seeder}; + +/// Module to proxy the serialization for `concrete-csprng::Seed` to avoid adding serde as a +/// dependency to `concrete-csprng` +#[cfg(feature = "__commons_serialization")] +pub mod serialization_proxy { + pub(crate) use concrete_csprng::seeders::Seed; + pub(crate) use serde::{Deserialize, Serialize}; + + // See https://serde.rs/remote-derive.html + // Serde calls this the definition of the remote type. It is just a copy of the remote data + // structure. The `remote` attribute gives the path to the actual type we intend to derive code + // for. This avoids having to introduce serde in concrete-csprng + #[derive(Serialize, Deserialize)] + #[serde(remote = "Seed")] + pub(crate) struct SeedSerdeDef(pub u128); +} + +#[cfg(feature = "__commons_serialization")] +pub(crate) use serialization_proxy::*; + +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub struct CompressionSeed { + #[cfg_attr(feature = "__commons_serialization", serde(with = "SeedSerdeDef"))] + pub seed: Seed, +} + +/// A cryptographically secure random number generator. +/// +/// This csprng is used by every objects that needs sampling in the library. If the proper +/// instructions are available on the machine, it will use an hardware-accelerated variant for +/// the generation. If not, a fallback software version will be used. +/// +/// # Safe multithreaded use +/// +/// When using a csprng in a multithreaded setting, it is important to make sure that the same +/// sequence of bytes is not generated twice on two different threads. This csprng offers a +/// simple way to ensure that: any generator can be _forked_ into several _bounded_ generators, +/// which are able to sample a fixed number of bytes. This forking operation has the effect of +/// shifting the state of the parent generator accordingly. This way, the children generators can be +/// used by the different threads safely: +/// +/// ```rust +/// use concrete_csprng::generators::SoftwareRandomGenerator; +/// use concrete_csprng::seeders::Seed; +/// use tfhe::core_crypto::commons::math::random::RandomGenerator; +/// let mut generator = RandomGenerator::::new(Seed(0)); +/// assert_eq!(generator.remaining_bytes(), None); // The generator is unbounded. +/// let children = generator +/// .try_fork(5, 2) // 5 generators each able to generate 2 bytes. +/// .unwrap() +/// .collect::>(); +/// for child in children.into_iter() { +/// assert_eq!(child.remaining_bytes(), Some(2)); +/// std::thread::spawn(move || { +/// let child = child; +/// // use the prng to generate 2 bytes. +/// // ... +/// }); +/// } +/// // use the parent to generate as many bytes as needed. +/// ``` +pub struct RandomGenerator(G); + +impl RandomGenerator { + pub fn generate_next(&mut self) -> u8 { + self.0.next_byte().unwrap() + } + + /// Generates a new generator, optionally seeding it with the given value. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// ``` + pub fn new(seed: Seed) -> RandomGenerator { + RandomGenerator(G::new(seed)) + } + + /// Returns the number of bytes that can still be generated, if the generator is bounded. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// assert_eq!(generator.remaining_bytes(), None); + /// let mut generator = generator.try_fork(1, 50).unwrap().next().unwrap(); + /// assert_eq!(generator.remaining_bytes(), Some(50)); + /// ``` + pub fn remaining_bytes(&self) -> Option { + >::try_into(self.0.remaining_bytes().0).ok() + } + + /// Tries to fork the current generator into `n_child` generator bounded to `bytes_per_child`. + /// If `n_child*bytes_per_child` exceeds the bound of the current generator, the method + /// returns `None`. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let children = generator.try_fork(5, 50).unwrap().collect::>(); + /// ``` + pub fn try_fork( + &mut self, + n_child: usize, + bytes_per_child: usize, + ) -> Result>, ForkError> { + self.0 + .try_fork(ChildrenCount(n_child), BytesPerChild(bytes_per_child)) + .map(|iter| iter.map(Self)) + } + + /// Generates a random uniform unsigned integer. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// let random = generator.random_uniform::(); + /// ``` + pub fn random_uniform>(&mut self) -> Scalar { + Scalar::generate_one(self, Uniform) + } + + /// Fills an `AsMutTensor` value with random uniform values. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(1000. as u32, 100); + /// generator.fill_tensor_with_random_uniform(&mut tensor); + /// ``` + pub fn fill_tensor_with_random_uniform(&mut self, output: &mut Tensorable) + where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + Scalar::fill_tensor(self, Uniform, output); + } + + /// Generates a tensor of random uniform values of a given size. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let t: Tensor> = generator.random_uniform_tensor(10); + /// assert_eq!(t.len(), 10); + /// let first_val = t.get_element(0); + /// for i in 1..10 { + /// assert_ne!(first_val, t.get_element(i)); + /// } + /// ``` + pub fn random_uniform_tensor>( + &mut self, + size: usize, + ) -> Tensor> { + Scalar::generate_tensor(self, Uniform, size) + } + + /// Generates a random uniform binary value. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let random: u32 = generator.random_uniform_binary(); + /// ``` + pub fn random_uniform_binary>(&mut self) -> Scalar { + Scalar::generate_one(self, UniformBinary) + } + + /// Fills an `AsMutTensor` value with random binary values. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(1u32, 100); + /// generator.fill_tensor_with_random_uniform_binary(&mut tensor); + /// ``` + pub fn fill_tensor_with_random_uniform_binary( + &mut self, + output: &mut Tensorable, + ) where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + Scalar::fill_tensor(self, UniformBinary, output); + } + + /// Generates a tensor of random binary values of a given size. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let t: Tensor> = generator.random_uniform_binary_tensor(10); + /// assert_eq!(t.len(), 10); + /// ``` + pub fn random_uniform_binary_tensor>( + &mut self, + size: usize, + ) -> Tensor> { + Scalar::generate_tensor(self, UniformBinary, size) + } + + /// Generates a random uniform ternary value. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let random: u32 = generator.random_uniform_ternary(); + /// ``` + pub fn random_uniform_ternary>(&mut self) -> Scalar { + Scalar::generate_one(self, UniformTernary) + } + + /// Fills an `AsMutTensor` value with random ternary values. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(1u32, 100); + /// generator.fill_tensor_with_random_uniform_ternary(&mut tensor); + /// ``` + pub fn fill_tensor_with_random_uniform_ternary( + &mut self, + output: &mut Tensorable, + ) where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + Scalar::fill_tensor(self, UniformTernary, output); + } + + /// Generates a tensor of random ternary values of a given size. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let t: Tensor> = generator.random_uniform_ternary_tensor(10); + /// assert_eq!(t.len(), 10); + /// ``` + pub fn random_uniform_ternary_tensor>( + &mut self, + size: usize, + ) -> Tensor> { + Scalar::generate_tensor(self, UniformTernary, size) + } + + /// Generates an unsigned integer whose n least significant bits are uniformly random, and the + /// other bits are zero. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let random: u8 = generator.random_uniform_n_lsb(3); + /// assert!(random <= 7 as u8); + /// ``` + pub fn random_uniform_n_lsb>( + &mut self, + n: usize, + ) -> Scalar { + Scalar::generate_one(self, UniformLsb { n }) + } + + /// Fills an `AsMutTensor` value with random values whose n lsbs are sampled uniformly. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(0 as u8, 100); + /// generator.fill_tensor_with_random_uniform_n_lsb(&mut tensor, 3); + /// ``` + pub fn fill_tensor_with_random_uniform_n_lsb( + &mut self, + output: &mut Tensorable, + n: usize, + ) where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + Scalar::fill_tensor(self, UniformLsb { n }, output); + } + + /// Generates a tensor of random uniform values, whose n lsbs are sampled uniformly. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let t: Tensor> = generator.random_uniform_n_lsb_tensor(10, 55); + /// assert_eq!(t.len(), 10); + /// let first_val = t.get_element(0); + /// for i in 1..10 { + /// assert_ne!(first_val, t.get_element(i)); + /// } + /// ``` + pub fn random_uniform_n_lsb_tensor>( + &mut self, + size: usize, + n: usize, + ) -> Tensor> { + Scalar::generate_tensor(self, UniformLsb { n }, size) + } + + /// Generates an unsigned integer whose n most significant bits are uniformly random, and the + /// other bits are zero. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let random: u8 = generator.random_uniform_n_msb(3); + /// assert!(random == 0 || random >= 32); + /// ``` + pub fn random_uniform_n_msb>( + &mut self, + n: usize, + ) -> Scalar { + Scalar::generate_one(self, UniformMsb { n }) + } + + /// Fills an `AsMutTensor` value with values whose n msbs are random. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(8 as u8, 100); + /// generator.fill_tensor_with_random_uniform_n_msb(&mut tensor, 5); + /// ``` + pub fn fill_tensor_with_random_uniform_n_msb( + &mut self, + output: &mut Tensorable, + n: usize, + ) where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + Scalar::fill_tensor(self, UniformMsb { n }, output) + } + + /// Generates a tensor of random uniform values, whose n msbs are sampled uniformly. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let t: Tensor> = generator.random_uniform_n_msb_tensor(10, 55); + /// assert_eq!(t.len(), 10); + /// let first_val = t.get_element(0); + /// for i in 1..10 { + /// assert_ne!(first_val, t.get_element(i)); + /// } + /// ``` + pub fn random_uniform_n_msb_tensor>( + &mut self, + size: usize, + n: usize, + ) -> Tensor> { + Scalar::generate_tensor(self, UniformMsb { n }, size) + } + + /// Generates a random uniform unsigned integer with probability `1-prob_zero`, and a zero value + /// with probability `prob_zero`. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let random = generator.random_uniform_with_zeros::(0.5); + /// let random = generator.random_uniform_with_zeros::(0.5); + /// let random = generator.random_uniform_with_zeros::(0.5); + /// let random = generator.random_uniform_with_zeros::(0.5); + /// let random = generator.random_uniform_with_zeros::(0.5); + /// assert_eq!(generator.random_uniform_with_zeros::(1.), 0); + /// assert_ne!(generator.random_uniform_with_zeros::(0.), 0); + /// ``` + pub fn random_uniform_with_zeros>( + &mut self, + prob_zero: f32, + ) -> Scalar { + Scalar::generate_one(self, UniformWithZeros { prob_zero }) + } + + /// Fills an `AsMutTensor` value with random values uniform with probability `prob` and zero + /// with probability `1-prob`. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(10 as u8, 100); + /// generator.fill_tensor_with_random_uniform_with_zeros(&mut tensor, 0.5); + /// ``` + pub fn fill_tensor_with_random_uniform_with_zeros( + &mut self, + output: &mut Tensorable, + prob_zero: f32, + ) where + Scalar: RandomGenerable, + Tensorable: AsMutTensor, + { + output.as_mut_tensor().iter_mut().for_each(|s| { + *s = self.random_uniform_with_zeros(prob_zero); + }); + } + + /// Generates a tensor of a given size, whose coefficients are random uniform with probability + /// `1-prob_zero`, and zero with probability `prob_zero`. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let t: Tensor> = generator.random_uniform_with_zeros_tensor(10, 0.); + /// assert_eq!(t.len(), 10); + /// t.iter().for_each(|a| assert_ne!(*a, 0)); + /// let t: Tensor> = generator.random_uniform_with_zeros_tensor(10, 1.); + /// t.iter().for_each(|a| assert_eq!(*a, 0)); + /// ``` + pub fn random_uniform_with_zeros_tensor>( + &mut self, + size: usize, + prob_zero: f32, + ) -> Tensor> { + (0..size) + .map(|_| self.random_uniform_with_zeros(prob_zero)) + .collect() + } + + /// Generates two floating point values drawn from a gaussian distribution with input mean and + /// standard deviation. + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// // for f32 + /// let (g1, g2): (f32, f32) = generator.random_gaussian(0. as f32, 1. as f32); + /// // check that both samples are in 6 sigmas. + /// assert!(g1.abs() <= 6.); + /// assert!(g2.abs() <= 6.); + /// // for f64 + /// let (g1, g2): (f64, f64) = generator.random_gaussian(0. as f64, 1. as f64); + /// // check that both samples are in 6 sigmas. + /// assert!(g1.abs() <= 6.); + /// assert!(g2.abs() <= 6.); + /// ``` + pub fn random_gaussian(&mut self, mean: Float, std: Float) -> (Scalar, Scalar) + where + Float: FloatingPoint, + (Scalar, Scalar): RandomGenerable>, + { + <(Scalar, Scalar)>::generate_one(self, Gaussian { std, mean }) + } + + /// Fills an `AsMutTensor` value with random gaussian values. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let mut tensor = Tensor::allocate(1000. as f32, 100); + /// generator.fill_tensor_with_random_gaussian(&mut tensor, 0., 1.); + /// tensor.iter().for_each(|t| assert_ne!(*t, 1000.)); + /// ``` + pub fn fill_tensor_with_random_gaussian( + &mut self, + output: &mut Tensorable, + mean: Float, + std: Float, + ) where + Float: FloatingPoint, + (Scalar, Scalar): RandomGenerable>, + Tensorable: AsMutTensor, + { + output + .as_mut_tensor() + .as_mut_slice() + .chunks_mut(2) + .for_each(|s| { + let (g1, g2) = <(Scalar, Scalar)>::generate_one(self, Gaussian { std, mean }); + if let Some(elem) = s.get_mut(0) { + *elem = g1; + } + if let Some(elem) = s.get_mut(1) { + *elem = g2; + } + }); + } + + /// Generates a new tensor of floating point values, randomly sampled from a gaussian + /// distribution: + /// + /// # Example + /// + /// ```rust + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let tensor: Tensor> = generator.random_gaussian_tensor(10_000, 0. as f32, 1. as f32); + /// assert_eq!(tensor.len(), 10_000); + /// tensor.iter().for_each(|a| assert!((*a).abs() <= 6.)); + /// ``` + pub fn random_gaussian_tensor( + &mut self, + size: usize, + mean: Float, + std: Float, + ) -> Tensor> + where + Float: FloatingPoint, + (Scalar, Scalar): RandomGenerable>, + Scalar: Numeric, + { + let mut tensor = Tensor::allocate(Scalar::ZERO, size); + self.fill_tensor_with_random_gaussian(&mut tensor, mean, std); + tensor + } +} + +#[cfg(feature = "__commons_parallel")] +impl RandomGenerator { + /// Tries to fork the current generator into `n_child` generator bounded to `bytes_per_child`, + /// as a parallel iterator. + /// + /// If `n_child*bytes_per_child` exceeds the bound of the current generator, the method + /// returns `None`. + /// + /// # Notes + /// + /// This method necessitates the "__commons_parallel" feature to be used. + /// + /// # Example + /// + /// ``` + /// use concrete_csprng::generators::SoftwareRandomGenerator; + /// use concrete_csprng::seeders::Seed; + /// use tfhe::core_crypto::commons::math::random::RandomGenerator; + /// let mut generator = RandomGenerator::::new(Seed(0)); + /// let children = generator.try_fork(5, 50).unwrap().collect::>(); + /// ``` + pub fn par_try_fork( + &mut self, + n_child: usize, + bytes_per_child: usize, + ) -> Result>, ForkError> { + self.0 + .par_try_fork(ChildrenCount(n_child), BytesPerChild(bytes_per_child)) + .map(|iter| iter.map(Self)) + } +} diff --git a/tfhe/src/core_crypto/commons/math/random/mod.rs b/tfhe/src/core_crypto/commons/math/random/mod.rs new file mode 100644 index 000000000..90839664b --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/mod.rs @@ -0,0 +1,90 @@ +//! A module containing random sampling functions. +//! +//! This module contains a [`RandomGenerator`] type, which exposes methods to sample numeric values +//! randomly according to a given distribution, for instance: +//! +//! + [`RandomGenerator::random_uniform`] samples a random unsigned integer with uniform +//! probability over the set of representable values. +//! + [`RandomGenerator::random_gaussian`] samples a random float with using a gaussian +//! distribution. +//! +//! The implementation relies on the [`RandomGenerable`] trait, which gives a type the ability to +//! be randomly generated according to a given distribution. The module contains multiple +//! implementations of this trait, for different distributions. Note, though, that instead of +//! using the [`RandomGenerable`] methods, you should use the various methods exposed by +//! [`RandomGenerator`] instead. +use crate::core_crypto::commons::math::tensor::{AsMutTensor, Tensor}; +use crate::core_crypto::commons::numeric::FloatingPoint; + +pub use gaussian::*; +pub use generator::*; +pub use uniform::*; +pub use uniform_binary::*; +pub use uniform_lsb::*; +pub use uniform_msb::*; +pub use uniform_ternary::*; +pub use uniform_with_zeros::*; + +#[cfg(test)] +mod tests; + +mod gaussian; +mod generator; +mod uniform; +mod uniform_binary; +mod uniform_lsb; +mod uniform_msb; +mod uniform_ternary; +mod uniform_with_zeros; + +pub trait RandomGenerable +where + Self: Sized, +{ + fn generate_one( + generator: &mut RandomGenerator, + distribution: D, + ) -> Self; + fn generate_tensor( + generator: &mut RandomGenerator, + distribution: D, + size: usize, + ) -> Tensor> { + (0..size) + .map(|_| Self::generate_one(generator, distribution)) + .collect() + } + fn fill_tensor( + generator: &mut RandomGenerator, + distribution: D, + tensor: &mut Tens, + ) where + Tens: AsMutTensor, + { + tensor.as_mut_tensor().iter_mut().for_each(|s| { + *s = Self::generate_one(generator, distribution); + }); + } +} + +/// A marker trait for types representing distributions. +pub trait Distribution: seal::Sealed + Copy {} +mod seal { + use crate::core_crypto::commons::numeric::FloatingPoint; + + pub trait Sealed {} + impl Sealed for super::Uniform {} + impl Sealed for super::UniformMsb {} + impl Sealed for super::UniformLsb {} + impl Sealed for super::UniformWithZeros {} + impl Sealed for super::UniformBinary {} + impl Sealed for super::UniformTernary {} + impl Sealed for super::Gaussian {} +} +impl Distribution for Uniform {} +impl Distribution for UniformMsb {} +impl Distribution for UniformLsb {} +impl Distribution for UniformWithZeros {} +impl Distribution for UniformBinary {} +impl Distribution for UniformTernary {} +impl Distribution for Gaussian {} diff --git a/tfhe/src/core_crypto/commons/math/random/tests.rs b/tfhe/src/core_crypto/commons/math/random/tests.rs new file mode 100644 index 000000000..9ae7b4a1a --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/tests.rs @@ -0,0 +1,86 @@ +use crate::core_crypto::prelude::LogStandardDev; + +use crate::core_crypto::commons::math::tensor::Tensor; +use crate::core_crypto::commons::math::torus::UnsignedTorus; +use crate::core_crypto::commons::test_tools::*; + +fn test_normal_random() { + //! test if the normal random generation with std_dev is below 3*std_dev (99.7%) + + // settings + let std_dev: f64 = f64::powi(2., -20); + let mean: f64 = 0.; + let k = 1_000_000; + let mut generator = new_random_generator(); + + // generates normal random + let mut samples_int = Tensor::allocate(T::ZERO, k); + generator.fill_tensor_with_random_gaussian(&mut samples_int, mean, std_dev); + + // converts into float + let mut samples_float = Tensor::allocate(0f64, k); + samples_float.fill_with_one(&samples_int, |a| a.into_torus()); + for x in samples_float.iter_mut() { + if *x > 0.5 { + *x = 1. - *x; + } + } + + // tests if over 3*std_dev + let mut number_of_samples_outside_confidence_interval: usize = 0; + for s in samples_float.iter() { + if *s > 3. * std_dev || *s < -3. * std_dev { + number_of_samples_outside_confidence_interval += 1; + } + } + + // computes the percentage of samples over 3*std_dev + let proportion_of_samples_outside_confidence_interval: f64 = + (number_of_samples_outside_confidence_interval as f64) / (k as f64); + + // test + assert!( + proportion_of_samples_outside_confidence_interval < 0.003, + "test normal random : proportion = {} ; n = {}", + proportion_of_samples_outside_confidence_interval, + number_of_samples_outside_confidence_interval + ); +} + +#[test] +fn test_normal_random_u32() { + test_normal_random::(); +} + +#[test] +fn test_normal_random_u64() { + test_normal_random::(); +} + +fn test_distribution() { + //! tests gaussianity against the rand crate generation + // settings + let std_dev: f64 = f64::powi(2., -5); + let mean: f64 = 0.; + let k = 10_000_000; + let mut generator = new_random_generator(); + + // generates normal random + let first = Tensor::allocate(T::ZERO, k); + let mut second = Tensor::allocate(T::ZERO, k); + generator.fill_tensor_with_random_gaussian(&mut second, mean, std_dev); + + assert_noise_distribution(&first, &second, LogStandardDev(-5.)); +} + +// // These tests are notoriously flaky + +// #[test] +// fn test_distribution_u32() { +// test_distribution::(); +// } + +// #[test] +// fn test_distribution_u64() { +// test_distribution::(); +// } diff --git a/tfhe/src/core_crypto/commons/math/random/uniform.rs b/tfhe/src/core_crypto/commons/math/random/uniform.rs new file mode 100644 index 000000000..d965d88e1 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/uniform.rs @@ -0,0 +1,35 @@ +use super::*; + +/// A distribution type representing uniform sampling for unsigned integer types. The value is +/// uniformly sampled in `[0, 2^n[` where `n` is the size of the integer type. +#[derive(Copy, Clone)] +pub struct Uniform; + +macro_rules! implement_uniform { + ($T:ty) => { + impl RandomGenerable for $T { + #[allow(unused)] + fn generate_one( + generator: &mut RandomGenerator, + distribution: Uniform, + ) -> Self { + let mut buf = [0; std::mem::size_of::<$T>()]; + buf.iter_mut().for_each(|a| *a = generator.generate_next()); + // We use from_le_bytes as most platforms are low endian, this avoids endianness + // issues + <$T>::from_le_bytes(buf) + } + } + }; +} + +implement_uniform!(u8); +implement_uniform!(u16); +implement_uniform!(u32); +implement_uniform!(u64); +implement_uniform!(u128); +implement_uniform!(i8); +implement_uniform!(i16); +implement_uniform!(i32); +implement_uniform!(i64); +implement_uniform!(i128); diff --git a/tfhe/src/core_crypto/commons/math/random/uniform_binary.rs b/tfhe/src/core_crypto/commons/math/random/uniform_binary.rs new file mode 100644 index 000000000..0cdf93a28 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/uniform_binary.rs @@ -0,0 +1,28 @@ +use super::*; + +/// A distribution type representing uniform sampling for binary type. +#[derive(Clone, Copy)] +pub struct UniformBinary; + +macro_rules! implement_uniform_binary { + ($T:ty) => { + impl RandomGenerable for $T { + #[allow(unused)] + fn generate_one( + generator: &mut RandomGenerator, + distribution: UniformBinary, + ) -> Self { + if generator.generate_next() & 1 == 1 { + 1 + } else { + 0 + } + } + } + }; +} + +implement_uniform_binary!(u8); +implement_uniform_binary!(u16); +implement_uniform_binary!(u32); +implement_uniform_binary!(u64); diff --git a/tfhe/src/core_crypto/commons/math/random/uniform_lsb.rs b/tfhe/src/core_crypto/commons/math/random/uniform_lsb.rs new file mode 100644 index 000000000..70466fa82 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/uniform_lsb.rs @@ -0,0 +1,28 @@ +use super::*; +use crate::core_crypto::commons::numeric::Numeric; +/// A distribution type representing random sampling for unsigned integer type, where the `n` +/// least significant bits are sampled in `[0, 2^n[`. +#[derive(Copy, Clone)] +pub struct UniformLsb { + /// The number of least significant bits that should be set randomly. + pub n: usize, +} + +macro_rules! implement_uniform_some_lsb { + ($T:ty) => { + impl RandomGenerable for $T { + fn generate_one( + generator: &mut RandomGenerator, + UniformLsb { n }: UniformLsb, + ) -> Self { + <$T>::generate_one(generator, Uniform) >> (<$T as Numeric>::BITS - n) + } + } + }; +} + +implement_uniform_some_lsb!(u8); +implement_uniform_some_lsb!(u16); +implement_uniform_some_lsb!(u32); +implement_uniform_some_lsb!(u64); +implement_uniform_some_lsb!(u128); diff --git a/tfhe/src/core_crypto/commons/math/random/uniform_msb.rs b/tfhe/src/core_crypto/commons/math/random/uniform_msb.rs new file mode 100644 index 000000000..8552e4933 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/uniform_msb.rs @@ -0,0 +1,29 @@ +use super::*; +use crate::core_crypto::commons::numeric::Numeric; + +/// A distribution type representing random sampling for unsigned integer types, where the `n` +/// most significant bits are sampled randomly in `[0, 2^n[`. +#[derive(Copy, Clone)] +pub struct UniformMsb { + /// The number of most significant bits that must be randomly set. + pub n: usize, +} + +macro_rules! implement_uniform_some_msb { + ($T:ty) => { + impl RandomGenerable for $T { + fn generate_one( + generator: &mut RandomGenerator, + UniformMsb { n }: UniformMsb, + ) -> Self { + <$T>::generate_one(generator, Uniform) << (<$T as Numeric>::BITS - n) + } + } + }; +} + +implement_uniform_some_msb!(u8); +implement_uniform_some_msb!(u16); +implement_uniform_some_msb!(u32); +implement_uniform_some_msb!(u64); +implement_uniform_some_msb!(u128); diff --git a/tfhe/src/core_crypto/commons/math/random/uniform_ternary.rs b/tfhe/src/core_crypto/commons/math/random/uniform_ternary.rs new file mode 100644 index 000000000..61246ff40 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/uniform_ternary.rs @@ -0,0 +1,31 @@ +use super::*; + +/// A distribution type representing uniform sampling for ternary type. +#[derive(Clone, Copy)] +pub struct UniformTernary; + +macro_rules! implement_uniform_ternary { + ($T:ty) => { + impl RandomGenerable for $T { + #[allow(unused)] + fn generate_one( + generator: &mut RandomGenerator, + distribution: UniformTernary, + ) -> Self { + loop { + match generator.generate_next() & 3 { + 0 => return 0, + 1 => return 1, + 2 => return (0 as $T).wrapping_sub(1), + _ => {} + } + } + } + } + }; +} + +implement_uniform_ternary!(u8); +implement_uniform_ternary!(u16); +implement_uniform_ternary!(u32); +implement_uniform_ternary!(u64); diff --git a/tfhe/src/core_crypto/commons/math/random/uniform_with_zeros.rs b/tfhe/src/core_crypto/commons/math/random/uniform_with_zeros.rs new file mode 100644 index 000000000..9b4691ec9 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/random/uniform_with_zeros.rs @@ -0,0 +1,42 @@ +use super::*; +use crate::core_crypto::commons::numeric::Numeric; + +/// A distribution type that samples a uniform value with probability `1 - prob_zero`, and a zero +/// value with probaibility `prob_zero`. +#[derive(Copy, Clone)] +pub struct UniformWithZeros { + /// The probability of the output being a zero + pub prob_zero: f32, +} + +#[allow(unused_macros)] +macro_rules! implement_uniform_with_zeros { + ($T:ty, $bits:literal) => { + impl RandomGenerable for $T { + #[allow(unused)] + fn generate_one( + generator: &mut RandomGenerator, + UniformWithZeros { prob_zero }: UniformWithZeros, + ) -> Self { + let uniform_u32: u32 = u32::generate_one(generator, Uniform); + let float_sample = uniform_u32 as f32 / u32::MAX as f32; + if float_sample < prob_zero { + <$T>::ZERO + } else { + Self::generate_one(generator, Uniform) + } + } + } + }; +} + +implement_uniform_with_zeros!(u8, 1); +implement_uniform_with_zeros!(u16, 2); +implement_uniform_with_zeros!(u32, 4); +implement_uniform_with_zeros!(u64, 8); +implement_uniform_with_zeros!(u128, 16); +implement_uniform_with_zeros!(i8, 1); +implement_uniform_with_zeros!(i16, 2); +implement_uniform_with_zeros!(i32, 4); +implement_uniform_with_zeros!(i64, 8); +implement_uniform_with_zeros!(i128, 16); diff --git a/tfhe/src/core_crypto/commons/math/tensor/as_element.rs b/tfhe/src/core_crypto/commons/math/tensor/as_element.rs new file mode 100644 index 000000000..dd8bceef8 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/as_element.rs @@ -0,0 +1,15 @@ +/// A trait allowing to treat a value as a reference to an alement of a different type. +pub trait AsRefElement { + /// The element type. + type Element; + /// Returns a reference to the element enclosed in the type. + fn as_element(&self) -> &Self::Element; +} + +/// A trait allowing to treat a value as a mutable reference to an element of a different type. +pub trait AsMutElement: AsRefElement::Element> { + /// The element type. + type Element; + /// Returns a mutable reference to the element enclosed in the type. + fn as_mut_element(&mut self) -> &mut ::Element; +} diff --git a/tfhe/src/core_crypto/commons/math/tensor/as_slice.rs b/tfhe/src/core_crypto/commons/math/tensor/as_slice.rs new file mode 100644 index 000000000..236efdc96 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/as_slice.rs @@ -0,0 +1,115 @@ +/// A trait allowing to extract a slice from a tensor. +/// +/// This trait is one of the two traits which allows to use [`Tensor`](super::Tensor) whith any data +/// container. +/// This trait basically allows to extract a slice `&[T]` out of a container, to implement +/// operations directly on the slice: +/// ```rust,ignore +/// // Implementing AsSlice for Vec +/// impl AsSlice for Vec { +/// type Element = Element; +/// fn as_slice(&self) -> &[Element] { +/// self.as_slice() +/// } +/// } +/// ``` +/// It is akin to the [`std::borrow::Borrow`] trait from the standard library, but it is local to +/// this crate (which makes this little explanation possible), and the scalar type is an associated +/// type. Having an associated type instead of a generic type, tends to make signatures a little +/// leaner. +/// +/// Finally, you should note that we have a blanket implementation that implements `AsView` for +/// `Tensor` where `Cont` is itself `AsView`: +/// ```rust,ignore +/// impl AsView for Tensor +/// where +/// Cont: AsView, +/// { +/// type Scalar = Cont::Scalar; +/// fn as_view(&self) -> &[Self::Scalar] { +/// // implementation +/// } +/// } +/// ``` +/// This is blanket implementation is used by the methods of the `Tensor` structure for instance. +pub trait AsRefSlice { + /// The type of the elements of the collection. + type Element; + /// Returns a slice from the container. + fn as_slice(&self) -> &[Self::Element]; +} + +impl AsRefSlice for Vec { + type Element = Element; + fn as_slice(&self) -> &[Element] { + self.as_slice() + } +} + +impl AsRefSlice for aligned_vec::AVec { + type Element = Element; + fn as_slice(&self) -> &[Element] { + self.as_slice() + } +} + +impl AsRefSlice for [Element; 1] { + type Element = Element; + fn as_slice(&self) -> &[Element] { + &self[..] + } +} + +impl AsRefSlice for &[Element] { + type Element = Element; + fn as_slice(&self) -> &[Element] { + self + } +} + +impl AsRefSlice for &mut [Element] { + type Element = Element; + fn as_slice(&self) -> &[Element] { + self + } +} + +/// A trait allowing to extract a mutable slice from a tensor. +/// +/// The logic is the same as for the `AsRefTensor`, but here, it allows to access mutable slices +/// instead. See the [`AsRefTensor`](super::AsRefTensor) documentation for a more detailed +/// explanation of the logic. +pub trait AsMutSlice: AsRefSlice::Element> { + /// The type of the elements of the collection + type Element; + /// Returns a mutable slice from the container. + fn as_mut_slice(&mut self) -> &mut [::Element]; +} + +impl AsMutSlice for Vec { + type Element = Element; + fn as_mut_slice(&mut self) -> &mut [Element] { + self.as_mut_slice() + } +} + +impl AsMutSlice for aligned_vec::AVec { + type Element = Element; + fn as_mut_slice(&mut self) -> &mut [Element] { + self.as_mut_slice() + } +} + +impl AsMutSlice for [Element; 1] { + type Element = Element; + fn as_mut_slice(&mut self) -> &mut [Element] { + &mut self[..] + } +} + +impl AsMutSlice for &mut [Element] { + type Element = Element; + fn as_mut_slice(&mut self) -> &mut [Element] { + self + } +} diff --git a/tfhe/src/core_crypto/commons/math/tensor/as_tensor.rs b/tfhe/src/core_crypto/commons/math/tensor/as_tensor.rs new file mode 100644 index 000000000..e276a1046 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/as_tensor.rs @@ -0,0 +1,68 @@ +use crate::core_crypto::commons::math::tensor::Tensor; + +use super::{AsMutSlice, AsRefSlice}; + +/// A trait for [`Tensor`]-based types, allowing to borrow the enclosed tensor. +/// +/// This trait is used by the types that build on the `Tensor` type to implement multi-dimensional +/// collections of various kind. In essence, this trait allows to extract a tensor properly +/// qualified to use all the methods of the `Tensor` type: +/// ```rust +/// use tfhe::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor, Tensor}; +/// +/// pub struct Matrix { +/// tensor: Tensor, +/// row_length: usize, +/// } +/// +/// pub struct Row { +/// tensor: Tensor, +/// } +/// +/// impl AsRefTensor for Matrix +/// where +/// Cont: AsRefSlice, +/// { +/// type Element = Cont::Element; +/// type Container = Cont; +/// fn as_tensor(&self) -> &Tensor { +/// &self.tensor +/// } +/// } +/// +/// impl Matrix { +/// // Returns an iterator over the matrix rows. +/// pub fn row_iter(&self) -> impl Iterator::Element]>> +/// where +/// Self: AsRefTensor, +/// { +/// self.as_tensor() // `AsRefTensor` method returning a `&Tensor` +/// .as_slice() // Since `Cont` is `AsView`, so is `Tensor` +/// .chunks(self.row_length) // Split in chunks of the size of the rows. +/// .map(|sub| Row { +/// tensor: Tensor::from_container(sub), +/// }) // Wraps into a row type. +/// } +/// } +/// ``` +pub trait AsRefTensor { + /// The element type. + type Element; + /// The container used by the tensor. + type Container: AsRefSlice::Element>; + /// Returns a reference to the enclosed tensor. + fn as_tensor(&self) -> &Tensor; +} + +/// A trait for [`Tensor`]-based types, allowing to mutably borrow the enclosed tensor. +/// +/// This trait implements the same logic as `AsRefTensor`, but for mutable borrow instead. See the +/// [`AsRefTensor`] documentation for more explanations on the logic. +pub trait AsMutTensor: AsRefTensor::Element> { + /// The element type. + type Element; + /// The container used by the tensor. + type Container: AsMutSlice::Element>; + /// Returns a mutable reference to the enclosed tensor. + fn as_mut_tensor(&mut self) -> &mut Tensor<::Container>; +} diff --git a/tfhe/src/core_crypto/commons/math/tensor/into_tensor.rs b/tfhe/src/core_crypto/commons/math/tensor/into_tensor.rs new file mode 100644 index 000000000..2ac519a89 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/into_tensor.rs @@ -0,0 +1,16 @@ +use crate::core_crypto::commons::math::tensor::Tensor; + +use super::AsRefSlice; + +/// This trait allows to extract the tensor of a [`Tensor`]-based type. +/// +/// This trait allows to consume a value, and extracts the tensor that was wrapped inside, to +/// return it to the caller. +pub trait IntoTensor { + /// The element type of the collection container. + type Element; + /// The type of the collection container. + type Container: AsRefSlice::Element>; + /// Consumes `self` and returns an owned tensor. + fn into_tensor(self) -> Tensor; +} diff --git a/tfhe/src/core_crypto/commons/math/tensor/mod.rs b/tfhe/src/core_crypto/commons/math/tensor/mod.rs new file mode 100644 index 000000000..90f205322 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/mod.rs @@ -0,0 +1,283 @@ +//! Operations on collections of values. +//! +//! This module contains a [`Tensor`] type, central to the whole library. In essence, a tensor +//! wraps a data container, and provides a set of methods to operate with other tensors of +//! the same length: +//! ``` +//! use tfhe::core_crypto::commons::math::tensor::Tensor; +//! // We allocate two tensors of size 10 +//! let mut tensor1 = Tensor::allocate(5u32, 10); +//! let tensor2 = Tensor::allocate(3u32, 10); +//! // We update the values of `tensor1` inplace, by adding it the values of `tensor2`; +//! tensor1.update_with_wrapping_add(&tensor2); +//! ``` +//! +//! The first interest of this type is that it can be backed by several collection containers, +//! such as `Vec`, `&mut [T]` or `&[T]`. Operations can homogeneously be applied to tensors +//! backed by different containers: +//! ``` +//! use tfhe::core_crypto::commons::math::tensor::Tensor; +//! // `allocate` returns a tensor backed by a `Vec` +//! let tensor1: Tensor> = Tensor::allocate(5, 100); +//! // `from_cont` allows you to create a tensor from any container +//! let mut distant_container = vec![4 as u32; 100]; +//! let mut tensor2: Tensor<&mut [u32]> = Tensor::from_container(distant_container.as_mut_slice()); +//! // We update the values of `distant_container` via `tensor2` +//! tensor2.update_with_wrapping_add(&tensor1); +//! ``` +//! +//! It is important to note that the `Tensor` type we have here, is *not* an n-dimmensional array, +//! as is common in scientific computating libraries. It is indexed by a single integral value, +//! and only operations with tensors of the same *length* are authorized. +//! +//! Despite this apparent limitation, `Tensor` are used throughout this library as the backbone of +//! several structures representing multi-dimensional collections. The pattern we use for such +//! structures is pretty simple: +//! ``` +//! use tfhe::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor, Tensor}; +//! +//! // We want to have a matrix structure stored row-major. +//! pub struct Matrix { +//! tensor: Tensor, +//! row_length: usize, +//! } +//! +//! // Our matrix is row-major, so we must be able to iterate over rows. +//! pub struct Row { +//! tensor: Tensor, +//! } +//! +//! impl Matrix { +//! // Returns an iterator over the matrix rows. +//! pub fn row_iter(&self) -> impl Iterator::Element]>> +//! where +//! Self: AsRefTensor, +//! { +//! self.as_tensor() +//! .as_slice() +//! .chunks(self.row_length) +//! .map(|sub| Row { +//! tensor: Tensor::from_container(sub), +//! }) +//! } +//! } +//! ``` +//! +//! You can combine such structures to implement n-dimensional arrays of any size. This approach +//! has the benefit of making the orderinng of the element explicit, and to provide a specific type +//! and a specific set of operations for the different dimensions of your array. This prevents you +//! from shooting yourself in the foot when messing with your data layout, writing new code, or +//! refactoring. + +// This macro implements various traits for a tensor-based object. To work properly, the object in +// question must be a structure with a `tensor` field. +macro_rules! tensor_traits { + ($Type:ident) => { + impl $crate::core_crypto::commons::math::tensor::AsRefTensor for $Type + where + Cont: $crate::core_crypto::commons::math::tensor::AsRefSlice, + { + type Element = Element; + type Container = Cont; + fn as_tensor(&self) -> &Tensor { + &self.tensor + } + } + + impl $crate::core_crypto::commons::math::tensor::AsMutTensor for $Type + where + Cont: $crate::core_crypto::commons::math::tensor::AsMutSlice, + { + type Element = Element; + type Container = Cont; + fn as_mut_tensor( + &mut self, + ) -> &mut Tensor< + ::Container, + > { + &mut self.tensor + } + } + + impl $crate::core_crypto::commons::math::tensor::IntoTensor for $Type + where + Cont: $crate::core_crypto::commons::math::tensor::AsRefSlice, + { + type Element = + ::Element; + type Container = Cont; + fn into_tensor(self) -> Tensor { + self.tensor + } + } + }; +} +pub(crate) use tensor_traits; + +macro_rules! current_func_path { + () => {{ + fn name(_any: T) -> &'static str { + std::any::type_name::() + } + fn t() {} + let output = name(t); + &output[..output.len() - 3] + }}; +} +pub(crate) use current_func_path; + +macro_rules! ck_dim_eq { + ($tensor_size: expr => $($size: expr),* ) => { + let func = $crate::core_crypto::commons::math::tensor::current_func_path!(); + $( + + debug_assert!( + $tensor_size == $size, + "Called operation {} on tensors of incompatible size. {} (={:?}) does not equals \ + {} (={:?}).", + func, + stringify!($size), + $size, + stringify!($tensor_size), + $tensor_size + ); + )* + }; +} +pub(crate) use ck_dim_eq; + +macro_rules! ck_dim_div { + ($tensor_size: expr => $($size: expr),* ) => { + $( + let func = $crate::core_crypto::commons::math::tensor::current_func_path!(); + debug_assert!( + $tensor_size % $size == 0, + "Called operation {} on tensors of incompatible size. {} (={:?}) does not divide \ + {} (={:?})", + func, + stringify!($size), + $size, + stringify!($tensor_size), + $tensor_size + ); + )* + }; +} +pub(crate) use ck_dim_div; + +#[cfg(test)] +mod tests; + +#[allow(clippy::module_inception)] +mod tensor; +pub use tensor::*; + +mod as_slice; +pub use as_slice::*; + +mod as_element; +pub use as_element::*; + +mod as_tensor; +pub use as_tensor::*; + +mod into_tensor; +pub use into_tensor::*; + +pub trait Container: AsRef<[Self::Element]> { + type Element; + + fn container_len(&self) -> usize { + self.as_ref().len() + } +} + +pub trait ContainerOwned: Container + AsMut<[Self::Element]> { + fn collect>(iter: I) -> Self; +} + +impl Container for aligned_vec::ABox<[T]> { + type Element = T; +} + +impl Container for Box<[T]> { + type Element = T; +} + +impl Container for aligned_vec::AVec { + type Element = T; +} + +impl Container for Vec { + type Element = T; +} + +impl ContainerOwned for aligned_vec::ABox<[T]> { + fn collect>(iter: I) -> Self { + aligned_vec::AVec::::from_iter(0, iter).into_boxed_slice() + } +} + +impl<'a, T> Container for &'a [T] { + type Element = T; +} + +impl<'a, T> Container for &'a mut [T] { + type Element = T; +} + +pub trait Split: Sized { + type Chunks: DoubleEndedIterator + ExactSizeIterator; + + fn into_chunks(self, chunk_size: usize) -> Self::Chunks; + fn split_into(self, chunk_count: usize) -> Self::Chunks; + fn split_at(self, mid: usize) -> (Self, Self); +} + +impl<'a, T> Split for &'a [T] { + type Chunks = core::slice::ChunksExact<'a, T>; + + #[inline] + fn into_chunks(self, chunk_size: usize) -> Self::Chunks { + debug_assert_eq!(self.len() % chunk_size, 0); + self.chunks_exact(chunk_size) + } + #[inline] + fn split_into(self, chunk_count: usize) -> Self::Chunks { + if chunk_count == 0 { + debug_assert_eq!(self.len(), 0); + self.chunks_exact(1) + } else { + debug_assert_eq!(self.len() % chunk_count, 0); + self.chunks_exact(self.len() / chunk_count) + } + } + #[inline] + fn split_at(self, mid: usize) -> (Self, Self) { + self.split_at(mid) + } +} + +impl<'a, T> Split for &'a mut [T] { + type Chunks = core::slice::ChunksExactMut<'a, T>; + + #[inline] + fn into_chunks(self, chunk_size: usize) -> Self::Chunks { + debug_assert_eq!(self.len() % chunk_size, 0); + self.chunks_exact_mut(chunk_size) + } + #[inline] + fn split_into(self, chunk_count: usize) -> Self::Chunks { + if chunk_count == 0 { + debug_assert_eq!(self.len(), 0); + self.chunks_exact_mut(1) + } else { + debug_assert_eq!(self.len() % chunk_count, 0); + self.chunks_exact_mut(self.len() / chunk_count) + } + } + #[inline] + fn split_at(self, mid: usize) -> (Self, Self) { + self.split_at_mut(mid) + } +} diff --git a/tfhe/src/core_crypto/commons/math/tensor/tensor.rs b/tfhe/src/core_crypto/commons/math/tensor/tensor.rs new file mode 100644 index 000000000..bfa91352a --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/tensor.rs @@ -0,0 +1,1364 @@ +use std::iter::FromIterator; +use std::ops::{ + BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Shl, ShlAssign, Shr, ShrAssign, +}; +use std::slice::SliceIndex; + +#[cfg(feature = "__commons_parallel")] +use rayon::{iter::IndexedParallelIterator, prelude::*}; +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +use crate::core_crypto::commons::numeric::{CastFrom, UnsignedInteger}; + +use crate::core_crypto::commons::utils::zip; + +use super::{AsMutSlice, AsMutTensor, AsRefSlice, AsRefTensor}; + +/// A generic type to perform operations on collections of scalar values. +/// +/// See the [module-level](`super`) documentation for more explanations on the logic of this type. +/// +/// # Naming convention +/// +/// The methods that may mutate the values of a `Tensor`, follow a convention: +/// +/// + Methods prefixed with `update_with` use the current values of `self` when performing the +/// operation. +/// + Methods prefixed with `fill_with` discard the current vales of `self`, and overwrite it with +/// the result of an operation on other values. +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +#[repr(transparent)] +pub struct Tensor(Container); + +impl Tensor> { + /// Allocates a new `Tensor>` whose values are all `value`. + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// assert_eq!(tensor.len(), 1000); + /// assert_eq!(*tensor.get_element(0), 9); + /// assert_eq!(*tensor.get_element(1), 9); + /// ``` + pub fn allocate(value: Element, size: usize) -> Self + where + Element: Copy, + { + Tensor(vec![value; size]) + } +} + +macro_rules! fill_with { + ($Trait:ident, $name: ident, $($func:tt)*) => { + pub fn $name( + &mut self, + lhs: &Tensor, + rhs: &Tensor, + ) where + Tensor: AsRefSlice, + Tensor: AsRefSlice, + Self: AsMutSlice, + as AsRefSlice>::Element: $Trait< as AsRefSlice>::Element, + Output=::Element>, + as AsRefSlice>::Element: Copy, + as AsRefSlice>::Element: Copy + { + ck_dim_eq!(self.len() => lhs.len()); + ck_dim_eq!(self.len() => rhs.len()); + Tensor::fill_with_two(self, lhs, rhs, $($func)*); + } + }; +} + +macro_rules! fill_with_wrapping { + ($name: ident, $($func:tt)*) => { + pub fn $name( + &mut self, + lhs: &Tensor, + rhs: &Tensor, + ) where + Tensor: AsRefSlice, + Tensor: AsRefSlice, + Self: AsMutSlice, + Element: UnsignedInteger + { + ck_dim_eq!(self.len() => lhs.len()); + ck_dim_eq!(self.len() => rhs.len()); + Tensor::fill_with_two(self, lhs, rhs, $($func)*); + } + }; +} + +macro_rules! update_with { + ($Trait:ident, $name: ident, $($func:tt)*) => { + pub fn $name( + &mut self, + other: &Tensor, + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + ::Element: $Trait< as AsRefSlice>::Element>, + as AsRefSlice>::Element: Copy + { + ck_dim_eq!(self.len() => other.len()); + self.update_with_one(other, $($func)*); + } + }; +} + +macro_rules! update_with_wrapping { + ($name: ident, $($func:tt)*) => { + pub fn $name( + &mut self, + other: &Tensor, + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + Element: UnsignedInteger + { + ck_dim_eq!(self.len() => other.len()); + self.update_with_one(other, $($func)*); + } + }; +} + +macro_rules! update_with_scalar { + ($Trait:ident, $name: ident, $($func:tt)*) => { + pub fn $name( + &mut self, + element: &Element, + ) where + Self: AsMutSlice, + ::Element: $Trait, + Element: Copy + { + self.update_with_element(element, $($func)*); + } + }; +} + +macro_rules! update_with_wrapping_scalar { + ($name: ident, $($func:tt)*) => { + pub fn $name( + &mut self, + element: &Element, + ) where + Self: AsMutSlice, + Element: UnsignedInteger + { + self.update_with_element(element, $($func)*); + } + }; +} + +impl Tensor { + /// Creates a new tensor from a container. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let vec = vec![9 as u8; 1000]; + /// let view = vec.as_slice(); + /// let tensor = Tensor::from_container(view); + /// assert_eq!(tensor.len(), 1000); + /// assert_eq!(*tensor.get_element(0), 9); + /// assert_eq!(*tensor.get_element(1), 9); + /// ``` + pub fn from_container(cont: Container) -> Self { + Tensor(cont) + } + + /// Consumes a tensor and returns its container. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// let vec = tensor.into_container(); + /// assert_eq!(vec.len(), 1000); + /// assert_eq!(vec[0], 9); + /// assert_eq!(vec[1], 9); + /// ``` + pub fn into_container(self) -> Container { + self.0 + } + + /// Returns a reference to the tensor container. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// let vecref: &Vec<_> = tensor.as_container(); + /// ``` + pub fn as_container(&self) -> &Container { + &self.0 + } + + /// Returns a mutable reference to the tensor container. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// let vecmut: &mut Vec<_> = tensor.as_mut_container(); + /// ``` + pub fn as_mut_container(&mut self) -> &mut Container { + &mut self.0 + } + + /// Returns the length of the tensor. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// assert_eq!(tensor.len(), 1000); + /// ``` + pub fn len(&self) -> usize + where + Self: AsRefSlice, + { + self.as_slice().len() + } + + /// Returns whether the tensor is empty. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// assert_eq!(tensor.is_empty(), false); + /// ``` + pub fn is_empty(&self) -> bool + where + Self: AsRefSlice, + { + self.as_slice().len() == 0 + } + + /// Returns an iterator over `&Scalar` elements. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// for scalar in tensor.iter() { + /// assert_eq!(*scalar, 9); + /// } + /// ``` + pub fn iter( + &self, + ) -> impl DoubleEndedIterator::Element> + ExactSizeIterator + where + Self: AsRefSlice, + { + self.as_slice().iter() + } + + /// Returns a parallel iterator over `&Scalar` elements. + /// + /// # Notes: + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.par_iter().for_each(|scalar| { + /// assert_eq!(*scalar, 9); + /// }); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_iter(&self) -> impl IndexedParallelIterator::Element> + where + Self: AsRefSlice, + ::Element: Sync, + { + self.as_slice().as_parallel_slice().par_iter() + } + + /// Returns an iterator over `&mut T` elements. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// for mut scalar in tensor.iter_mut() { + /// *scalar = 8; + /// } + /// for scalar in tensor.iter() { + /// assert_eq!(*scalar, 8); + /// } + /// ``` + pub fn iter_mut( + &mut self, + ) -> impl DoubleEndedIterator::Element> + ExactSizeIterator + where + Self: AsMutSlice, + { + self.as_mut_slice().iter_mut() + } + + /// Returns a parallel iterator over `&mut T` elements. + /// + /// # Notes: + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.iter_mut().for_each(|mut scalar| { + /// *scalar = 8; + /// }); + /// for scalar in tensor.iter() { + /// assert_eq!(*scalar, 8); + /// } + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_iter_mut( + &mut self, + ) -> impl IndexedParallelIterator::Element> + where + Self: AsMutSlice, + ::Element: Sync + Send, + { + self.as_mut_slice().as_parallel_slice_mut().par_iter_mut() + } + + /// Returns an iterator over sub tensors `Tensor<&[Scalar]>`. + /// + /// # Note: + /// The length of the sub-tensors must divide the size of the tensor. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// for sub in tensor.subtensor_iter(10) { + /// assert_eq!(sub.len(), 10); + /// } + /// ``` + pub fn subtensor_iter( + &self, + size: usize, + ) -> impl DoubleEndedIterator::Element]>> + ExactSizeIterator + where + Self: AsRefSlice, + { + debug_assert!(self.as_slice().len() % size == 0, "Uneven chunks size"); + self.as_slice().chunks(size).map(Tensor::from_container) + } + + /// Returns a parallel iterator over sub tensors `Tensor<&[Scalar]>`. + /// + /// # Note: + /// The length of the sub-tensors must divide the size of the tensor. + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example: + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.par_subtensor_iter(10).for_each(|sub| { + /// assert_eq!(sub.len(), 10); + /// }); + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_subtensor_iter( + &self, + size: usize, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsRefSlice, + ::Element: Sync, + { + debug_assert!(self.as_slice().len() % size == 0, "Uneven chunks size"); + self.as_slice().par_chunks(size).map(Tensor::from_container) + } + + /// Returns an iterator over mutable sub tensors `Tensor<&mut [Scalar]>`. + /// + /// # Note: + /// The length of the sub-tensors must divide the size of the tensor. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// for mut sub in tensor.subtensor_iter_mut(10) { + /// assert_eq!(sub.len(), 10); + /// *sub.get_element_mut(0) = 1; + /// } + /// for sub in tensor.subtensor_iter(20) { + /// assert_eq!(*sub.get_element(0), 1); + /// assert_eq!(*sub.get_element(10), 1); + /// } + /// ``` + pub fn subtensor_iter_mut( + &mut self, + size: usize, + ) -> impl DoubleEndedIterator::Element]>> + ExactSizeIterator + where + Self: AsMutSlice, + { + debug_assert!(self.as_slice().len() % size == 0, "Uneven chunks size"); + self.as_mut_slice() + .chunks_mut(size) + .map(Tensor::from_container) + } + + /// Returns a parallel iterator over mutable sub tensors `Tensor<&mut [Scalar]>`. + /// + /// # Note: + /// + /// The length of the sub-tensors must divide the size of the tensor. + /// This iterator is hidden behind the "__commons_parallel" feature gate. + /// + /// # Example: + /// ``` + /// use rayon::iter::ParallelIterator; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.par_subtensor_iter_mut(10).for_each(|mut sub| { + /// assert_eq!(sub.len(), 10); + /// *sub.get_element_mut(0) = 1; + /// }); + /// for sub in tensor.subtensor_iter(20) { + /// assert_eq!(*sub.get_element(0), 1); + /// assert_eq!(*sub.get_element(10), 1); + /// } + /// ``` + #[cfg(feature = "__commons_parallel")] + pub fn par_subtensor_iter_mut( + &mut self, + size: usize, + ) -> impl IndexedParallelIterator::Element]>> + where + Self: AsMutSlice, + ::Element: Sync + Send, + { + debug_assert!(self.as_slice().len() % size == 0, "Uneven chunks size"); + self.as_mut_slice() + .par_chunks_mut(size) + .map(Tensor::from_container) + } + + /// Returns a reference to the first element. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// assert_eq!(*tensor.first(), 9); + /// ``` + pub fn first(&self) -> &Element + where + Self: AsRefSlice, + { + self.as_slice().first().unwrap() + } + + /// Returns a reference to the last element. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// assert_eq!(*tensor.last(), 9); + /// ``` + pub fn last(&self) -> &Element + where + Self: AsRefSlice, + { + self.as_slice().last().unwrap() + } + + /// Returns a mutable reference to the first element. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// *tensor.first_mut() = 8; + /// assert_eq!(*tensor.get_element(0), 8); + /// assert_eq!(*tensor.get_element(1), 9); + /// ``` + pub fn first_mut(&mut self) -> &mut Element + where + Self: AsMutSlice, + { + self.as_mut_slice().first_mut().unwrap() + } + + /// Returns a mutable reference to the last element. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// *tensor.last_mut() = 8; + /// assert_eq!(*tensor.get_element(999), 8); + /// assert_eq!(*tensor.get_element(1), 9); + /// ``` + pub fn last_mut(&mut self) -> &mut Element + where + Self: AsMutSlice, + { + self.as_mut_slice().last_mut().unwrap() + } + + /// Returns a reference to the first element, and a ref tensor for the rest of the values. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// let (first, end) = tensor.split_first(); + /// assert_eq!(*first, 9); + /// assert_eq!(end.len(), 999); + /// ``` + pub fn split_first(&self) -> (&Element, Tensor<&[Element]>) + where + Self: AsRefSlice, + { + self.as_slice() + .split_first() + .map(|(f, r)| (f, Tensor(r))) + .unwrap() + } + + /// Returns a reference to the last element, and a ref tensor to the rest of the values. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// let (last, beginning) = tensor.split_last(); + /// assert_eq!(*last, 9); + /// assert_eq!(beginning.len(), 999); + /// ``` + pub fn split_last(&self) -> (&Element, Tensor<&[Element]>) + where + Self: AsRefSlice, + { + self.as_slice() + .split_last() + .map(|(f, r)| (f, Tensor(r))) + .unwrap() + } + + /// Returns a mutable reference to the first element, and a mut tensor for the rest of the + /// values. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// let (mut first, mut end) = tensor.split_first_mut(); + /// *first = 8; + /// *end.get_element_mut(0) = 7; + /// assert_eq!(*tensor.get_element(0), 8); + /// assert_eq!(*tensor.get_element(1), 7); + /// assert_eq!(*tensor.get_element(2), 9); + /// ``` + pub fn split_first_mut(&mut self) -> (&mut Element, Tensor<&mut [Element]>) + where + Self: AsMutSlice, + { + self.as_mut_slice() + .split_first_mut() + .map(|(f, r)| (f, Tensor(r))) + .unwrap() + } + + /// Returns a mutable reference to the last element, and a mut tensor for the rest of the + /// values. + /// + /// # Note: + /// + /// Panics if the tensor is empty. + /// + /// # Example: + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// let (mut last, mut beginning) = tensor.split_last_mut(); + /// *last = 8; + /// *beginning.get_element_mut(0) = 7; + /// assert_eq!(*tensor.get_element(0), 7); + /// assert_eq!(*tensor.get_element(999), 8); + /// assert_eq!(*tensor.get_element(2), 9); + /// ``` + pub fn split_last_mut(&mut self) -> (&mut Element, Tensor<&mut [Element]>) + where + Self: AsMutSlice, + { + self.as_mut_slice() + .split_last_mut() + .map(|(f, r)| (f, Tensor(r))) + .unwrap() + } + + /// Returns a sub tensor from a range of indices. + /// + /// # Note: + /// + /// Panics if the indices are out of range. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// let sub = tensor.get_sub(0..3); + /// assert_eq!(sub.len(), 3); + /// ``` + pub fn get_sub(&self, index: Index) -> Tensor<&[::Element]> + where + Self: AsRefSlice, + Index: + SliceIndex<[::Element], Output = [::Element]>, + { + Tensor(&self.as_slice()[index]) + } + + /// Returns a mutable sub tensor from a range of indices. + /// + /// # Note: + /// + /// Panics if the indices are out of range. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// let mut sub = tensor.get_sub_mut(0..3); + /// sub.fill_with_element(0); + /// assert_eq!(*tensor.get_element(0), 0); + /// assert_eq!(*tensor.get_element(3), 9); + /// ``` + pub fn get_sub_mut( + &mut self, + index: Index, + ) -> Tensor<&mut [::Element]> + where + Self: AsMutSlice, + Index: + SliceIndex<[::Element], Output = [::Element]>, + { + Tensor(&mut self.as_mut_slice()[index]) + } + + /// Returns a reference to an element from an index. + /// + /// # Note: + /// + /// Panics if the index is out of range. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let tensor = Tensor::allocate(9 as u8, 1000); + /// assert_eq!(*tensor.get_element(0), 9); + /// ``` + pub fn get_element(&self, index: usize) -> &::Element + where + Self: AsRefSlice, + { + &self.as_slice()[index] + } + + /// Returns a mutable reference to an element from an index. + /// + /// # Note: + /// + /// Panics if the index is out of range. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// *tensor.get_element_mut(0) = 8; + /// assert_eq!(*tensor.get_element(0), 8); + /// assert_eq!(*tensor.get_element(1), 9); + /// ``` + pub fn get_element_mut(&mut self, index: usize) -> &mut ::Element + where + Self: AsMutSlice, + { + &mut self.as_mut_slice()[index] + } + + /// Sets the value of an element at a given index. + /// + /// # Note: + /// + /// Panics if the index is out of range. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// *tensor.get_element_mut(0) = 8; + /// assert_eq!(*tensor.get_element(0), 8); + /// assert_eq!(*tensor.get_element(1), 9); + /// ``` + pub fn set_element(&mut self, index: usize, val: ::Element) + where + Self: AsMutSlice, + { + self.as_mut_slice()[index] = val; + } + + /// Fills a tensor with the values of another tensor, using memcpy. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor1 = Tensor::allocate(9 as u8, 1000); + /// let tensor2 = Tensor::allocate(10 as u8, 1000); + /// tensor1.fill_with_copy(&tensor2); + /// assert_eq!(*tensor2.get_element(0), 10); + /// ``` + pub fn fill_with_copy(&mut self, other: &Tensor) + where + Self: AsMutSlice, + Tensor: AsRefSlice, + Element: Copy, + { + ck_dim_eq!(self.len() => other.len()); + self.as_mut_slice().copy_from_slice(other.as_slice()); + } + + /// Fills two tensors with the result of the operation on a single one. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor1 = Tensor::allocate(9 as u8, 1000); + /// let mut tensor2 = Tensor::allocate(9 as u8, 1000); + /// let tensor3 = Tensor::allocate(10 as u8, 1000); + /// Tensor::fill_two_with_one(&mut tensor1, &mut tensor2, &tensor3, |a| (*a, *a)); + /// assert_eq!(*tensor1.get_element(0), 10); + /// assert_eq!(*tensor2.get_element(0), 10); + /// ``` + pub fn fill_two_with_one( + first: &mut Self, + second: &mut Tensor, + one: &Tensor, + ope: impl Fn( + & as AsRefSlice>::Element, + ) -> ( + ::Element, + as AsRefSlice>::Element, + ), + ) where + Self: AsMutSlice, + Tensor: AsMutSlice, + Tensor: AsRefSlice, + { + ck_dim_eq!(first.len() => one.len()); + ck_dim_eq!(second.len() => one.len()); + for (first_i, (second_i, one_i)) in zip!(first.iter_mut(), second.iter_mut(), one.iter()) { + let (f, s) = ope(one_i); + *first_i = f; + *second_i = s; + } + } + + /// Fills a mutable tensor with the result of an element-wise operation on two other tensors of + /// the same size + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(1 as u8, 1000); + /// let t3 = Tensor::allocate(2 as u8, 1000); + /// t1.fill_with_two(&t2, &t3, |t2, t3| t3 + t2); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 3); + /// } + /// ``` + pub fn fill_with_two( + &mut self, + lhs: &Tensor, + rhs: &Tensor, + ope: impl Fn( + & as AsRefSlice>::Element, + & as AsRefSlice>::Element, + ) -> ::Element, + ) where + Tensor: AsRefSlice, + Tensor: AsRefSlice, + Self: AsMutSlice, + { + ck_dim_eq!(self.len() => lhs.len()); + ck_dim_eq!(self.len() => rhs.len()); + for (output_i, (lhs_i, rhs_i)) in zip!( + self.iter_mut(), + lhs.as_slice().iter(), + rhs.as_slice().iter() + ) { + *output_i = ope(lhs_i, rhs_i); + } + } + + /// Fills a mutable tensor with the result of an element-wise operation on one other tensor of + /// the same size + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(2 as u8, 1000); + /// t1.fill_with_one(&t2, |t2| t2.pow(2)); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 4); + /// } + /// ``` + pub fn fill_with_one( + &mut self, + other: &Tensor, + ope: impl Fn(& as AsRefSlice>::Element) -> ::Element, + ) where + Tensor: AsRefSlice, + Self: AsMutSlice, + { + ck_dim_eq!(self.len() => other.len()); + for (output_i, other_i) in zip!(self.iter_mut(), other.as_slice().iter()) { + *output_i = ope(other_i); + } + } + + /// Fills a mutable tensor with an element. + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.fill_with_element(8); + /// for scalar in tensor.iter() { + /// assert_eq!(*scalar, 8); + /// } + /// ``` + pub fn fill_with_element(&mut self, element: ::Element) + where + Self: AsMutSlice, + ::Element: Copy, + { + for output_i in self.iter_mut() { + *output_i = element; + } + } + + /// Fills a mutable tensor by repeatedly calling a closure. + /// + /// ``` + /// use std::cell::RefCell; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u16, 1000); + /// let mut boxed = RefCell::from(0); + /// tensor.fill_with(|| { + /// *boxed.borrow_mut() += 1; + /// *boxed.borrow() + /// }); + /// assert_eq!(*tensor.get_element(0), 1); + /// assert_eq!(*tensor.get_element(1), 2); + /// assert_eq!(*tensor.get_element(2), 3); + /// ``` + pub fn fill_with(&mut self, ope: impl Fn() -> ::Element) + where + Self: AsMutSlice, + { + for output_i in self.iter_mut() { + *output_i = ope(); + } + } + + /// Fills a mutable tensor by casting elements of another one. + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u16, 1000); + /// let mut t2 = Tensor::allocate(8. as f32, 1000); + /// t1.fill_with_cast(&t2); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 8); + /// } + /// ``` + pub fn fill_with_cast(&mut self, other: &Tensor) + where + Self: AsMutSlice, + Tensor: AsRefSlice, + ::Element: CastFrom< as AsRefSlice>::Element>, + as AsRefSlice>::Element: Copy, + { + ck_dim_eq!(self.len() => other.len()); + self.fill_with_one(other, |a| ::Element::cast_from(*a)); + } + + fill_with!(BitAnd, fill_with_bit_and, |l, r| *l & *r); + fill_with!(BitOr, fill_with_bit_or, |l, r| *l | *r); + fill_with!(BitXor, fill_with_bit_xor, |l, r| *l ^ *r); + fill_with!(Shl, fill_with_bit_shl, |l, r| *l << *r); + fill_with!(Shr, fill_with_bit_shr, |l, r| *l >> *r); + + fill_with_wrapping!(fill_with_wrapping_add, |l, r| l.wrapping_add(*r)); + fill_with_wrapping!(fill_with_wrapping_sub, |l, r| l.wrapping_sub(*r)); + fill_with_wrapping!(fill_with_wrapping_mul, |l, r| l.wrapping_mul(*r)); + fill_with_wrapping!(fill_with_wrapping_div, |l, r| l.wrapping_div(*r)); + + /// Updates two tensors with the result of the operation with a single one. + /// + /// # Example: + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor1 = Tensor::allocate(9 as u8, 1000); + /// let mut tensor2 = Tensor::allocate(9 as u8, 1000); + /// let tensor3 = Tensor::allocate(10 as u8, 1000); + /// Tensor::update_two_with_one(&mut tensor1, &mut tensor2, &tensor3, |a, b, c| { + /// *a += *c; + /// *b += *c + 1; + /// }); + /// assert_eq!(*tensor1.get_element(0), 19); + /// assert_eq!(*tensor2.get_element(0), 20); + /// ``` + pub fn update_two_with_one( + first: &mut Self, + second: &mut Tensor, + one: &Tensor, + ope: impl Fn( + &mut ::Element, + &mut as AsRefSlice>::Element, + & as AsRefSlice>::Element, + ), + ) where + Self: AsMutSlice, + Tensor: AsMutSlice, + Tensor: AsRefSlice, + { + ck_dim_eq!(first.len() => one.len()); + ck_dim_eq!(second.len() => one.len()); + for (first_i, (second_i, one_i)) in zip!(first.iter_mut(), second.iter_mut(), one.iter()) { + ope(first_i, second_i, one_i); + } + } + + /// Updates a mutable tensor with the result of an element-wise operation with two other + /// tensors of the same size. + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(1 as u8, 1000); + /// let t3 = Tensor::allocate(2 as u8, 1000); + /// t1.update_with_two(&t2, &t3, |t1, t2, t3| *t1 += t3 + t2); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 12); + /// } + /// ``` + pub fn update_with_two( + &mut self, + first: &Tensor, + second: &Tensor, + ope: impl Fn( + &mut ::Element, + & as AsRefSlice>::Element, + & as AsRefSlice>::Element, + ), + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + Tensor: AsRefSlice, + { + ck_dim_eq!(self.len() => first.len()); + ck_dim_eq!(self.len() => second.len()); + for (self_i, (first_i, second_i)) in zip!( + self.iter_mut(), + first.as_slice().iter(), + second.as_slice().iter() + ) { + ope(self_i, first_i, second_i); + } + } + + /// Updates a mutable tensor with the result of an element-wise operation with one other tensor + /// of the same size + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(2 as u8, 1000); + /// t1.update_with_one(&t2, |t1, t2| *t1 += t2.pow(2)); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 13); + /// } + /// ``` + pub fn update_with_one( + &mut self, + other: &Tensor, + ope: impl Fn(&mut ::Element, & as AsRefSlice>::Element), + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + { + ck_dim_eq!(self.len() => other.len()); + for (self_i, other_i) in zip!(self.iter_mut(), other.as_slice().iter()) { + ope(self_i, other_i); + } + } + + /// Updates a mutable tensor with an element. + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.update_with_element(8, |t, s| *t += s); + /// for scalar in tensor.iter() { + /// assert_eq!(*scalar, 17); + /// } + /// ``` + pub fn update_with_element( + &mut self, + scalar: Element, + ope: impl Fn(&mut ::Element, Element), + ) where + Self: AsMutSlice, + Element: Copy, + { + for self_i in self.iter_mut() { + ope(self_i, scalar); + } + } + + /// Updates a mutable tensor by repeatedly calling a closure. + /// + /// ``` + /// use std::cell::RefCell; + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u16, 1000); + /// let mut boxed = RefCell::from(0); + /// tensor.update_with(|t| { + /// *boxed.borrow_mut() += 1; + /// *t += *boxed.borrow() + /// }); + /// assert_eq!(*tensor.get_element(0), 10); + /// assert_eq!(*tensor.get_element(1), 11); + /// assert_eq!(*tensor.get_element(2), 12); + /// ``` + pub fn update_with(&mut self, ope: impl Fn(&mut ::Element)) + where + Self: AsMutSlice, + { + for self_i in self.iter_mut() { + ope(self_i); + } + } + + update_with!(BitAndAssign, update_with_and, |s, a| *s &= *a); + update_with!(BitOrAssign, update_with_or, |s, a| *s |= *a); + update_with!(BitXorAssign, update_with_xor, |s, a| *s ^= *a); + update_with!(ShlAssign, update_with_shl, |s, a| *s <<= *a); + update_with!(ShrAssign, update_with_shr, |s, a| *s >>= *a); + + update_with_wrapping!(update_with_wrapping_add, |s, a| *s = s.wrapping_add(*a)); + update_with_wrapping!(update_with_wrapping_sub, |s, a| *s = s.wrapping_sub(*a)); + update_with_wrapping!(update_with_wrapping_mul, |s, a| *s = s.wrapping_mul(*a)); + update_with_wrapping!(update_with_wrapping_div, |s, a| *s = s.wrapping_div(*a)); + + update_with_scalar!(BitAndAssign, update_with_scalar_and, |s, a| *s &= *a); + update_with_scalar!(BitOrAssign, update_with_scalar_or, |s, a| *s |= *a); + update_with_scalar!(BitXorAssign, update_with_scalar_xor, |s, a| *s ^= *a); + update_with_scalar!(ShlAssign, update_with_scalar_shl, |s, a| *s <<= *a); + update_with_scalar!(ShrAssign, update_with_scalar_shr, |s, a| *s >>= *a); + + update_with_wrapping_scalar!(update_with_wrapping_scalar_add, |s, a| *s = + s.wrapping_add(*a)); + update_with_wrapping_scalar!(update_with_wrapping_scalar_sub, |s, a| *s = + s.wrapping_sub(*a)); + update_with_wrapping_scalar!(update_with_wrapping_scalar_mul, |s, a| *s = + s.wrapping_mul(*a)); + update_with_wrapping_scalar!(update_with_wrapping_scalar_div, |s, a| *s = + s.wrapping_div(*a)); + + /// Sets each value of `self` to its own wrapping opposite. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::allocate(9 as u8, 1000); + /// tensor.update_with_wrapping_neg(); + /// for scalar in tensor.iter() { + /// assert_eq!(*scalar, 247); + /// } + /// ``` + pub fn update_with_wrapping_neg(&mut self) + where + Self: AsMutSlice, + ::Element: UnsignedInteger, + { + self.update_with(|a| *a = a.wrapping_neg()); + } + + /// Fills a mutable tensor with the result of the wrapping multiplication of elements of + /// another tensor by an element. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(3 as u8, 1000); + /// t1.fill_with_wrapping_element_mul(&t2, 250); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 238); + /// } + /// ``` + pub fn fill_with_wrapping_element_mul( + &mut self, + tensor: &Tensor, + element: Element, + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + Element: UnsignedInteger, + { + ck_dim_eq!(self.len() => tensor.len()); + self.fill_with_one(tensor, |t| t.wrapping_mul(element)); + } + + /// Updates the values of a mutable tensor by wrap-subtracting the wrapping product of the + /// elements of another tensor and an element. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(2 as u8, 1000); + /// t1.update_with_wrapping_sub_element_mul(&t2, 250); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 21); + /// } + /// ``` + pub fn update_with_wrapping_sub_element_mul( + &mut self, + tensor: &Tensor, + scalar: Element, + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + Element: UnsignedInteger, + { + ck_dim_eq!(self.len() => tensor.len()); + self.update_with_one(tensor, |s, t| *s = s.wrapping_sub(t.wrapping_mul(scalar))); + } + + /// Updates the values of a mutable tensor by wrap-adding the wrapping product of the elements + /// of another tensor and an element. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut t1 = Tensor::allocate(9 as u8, 1000); + /// let t2 = Tensor::allocate(2 as u8, 1000); + /// t1.update_with_wrapping_add_element_mul(&t2, 250); + /// for scalar in t1.iter() { + /// assert_eq!(*scalar, 253); + /// } + /// ``` + pub fn update_with_wrapping_add_element_mul( + &mut self, + tensor: &Tensor, + element: Element, + ) where + Self: AsMutSlice, + Tensor: AsRefSlice, + Element: UnsignedInteger, + { + ck_dim_eq!(self.len() => tensor.len()); + self.update_with_one(tensor, |s, t| *s = s.wrapping_add(t.wrapping_mul(element))); + } + + /// Computes a value by folding a tensor with another. + /// + /// # Example + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let t1 = Tensor::allocate(10 as u16, 10); + /// let t2 = Tensor::allocate(2 as u16, 10); + /// let val = t1.fold_with_one(&t2, 0, |mut a, t1, t2| { + /// a += t1 + t2; + /// a + /// }); + /// assert_eq!(val, 120); + /// ``` + pub fn fold_with_one( + &self, + other: &Tensor, + acc: Output, + ope: impl Fn( + Output, + &::Element, + & as AsRefSlice>::Element, + ) -> Output, + ) -> Output + where + Self: AsRefSlice, + Tensor: AsRefSlice, + { + ck_dim_eq!(self.len() => other.len()); + self.iter() + .zip(other.as_slice().iter()) + .fold(acc, |acc, (s_i, o_i)| ope(acc, s_i, o_i)) + } + + /// Reverses the elements of the tensor inplace. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::from_container(vec![1u8, 2, 3, 4]); + /// tensor.reverse(); + /// assert_eq!(*tensor.get_element(0), 4); + /// ``` + pub fn reverse(&mut self) + where + Self: AsMutSlice, + { + self.as_mut_slice().reverse() + } + + /// Rotates the elements of the tensor to the right, inplace. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::from_container(vec![1u8, 2, 3, 4]); + /// tensor.rotate_right(2); + /// assert_eq!(*tensor.get_element(0), 3); + /// ``` + pub fn rotate_right(&mut self, n: usize) + where + Self: AsMutSlice, + { + self.as_mut_slice().rotate_right(n) + } + + /// Rotates the elements of the tensor to the left, inplace. + /// + /// # Example + /// + /// ``` + /// use tfhe::core_crypto::commons::math::tensor::Tensor; + /// let mut tensor = Tensor::from_container(vec![1u8, 2, 3, 4]); + /// tensor.rotate_left(2); + /// assert_eq!(*tensor.get_element(0), 3); + /// ``` + pub fn rotate_left(&mut self, n: usize) + where + Self: AsMutSlice, + { + self.as_mut_slice().rotate_left(n) + } +} + +impl FromIterator for Tensor> { + fn from_iter>(iter: I) -> Self { + let mut v = Vec::new(); + for i in iter { + v.push(i); + } + Tensor(v) + } +} + +impl AsRefSlice for Tensor +where + Cont: AsRefSlice, +{ + type Element = Cont::Element; + fn as_slice(&self) -> &[Self::Element] { + self.0.as_slice() + } +} + +impl AsMutSlice for Tensor +where + Cont: AsMutSlice, +{ + type Element = ::Element; + fn as_mut_slice(&mut self) -> &mut [::Element] { + self.0.as_mut_slice() + } +} + +impl AsRefTensor for Tensor +where + Cont: AsRefSlice, +{ + type Element = Cont::Element; + type Container = Cont; + fn as_tensor(&self) -> &Tensor { + self + } +} + +impl AsMutTensor for Tensor +where + Cont: AsMutSlice, +{ + type Element = ::Element; + type Container = Cont; + fn as_mut_tensor(&mut self) -> &mut Tensor { + self + } +} diff --git a/tfhe/src/core_crypto/commons/math/tensor/tests.rs b/tfhe/src/core_crypto/commons/math/tensor/tests.rs new file mode 100644 index 000000000..88710fc2c --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/tensor/tests.rs @@ -0,0 +1,1257 @@ +use super::Tensor; + +#[test] +fn test_add_u32() { + let t_1 = Tensor::from_container(vec![ + 2_614_422_625_u32, + 1_347_010_255_u32, + 1_755_118_555_u32, + 3_348_067_670_u32, + 3_896_589_259_u32, + 97_617_327_u32, + 1_545_053_739_u32, + 1_211_085_433_u32, + 2_684_538_667_u32, + 202_832_626_u32, + 1_638_508_087_u32, + 879_523_200_u32, + 2_456_511_176_u32, + 2_648_745_580_u32, + 967_205_272_u32, + 54_854_762_u32, + 2_609_115_771_u32, + 1_725_392_344_u32, + 2_314_671_715_u32, + 1_840_995_902_u32, + 4_041_278_880_u32, + 275_079_767_u32, + 2_300_142_423_u32, + 2_333_095_686_u32, + 3_026_580_357_u32, + 21_931_374_u32, + 372_535_067_u32, + 6_439_834_u32, + 762_787_515_u32, + 2_734_668_397_u32, + 3_013_991_526_u32, + 579_324_780_u32, + 916_175_967_u32, + 850_321_436_u32, + 978_826_112_u32, + 1_360_938_704_u32, + 2_363_410_736_u32, + 353_572_296_u32, + 2_196_029_604_u32, + 1_676_698_573_u32, + 71_702_920_u32, + 433_353_586_u32, + 3_336_662_792_u32, + 3_815_644_954_u32, + 2_974_299_797_u32, + 1_990_548_820_u32, + 1_683_843_869_u32, + 2_152_628_932_u32, + 3_625_450_751_u32, + 2_366_853_676_u32, + 1_798_342_904_u32, + 2_869_368_979_u32, + 1_185_695_639_u32, + 3_173_469_147_u32, + 1_531_916_725_u32, + 3_326_214_024_u32, + 2_067_990_523_u32, + 976_120_805_u32, + 3_535_693_006_u32, + 4_223_913_473_u32, + 2_143_410_133_u32, + 187_637_181_u32, + 2_370_649_336_u32, + 3_155_284_399_u32, + 3_282_898_811_u32, + 3_068_767_567_u32, + 3_033_732_496_u32, + 3_278_852_653_u32, + 1_988_815_405_u32, + 3_318_268_258_u32, + 402_934_292_u32, + 3_162_645_643_u32, + 2_103_209_800_u32, + 4_253_170_701_u32, + 2_489_673_789_u32, + 2_224_135_091_u32, + 1_848_398_457_u32, + 3_159_326_514_u32, + 3_865_725_686_u32, + 674_027_046_u32, + 3_191_092_214_u32, + 356_413_912_u32, + 682_734_067_u32, + 2_368_555_344_u32, + 614_314_161_u32, + 3_515_266_737_u32, + 949_414_245_u32, + 2_046_032_417_u32, + 1_495_462_201_u32, + 2_307_315_576_u32, + 1_960_455_472_u32, + 917_911_666_u32, + 1_518_075_072_u32, + 2_925_772_427_u32, + 298_590_050_u32, + 1_441_972_928_u32, + 666_987_301_u32, + 2_167_997_170_u32, + 3_413_359_382_u32, + 3_526_531_810_u32, + ]); + let t_2 = Tensor::from_container(vec![ + 3_428_858_567_u32, + 827_447_270_u32, + 959_110_479_u32, + 4_184_350_429_u32, + 1_820_415_259_u32, + 2_322_099_741_u32, + 1_328_906_591_u32, + 1_664_312_159_u32, + 549_610_931_u32, + 2_945_591_302_u32, + 295_342_634_u32, + 1_589_486_080_u32, + 1_359_822_125_u32, + 1_285_568_394_u32, + 1_881_925_871_u32, + 3_058_045_327_u32, + 1_773_709_235_u32, + 3_813_730_789_u32, + 823_940_101_u32, + 2_480_100_080_u32, + 3_639_129_118_u32, + 759_351_495_u32, + 1_301_750_125_u32, + 1_054_832_776_u32, + 3_245_556_275_u32, + 2_800_997_186_u32, + 1_256_287_364_u32, + 2_573_603_461_u32, + 2_328_221_582_u32, + 1_633_069_253_u32, + 102_853_950_u32, + 2_716_685_335_u32, + 503_267_884_u32, + 2_202_048_416_u32, + 1_602_161_938_u32, + 1_927_466_558_u32, + 858_392_614_u32, + 956_183_465_u32, + 4_135_389_917_u32, + 951_071_347_u32, + 2_318_567_902_u32, + 2_004_258_446_u32, + 1_797_038_763_u32, + 1_610_761_714_u32, + 3_236_519_313_u32, + 316_586_765_u32, + 307_967_731_u32, + 3_588_485_359_u32, + 1_947_118_682_u32, + 2_002_927_095_u32, + 1_136_304_281_u32, + 1_065_157_362_u32, + 3_714_003_080_u32, + 786_946_775_u32, + 2_441_787_699_u32, + 832_944_437_u32, + 3_651_539_633_u32, + 798_050_864_u32, + 669_130_367_u32, + 2_000_552_570_u32, + 2_700_875_050_u32, + 1_118_200_520_u32, + 708_688_839_u32, + 3_463_285_323_u32, + 2_270_978_169_u32, + 237_144_352_u32, + 660_096_080_u32, + 4_230_221_095_u32, + 3_280_398_471_u32, + 3_354_541_293_u32, + 833_549_915_u32, + 1_136_697_343_u32, + 1_096_316_847_u32, + 2_476_951_406_u32, + 3_971_141_200_u32, + 12_788_335_u32, + 1_580_197_299_u32, + 2_444_226_867_u32, + 387_836_375_u32, + 4_067_693_575_u32, + 3_918_490_972_u32, + 3_704_639_326_u32, + 664_901_677_u32, + 1_847_150_125_u32, + 476_514_752_u32, + 466_866_141_u32, + 704_667_620_u32, + 2_652_242_441_u32, + 3_683_680_188_u32, + 2_589_696_574_u32, + 571_587_908_u32, + 953_792_331_u32, + 3_677_654_843_u32, + 2_056_915_677_u32, + 3_272_850_239_u32, + 1_669_788_385_u32, + 2_311_731_086_u32, + 162_748_026_u32, + 907_347_406_u32, + 2_760_784_143_u32, + ]); + let ground_truth_t_3 = Tensor::from_container(vec![ + 1_748_313_896_u32, + 2_174_457_525_u32, + 2_714_229_034_u32, + 3_237_450_803_u32, + 1_422_037_222_u32, + 2_419_717_068_u32, + 2_873_960_330_u32, + 2_875_397_592_u32, + 3_234_149_598_u32, + 3_148_423_928_u32, + 1_933_850_721_u32, + 2_469_009_280_u32, + 3_816_333_301_u32, + 3_934_313_974_u32, + 2_849_131_143_u32, + 3_112_900_089_u32, + 87_857_710_u32, + 1_244_155_837_u32, + 3_138_611_816_u32, + 26_128_686_u32, + 3_385_440_702_u32, + 1_034_431_262_u32, + 3_601_892_548_u32, + 3_387_928_462_u32, + 1_977_169_336_u32, + 2_822_928_560_u32, + 1_628_822_431_u32, + 2_580_043_295_u32, + 3_091_009_097_u32, + 72_770_354_u32, + 3_116_845_476_u32, + 3_296_010_115_u32, + 1_419_443_851_u32, + 3_052_369_852_u32, + 2_580_988_050_u32, + 3_288_405_262_u32, + 3_221_803_350_u32, + 1_309_755_761_u32, + 2_036_452_225_u32, + 2_627_769_920_u32, + 2_390_270_822_u32, + 2_437_612_032_u32, + 838_734_259_u32, + 1_131_439_372_u32, + 1_915_851_814_u32, + 2_307_135_585_u32, + 1_991_811_600_u32, + 1_446_146_995_u32, + 1_277_602_137_u32, + 74_813_475_u32, + 2_934_647_185_u32, + 3_934_526_341_u32, + 604_731_423_u32, + 3_960_415_922_u32, + 3_973_704_424_u32, + 4_159_158_461_u32, + 1_424_562_860_u32, + 1_774_171_669_u32, + 4_204_823_373_u32, + 1_929_498_747_u32, + 549_317_887_u32, + 1_305_837_701_u32, + 3_079_338_175_u32, + 2_323_602_426_u32, + 1_258_909_684_u32, + 3_305_911_919_u32, + 3_693_828_576_u32, + 3_214_106_452_u32, + 974_246_580_u32, + 2_377_842_255_u32, + 1_236_484_207_u32, + 4_375_690_u32, + 3_199_526_647_u32, + 2_435_154_811_u32, + 2_165_847_693_u32, + 2_236_923_426_u32, + 3_428_595_756_u32, + 1_308_586_085_u32, + 4_253_562_061_u32, + 446_753_325_u32, + 2_814_615_890_u32, + 4_061_053_238_u32, + 1_347_635_744_u32, + 4_215_705_469_u32, + 1_090_828_913_u32, + 3_982_132_878_u32, + 1_654_081_865_u32, + 403_307_562_u32, + 884_175_093_u32, + 602_044_854_u32, + 2_532_043_380_u32, + 1_871_703_997_u32, + 900_762_619_u32, + 687_720_808_u32, + 3_571_440_289_u32, + 3_111_761_313_u32, + 2_978_718_387_u32, + 2_330_745_196_u32, + 25_739_492_u32, + 1_992_348_657_u32, + ]); + let mut t_3 = Tensor::allocate(0_u32, 100); + t_3.fill_with_wrapping_add(&t_1, &t_2); + + assert_eq!(t_3, ground_truth_t_3, "we are testing addition"); +} + +#[test] +fn test_sub_u64() { + let t_1 = Tensor::from_container(vec![ + 5_682_232_049_849_203_449_u64, + 1_744_272_140_419_931_610_u64, + 6_524_694_120_235_710_248_u64, + 501_685_223_587_207_708_u64, + 7_454_825_121_449_404_861_u64, + 1_452_049_147_138_516_728_u64, + 3_744_089_800_795_951_655_u64, + 2_900_714_251_440_266_072_u64, + 2_885_003_742_441_599_873_u64, + 5_127_037_330_303_939_263_u64, + 3_942_793_256_559_402_137_u64, + 2_938_215_163_794_025_597_u64, + 3_194_270_088_293_124_907_u64, + 3_798_617_854_173_374_109_u64, + 2_281_550_512_455_919_685_u64, + 1_378_021_925_594_404_903_u64, + 6_273_819_789_066_539_195_u64, + 5_891_518_315_759_560_031_u64, + 6_569_862_020_994_290_872_u64, + 2_312_304_860_409_402_175_u64, + 3_768_205_285_282_560_447_u64, + 2_813_718_090_332_844_130_u64, + 4_741_992_406_149_366_296_u64, + 2_862_912_615_999_044_257_u64, + 2_711_698_756_636_236_379_u64, + 3_105_025_607_153_753_493_u64, + 3_280_659_296_609_069_569_u64, + 1_621_356_564_053_932_659_u64, + 244_394_454_277_671_115_u64, + 1_370_168_407_221_172_838_u64, + 384_807_778_723_441_456_u64, + 5_421_384_837_838_501_695_u64, + 3_524_866_043_795_281_573_u64, + 273_224_951_302_481_390_u64, + 8_874_399_707_947_016_287_u64, + 5_042_853_686_974_107_712_u64, + 8_593_762_746_401_730_055_u64, + 4_298_169_213_116_104_086_u64, + 1_043_682_735_183_771_811_u64, + 8_271_963_865_357_943_237_u64, + 2_866_933_850_832_375_526_u64, + 3_680_273_731_625_120_587_u64, + 5_594_513_115_859_518_166_u64, + 1_643_917_283_539_244_290_u64, + 3_172_178_086_476_235_900_u64, + 6_964_486_530_272_725_036_u64, + 6_025_940_910_517_479_800_u64, + 8_277_718_434_101_601_483_u64, + 8_184_281_612_310_786_511_u64, + 5_373_031_274_997_880_981_u64, + 443_782_149_988_086_463_u64, + 9_185_207_564_855_550_126_u64, + 3_175_405_486_723_930_612_u64, + 538_795_803_601_238_624_u64, + 1_842_522_998_755_997_387_u64, + 756_815_213_533_913_513_u64, + 4_792_029_986_473_993_888_u64, + 4_782_811_555_589_976_751_u64, + 4_765_160_184_182_081_015_u64, + 6_870_421_860_884_204_987_u64, + 6_644_609_928_302_751_438_u64, + 9_205_665_417_060_638_521_u64, + 4_422_362_498_965_857_329_u64, + 3_911_541_231_075_340_397_u64, + 714_780_100_332_094_572_u64, + 854_285_349_422_025_761_u64, + 7_998_144_870_496_815_069_u64, + 4_601_820_771_957_226_501_u64, + 4_668_015_978_555_069_529_u64, + 3_107_134_174_330_286_017_u64, + 8_556_643_770_851_938_756_u64, + 7_603_022_701_719_395_789_u64, + 9_061_759_085_783_731_100_u64, + 335_871_293_124_179_717_u64, + 578_609_166_965_025_587_u64, + 8_344_077_009_920_132_885_u64, + 5_890_072_533_484_701_885_u64, + 4_572_233_892_255_728_435_u64, + 6_510_971_065_603_537_789_u64, + 2_119_489_420_934_143_588_u64, + 7_384_712_968_731_389_043_u64, + 5_631_603_782_423_650_945_u64, + 2_426_176_736_130_500_836_u64, + 8_725_885_473_278_349_136_u64, + 6_998_312_559_885_650_695_u64, + 1_747_649_994_418_612_192_u64, + 5_557_047_201_979_978_882_u64, + 4_330_564_741_999_955_015_u64, + 1_423_746_095_735_226_283_u64, + 6_729_353_041_636_611_170_u64, + 3_912_555_288_358_270_774_u64, + 6_236_800_801_119_694_988_u64, + 1_119_102_165_244_657_550_u64, + 5_444_700_680_136_175_568_u64, + 6_107_520_479_033_799_392_u64, + 6_092_621_673_178_322_094_u64, + 2_613_801_610_897_795_471_u64, + 7_958_414_627_268_864_059_u64, + 1_701_360_089_741_291_949_u64, + 8_900_744_252_335_003_997_u64, + ]); + let t_2 = Tensor::from_container(vec![ + 8_256_890_089_369_290_096_u64, + 6_729_858_587_364_974_993_u64, + 7_847_985_733_087_156_225_u64, + 4_256_288_592_723_368_540_u64, + 1_794_053_349_452_132_041_u64, + 6_010_968_597_662_138_399_u64, + 6_274_700_101_275_637_475_u64, + 3_672_569_542_766_325_569_u64, + 7_783_627_030_003_669_629_u64, + 249_357_646_255_069_879_u64, + 5_557_476_119_820_039_974_u64, + 8_042_948_614_404_456_368_u64, + 4_654_915_497_230_252_172_u64, + 7_722_972_477_579_752_886_u64, + 258_964_119_735_943_544_u64, + 3_661_700_972_414_689_603_u64, + 5_780_010_438_965_763_305_u64, + 5_399_007_971_131_851_993_u64, + 9_009_523_661_328_089_448_u64, + 670_837_492_260_568_551_u64, + 8_553_265_509_497_774_774_u64, + 475_007_578_406_922_623_u64, + 1_656_958_878_392_217_405_u64, + 3_145_284_643_778_286_187_u64, + 6_211_468_814_998_169_736_u64, + 7_898_586_816_448_146_424_u64, + 6_385_644_578_140_856_445_u64, + 6_278_113_144_098_235_027_u64, + 5_508_031_993_944_422_488_u64, + 2_541_351_454_611_805_754_u64, + 253_476_817_899_518_218_u64, + 4_042_272_828_677_076_320_u64, + 6_273_812_701_503_178_622_u64, + 7_154_361_991_326_158_245_u64, + 4_812_968_649_666_322_424_u64, + 8_058_877_626_669_330_796_u64, + 2_570_559_734_648_418_432_u64, + 3_260_085_933_573_705_643_u64, + 1_282_517_144_950_793_850_u64, + 1_370_863_113_856_127_345_u64, + 7_751_961_484_782_528_551_u64, + 2_576_515_167_053_557_195_u64, + 6_023_795_786_532_458_230_u64, + 6_726_030_942_349_870_732_u64, + 7_466_418_281_703_253_736_u64, + 8_567_435_608_821_064_654_u64, + 1_678_961_003_340_349_987_u64, + 3_502_334_064_042_353_274_u64, + 3_731_187_845_427_012_882_u64, + 5_317_359_253_712_576_816_u64, + 6_534_183_265_395_520_755_u64, + 3_251_278_594_118_653_876_u64, + 8_455_470_979_973_987_894_u64, + 1_134_450_355_974_787_411_u64, + 2_087_289_461_344_972_800_u64, + 898_091_164_345_629_933_u64, + 1_383_688_945_649_969_441_u64, + 6_412_373_125_771_730_589_u64, + 3_137_727_871_406_467_282_u64, + 2_531_450_854_507_130_283_u64, + 8_942_523_860_499_484_955_u64, + 3_053_185_116_942_316_003_u64, + 7_573_298_098_522_728_453_u64, + 7_850_035_594_752_589_513_u64, + 7_609_365_690_458_693_792_u64, + 3_979_440_714_450_645_544_u64, + 8_679_308_680_362_097_737_u64, + 3_937_728_290_719_953_722_u64, + 3_848_494_478_551_479_774_u64, + 3_384_383_891_744_980_023_u64, + 7_516_977_367_724_693_326_u64, + 435_538_850_065_011_084_u64, + 2_232_114_847_229_197_016_u64, + 8_939_199_010_658_684_319_u64, + 2_450_683_567_053_287_115_u64, + 7_734_458_772_215_536_274_u64, + 8_218_782_583_431_213_252_u64, + 8_553_066_689_779_351_731_u64, + 3_832_186_301_178_121_773_u64, + 6_381_512_211_621_916_311_u64, + 1_300_796_182_487_056_551_u64, + 8_878_587_019_650_156_826_u64, + 8_211_502_017_418_832_896_u64, + 8_604_481_095_678_650_971_u64, + 5_587_624_902_285_300_211_u64, + 3_426_416_825_207_801_687_u64, + 6_489_160_959_956_510_743_u64, + 7_128_903_513_419_063_730_u64, + 4_040_914_739_727_604_681_u64, + 2_688_309_576_123_447_655_u64, + 6_492_809_116_044_763_762_u64, + 7_116_289_826_504_895_799_u64, + 3_531_910_189_811_123_524_u64, + 789_365_259_912_898_159_u64, + 6_469_517_349_990_767_948_u64, + 1_189_741_323_354_180_502_u64, + 1_445_291_182_187_601_512_u64, + 8_147_987_973_554_022_701_u64, + 2_888_875_140_678_677_703_u64, + 7_057_288_198_514_267_233_u64, + ]); + let ground_truth_t_3 = Tensor::from_container(vec![ + 15_872_086_034_189_464_969_u64, + 13_461_157_626_764_508_233_u64, + 17_123_452_460_858_105_639_u64, + 14_692_140_704_573_390_784_u64, + 5_660_771_771_997_272_820_u64, + 13_887_824_623_185_929_945_u64, + 15_916_133_773_229_865_796_u64, + 17_674_888_782_383_492_119_u64, + 13_548_120_786_147_481_860_u64, + 4_877_679_684_048_869_384_u64, + 16_832_061_210_448_913_779_u64, + 13_342_010_623_099_120_845_u64, + 16_986_098_664_772_424_351_u64, + 14_522_389_450_303_172_839_u64, + 2_022_586_392_719_976_141_u64, + 16_163_065_026_889_266_916_u64, + 493_809_350_100_775_890_u64, + 492_510_344_627_708_038_u64, + 16_007_082_433_375_753_040_u64, + 1_641_467_368_148_833_624_u64, + 13_661_683_849_494_337_289_u64, + 2_338_710_511_925_921_507_u64, + 3_085_033_527_757_148_891_u64, + 18_164_372_045_930_309_686_u64, + 14_946_974_015_347_618_259_u64, + 13_653_182_864_415_158_685_u64, + 15_341_758_792_177_764_740_u64, + 13_789_987_493_665_249_248_u64, + 13_183_106_534_042_800_243_u64, + 17_275_561_026_318_918_700_u64, + 131_330_960_823_923_238_u64, + 1_379_112_009_161_425_375_u64, + 15_697_797_416_001_654_567_u64, + 11_565_607_033_685_874_761_u64, + 4_061_431_058_280_693_863_u64, + 15_430_720_134_014_328_532_u64, + 6_023_203_011_753_311_623_u64, + 1_038_083_279_542_398_443_u64, + 18_207_909_663_942_529_577_u64, + 6_901_100_751_501_815_892_u64, + 13_561_716_439_759_398_591_u64, + 1_103_758_564_571_563_392_u64, + 18_017_461_403_036_611_552_u64, + 13_364_630_414_898_925_174_u64, + 14_152_503_878_482_533_780_u64, + 16_843_794_995_161_211_998_u64, + 4_346_979_907_177_129_813_u64, + 4_775_384_370_059_248_209_u64, + 4_453_093_766_883_773_629_u64, + 55_672_021_285_304_165_u64, + 12_356_342_958_302_117_324_u64, + 5_933_928_970_736_896_250_u64, + 13_166_678_580_459_494_334_u64, + 17_851_089_521_336_002_829_u64, + 18_201_977_611_120_576_203_u64, + 18_305_468_122_897_835_196_u64, + 3_408_341_040_824_024_447_u64, + 16_817_182_503_527_797_778_u64, + 1_627_432_312_775_613_733_u64, + 4_338_971_006_377_074_704_u64, + 16_148_830_141_512_818_099_u64, + 6_152_480_300_118_322_518_u64, + 15_295_808_474_152_680_492_u64, + 14_508_249_710_032_302_500_u64, + 11_552_158_483_582_952_396_u64, + 15_321_588_708_680_931_833_u64, + 17_765_580_263_844_268_948_u64, + 664_092_481_237_272_779_u64, + 819_521_500_003_589_755_u64, + 18_169_494_356_294_857_610_u64, + 1_039_666_403_127_245_430_u64, + 7_167_483_851_654_384_705_u64, + 6_829_644_238_554_534_084_u64, + 9_843_416_356_175_047_014_u64, + 16_574_669_673_621_290_088_u64, + 609_618_237_704_596_611_u64, + 16_118_034_023_763_040_249_u64, + 14_465_911_276_185_928_320_u64, + 2_678_784_764_425_416_016_u64, + 14_184_721_283_021_778_893_u64, + 6_083_916_786_244_332_492_u64, + 15_199_760_836_483_045_735_u64, + 12_661_418_792_421_219_556_u64, + 121_404_377_599_698_165_u64, + 1_410_687_657_600_350_484_u64, + 16_767_977_242_920_362_121_u64, + 17_514_630_315_733_019_755_u64, + 15_648_405_302_290_442_901_u64, + 15_829_575_429_717_173_218_u64, + 4_041_043_465_513_163_515_u64, + 15_866_490_246_023_058_628_u64, + 17_567_255_048_324_350_805_u64, + 16_033_936_049_143_085_642_u64, + 4_655_335_420_223_277_409_u64, + 18_084_747_202_752_583_060_u64, + 4_902_880_349_824_141_592_u64, + 1_168_510_428_710_193_959_u64, + 18_257_170_727_424_392_974_u64, + 17_259_229_022_772_165_862_u64, + 1_843_456_053_820_736_764_u64, + ]); + let mut t_3 = Tensor::allocate(0_u64, 100); + t_3.fill_with_wrapping_sub(&t_1, &t_2); + + assert_eq!(t_3, ground_truth_t_3, "we are testing u-64 sub "); +} + +#[test] +fn test_sub_u32() { + let t_1 = Tensor::from_container(vec![ + 2_614_422_625_u32, + 1_347_010_255_u32, + 1_755_118_555_u32, + 3_348_067_670_u32, + 3_896_589_259_u32, + 97_617_327_u32, + 1_545_053_739_u32, + 1_211_085_433_u32, + 2_684_538_667_u32, + 202_832_626_u32, + 1_638_508_087_u32, + 879_523_200_u32, + 2_456_511_176_u32, + 2_648_745_580_u32, + 967_205_272_u32, + 54_854_762_u32, + 2_609_115_771_u32, + 1_725_392_344_u32, + 2_314_671_715_u32, + 1_840_995_902_u32, + 4_041_278_880_u32, + 275_079_767_u32, + 2_300_142_423_u32, + 2_333_095_686_u32, + 3_026_580_357_u32, + 21_931_374_u32, + 372_535_067_u32, + 6_439_834_u32, + 762_787_515_u32, + 2_734_668_397_u32, + 3_013_991_526_u32, + 579_324_780_u32, + 916_175_967_u32, + 850_321_436_u32, + 978_826_112_u32, + 1_360_938_704_u32, + 2_363_410_736_u32, + 353_572_296_u32, + 2_196_029_604_u32, + 1_676_698_573_u32, + 71_702_920_u32, + 433_353_586_u32, + 3_336_662_792_u32, + 3_815_644_954_u32, + 2_974_299_797_u32, + 1_990_548_820_u32, + 1_683_843_869_u32, + 2_152_628_932_u32, + 3_625_450_751_u32, + 2_366_853_676_u32, + 1_798_342_904_u32, + 2_869_368_979_u32, + 1_185_695_639_u32, + 3_173_469_147_u32, + 1_531_916_725_u32, + 3_326_214_024_u32, + 2_067_990_523_u32, + 976_120_805_u32, + 3_535_693_006_u32, + 4_223_913_473_u32, + 2_143_410_133_u32, + 187_637_181_u32, + 2_370_649_336_u32, + 3_155_284_399_u32, + 3_282_898_811_u32, + 3_068_767_567_u32, + 3_033_732_496_u32, + 3_278_852_653_u32, + 1_988_815_405_u32, + 3_318_268_258_u32, + 402_934_292_u32, + 3_162_645_643_u32, + 2_103_209_800_u32, + 4_253_170_701_u32, + 2_489_673_789_u32, + 2_224_135_091_u32, + 1_848_398_457_u32, + 3_159_326_514_u32, + 3_865_725_686_u32, + 674_027_046_u32, + 3_191_092_214_u32, + 356_413_912_u32, + 682_734_067_u32, + 2_368_555_344_u32, + 614_314_161_u32, + 3_515_266_737_u32, + 949_414_245_u32, + 2_046_032_417_u32, + 1_495_462_201_u32, + 2_307_315_576_u32, + 1_960_455_472_u32, + 917_911_666_u32, + 1_518_075_072_u32, + 2_925_772_427_u32, + 298_590_050_u32, + 1_441_972_928_u32, + 666_987_301_u32, + 2_167_997_170_u32, + 3_413_359_382_u32, + 3_526_531_810_u32, + ]); + let t_2 = Tensor::from_container(vec![ + 3_428_858_567_u32, + 827_447_270_u32, + 959_110_479_u32, + 4_184_350_429_u32, + 1_820_415_259_u32, + 2_322_099_741_u32, + 1_328_906_591_u32, + 1_664_312_159_u32, + 549_610_931_u32, + 2_945_591_302_u32, + 295_342_634_u32, + 1_589_486_080_u32, + 1_359_822_125_u32, + 1_285_568_394_u32, + 1_881_925_871_u32, + 3_058_045_327_u32, + 1_773_709_235_u32, + 3_813_730_789_u32, + 823_940_101_u32, + 2_480_100_080_u32, + 3_639_129_118_u32, + 759_351_495_u32, + 1_301_750_125_u32, + 1_054_832_776_u32, + 3_245_556_275_u32, + 2_800_997_186_u32, + 1_256_287_364_u32, + 2_573_603_461_u32, + 2_328_221_582_u32, + 1_633_069_253_u32, + 102_853_950_u32, + 2_716_685_335_u32, + 503_267_884_u32, + 2_202_048_416_u32, + 1_602_161_938_u32, + 1_927_466_558_u32, + 858_392_614_u32, + 956_183_465_u32, + 4_135_389_917_u32, + 951_071_347_u32, + 2_318_567_902_u32, + 2_004_258_446_u32, + 1_797_038_763_u32, + 1_610_761_714_u32, + 3_236_519_313_u32, + 316_586_765_u32, + 307_967_731_u32, + 3_588_485_359_u32, + 1_947_118_682_u32, + 2_002_927_095_u32, + 1_136_304_281_u32, + 1_065_157_362_u32, + 3_714_003_080_u32, + 786_946_775_u32, + 2_441_787_699_u32, + 832_944_437_u32, + 3_651_539_633_u32, + 798_050_864_u32, + 669_130_367_u32, + 2_000_552_570_u32, + 2_700_875_050_u32, + 1_118_200_520_u32, + 708_688_839_u32, + 3_463_285_323_u32, + 2_270_978_169_u32, + 237_144_352_u32, + 660_096_080_u32, + 4_230_221_095_u32, + 3_280_398_471_u32, + 3_354_541_293_u32, + 833_549_915_u32, + 1_136_697_343_u32, + 1_096_316_847_u32, + 2_476_951_406_u32, + 3_971_141_200_u32, + 12_788_335_u32, + 1_580_197_299_u32, + 2_444_226_867_u32, + 387_836_375_u32, + 4_067_693_575_u32, + 3_918_490_972_u32, + 3_704_639_326_u32, + 664_901_677_u32, + 1_847_150_125_u32, + 476_514_752_u32, + 466_866_141_u32, + 704_667_620_u32, + 2_652_242_441_u32, + 3_683_680_188_u32, + 2_589_696_574_u32, + 571_587_908_u32, + 953_792_331_u32, + 3_677_654_843_u32, + 2_056_915_677_u32, + 3_272_850_239_u32, + 1_669_788_385_u32, + 2_311_731_086_u32, + 162_748_026_u32, + 907_347_406_u32, + 2_760_784_143_u32, + ]); + let ground_truth_t_3 = Tensor::from_container(vec![ + 3_480_531_354_u32, + 519_562_985_u32, + 796_008_076_u32, + 3_458_684_537_u32, + 2_076_174_000_u32, + 2_070_484_882_u32, + 216_147_148_u32, + 3_841_740_570_u32, + 2_134_927_736_u32, + 1_552_208_620_u32, + 1_343_165_453_u32, + 3_585_004_416_u32, + 1_096_689_051_u32, + 1_363_177_186_u32, + 3_380_246_697_u32, + 1_291_776_731_u32, + 835_406_536_u32, + 2_206_628_851_u32, + 1_490_731_614_u32, + 3_655_863_118_u32, + 402_149_762_u32, + 3_810_695_568_u32, + 998_392_298_u32, + 1_278_262_910_u32, + 4_075_991_378_u32, + 1_515_901_484_u32, + 3_411_214_999_u32, + 1_727_803_669_u32, + 2_729_533_229_u32, + 1_101_599_144_u32, + 2_911_137_576_u32, + 2_157_606_741_u32, + 412_908_083_u32, + 2_943_240_316_u32, + 3_671_631_470_u32, + 3_728_439_442_u32, + 1_505_018_122_u32, + 3_692_356_127_u32, + 2_355_606_983_u32, + 725_627_226_u32, + 2_048_102_314_u32, + 2_724_062_436_u32, + 1_539_624_029_u32, + 2_204_883_240_u32, + 4_032_747_780_u32, + 1_673_962_055_u32, + 1_375_876_138_u32, + 2_859_110_869_u32, + 1_678_332_069_u32, + 363_926_581_u32, + 662_038_623_u32, + 1_804_211_617_u32, + 1_766_659_855_u32, + 2_386_522_372_u32, + 3_385_096_322_u32, + 2_493_269_587_u32, + 2_711_418_186_u32, + 178_069_941_u32, + 2_866_562_639_u32, + 2_223_360_903_u32, + 3_737_502_379_u32, + 3_364_403_957_u32, + 1_661_960_497_u32, + 3_986_966_372_u32, + 1_011_920_642_u32, + 2_831_623_215_u32, + 2_373_636_416_u32, + 3_343_598_854_u32, + 3_003_384_230_u32, + 4_258_694_261_u32, + 3_864_351_673_u32, + 2_025_948_300_u32, + 1_006_892_953_u32, + 1_776_219_295_u32, + 2_813_499_885_u32, + 2_211_346_756_u32, + 268_201_158_u32, + 715_099_647_u32, + 3_477_889_311_u32, + 901_300_767_u32, + 3_567_568_538_u32, + 946_741_882_u32, + 17_832_390_u32, + 521_405_219_u32, + 137_799_409_u32, + 3_048_400_596_u32, + 244_746_625_u32, + 3_688_757_272_u32, + 2_106_749_309_u32, + 4_012_586_298_u32, + 1_388_867_564_u32, + 4_259_086_631_u32, + 2_135_387_525_u32, + 868_856_750_u32, + 1_320_707_107_u32, + 4_067_151_839_u32, + 2_650_223_511_u32, + 2_005_249_144_u32, + 2_506_011_976_u32, + 765_747_667_u32, + ]); + let mut t_3 = Tensor::allocate(0_u32, 100); + t_3.fill_with_wrapping_sub(&t_1, &t_2); + + assert_eq!(t_3, ground_truth_t_3, "we are testing substraction"); +} + +#[test] +fn test_add_u64() { + let t_1 = Tensor::from_container(vec![ + 5_682_232_049_849_203_449_u64, + 1_744_272_140_419_931_610_u64, + 6_524_694_120_235_710_248_u64, + 501_685_223_587_207_708_u64, + 7_454_825_121_449_404_861_u64, + 1_452_049_147_138_516_728_u64, + 3_744_089_800_795_951_655_u64, + 2_900_714_251_440_266_072_u64, + 2_885_003_742_441_599_873_u64, + 5_127_037_330_303_939_263_u64, + 3_942_793_256_559_402_137_u64, + 2_938_215_163_794_025_597_u64, + 3_194_270_088_293_124_907_u64, + 3_798_617_854_173_374_109_u64, + 2_281_550_512_455_919_685_u64, + 1_378_021_925_594_404_903_u64, + 6_273_819_789_066_539_195_u64, + 5_891_518_315_759_560_031_u64, + 6_569_862_020_994_290_872_u64, + 2_312_304_860_409_402_175_u64, + 3_768_205_285_282_560_447_u64, + 2_813_718_090_332_844_130_u64, + 4_741_992_406_149_366_296_u64, + 2_862_912_615_999_044_257_u64, + 2_711_698_756_636_236_379_u64, + 3_105_025_607_153_753_493_u64, + 3_280_659_296_609_069_569_u64, + 1_621_356_564_053_932_659_u64, + 244_394_454_277_671_115_u64, + 1_370_168_407_221_172_838_u64, + 384_807_778_723_441_456_u64, + 5_421_384_837_838_501_695_u64, + 3_524_866_043_795_281_573_u64, + 273_224_951_302_481_390_u64, + 8_874_399_707_947_016_287_u64, + 5_042_853_686_974_107_712_u64, + 8_593_762_746_401_730_055_u64, + 4_298_169_213_116_104_086_u64, + 1_043_682_735_183_771_811_u64, + 8_271_963_865_357_943_237_u64, + 2_866_933_850_832_375_526_u64, + 3_680_273_731_625_120_587_u64, + 5_594_513_115_859_518_166_u64, + 1_643_917_283_539_244_290_u64, + 3_172_178_086_476_235_900_u64, + 6_964_486_530_272_725_036_u64, + 6_025_940_910_517_479_800_u64, + 8_277_718_434_101_601_483_u64, + 8_184_281_612_310_786_511_u64, + 5_373_031_274_997_880_981_u64, + 443_782_149_988_086_463_u64, + 9_185_207_564_855_550_126_u64, + 3_175_405_486_723_930_612_u64, + 538_795_803_601_238_624_u64, + 1_842_522_998_755_997_387_u64, + 756_815_213_533_913_513_u64, + 4_792_029_986_473_993_888_u64, + 4_782_811_555_589_976_751_u64, + 4_765_160_184_182_081_015_u64, + 6_870_421_860_884_204_987_u64, + 6_644_609_928_302_751_438_u64, + 9_205_665_417_060_638_521_u64, + 4_422_362_498_965_857_329_u64, + 3_911_541_231_075_340_397_u64, + 714_780_100_332_094_572_u64, + 854_285_349_422_025_761_u64, + 7_998_144_870_496_815_069_u64, + 4_601_820_771_957_226_501_u64, + 4_668_015_978_555_069_529_u64, + 3_107_134_174_330_286_017_u64, + 8_556_643_770_851_938_756_u64, + 7_603_022_701_719_395_789_u64, + 9_061_759_085_783_731_100_u64, + 335_871_293_124_179_717_u64, + 578_609_166_965_025_587_u64, + 8_344_077_009_920_132_885_u64, + 5_890_072_533_484_701_885_u64, + 4_572_233_892_255_728_435_u64, + 6_510_971_065_603_537_789_u64, + 2_119_489_420_934_143_588_u64, + 7_384_712_968_731_389_043_u64, + 5_631_603_782_423_650_945_u64, + 2_426_176_736_130_500_836_u64, + 8_725_885_473_278_349_136_u64, + 6_998_312_559_885_650_695_u64, + 1_747_649_994_418_612_192_u64, + 5_557_047_201_979_978_882_u64, + 4_330_564_741_999_955_015_u64, + 1_423_746_095_735_226_283_u64, + 6_729_353_041_636_611_170_u64, + 3_912_555_288_358_270_774_u64, + 6_236_800_801_119_694_988_u64, + 1_119_102_165_244_657_550_u64, + 5_444_700_680_136_175_568_u64, + 6_107_520_479_033_799_392_u64, + 6_092_621_673_178_322_094_u64, + 2_613_801_610_897_795_471_u64, + 7_958_414_627_268_864_059_u64, + 1_701_360_089_741_291_949_u64, + 8_900_744_252_335_003_997_u64, + ]); + let t_2 = Tensor::from_container(vec![ + 8_256_890_089_369_290_096_u64, + 6_729_858_587_364_974_993_u64, + 7_847_985_733_087_156_225_u64, + 4_256_288_592_723_368_540_u64, + 1_794_053_349_452_132_041_u64, + 6_010_968_597_662_138_399_u64, + 6_274_700_101_275_637_475_u64, + 3_672_569_542_766_325_569_u64, + 7_783_627_030_003_669_629_u64, + 249_357_646_255_069_879_u64, + 5_557_476_119_820_039_974_u64, + 8_042_948_614_404_456_368_u64, + 4_654_915_497_230_252_172_u64, + 7_722_972_477_579_752_886_u64, + 258_964_119_735_943_544_u64, + 3_661_700_972_414_689_603_u64, + 5_780_010_438_965_763_305_u64, + 5_399_007_971_131_851_993_u64, + 9_009_523_661_328_089_448_u64, + 670_837_492_260_568_551_u64, + 8_553_265_509_497_774_774_u64, + 475_007_578_406_922_623_u64, + 1_656_958_878_392_217_405_u64, + 3_145_284_643_778_286_187_u64, + 6_211_468_814_998_169_736_u64, + 7_898_586_816_448_146_424_u64, + 6_385_644_578_140_856_445_u64, + 6_278_113_144_098_235_027_u64, + 5_508_031_993_944_422_488_u64, + 2_541_351_454_611_805_754_u64, + 253_476_817_899_518_218_u64, + 4_042_272_828_677_076_320_u64, + 6_273_812_701_503_178_622_u64, + 7_154_361_991_326_158_245_u64, + 4_812_968_649_666_322_424_u64, + 8_058_877_626_669_330_796_u64, + 2_570_559_734_648_418_432_u64, + 3_260_085_933_573_705_643_u64, + 1_282_517_144_950_793_850_u64, + 1_370_863_113_856_127_345_u64, + 7_751_961_484_782_528_551_u64, + 2_576_515_167_053_557_195_u64, + 6_023_795_786_532_458_230_u64, + 6_726_030_942_349_870_732_u64, + 7_466_418_281_703_253_736_u64, + 8_567_435_608_821_064_654_u64, + 1_678_961_003_340_349_987_u64, + 3_502_334_064_042_353_274_u64, + 3_731_187_845_427_012_882_u64, + 5_317_359_253_712_576_816_u64, + 6_534_183_265_395_520_755_u64, + 3_251_278_594_118_653_876_u64, + 8_455_470_979_973_987_894_u64, + 1_134_450_355_974_787_411_u64, + 2_087_289_461_344_972_800_u64, + 898_091_164_345_629_933_u64, + 1_383_688_945_649_969_441_u64, + 6_412_373_125_771_730_589_u64, + 3_137_727_871_406_467_282_u64, + 2_531_450_854_507_130_283_u64, + 8_942_523_860_499_484_955_u64, + 3_053_185_116_942_316_003_u64, + 7_573_298_098_522_728_453_u64, + 7_850_035_594_752_589_513_u64, + 7_609_365_690_458_693_792_u64, + 3_979_440_714_450_645_544_u64, + 8_679_308_680_362_097_737_u64, + 3_937_728_290_719_953_722_u64, + 3_848_494_478_551_479_774_u64, + 3_384_383_891_744_980_023_u64, + 7_516_977_367_724_693_326_u64, + 435_538_850_065_011_084_u64, + 2_232_114_847_229_197_016_u64, + 8_939_199_010_658_684_319_u64, + 2_450_683_567_053_287_115_u64, + 7_734_458_772_215_536_274_u64, + 8_218_782_583_431_213_252_u64, + 8_553_066_689_779_351_731_u64, + 3_832_186_301_178_121_773_u64, + 6_381_512_211_621_916_311_u64, + 1_300_796_182_487_056_551_u64, + 8_878_587_019_650_156_826_u64, + 8_211_502_017_418_832_896_u64, + 8_604_481_095_678_650_971_u64, + 5_587_624_902_285_300_211_u64, + 3_426_416_825_207_801_687_u64, + 6_489_160_959_956_510_743_u64, + 7_128_903_513_419_063_730_u64, + 4_040_914_739_727_604_681_u64, + 2_688_309_576_123_447_655_u64, + 6_492_809_116_044_763_762_u64, + 7_116_289_826_504_895_799_u64, + 3_531_910_189_811_123_524_u64, + 789_365_259_912_898_159_u64, + 6_469_517_349_990_767_948_u64, + 1_189_741_323_354_180_502_u64, + 1_445_291_182_187_601_512_u64, + 8_147_987_973_554_022_701_u64, + 2_888_875_140_678_677_703_u64, + 7_057_288_198_514_267_233_u64, + ]); + let ground_truth_t_3 = Tensor::from_container(vec![ + 13_939_122_139_218_493_545_u64, + 8_474_130_727_784_906_603_u64, + 14_372_679_853_322_866_473_u64, + 4_757_973_816_310_576_248_u64, + 9_248_878_470_901_536_902_u64, + 7_463_017_744_800_655_127_u64, + 10_018_789_902_071_589_130_u64, + 6_573_283_794_206_591_641_u64, + 10_668_630_772_445_269_502_u64, + 5_376_394_976_559_009_142_u64, + 9_500_269_376_379_442_111_u64, + 10_981_163_778_198_481_965_u64, + 7_849_185_585_523_377_079_u64, + 11_521_590_331_753_126_995_u64, + 2_540_514_632_191_863_229_u64, + 5_039_722_898_009_094_506_u64, + 12_053_830_228_032_302_500_u64, + 11_290_526_286_891_412_024_u64, + 15_579_385_682_322_380_320_u64, + 2_983_142_352_669_970_726_u64, + 12_321_470_794_780_335_221_u64, + 3_288_725_668_739_766_753_u64, + 6_398_951_284_541_583_701_u64, + 6_008_197_259_777_330_444_u64, + 8_923_167_571_634_406_115_u64, + 11_003_612_423_601_899_917_u64, + 9_666_303_874_749_926_014_u64, + 7_899_469_708_152_167_686_u64, + 5_752_426_448_222_093_603_u64, + 3_911_519_861_832_978_592_u64, + 638_284_596_622_959_674_u64, + 9_463_657_666_515_578_015_u64, + 9_798_678_745_298_460_195_u64, + 7_427_586_942_628_639_635_u64, + 13_687_368_357_613_338_711_u64, + 13_101_731_313_643_438_508_u64, + 11_164_322_481_050_148_487_u64, + 7_558_255_146_689_809_729_u64, + 2_326_199_880_134_565_661_u64, + 9_642_826_979_214_070_582_u64, + 10_618_895_335_614_904_077_u64, + 6_256_788_898_678_677_782_u64, + 11_618_308_902_391_976_396_u64, + 8_369_948_225_889_115_022_u64, + 10_638_596_368_179_489_636_u64, + 15_531_922_139_093_789_690_u64, + 7_704_901_913_857_829_787_u64, + 11_780_052_498_143_954_757_u64, + 11_915_469_457_737_799_393_u64, + 10_690_390_528_710_457_797_u64, + 6_977_965_415_383_607_218_u64, + 12_436_486_158_974_204_002_u64, + 11_630_876_466_697_918_506_u64, + 1_673_246_159_576_026_035_u64, + 3_929_812_460_100_970_187_u64, + 1_654_906_377_879_543_446_u64, + 6_175_718_932_123_963_329_u64, + 11_195_184_681_361_707_340_u64, + 7_902_888_055_588_548_297_u64, + 9_401_872_715_391_335_270_u64, + 15_587_133_788_802_236_393_u64, + 12_258_850_534_002_954_524_u64, + 11_995_660_597_488_585_782_u64, + 11_761_576_825_827_929_910_u64, + 8_324_145_790_790_788_364_u64, + 4_833_726_063_872_671_305_u64, + 16_677_453_550_858_912_806_u64, + 8_539_549_062_677_180_223_u64, + 8_516_510_457_106_549_303_u64, + 6_491_518_066_075_266_040_u64, + 16_073_621_138_576_632_082_u64, + 8_038_561_551_784_406_873_u64, + 11_293_873_933_012_928_116_u64, + 9_275_070_303_782_864_036_u64, + 3_029_292_734_018_312_702_u64, + 16_078_535_782_135_669_159_u64, + 14_108_855_116_915_915_137_u64, + 13_125_300_582_035_080_166_u64, + 10_343_157_366_781_659_562_u64, + 8_501_001_632_556_059_899_u64, + 8_685_509_151_218_445_594_u64, + 14_510_190_802_073_807_771_u64, + 10_637_678_753_549_333_732_u64, + 17_330_366_568_957_000_107_u64, + 12_585_937_462_170_950_906_u64, + 5_174_066_819_626_413_879_u64, + 12_046_208_161_936_489_625_u64, + 11_459_468_255_419_018_745_u64, + 5_464_660_835_462_830_964_u64, + 9_417_662_617_760_058_825_u64, + 10_405_364_404_403_034_536_u64, + 13_353_090_627_624_590_787_u64, + 4_651_012_355_055_781_074_u64, + 6_234_065_940_049_073_727_u64, + 12_577_037_829_024_567_340_u64, + 7_282_362_996_532_502_596_u64, + 4_059_092_793_085_396_983_u64, + 16_106_402_600_822_886_760_u64, + 4_590_235_230_419_969_652_u64, + 15_958_032_450_849_271_230_u64, + ]); + let mut t_3 = Tensor::allocate(0_u64, 100); + t_3.fill_with_wrapping_add(&t_1, &t_2); + + assert_eq!(t_3, ground_truth_t_3, "we are testing u64 add"); +} diff --git a/tfhe/src/core_crypto/commons/math/torus/mod.rs b/tfhe/src/core_crypto/commons/math/torus/mod.rs new file mode 100644 index 000000000..cb91c9ad8 --- /dev/null +++ b/tfhe/src/core_crypto/commons/math/torus/mod.rs @@ -0,0 +1,102 @@ +//! Converting to torus values. +//! +//! The theory behind some of the homomorphic operators of the library, uses the real torus +//! $\mathbb{T} = \mathbb{R} / \mathbb{Z}$, or the set or real numbers modulo 1 (elements of the +//! torus are in $[0,1)$). In practice, floating-point number are not well suited to performing +//! operations on the torus, and we prefer to use unsigned integer values to represent them. +//! Indeed, unsigned integer can be used to encode the decimal part of the torus element with a +//! fixed precision. +//! +//! Still, in some cases, we may need to represent an unsigned integer as a torus value in +//! floating point representation. For this reason we provide the [`IntoTorus`] and [`FromTorus`] +//! traits which allow to go back and forth between an unsigned integer representation and a +//! floating point representation. + +use crate::core_crypto::commons::math::random::{ + Gaussian, RandomGenerable, Uniform, UniformBinary, UniformTernary, +}; +pub use crate::core_crypto::commons::numeric::{CastInto, FloatingPoint, Numeric, UnsignedInteger}; +use crate::core_crypto::prelude::LogStandardDev; +use std::fmt::{Debug, Display}; + +/// A trait that converts a torus element in unsigned integer representation to the closest +/// torus element in floating point representation. +pub trait IntoTorus: Sized +where + F: FloatingPoint, + Self: UnsignedInteger, +{ + /// Consumes `self` and returns its closest floating point representation. + fn into_torus(self) -> F; +} + +/// A trait that converts a torus element in floating point representation into the closest torus +/// element in unsigned integer representation. +pub trait FromTorus: Sized +where + F: FloatingPoint, + Self: UnsignedInteger, +{ + /// Consumes `input` and returns its closest unsigned integer representation. + fn from_torus(input: F) -> Self; +} + +macro_rules! implement { + ($Type: tt) => { + impl IntoTorus for $Type + where + F: FloatingPoint + CastInto, + Self: CastInto, + { + #[inline] + fn into_torus(self) -> F { + let self_f: F = self.cast_into(); + return self_f * (F::TWO.powi(-(::BITS as i32))); + } + } + impl FromTorus for $Type + where + F: FloatingPoint + CastInto + CastInto, + Self: CastInto, + { + #[inline] + fn from_torus(input: F) -> Self { + let mut fract = input - F::round(input); + fract *= F::TWO.powi(::BITS as i32); + fract = F::round(fract); + let signed: Self::Signed = fract.cast_into(); + return signed.cast_into(); + } + } + }; +} + +implement!(u8); +implement!(u16); +implement!(u32); +implement!(u64); +implement!(u128); + +/// A marker trait for unsigned integer types that can be used in ciphertexts, keys etc. +pub trait UnsignedTorus: + UnsignedInteger + + FromTorus + + IntoTorus + + RandomGenerable> + + RandomGenerable + + RandomGenerable + + RandomGenerable + + Display + + Debug +{ + /// The log standard deviation used to sample gaussian keys in this precision. + const GAUSSIAN_KEY_LOG_STD: LogStandardDev; +} + +impl UnsignedTorus for u32 { + const GAUSSIAN_KEY_LOG_STD: LogStandardDev = LogStandardDev(-30.32192809488736); +} + +impl UnsignedTorus for u64 { + const GAUSSIAN_KEY_LOG_STD: LogStandardDev = LogStandardDev(-62.32192809488736); +} diff --git a/tfhe/src/core_crypto/commons/mod.rs b/tfhe/src/core_crypto/commons/mod.rs new file mode 100644 index 000000000..92be6f037 --- /dev/null +++ b/tfhe/src/core_crypto/commons/mod.rs @@ -0,0 +1,256 @@ +#![allow(dead_code, deprecated)] // For the time being + +#[allow(unused_macros)] +macro_rules! assert_delta { + ($A:expr, $B:expr, $d:expr) => { + for (x, y) in $A.iter().zip($B) { + assert!((*x as i64 - y as i64).abs() <= $d, "{} != {} ", *x, y); + } + }; +} + +#[allow(unused_macros)] +macro_rules! assert_delta_scalar { + ($A:expr, $B:expr, $d:expr) => { + assert!( + ($A as i64 - $B as i64).abs() <= $d, + "{} != {} +- {}", + $A, + $B, + $d + ); + }; +} + +#[allow(unused_macros)] +macro_rules! assert_delta_scalar_float { + ($A:expr, $B:expr, $d:expr) => { + assert!(($A - $B).abs() <= $d, "{} != {} +- {}", $A, $B, $d); + }; +} + +#[allow(unused_macros)] +macro_rules! modular_distance { + ($A:expr, $B:expr) => { + ($A.wrapping_sub($B)).min($B.wrapping_sub($A)) + }; +} + +pub mod crypto; +pub mod math; +pub mod numeric; +pub mod utils; + +#[doc(hidden)] +#[cfg(test)] +pub mod test_tools { + use rand::Rng; + + use crate::core_crypto::commons::crypto::secret::generators::{ + EncryptionRandomGenerator, SecretRandomGenerator, + }; + use crate::core_crypto::commons::math::random::{RandomGenerable, RandomGenerator, Uniform}; + use crate::core_crypto::commons::math::tensor::{AsRefSlice, AsRefTensor}; + use crate::core_crypto::commons::math::torus::UnsignedTorus; + use crate::core_crypto::commons::numeric::UnsignedInteger; + use crate::core_crypto::prelude::{ + CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, + GlweDimension, LweDimension, PlaintextCount, PolynomialSize, + }; + use concrete_csprng::generators::SoftwareRandomGenerator; + use concrete_csprng::seeders::{Seed, Seeder}; + + fn modular_distance(first: T, other: T) -> T { + let d0 = first.wrapping_sub(other); + let d1 = other.wrapping_sub(first); + std::cmp::min(d0, d1) + } + + fn torus_modular_distance(first: T, other: T) -> f64 { + let d0 = first.wrapping_sub(other); + let d1 = other.wrapping_sub(first); + if d0 < d1 { + let d: f64 = d0.cast_into(); + d / 2_f64.powi(T::BITS as i32) + } else { + let d: f64 = d1.cast_into(); + -d / 2_f64.powi(T::BITS as i32) + } + } + + pub fn new_random_generator() -> RandomGenerator { + RandomGenerator::new(random_seed()) + } + + pub fn new_secret_random_generator() -> SecretRandomGenerator { + SecretRandomGenerator::new(random_seed()) + } + + pub fn new_encryption_random_generator() -> EncryptionRandomGenerator { + EncryptionRandomGenerator::new(random_seed(), &mut UnsafeRandSeeder) + } + + pub fn random_seed() -> Seed { + Seed(rand::thread_rng().gen()) + } + + pub struct UnsafeRandSeeder; + + impl Seeder for UnsafeRandSeeder { + fn seed(&mut self) -> Seed { + Seed(rand::thread_rng().gen()) + } + + fn is_available() -> bool { + true + } + } + + pub fn assert_delta_std_dev( + first: &First, + second: &Second, + dist: impl DispersionParameter, + ) where + First: AsRefTensor, + Second: AsRefTensor, + Element: UnsignedTorus, + { + for (x, y) in first.as_tensor().iter().zip(second.as_tensor().iter()) { + println!("{:?}, {:?}", *x, *y); + println!("{}", dist.get_standard_dev()); + let distance: f64 = modular_distance(*x, *y).cast_into(); + let torus_distance = distance / 2_f64.powi(Element::BITS as i32); + assert!( + torus_distance <= 5. * dist.get_standard_dev(), + "{} != {} ", + x, + y + ); + } + } + + pub fn assert_noise_distribution( + first: &First, + second: &Second, + dist: impl DispersionParameter, + ) where + First: AsRefTensor, + Second: AsRefTensor, + Element: UnsignedTorus, + { + use crate::core_crypto::commons::math::tensor::Tensor; + use rand::distributions::{Distribution, Normal}; + + let std_dev = dist.get_standard_dev(); + let confidence = 0.95; + let n_slots = first.as_tensor().len(); + + // allocate 2 slices: one for the error samples obtained, the second for fresh samples + // according to the std_dev computed + let mut sdk_samples = Tensor::allocate(0.0_f64, n_slots); + + // recover the errors from each ciphertexts + sdk_samples.fill_with_two(first.as_tensor(), second.as_tensor(), |a, b| { + torus_modular_distance(*a, *b) + }); + + // fill the theoretical sample vector according to std_dev using the rand crate + let mut theoretical_samples: Vec = Vec::with_capacity(n_slots); + let normal = Normal::new(0.0, std_dev); + for _i in 0..n_slots { + theoretical_samples.push(normal.sample(&mut rand::thread_rng())); + } + + // compute the kolmogorov smirnov test + let result = kolmogorov_smirnov::test_f64( + sdk_samples.as_slice(), + theoretical_samples.as_slice(), + confidence, + ); + assert!( + !result.is_rejected, + "Not the same distribution with a probability of {}", + result.reject_probability + ); + } + + /// Returns a random plaintext count in [1;max]. + pub fn random_plaintext_count(max: usize) -> PlaintextCount { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + PlaintextCount((rng.gen::() % (max - 1)) + 1) + } + + /// Returns a random ciphertext count in [1;max]. + pub fn random_ciphertext_count(max: usize) -> CiphertextCount { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + CiphertextCount((rng.gen::() % (max - 1)) + 1) + } + + /// Returns a random LWE dimension in [1;max]. + pub fn random_lwe_dimension(max: usize) -> LweDimension { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + LweDimension((rng.gen::() % (max - 1)) + 1) + } + + /// Returns a random GLWE dimension in [1;max]. + pub fn random_glwe_dimension(max: usize) -> GlweDimension { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + GlweDimension((rng.gen::() % (max - 1)) + 1) + } + + /// Returns a random polynomial size in [2;max]. + pub fn random_polynomial_size(max: usize) -> PolynomialSize { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + PolynomialSize((rng.gen::() % (max - 2)) + 2) + } + + /// Returns a random base log in [2;max]. + pub fn random_base_log(max: usize) -> DecompositionBaseLog { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + DecompositionBaseLog((rng.gen::() % (max - 2)) + 2) + } + + /// Returns a random level count in [2;max]. + pub fn random_level_count(max: usize) -> DecompositionLevelCount { + assert_ne!(max, 0, "Max cannot be 0"); + let mut rng = rand::thread_rng(); + DecompositionLevelCount((rng.gen::() % (max - 2)) + 2) + } + + pub fn random_i32_between(range: std::ops::Range) -> i32 { + use rand::distributions::{Distribution, Uniform}; + let between = Uniform::from(range); + let mut rng = rand::thread_rng(); + between.sample(&mut rng) + } + + pub fn random_usize_between(range: std::ops::Range) -> usize { + use rand::distributions::{Distribution, Uniform}; + let between = Uniform::from(range); + let mut rng = rand::thread_rng(); + between.sample(&mut rng) + } + + pub fn any_usize() -> usize { + random_usize_between(0..usize::MAX) + } + + pub fn random_uint_between>( + range: std::ops::Range, + ) -> T { + let mut generator = new_random_generator(); + let val: T = generator.random_uniform(); + val % (range.end - range.start) + range.start + } + + pub fn any_uint>() -> T { + let mut generator = new_random_generator(); + generator.random_uniform() + } +} diff --git a/tfhe/src/core_crypto/commons/numeric/float.rs b/tfhe/src/core_crypto/commons/numeric/float.rs new file mode 100644 index 000000000..743cd77ac --- /dev/null +++ b/tfhe/src/core_crypto/commons/numeric/float.rs @@ -0,0 +1,140 @@ +use super::Numeric; +use std::ops::{ + Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign, +}; + +/// A trait shared by all the floating point types. +pub trait FloatingPoint: + Numeric + + Neg + + Add + + AddAssign + + Div + + DivAssign + + Mul + + MulAssign + + Rem + + RemAssign + + Sub + + SubAssign +{ + /// Raises a float to an integer power. + #[must_use] + fn powi(self, power: i32) -> Self; + + /// Rounds the float to the closest integer. + #[must_use] + fn round(self) -> Self; + + /// Keeps the fractional part of the number. + #[must_use] + fn fract(self) -> Self; + + /// Remainder of the euclidean division. + #[must_use] + fn rem_euclid(self, rhs: Self) -> Self; + + /// Returns the square root of the input float. + #[must_use] + fn sqrt(self) -> Self; + + /// Returns the natural logarithm of the input float. + #[must_use] + fn ln(self) -> Self; + + /// Returns the absolute value of the input float. + #[must_use] + fn abs(self) -> Self; + + /// Returns the floor value of the input float. + #[must_use] + fn floor(self) -> Self; + + /// Returns a bit representation of the float, with the sign, exponent, and mantissa bits + /// separated by whitespaces for increased readability. + fn to_bit_string(&self) -> String; +} + +macro_rules! implement { + ($Type: tt, $bits:expr) => { + impl Numeric for $Type { + const BITS: usize = $bits; + const ZERO: Self = 0.; + const ONE: Self = 1.; + const TWO: Self = 2.; + const MAX: Self = <$Type>::MAX; + } + impl FloatingPoint for $Type { + #[inline] + fn powi(self, power: i32) -> Self { + self.powi(power) + } + #[inline] + fn round(self) -> Self { + self.round() + } + #[inline] + fn fract(self) -> Self { + self.fract() + } + #[inline] + fn rem_euclid(self, rhs: Self) -> Self { + self.rem_euclid(rhs) + } + #[inline] + fn sqrt(self) -> Self { + self.sqrt() + } + #[inline] + fn ln(self) -> Self { + self.ln() + } + #[inline] + fn abs(self) -> Self { + self.abs() + } + #[inline] + fn floor(self) -> Self { + self.floor() + } + fn to_bit_string(&self) -> String { + if Self::BITS == 32 { + let mut bit_string = format!("{:032b}", self.to_bits()); + bit_string.insert(1, ' '); + bit_string.insert(10, ' '); + format!("{}", bit_string) + } else { + let mut bit_string = format!("{:064b}", self.to_bits()); + bit_string.insert(1, ' '); + bit_string.insert(13, ' '); + format!("{}", bit_string) + } + } + } + }; +} + +implement!(f64, 64); +implement!(f32, 32); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_f64_binary_rep() { + let a = 1123214.4321432_f64; + let b = a.to_bit_string(); + assert_eq!( + b, + "0 10000010011 0001001000111000111001101110101000001110111111001111".to_string() + ); + } + + #[test] + fn test_f32_binary_rep() { + let a = -1.276_663_9e27_f32; + let b = a.to_bit_string(); + assert_eq!(b, "1 11011001 00001000000000100000011".to_string()); + } +} diff --git a/tfhe/src/core_crypto/commons/numeric/mod.rs b/tfhe/src/core_crypto/commons/numeric/mod.rs new file mode 100644 index 000000000..9ece5979d --- /dev/null +++ b/tfhe/src/core_crypto/commons/numeric/mod.rs @@ -0,0 +1,106 @@ +//! Generic numeric types. +//! +//! This module contains types and traits to manipulate numeric types in a generic manner. For +//! instance, in the standard library, the `f32` and `f64` trait share a lot of methods of the +//! same name and same semantics. Still, it is not possible to use them generically. This module +//! provides the [`FloatingPoint`] trait, implemented by both of those type, to remedy the +//! situation. +//! +//! # Note +//! +//! The current implementation of those traits does not strive to be general, in the sense that +//! not all the common methods of the same kind of types are exposed. Only were included the ones +//! that are used in the rest of the library. + +pub use float::*; +pub use signed::*; +pub use unsigned::*; + +mod float; +mod signed; +mod unsigned; + +/// A trait implemented by any generic numeric type suitable for computations. +pub trait Numeric: Sized + Copy + PartialEq + PartialOrd + 'static { + /// This size of the type in bits. + const BITS: usize; + + /// The null element of the type. + const ZERO: Self; + + /// The identity element of the type. + const ONE: Self; + + /// A value of two. + const TWO: Self; + + /// The largest value that can be encoded by the type. + const MAX: Self; +} + +/// A trait that allows to generically cast one type from another. +/// +/// This type is similar to the [`std::convert::From`] trait, but the conversion between the two +/// types is deferred to the individual `as` casting. If in doubt about the semantics of such a +/// casting, refer to +/// [the rust reference](https://doc.rust-lang.org/reference/expressions/operator-expr.html#type-cast-expressions). +pub trait CastFrom { + fn cast_from(input: Input) -> Self; +} + +/// A trait that allows to generically cast one type into another. +/// +/// This type is similar to the [`std::convert::Into`] trait, but the conversion between the two +/// types is deferred to the individual `as` casting. If in doubt about the semantics of such a +/// casting, refer to +/// [the rust reference](https://doc.rust-lang.org/reference/expressions/operator-expr.html#type-cast-expressions). +pub trait CastInto { + fn cast_into(self) -> Output; +} + +impl CastInto for Input +where + Output: CastFrom, +{ + fn cast_into(self) -> Output { + Output::cast_from(self) + } +} + +macro_rules! implement_cast { + ($Input:ty, {$($Output:ty),*}) => { + $( + impl CastFrom<$Input> for $Output { + #[inline] + fn cast_from(input: $Input) -> $Output { + input as $Output + } + } + )* + }; + ($Input: ty) => { + implement_cast!($Input, {f32, f64, usize, u8, u16, u32, u64, u128, isize, i8, i16, i32, + i64, i128}); + }; + ($($Input: ty),*) => { + $( + implement_cast!($Input); + )* + } +} + +implement_cast!(f32, f64, u8, u16, u32, u64, u128, i8, i16, i32, i64, i128, usize, isize); + +impl CastFrom for Num +where + Num: Numeric, +{ + #[inline] + fn cast_from(input: bool) -> Num { + if input { + Num::ONE + } else { + Num::ZERO + } + } +} diff --git a/tfhe/src/core_crypto/commons/numeric/signed.rs b/tfhe/src/core_crypto/commons/numeric/signed.rs new file mode 100644 index 000000000..e134724ca --- /dev/null +++ b/tfhe/src/core_crypto/commons/numeric/signed.rs @@ -0,0 +1,135 @@ +use super::{CastFrom, CastInto, Numeric, UnsignedInteger}; +use std::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, + Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, +}; + +/// A trait shared by all the unsigned integer types. +pub trait SignedInteger: + Numeric + + Neg + + Add + + AddAssign + + Div + + DivAssign + + Mul + + MulAssign + + Rem + + RemAssign + + Sub + + SubAssign + + BitAnd + + BitAndAssign + + BitOr + + BitOrAssign + + BitXor + + BitXorAssign + + Not + + Shl + + ShlAssign + + Shr + + ShrAssign + + CastFrom + + CastInto +{ + /// The unsigned type of the same precicion + type Unsigned: UnsignedInteger + CastFrom; + + /// Returns the casting of the current value to the unsigned type of the same size. + fn into_unsigned(self) -> Self::Unsigned; + + /// Returns a bit representation of the integer, where blocks of length `block_length` are + /// separated by whitespaces to increase the readability. + fn to_bits_string(&self, block_length: usize) -> String; +} + +macro_rules! implement { + ($Type: tt, $UnsignedType:ty, $bits:expr) => { + impl Numeric for $Type { + const BITS: usize = $bits; + const ZERO: Self = 0; + const ONE: Self = 1; + const TWO: Self = 2; + const MAX: Self = <$Type>::MAX; + } + impl SignedInteger for $Type { + type Unsigned = $UnsignedType; + #[inline] + fn into_unsigned(self) -> Self::Unsigned { + Self::Unsigned::cast_from(self) + } + fn to_bits_string(&self, break_every: usize) -> String { + let mut strn = match <$Type as Numeric>::BITS { + 8 => format!("{:08b}", self), + 16 => format!("{:016b}", self), + 32 => format!("{:032b}", self), + 64 => format!("{:064b}", self), + 128 => format!("{:0128b}", self), + _ => unreachable!(), + }; + for i in (1..(<$Type as Numeric>::BITS / break_every)).rev() { + strn.insert(i * break_every, ' '); + } + strn + } + } + }; +} + +implement!(i8, u8, 8); +implement!(i16, u16, 16); +implement!(i32, u32, 32); +implement!(i64, u64, 64); +implement!(i128, u128, 128); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_sint8_binary_rep() { + let a: i8 = -100; + let b = a.to_bits_string(4); + assert_eq!(b, "1001 1100".to_string()); + } + + #[test] + fn test_sint16_binary_rep() { + let a: i16 = -25702; + let b = a.to_bits_string(4); + assert_eq!(b, "1001 1011 1001 1010".to_string()); + } + + #[test] + fn test_sint32_binary_rep() { + let a: i32 = -1684411356; + let b = a.to_bits_string(4); + assert_eq!(b, "1001 1011 1001 1001 1110 1100 0010 0100".to_string()); + } + + #[test] + fn test_sint64_binary_rep() { + let a: i64 = -7_234_491_689_707_068_824; + let b = a.to_bits_string(4); + assert_eq!( + b, + "1001 1011 1001 1001 1110 1100 0010 0011 \ + 0110 0000 0111 1110 1010 0010 0110 1000" + .to_string() + ); + } + + #[test] + fn test_sint128_binary_rep() { + let a: i128 = -124_282_366_920_938_463_463_374_121_543_098_288_434; + let b = a.to_bits_string(4); + assert_eq!( + b, + "1010 0010 1000 0000 0001 0110 0011 1000 \ + 0111 0001 1001 1101 1111 1010 0100 1111 \ + 0100 0111 1100 1111 1110 1111 0110 1001 \ + 1100 0101 1001 0010 0011 0110 1100 1110" + .to_string() + ); + } +} diff --git a/tfhe/src/core_crypto/commons/numeric/unsigned.rs b/tfhe/src/core_crypto/commons/numeric/unsigned.rs new file mode 100644 index 000000000..1b78fa65a --- /dev/null +++ b/tfhe/src/core_crypto/commons/numeric/unsigned.rs @@ -0,0 +1,190 @@ +use super::{CastFrom, CastInto, Numeric, SignedInteger}; +use std::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, + Mul, MulAssign, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, +}; + +/// A trait shared by all the unsigned integer types. +pub trait UnsignedInteger: + Numeric + + Ord + + Eq + + Add + + AddAssign + + Div + + DivAssign + + Mul + + MulAssign + + Rem + + RemAssign + + Sub + + SubAssign + + BitAnd + + BitAndAssign + + BitOr + + BitOrAssign + + BitXor + + BitXorAssign + + Not + + Shl + + ShlAssign + + Shr + + ShrAssign + + CastFrom + + CastInto +{ + /// The signed type of the same precision. + type Signed: SignedInteger + CastFrom; + /// Compute an addition, modulo the max of the type. + #[must_use] + fn wrapping_add(self, other: Self) -> Self; + /// Compute a subtraction, modulo the max of the type. + #[must_use] + fn wrapping_sub(self, other: Self) -> Self; + /// Compute a division, modulo the max of the type. + #[must_use] + fn wrapping_div(self, other: Self) -> Self; + /// Compute a multiplication, modulo the max of the type. + #[must_use] + fn wrapping_mul(self, other: Self) -> Self; + /// Compute a negation, modulo the max of the type. + #[must_use] + fn wrapping_neg(self) -> Self; + /// Compute an exponentiation, modulo the max of the type. + #[must_use] + fn wrapping_pow(self, exp: u32) -> Self; + /// Panic free shift-left operation. + #[must_use] + fn wrapping_shl(self, rhs: u32) -> Self; + /// Panic free shift-right operation. + #[must_use] + fn wrapping_shr(self, rhs: u32) -> Self; + /// Returns the casting of the current value to the signed type of the same size. + fn into_signed(self) -> Self::Signed; + /// Returns a bit representation of the integer, where blocks of length `block_length` are + /// separated by whitespaces to increase the readability. + fn to_bits_string(&self, block_length: usize) -> String; +} + +macro_rules! implement { + ($Type: tt, $SignedType:ty, $bits:expr) => { + impl Numeric for $Type { + const BITS: usize = $bits; + const ZERO: Self = 0; + const ONE: Self = 1; + const TWO: Self = 2; + const MAX: Self = <$Type>::MAX; + } + impl UnsignedInteger for $Type { + type Signed = $SignedType; + #[inline] + fn into_signed(self) -> Self::Signed { + Self::Signed::cast_from(self) + } + fn to_bits_string(&self, break_every: usize) -> String { + let mut strn = match <$Type as Numeric>::BITS { + 8 => format!("{:08b}", self), + 16 => format!("{:016b}", self), + 32 => format!("{:032b}", self), + 64 => format!("{:064b}", self), + 128 => format!("{:0128b}", self), + _ => unreachable!(), + }; + for i in (1..(<$Type as Numeric>::BITS / break_every)).rev() { + strn.insert(i * break_every, ' '); + } + strn + } + #[inline] + fn wrapping_add(self, other: Self) -> Self { + self.wrapping_add(other) + } + #[inline] + fn wrapping_sub(self, other: Self) -> Self { + self.wrapping_sub(other) + } + #[inline] + fn wrapping_div(self, other: Self) -> Self { + self.wrapping_div(other) + } + #[inline] + fn wrapping_mul(self, other: Self) -> Self { + self.wrapping_mul(other) + } + #[inline] + fn wrapping_neg(self) -> Self { + self.wrapping_neg() + } + #[inline] + fn wrapping_shl(self, rhs: u32) -> Self { + self.wrapping_shl(rhs) + } + #[inline] + fn wrapping_shr(self, rhs: u32) -> Self { + self.wrapping_shr(rhs) + } + #[inline] + fn wrapping_pow(self, exp: u32) -> Self { + self.wrapping_pow(exp) + } + } + }; +} + +implement!(u8, i8, 8); +implement!(u16, i16, 16); +implement!(u32, i32, 32); +implement!(u64, i64, 64); +implement!(u128, i128, 128); + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_uint8_binary_rep() { + let a: u8 = 100; + let b = a.to_bits_string(4); + assert_eq!(b, "0110 0100".to_string()); + } + + #[test] + fn test_uint16_binary_rep() { + let a: u16 = 25702; + let b = a.to_bits_string(4); + assert_eq!(b, "0110 0100 0110 0110".to_string()); + } + + #[test] + fn test_uint32_binary_rep() { + let a: u32 = 1684411356; + let b = a.to_bits_string(4); + assert_eq!(b, "0110 0100 0110 0110 0001 0011 1101 1100".to_string()); + } + + #[test] + fn test_uint64_binary_rep() { + let a: u64 = 7_234_491_689_707_068_824; + let b = a.to_bits_string(4); + assert_eq!( + b, + "0110 0100 0110 0110 0001 0011 1101 1100 \ + 1001 1111 1000 0001 0101 1101 1001 1000" + .to_string() + ); + } + + #[test] + fn test_uint128_binary_rep() { + let a: u128 = 124_282_366_920_938_463_463_374_121_543_098_288_434; + let b = a.to_bits_string(4); + assert_eq!( + b, + "0101 1101 0111 1111 1110 1001 1100 0111 \ + 1000 1110 0110 0010 0000 0101 1011 0000 \ + 1011 1000 0011 0000 0001 0000 1001 0110 \ + 0011 1010 0110 1101 1100 1001 0011 0010" + .to_string() + ); + } +} diff --git a/tfhe/src/core_crypto/commons/utils.rs b/tfhe/src/core_crypto/commons/utils.rs new file mode 100644 index 000000000..f41a9a4b8 --- /dev/null +++ b/tfhe/src/core_crypto/commons/utils.rs @@ -0,0 +1,128 @@ +//! Utilities for the library. + +/// This macro is used in tandem with the [`zip_args`] macro, to allow to zip iterators and access +/// them in an non-nested fashion. This makes large zip iterators easier to write, but also, +/// makes the code faster, as zipped-flatten iterators are hard to optimize for the compiler. +macro_rules! zip { + ($($iterator:expr),*) => { + $crate::core_crypto::commons::utils::zip!(@zip $($iterator),*) + }; + (@zip $first:expr, $($iterator:expr),* ) => { + $first.zip($crate::core_crypto::commons::utils::zip!(@zip $($iterator),*)) + }; + (@zip $first:expr) => { + $first + }; +} +pub(crate) use zip; + +/// Companion macro to flatten the iterators made with the [`zip`] +macro_rules! zip_args { + ($($iterator:pat),*) => { + $crate::core_crypto::commons::utils::zip_args!(@zip $($iterator),*) + }; + (@zip $first:pat, $second:pat) => { + ($first, $second) + }; + (@zip $first:pat, $($iterator:pat),*) => { + ($first, $crate::core_crypto::commons::utils::zip_args!(@zip $($iterator),*)) + }; +} +pub(crate) use zip_args; + +#[inline] +fn assert_same_len(a: (usize, Option), b: (usize, Option)) { + debug_assert_eq!(a.1, Some(a.0)); + debug_assert_eq!(b.1, Some(b.0)); + debug_assert_eq!(a.0, b.0); +} + +/// Returns a Zip iterator, but checks that the two components have the same length. +pub trait ZipChecked: IntoIterator + Sized { + #[inline] + fn zip_checked( + self, + b: B, + ) -> core::iter::Zip<::IntoIter, ::IntoIter> { + let a = self.into_iter(); + let b = b.into_iter(); + assert_same_len(a.size_hint(), b.size_hint()); + core::iter::zip(a, b) + } +} + +impl ZipChecked for A {} + +// https://docs.rs/itertools/0.7.8/src/itertools/lib.rs.html#247-269 +#[allow(unused_macros)] +macro_rules! izip { + // eg. __izip_closure!(((a, b), c) => (a, b, c) , dd , ee ) + (@ __closure @ $p:pat => $tup:expr) => { + |$p| $tup + }; + + // The "b" identifier is a different identifier on each recursion level thanks to hygiene. + (@ __closure @ $p:pat => ( $($tup:tt)* ) , $_iter:expr $( , $tail:expr )*) => { + $crate::core_crypto::commons::utils::izip!(@ __closure @ ($p, b) => ( $($tup)*, b ) $( , $tail )*) + }; + + ( $first:expr $(,)?) => { + { + #[allow(unused_imports)] + use $crate::core_crypto::commons::utils::ZipChecked; + ::core::iter::IntoIterator::into_iter($first) + } + }; + ( $first:expr, $($rest:expr),+ $(,)?) => { + { + #[allow(unused_imports)] + use $crate::core_crypto::commons::utils::ZipChecked; + ::core::iter::IntoIterator::into_iter($first) + $(.zip_checked($rest))* + .map($crate::core_crypto::commons::utils::izip!(@ __closure @ a => (a) $( , $rest )*)) + } + }; +} + +#[allow(unused_imports)] +pub(crate) use izip; + +#[cfg(test)] +mod test { + #![allow(clippy::many_single_char_names)] + + #[test] + fn test_zip() { + let a = vec![1, 2, 3]; + let b = vec![4, 5, 6]; + let c = vec![7, 8, 9]; + let d = vec![10, 11, 12]; + let e = vec![13, 14, 15]; + let f = vec![16, 17, 18]; + let g = vec![19, 20, 21]; + for zip_args!(a, b, c) in zip!(a.iter(), b.iter(), c.iter()) { + println!("{},{},{}", a, b, c); + } + let mut iterator = zip!( + a.into_iter(), + b.into_iter(), + c.into_iter(), + d.into_iter(), + e.into_iter(), + f.into_iter(), + g.into_iter() + ); + assert_eq!( + iterator.next().unwrap(), + (1, (4, (7, (10, (13, (16, 19)))))) + ); + assert_eq!( + iterator.next().unwrap(), + (2, (5, (8, (11, (14, (17, 20)))))) + ); + assert_eq!( + iterator.next().unwrap(), + (3, (6, (9, (12, (15, (18, 21)))))) + ); + } +} diff --git a/tfhe/src/core_crypto/mod.rs b/tfhe/src/core_crypto/mod.rs new file mode 100644 index 000000000..473ebdc05 --- /dev/null +++ b/tfhe/src/core_crypto/mod.rs @@ -0,0 +1,44 @@ +#![deny(rustdoc::broken_intra_doc_links)] +#![cfg_attr(docsrs, feature(doc_cfg))] +//! Welcome to the tfhe.rs `core_crypto` module documentation! +//! +//! This library contains a set of low-level primitives which can be used to implement *Fully +//! Homomorphically Encrypted* (FHE) programs. In a nutshell, fully homomorphic encryption makes it +//! possible to perform arbitrary computations over encrypted data. With FHE, you can perform +//! computations without putting your trust on third-party computation providers. +//! +//! # Audience +//! +//! This library is geared towards people who already know their way around FHE. It gives the user +//! freedom of choice over a breadth of parameters, which can lead to less than 128 bits of security +//! if chosen incorrectly +//! +//! # Architecture +//! +//! `core_crypto` is modular which makes it possible to use different backends to perform FHE +//! operations. Its design revolves around two modules: +//! +//! + The [`specification`](crate::core_crypto::specification) module contains a specification (in +//! the form of traits) of Zama's variant of the TFHE scheme. It describes the FHE objects and +//! operators, which are exposed by the library. +//! + The [`backends`](crate::core_crypto::backends) module contains various backends implementing +//! all or a part of this scheme. These different backends can be activated by feature flags, each +//! making use of different hardware or system libraries to make the operations faster. +//! +//! # Activating backends +//! +//! The different backends can be activated using the feature flags `backend_*`. The `backend_core` +//! contains an engine executing operations on a single thread of the cpu. It is activated by +//! default. +//! +//! # Navigating the code +//! +//! If this is your first time looking at the `core_crypto` module code, it may be simpler for you +//! to first have a look at the [`specification`](crate::core_crypto::specification) module, which +//! contains explanations on the abstract API, and navigate from there. + +pub mod backends; +#[doc(hidden)] +pub mod commons; +pub mod prelude; +pub mod specification; diff --git a/tfhe/src/core_crypto/prelude.rs b/tfhe/src/core_crypto/prelude.rs new file mode 100644 index 000000000..bdd6b357c --- /dev/null +++ b/tfhe/src/core_crypto/prelude.rs @@ -0,0 +1,40 @@ +#![doc(hidden)] + +// ----------------------------------------------------------------------------------- SPECIFICATION +pub use super::specification::engines::*; +pub use super::specification::entities::*; + +// --------------------------------------------------------------------------------- DEFAULT BACKEND +#[cfg(feature = "backend_default")] +pub use super::backends::default::engines::*; +#[cfg(feature = "backend_default")] +pub use super::backends::default::entities::*; + +// --------------------------------------------------------------------------------- FFT BACKEND +#[cfg(feature = "backend_fft")] +pub use super::backends::fft::engines::*; +#[cfg(feature = "backend_fft")] +pub use super::backends::fft::entities::*; + +// ------------------------------------------------------------------------------------ CUDA BACKEND +#[cfg(feature = "backend_cuda")] +pub use super::backends::cuda::engines::*; +#[cfg(feature = "backend_cuda")] +pub use super::backends::cuda::entities::*; + +// -------------------------------------------------------------------------------- COMMONS REEXPORT +pub use super::specification::dispersion::*; +pub use super::specification::key_kinds::*; +pub use super::specification::parameters::*; +pub use super::specification::*; + +// --------------------------------------------------------------------------------- CSPRNG REEXPORT +// Re-export the different seeders of the `concrete-csprng` crate, which are needed to construct +// default engines. +#[cfg(target_os = "macos")] +pub use concrete_csprng::seeders::AppleSecureEnclaveSeeder; +#[cfg(feature = "seeder_x86_64_rdseed")] +pub use concrete_csprng::seeders::RdseedSeeder; +pub use concrete_csprng::seeders::Seeder; +#[cfg(feature = "seeder_unix")] +pub use concrete_csprng::seeders::UnixSeeder; diff --git a/tfhe/src/core_crypto/specification/dispersion.rs b/tfhe/src/core_crypto/specification/dispersion.rs new file mode 100644 index 000000000..30a08a23c --- /dev/null +++ b/tfhe/src/core_crypto/specification/dispersion.rs @@ -0,0 +1,198 @@ +//! Noise distribution +//! +//! When dealing with noise, we tend to use different representation for the same value. In +//! general, the noise is specified by the standard deviation of a gaussian distribution, which +//! is of the form $\sigma = 2^p$, with $p$ a negative integer. Depending on the use case though, +//! we rely on different representations for this quantity: +//! +//! + $\sigma$ can be encoded in the [`StandardDev`] type. +//! + $p$ can be encoded in the [`LogStandardDev`] type. +//! + $\sigma^2$ can be encoded in the [`Variance`] type. +//! +//! In any of those cases, the corresponding type implements the `DispersionParameter` trait, +//! which makes if possible to use any of those representations generically when noise must be +//! defined. + +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// A trait for types representing distribution parameters, for a given unsigned integer type. +// Warning: +// DispersionParameter type should ONLY wrap a single native type. +// As long as Variance wraps a native type (f64) it is ok to derive it from Copy instead of +// Clone because f64 is itself Copy and stored in register. +pub trait DispersionParameter: Copy { + /// Returns the standard deviation of the distribution, i.e. $\sigma = 2^p$. + fn get_standard_dev(&self) -> f64; + /// Returns the variance of the distribution, i.e. $\sigma^2 = 2^{2p}$. + fn get_variance(&self) -> f64; + /// Returns base 2 logarithm of the standard deviation of the distribution, i.e. + /// $\log\_2(\sigma)=p$ + fn get_log_standard_dev(&self) -> f64; + /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{q-p}$. + fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64; + + /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $2^{2(q-p)}$. + fn get_modular_variance(&self, log2_modulus: u32) -> f64; + + /// For a `Uint` type representing $\mathbb{Z}/2^q\mathbb{Z}$, we return $q-p$. + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64; +} + +/// A distribution parameter that uses the base-2 logarithm of the standard deviation as +/// representation. +/// +/// # Example: +/// +/// ``` +/// use tfhe::core_crypto::prelude::{DispersionParameter, LogStandardDev}; +/// let params = LogStandardDev::from_log_standard_dev(-25.); +/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.)); +/// assert_eq!(params.get_log_standard_dev(), -25.); +/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2)); +/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.),); +/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.); +/// assert_eq!( +/// params.get_modular_variance(32), +/// 2_f64.powf(32. - 25.).powi(2) +/// ); +/// +/// let modular_params = LogStandardDev::from_modular_log_standard_dev(22., 32); +/// assert_eq!(modular_params.get_standard_dev(), 2_f64.powf(-10.)); +/// ``` +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct LogStandardDev(pub f64); + +impl LogStandardDev { + pub fn from_log_standard_dev(log_std: f64) -> LogStandardDev { + LogStandardDev(log_std) + } + + pub fn from_modular_log_standard_dev(log_std: f64, log2_modulus: u32) -> LogStandardDev { + LogStandardDev(log_std - log2_modulus as f64) + } +} + +impl DispersionParameter for LogStandardDev { + fn get_standard_dev(&self) -> f64 { + f64::powf(2., self.0) + } + fn get_variance(&self) -> f64 { + f64::powf(2., self.0 * 2.) + } + fn get_log_standard_dev(&self) -> f64 { + self.0 + } + fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 { + f64::powf(2., log2_modulus as f64 + self.0) + } + fn get_modular_variance(&self, log2_modulus: u32) -> f64 { + f64::powf(2., (log2_modulus as f64 + self.0) * 2.) + } + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 { + log2_modulus as f64 + self.0 + } +} + +/// A distribution parameter that uses the standard deviation as representation. +/// +/// # Example: +/// +/// ``` +/// use tfhe::core_crypto::prelude::{DispersionParameter, StandardDev}; +/// let params = StandardDev::from_standard_dev(2_f64.powf(-25.)); +/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.)); +/// assert_eq!(params.get_log_standard_dev(), -25.); +/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2)); +/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.)); +/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.); +/// assert_eq!( +/// params.get_modular_variance(32), +/// 2_f64.powf(32. - 25.).powi(2) +/// ); +/// ``` +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct StandardDev(pub f64); + +impl StandardDev { + pub fn from_standard_dev(std: f64) -> StandardDev { + StandardDev(std) + } + + pub fn from_modular_standard_dev(std: f64, log2_modulus: u32) -> StandardDev { + StandardDev(std / 2_f64.powf(log2_modulus as f64)) + } +} + +impl DispersionParameter for StandardDev { + fn get_standard_dev(&self) -> f64 { + self.0 + } + fn get_variance(&self) -> f64 { + self.0.powi(2) + } + fn get_log_standard_dev(&self) -> f64 { + self.0.log2() + } + fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 { + 2_f64.powf(log2_modulus as f64 + self.0.log2()) + } + fn get_modular_variance(&self, log2_modulus: u32) -> f64 { + 2_f64.powf(2. * (log2_modulus as f64 + self.0.log2())) + } + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 { + log2_modulus as f64 + self.0.log2() + } +} + +/// A distribution parameter that uses the variance as representation +/// +/// # Example: +/// +/// ``` +/// use tfhe::core_crypto::prelude::{DispersionParameter, Variance}; +/// let params = Variance::from_variance(2_f64.powi(-50)); +/// assert_eq!(params.get_standard_dev(), 2_f64.powf(-25.)); +/// assert_eq!(params.get_log_standard_dev(), -25.); +/// assert_eq!(params.get_variance(), 2_f64.powf(-25.).powi(2)); +/// assert_eq!(params.get_modular_standard_dev(32), 2_f64.powf(32. - 25.)); +/// assert_eq!(params.get_modular_log_standard_dev(32), 32. - 25.); +/// assert_eq!( +/// params.get_modular_variance(32), +/// 2_f64.powf(32. - 25.).powi(2) +/// ); +/// ``` +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)] +pub struct Variance(pub f64); + +impl Variance { + pub fn from_variance(var: f64) -> Variance { + Variance(var) + } + + pub fn from_modular_variance(var: f64, log2_modulus: u32) -> Variance { + Variance(var / 2_f64.powf(log2_modulus as f64 * 2.)) + } +} + +impl DispersionParameter for Variance { + fn get_standard_dev(&self) -> f64 { + self.0.sqrt() + } + fn get_variance(&self) -> f64 { + self.0 + } + fn get_log_standard_dev(&self) -> f64 { + self.0.sqrt().log2() + } + fn get_modular_standard_dev(&self, log2_modulus: u32) -> f64 { + 2_f64.powf(log2_modulus as f64 + self.0.sqrt().log2()) + } + fn get_modular_variance(&self, log2_modulus: u32) -> f64 { + 2_f64.powf(2. * (log2_modulus as f64 + self.0.sqrt().log2())) + } + fn get_modular_log_standard_dev(&self, log2_modulus: u32) -> f64 { + log2_modulus as f64 + self.0.sqrt().log2() + } +} diff --git a/tfhe/src/core_crypto/specification/engines/cleartext_creation.rs b/tfhe/src/core_crypto/specification/engines/cleartext_creation.rs new file mode 100644 index 000000000..63d6fd4b7 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/cleartext_creation.rs @@ -0,0 +1,36 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::CleartextEntity; + +engine_error! { + CleartextCreationError for CleartextCreationEngine @ +} + +/// A trait for engines creating cleartexts from arbitrary values. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a cleartext from the `value` +/// arbitrary value. By arbitrary here, we mean that `Value` can be any type that suits the backend +/// implementor (an integer, a struct wrapping integers, a struct wrapping foreign data or any other +/// thing). +/// +/// # Formal Definition +pub trait CleartextCreationEngine: AbstractEngine +where + Cleartext: CleartextEntity, +{ + /// Creates a cleartext from an arbitrary value. + fn create_cleartext_from( + &mut self, + value: &Value, + ) -> Result>; + + /// Unsafely creates a cleartext from an arbitrary value. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`CleartextCreationError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn create_cleartext_from_unchecked(&mut self, value: &Value) -> Cleartext; +} diff --git a/tfhe/src/core_crypto/specification/engines/entity_deserialization.rs b/tfhe/src/core_crypto/specification/engines/entity_deserialization.rs new file mode 100644 index 000000000..844746a1e --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/entity_deserialization.rs @@ -0,0 +1,32 @@ +use super::engine_error; +use crate::core_crypto::prelude::AbstractEntity; +use crate::core_crypto::specification::engines::AbstractEngine; + +engine_error! { + EntityDeserializationError for EntityDeserializationEngine @ +} + +/// A trait for engines deserializing entities. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates an entity containing the +/// deserialization of the `serialized` type. +pub trait EntityDeserializationEngine: AbstractEngine +where + Entity: AbstractEntity, +{ + /// Deserializes an entity. + fn deserialize( + &mut self, + serialized: Serialized, + ) -> Result>; + + /// Unsafely deserializes an entity. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`EntityDeserializationError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn deserialize_unchecked(&mut self, serialized: Serialized) -> Entity; +} diff --git a/tfhe/src/core_crypto/specification/engines/entity_serialization.rs b/tfhe/src/core_crypto/specification/engines/entity_serialization.rs new file mode 100644 index 000000000..1b7b610bf --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/entity_serialization.rs @@ -0,0 +1,32 @@ +use super::engine_error; +use crate::core_crypto::prelude::AbstractEntity; +use crate::core_crypto::specification::engines::AbstractEngine; + +engine_error! { + EntitySerializationError for EntitySerializationEngine @ +} + +/// A trait for engines serializing entities. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a value containing the serialization +/// of `entity`. +pub trait EntitySerializationEngine: AbstractEngine +where + Entity: AbstractEntity, +{ + /// Serializes an entity. + fn serialize( + &mut self, + entity: &Entity, + ) -> Result>; + + /// Unsafely serializes an entity. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`EntitySerializationError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn serialize_unchecked(&mut self, entity: &Entity) -> Serialized; +} diff --git a/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_consuming_retrieval.rs b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_consuming_retrieval.rs new file mode 100644 index 000000000..e14ecefe2 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_consuming_retrieval.rs @@ -0,0 +1,37 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::GlweCiphertextEntity; + +engine_error! { + GlweCiphertextConsumingRetrievalError for GlweCiphertextConsumingRetrievalEngine @ +} + +/// A trait for engines retrieving the content of the container from a GLWE ciphertext consuming it +/// in the process. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation retrieves the content of the container from the +/// `input` GLWE ciphertext consuming it in the process. +pub trait GlweCiphertextConsumingRetrievalEngine: AbstractEngine +where + Ciphertext: GlweCiphertextEntity, +{ + /// Retrieves the content of the container from a GLWE ciphertext, consuming it in the process. + fn consume_retrieve_glwe_ciphertext( + &mut self, + ciphertext: Ciphertext, + ) -> Result>; + + /// Unsafely retrieves the content of the container from a GLWE ciphertext, consuming it in the + /// process. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`GlweCiphertextConsumingRetrievalError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn consume_retrieve_glwe_ciphertext_unchecked( + &mut self, + ciphertext: Ciphertext, + ) -> Container; +} diff --git a/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_conversion.rs b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_conversion.rs new file mode 100644 index 000000000..97f0464a3 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_conversion.rs @@ -0,0 +1,36 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::GlweCiphertextEntity; + +engine_error! { + GlweCiphertextConversionError for GlweCiphertextConversionEngine @ +} + +/// A trait for engines converting GLWE ciphertexts. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a GLWE ciphertext containing the +/// conversion of the `input` GLWE ciphertext to a type with a different representation (for +/// instance from cpu to gpu memory). +/// +/// # Formal Definition +pub trait GlweCiphertextConversionEngine: AbstractEngine +where + Input: GlweCiphertextEntity, + Output: GlweCiphertextEntity, +{ + /// Converts a GLWE ciphertext. + fn convert_glwe_ciphertext( + &mut self, + input: &Input, + ) -> Result>; + + /// Unsafely converts a GLWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`GlweCiphertextConversionError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn convert_glwe_ciphertext_unchecked(&mut self, input: &Input) -> Output; +} diff --git a/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_creation.rs b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_creation.rs new file mode 100644 index 000000000..d49c0ed72 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_creation.rs @@ -0,0 +1,62 @@ +use super::engine_error; +use crate::core_crypto::prelude::PolynomialSize; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::GlweCiphertextEntity; + +engine_error! { + GlweCiphertextCreationError for GlweCiphertextCreationEngine @ + EmptyContainer => "The container used to create the GLWE ciphertext is of length 0!", + InvalidContainerSize => "The length of the container used to create the GLWE ciphertext \ + needs to be a multiple of `polynomial_size`." +} + +impl GlweCiphertextCreationError { + /// Validates the inputs, the container is expected to have a length of + /// glwe_size * polynomial_size, during construction we only get the container and the + /// polynomial size so we check the length is consistent, the GLWE size is deduced by the + /// ciphertext implementation from the container and the polynomial size. + pub fn perform_generic_checks( + container_length: usize, + polynomial_size: PolynomialSize, + ) -> Result<(), Self> { + if container_length == 0 { + return Err(Self::EmptyContainer); + } + if container_length % polynomial_size.0 != 0 { + return Err(Self::InvalidContainerSize); + } + + Ok(()) + } +} + +/// A trait for engines creating a GLWE ciphertext from an arbitrary container. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation creates a GLWE ciphertext from the abitrary +/// `container`. By arbitrary here, we mean that `Container` can be any type that allows to +/// instantiate a `GlweCiphertextEntity`. +pub trait GlweCiphertextCreationEngine: AbstractEngine +where + Ciphertext: GlweCiphertextEntity, +{ + /// Creates a GLWE ciphertext from an arbitrary container. + fn create_glwe_ciphertext_from( + &mut self, + container: Container, + polynomial_size: PolynomialSize, + ) -> Result>; + + /// Unsafely creates a GLWE ciphertext from an arbitrary container. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`GlweCiphertextCreationError`]. For safety concerns _specific_ to an engine, refer + /// to the implementer safety section. + unsafe fn create_glwe_ciphertext_from_unchecked( + &mut self, + container: Container, + polynomial_size: PolynomialSize, + ) -> Ciphertext; +} diff --git a/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_trivial_encryption.rs b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_trivial_encryption.rs new file mode 100644 index 000000000..57379953a --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/glwe_ciphertext_trivial_encryption.rs @@ -0,0 +1,48 @@ +use super::engine_error; +use crate::core_crypto::prelude::{GlweSize, PlaintextVectorEntity}; + +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::GlweCiphertextEntity; + +engine_error! { + GlweCiphertextTrivialEncryptionError for GlweCiphertextTrivialEncryptionEngine @ +} + +/// A trait for engines trivially encrypting GLWE ciphertext. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a GLWE ciphertext containing the +/// trivial encryption of the `input` plaintext vector with the requested `glwe_size`. +/// +/// # Formal Definition +/// +/// A trivial encryption uses a zero mask and no noise. +/// It is absolutely not secure, as the body contains a direct copy of the plaintext. +/// However, it is useful for some FHE algorithms taking public information as input. For +/// example, a trivial GLWE encryption of a public lookup table is used in the bootstrap. +pub trait GlweCiphertextTrivialEncryptionEngine: + AbstractEngine +where + PlaintextVector: PlaintextVectorEntity, + Ciphertext: GlweCiphertextEntity, +{ + /// Trivially encrypts a plaintext vector into a GLWE ciphertext. + fn trivially_encrypt_glwe_ciphertext( + &mut self, + glwe_size: GlweSize, + input: &PlaintextVector, + ) -> Result>; + + /// Unsafely creates the trivial GLWE encryption of the plaintext vector. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`GlweCiphertextTrivialEncryptionError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn trivially_encrypt_glwe_ciphertext_unchecked( + &mut self, + glwe_size: GlweSize, + input: &PlaintextVector, + ) -> Ciphertext; +} diff --git a/tfhe/src/core_crypto/specification/engines/glwe_secret_key_generation.rs b/tfhe/src/core_crypto/specification/engines/glwe_secret_key_generation.rs new file mode 100644 index 000000000..0a5a0c9b5 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/glwe_secret_key_generation.rs @@ -0,0 +1,67 @@ +use super::engine_error; +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::GlweSecretKeyEntity; + +engine_error! { + GlweSecretKeyGenerationError for GlweSecretKeyGenerationEngine @ + NullGlweDimension => "The secret key GLWE dimension must be greater than zero.", + NullPolynomialSize => "The secret key polynomial size must be greater than zero.", + SizeOnePolynomial => "The secret key polynomial size must be greater than one. Otherwise you \ + should prefer the LWE scheme." +} + +impl GlweSecretKeyGenerationError { + /// Validates the inputs + pub fn perform_generic_checks( + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> Result<(), Self> { + if glwe_dimension.0 == 0 { + return Err(Self::NullGlweDimension); + } + + if polynomial_size.0 == 0 { + return Err(Self::NullPolynomialSize); + } + + if polynomial_size.0 == 1 { + return Err(Self::SizeOnePolynomial); + } + + Ok(()) + } +} + +/// A trait for engines generating new GLWE secret keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a new GLWE secret key. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::entities::GlweSecretKeyEntity`) +pub trait GlweSecretKeyGenerationEngine: AbstractEngine +where + SecretKey: GlweSecretKeyEntity, +{ + /// Generates a new GLWE secret key. + fn generate_new_glwe_secret_key( + &mut self, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> Result>; + + /// Unsafely generates a new GLWE secret key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`GlweSecretKeyGenerationError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn generate_new_glwe_secret_key_unchecked( + &mut self, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + ) -> SecretKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/glwe_to_lwe_secret_key_transformation.rs b/tfhe/src/core_crypto/specification/engines/glwe_to_lwe_secret_key_transformation.rs new file mode 100644 index 000000000..1904e8c0e --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/glwe_to_lwe_secret_key_transformation.rs @@ -0,0 +1,39 @@ +use super::engine_error; +use crate::core_crypto::prelude::AbstractEngine; + +use crate::core_crypto::specification::entities::{GlweSecretKeyEntity, LweSecretKeyEntity}; + +engine_error! { + GlweToLweSecretKeyTransformationError for GlweToLweSecretKeyTransformationEngine @ +} + +/// A trait for engines transforming GLWE secret keys into LWE secret keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation moves the existing GLWE into a fresh LWE secret +/// key. +/// +/// # Formal Definition +pub trait GlweToLweSecretKeyTransformationEngine: AbstractEngine +where + InputKey: GlweSecretKeyEntity, + OutputKey: LweSecretKeyEntity, +{ + /// Does the transformation of the GLWE secret key into an LWE secret key + fn transform_glwe_secret_key_to_lwe_secret_key( + &mut self, + glwe_secret_key: InputKey, + ) -> Result>; + + /// Unsafely transforms a GLWE secret key into an LWE secret key + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`GlweToLweSecretKeyTransformationError`]. + /// For safety concerns _specific_ to an engine, refer to the implementer safety section. + unsafe fn transform_glwe_secret_key_to_lwe_secret_key_unchecked( + &mut self, + glwe_secret_key: InputKey, + ) -> OutputKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_bootstrap_key_conversion.rs b/tfhe/src/core_crypto/specification/engines/lwe_bootstrap_key_conversion.rs new file mode 100644 index 000000000..c2165f766 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_bootstrap_key_conversion.rs @@ -0,0 +1,36 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweBootstrapKeyEntity; + +engine_error! { + LweBootstrapKeyConversionError for LweBootstrapKeyConversionEngine @ +} + +/// A trait for engines converting LWE bootstrap keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a LWE bootstrap key containing the +/// conversion of the `input` bootstrap key to a type with a different representation (for instance +/// from cpu to gpu memory). +/// +/// # Formal Definition +pub trait LweBootstrapKeyConversionEngine: AbstractEngine +where + InputKey: LweBootstrapKeyEntity, + OutputKey: LweBootstrapKeyEntity, +{ + /// Converts an LWE bootstrap key. + fn convert_lwe_bootstrap_key( + &mut self, + input: &InputKey, + ) -> Result>; + + /// Unsafely converts an LWE bootstrap key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweBootstrapKeyConversionError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn convert_lwe_bootstrap_key_unchecked(&mut self, input: &InputKey) -> OutputKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_bootstrap_key_generation.rs b/tfhe/src/core_crypto/specification/engines/lwe_bootstrap_key_generation.rs new file mode 100644 index 000000000..130a407b0 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_bootstrap_key_generation.rs @@ -0,0 +1,76 @@ +use super::engine_error; +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, Variance}; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{ + GlweSecretKeyEntity, LweBootstrapKeyEntity, LweSecretKeyEntity, +}; + +engine_error! { + LweBootstrapKeyGenerationError for LweBootstrapKeyGenerationEngine @ + NullDecompositionBaseLog => "The key decomposition base log must be greater than zero.", + NullDecompositionLevelCount => "The key decomposition level count must be greater than zero.", + DecompositionTooLarge => "The decomposition precision (base log * level count) must not exceed \ + the precision of the ciphertext." +} + +impl LweBootstrapKeyGenerationError { + pub fn perform_generic_checks( + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + ciphertext_modulus_log: usize, + ) -> Result<(), Self> { + if decomposition_base_log.0 == 0 { + return Err(Self::NullDecompositionBaseLog); + } + if decomposition_level_count.0 == 0 { + return Err(Self::NullDecompositionLevelCount); + } + if decomposition_base_log.0 * decomposition_level_count.0 > ciphertext_modulus_log { + return Err(Self::DecompositionTooLarge); + } + Ok(()) + } +} + +/// A trait for engines generating new LWE bootstrap keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a new LWE bootstrap key from the +/// `input_key` LWE secret key, and the `output_key` GLWE secret key. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::entities::LweBootstrapKeyEntity`) +pub trait LweBootstrapKeyGenerationEngine: + AbstractEngine +where + BootstrapKey: LweBootstrapKeyEntity, + LweSecretKey: LweSecretKeyEntity, + GlweSecretKey: GlweSecretKeyEntity, +{ + /// Generates a new LWE bootstrap key. + fn generate_new_lwe_bootstrap_key( + &mut self, + input_key: &LweSecretKey, + output_key: &GlweSecretKey, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result>; + + /// Unsafely generates a new LWE bootstrap key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweBootstrapKeyGenerationError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn generate_new_lwe_bootstrap_key_unchecked( + &mut self, + input_key: &LweSecretKey, + output_key: &GlweSecretKey, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> BootstrapKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_cleartext_fusing_multiplication.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_cleartext_fusing_multiplication.rs new file mode 100644 index 000000000..d780e63b9 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_cleartext_fusing_multiplication.rs @@ -0,0 +1,44 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{CleartextEntity, LweCiphertextEntity}; + +engine_error! { + LweCiphertextCleartextFusingMultiplicationError for LweCiphertextCleartextFusingMultiplicationEngine @ +} + +/// A trait for engines multiplying (fusing) LWE ciphertexts by cleartexts. +/// +/// # Semantics +/// +/// This [fusing](super#operation-semantics) operation multiply the `output` LWE ciphertext with +/// the `input` cleartext. +/// +/// # Formal Definition +/// +/// cf +/// [`here`](`crate::core_crypto::specification::engines::LweCiphertextCleartextFusingMultiplicationEngine`) +pub trait LweCiphertextCleartextFusingMultiplicationEngine: + AbstractEngine +where + Cleartext: CleartextEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Multiply an LWE ciphertext with a cleartext. + fn fuse_mul_lwe_ciphertext_cleartext( + &mut self, + output: &mut Ciphertext, + input: &Cleartext, + ) -> Result<(), LweCiphertextCleartextFusingMultiplicationError>; + + /// Unsafely multiply an LWE ciphertext with a cleartext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextCleartextFusingMultiplicationError`]. For safety concerns _specific_ to + /// an engine, refer to the implementer safety section. + unsafe fn fuse_mul_lwe_ciphertext_cleartext_unchecked( + &mut self, + output: &mut Ciphertext, + input: &Cleartext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_consuming_retrieval.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_consuming_retrieval.rs new file mode 100644 index 000000000..0e14eff6e --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_consuming_retrieval.rs @@ -0,0 +1,37 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextConsumingRetrievalError for LweCiphertextConsumingRetrievalEngine @ +} + +/// A trait for engines retrieving the content of the container from an LWE ciphertext consuming it +/// in the process. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation retrieves the content of the container from the +/// `input` LWE ciphertext consuming it in the process. +pub trait LweCiphertextConsumingRetrievalEngine: AbstractEngine +where + Ciphertext: LweCiphertextEntity, +{ + /// Retrieves the content of the container from an LWE ciphertext, consuming it in the process. + fn consume_retrieve_lwe_ciphertext( + &mut self, + ciphertext: Ciphertext, + ) -> Result>; + + /// Unsafely retrieves the content of the container from an LWE ciphertext, consuming it in the + /// process. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextConsumingRetrievalError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn consume_retrieve_lwe_ciphertext_unchecked( + &mut self, + ciphertext: Ciphertext, + ) -> Container; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_conversion.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_conversion.rs new file mode 100644 index 000000000..a41d8dfd6 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_conversion.rs @@ -0,0 +1,36 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextConversionError for LweCiphertextConversionEngine @ +} + +/// A trait for engines converting LWE ciphertexts. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a LWE ciphertext containing the +/// conversion of the `input` LWE ciphertext to a type with a different representation (for instance +/// from cpu to gpu memory). +/// +/// # Formal Definition +pub trait LweCiphertextConversionEngine: AbstractEngine +where + Input: LweCiphertextEntity, + Output: LweCiphertextEntity, +{ + /// Converts a LWE ciphertext. + fn convert_lwe_ciphertext( + &mut self, + input: &Input, + ) -> Result>; + + /// Unsafely converts a LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextConversionError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn convert_lwe_ciphertext_unchecked(&mut self, input: &Input) -> Output; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_creation.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_creation.rs new file mode 100644 index 000000000..fbd25cb3b --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_creation.rs @@ -0,0 +1,44 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextCreationError for LweCiphertextCreationEngine @ + EmptyContainer => "The container used to create the LWE ciphertext is of length 0!" +} + +impl LweCiphertextCreationError { + /// Validates the inputs + pub fn perform_generic_checks(container_length: usize) -> Result<(), Self> { + if container_length == 0 { + return Err(Self::EmptyContainer); + } + Ok(()) + } +} + +/// A trait for engines creating an LWE ciphertext from an arbitrary container. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation creates an LWE ciphertext from the abitrary +/// `container`. By arbitrary here, we mean that `Container` can be any type that allows to +/// instantiate an `LweCiphertextEntity`. +pub trait LweCiphertextCreationEngine: AbstractEngine +where + Ciphertext: LweCiphertextEntity, +{ + /// Creates an LWE ciphertext from an arbitrary container. + fn create_lwe_ciphertext_from( + &mut self, + container: Container, + ) -> Result>; + + /// Unsafely creates an LWE ciphertext from an arbitrary container. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextCreationError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn create_lwe_ciphertext_from_unchecked(&mut self, container: Container) -> Ciphertext; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_decryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_decryption.rs new file mode 100644 index 000000000..5b67e0eff --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_decryption.rs @@ -0,0 +1,60 @@ +use super::engine_error; + +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{ + LweCiphertextEntity, LweSecretKeyEntity, PlaintextEntity, +}; + +engine_error! { + LweCiphertextDecryptionError for LweCiphertextDecryptionEngine @ +} + +/// A trait for engines decrypting LWE ciphertexts. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates an plaintext containing the +/// decryption of the `input` LWE ciphertext, under the `key` secret key. +/// +/// # Formal Definition +/// +/// ## LWE Decryption +/// ###### inputs: +/// - $\mathsf{ct} = \left( \vec{a} , b\right) \in \mathsf{LWE}^n\_{\vec{s}}( \mathsf{pt} )\subseteq +/// \mathbb{Z}\_q^{(n+1)}$: an LWE ciphertext +/// - $\vec{s}\in\mathbb{Z}\_q^n$: a secret key +/// +/// ###### outputs: +/// - $\mathsf{pt}\in\mathbb{Z}\_q$: a plaintext +/// +/// ###### algorithm: +/// 1. compute $\mathsf{pt} = b - \left\langle \vec{a} , \vec{s} \right\rangle \in\mathbb{Z}\_q$ +/// 3. output $\mathsf{pt}$ +/// +/// **Remark:** Observe that the decryption is followed by a decoding phase that will contain a +/// rounding. +pub trait LweCiphertextDecryptionEngine: AbstractEngine +where + SecretKey: LweSecretKeyEntity, + Ciphertext: LweCiphertextEntity, + Plaintext: PlaintextEntity, +{ + /// Decrypts an LWE ciphertext. + fn decrypt_lwe_ciphertext( + &mut self, + key: &SecretKey, + input: &Ciphertext, + ) -> Result>; + + /// Unsafely decrypts an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDecryptionError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn decrypt_lwe_ciphertext_unchecked( + &mut self, + key: &SecretKey, + input: &Ciphertext, + ) -> Plaintext; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_addition.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_addition.rs new file mode 100644 index 000000000..4e54040c9 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_addition.rs @@ -0,0 +1,100 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextDiscardingAdditionError for LweCiphertextDiscardingAdditionEngine @ + LweDimensionMismatch => "All the ciphertext LWE dimensions must be the same." +} + +impl LweCiphertextDiscardingAdditionError { + /// Validates the inputs + pub fn perform_generic_checks( + output: &OutputCiphertext, + input_1: &InputCiphertext, + input_2: &InputCiphertext, + ) -> Result<(), Self> + where + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, + { + if output.lwe_dimension() != input_1.lwe_dimension() + || output.lwe_dimension() != input_2.lwe_dimension() + { + return Err(Self::LweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines adding (discarding) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext with +/// the addition of the `input_1` LWE ciphertext and the `input_2` LWE ciphertext. +/// +/// # Formal Definition +/// +/// ## LWE homomorphic addition +/// +/// It is a specification of the GLWE homomorphic addition described below. +/// +/// ## GLWE homomorphic addition +/// [`GLWE ciphertexts`](`crate::core_crypto::specification::entities::GlweCiphertextEntity`) +/// are homomorphic with +/// respect to the addition. +/// Let two GLWE ciphertexts +/// $$ +/// \begin{cases} +/// \mathsf{CT}\_1 = \left( \vec{A}\_1, B\_1\right) \in \mathsf{GLWE}\_{\vec{S}} \left( +/// \mathsf{PT}\_1 \right) \subseteq \mathcal{R}\_q^{k+1} \\ \mathsf{CT}\_2 = \left( \vec{A}\_2, +/// B\_2\right) \in \mathsf{GLWE}\_{\vec{S}} \left( \mathsf{PT}\_2 \right) \subseteq +/// \mathcal{R}\_q^{k+1} \end{cases} $$ +/// encrypted under the same +/// [`GLWE secret key`](`crate::core_crypto::specification::entities::GlweSecretKeyEntity`) +/// $\vec{S} \in \mathcal{R}\_q^k$. We can add these ciphertexts homomorhically and obtain as a +/// result a new GLWE ciphertext encrypting the sum of the two plaintexts $\mathsf{PT}\_1 + +/// \mathsf{PT}\_2$. +/// +/// ###### inputs: +/// - $\mathsf{CT}\_1 = \left( \vec{A}\_1, B\_1\right) \in \mathsf{GLWE}\_{\vec{S}} \left( +/// \mathsf{PT}\_1 \right) \subseteq \mathcal{R}\_q^{k+1}$: a GLWE ciphertext +/// - $\mathsf{CT}\_2 = \left( \vec{A}\_2, B\_2\right) \in \mathsf{GLWE}\_{\vec{S}} \left( +/// \mathsf{PT}\_2 \right) \subseteq \mathcal{R}\_q^{k+1}$: a GLWE ciphertext +/// +/// ###### outputs: +/// - $\mathsf{CT} = \left( \vec{A} , B \right) \in \mathsf{GLWE}\_{\vec{S}}( \mathsf{PT}\_1 + +/// \mathsf{PT}\_2 )\subseteq \mathcal{R}\_q^{k+1}$: an GLWE ciphertext +/// +/// ###### algorithm: +/// 1. Compute $\vec{A} = \vec{A}\_1 + \vec{A}\_2 \in\mathcal{R}^k\_q$ +/// 2. Compute $B = B\_1 + B\_2 \in\mathcal{R}\_q$ +/// 3. Output $\left( \vec{A} , B \right)$ +pub trait LweCiphertextDiscardingAdditionEngine: + AbstractEngine +where + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, +{ + /// Adds two LWE ciphertexts. + fn discard_add_lwe_ciphertext( + &mut self, + output: &mut OutputCiphertext, + input_1: &InputCiphertext, + input_2: &InputCiphertext, + ) -> Result<(), LweCiphertextDiscardingAdditionError>; + + /// Unsafely adds two LWE ciphertexts. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingAdditionError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn discard_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut OutputCiphertext, + input_1: &InputCiphertext, + input_2: &InputCiphertext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_bit_extraction.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_bit_extraction.rs new file mode 100644 index 000000000..1a1ea0897 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_bit_extraction.rs @@ -0,0 +1,127 @@ +use super::engine_error; +use crate::core_crypto::prelude::{ + AbstractEngine, CiphertextModulusLog, DeltaLog, ExtractedBitsCount, LweBootstrapKeyEntity, + LweCiphertextEntity, LweCiphertextVectorEntity, LweKeyswitchKeyEntity, +}; + +engine_error! { + LweCiphertextDiscardingBitExtractError for LweCiphertextDiscardingBitExtractEngine @ + InputLweDimensionMismatch => "The input ciphertext and bootstrap key LWE dimension must be the \ + same.", + InputKeyswitchKeyLweDimensionMismatch => "The input ciphertext LWE dimension must be the same \ + as the keyswitch key input LWE dimension.", + OutputLweDimensionMismatch => "The output ciphertext vector LWE dimension must be the same \ + as the output LWE dimension of the keyswitch key.", + ExtractedBitsCountMismatch => "The output LWE ciphertext vector count must be the same as \ + the number of bits to extract.", + KeyDimensionMismatch => "The keyswitch key output LWE dimension must be the same as the \ + bootstrap key input LWE dimension.", + NotEnoughBitsToExtract => "The number of bits to extract, starting from the bit at index \ + delta_log towards the most significant bits, should not exceed the \ + total number of available bits in the ciphertext." +} + +impl LweCiphertextDiscardingBitExtractError { + /// Validates the inputs + pub fn perform_generic_checks< + BootstrapKey, + KeyswitchKey, + InputCiphertext, + OutputCiphertextVector, + >( + output: &OutputCiphertextVector, + input: &InputCiphertext, + bsk: &BootstrapKey, + ksk: &KeyswitchKey, + extracted_bits_count: ExtractedBitsCount, + ciphertext_modulus_log: CiphertextModulusLog, + delta_log: DeltaLog, + ) -> Result<(), Self> + where + BootstrapKey: LweBootstrapKeyEntity, + KeyswitchKey: LweKeyswitchKeyEntity, + InputCiphertext: LweCiphertextEntity, + OutputCiphertextVector: LweCiphertextVectorEntity, + { + if input.lwe_dimension() != bsk.output_lwe_dimension() { + return Err(Self::InputLweDimensionMismatch); + } + if input.lwe_dimension() != ksk.input_lwe_dimension() { + return Err(Self::InputKeyswitchKeyLweDimensionMismatch); + } + if output.lwe_dimension() != ksk.output_lwe_dimension() { + return Err(Self::OutputLweDimensionMismatch); + } + if output.lwe_ciphertext_count().0 != extracted_bits_count.0 { + return Err(Self::ExtractedBitsCountMismatch); + } + if ksk.output_lwe_dimension() != bsk.input_lwe_dimension() { + return Err(Self::KeyDimensionMismatch); + } + if ciphertext_modulus_log.0 < extracted_bits_count.0 + delta_log.0 { + return Err(Self::NotEnoughBitsToExtract); + } + Ok(()) + } +} + +/// A trait for engines doing a (discarding) bit extract over LWE ciphertexts. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext vector +/// with the bit extraction of the `input` LWE ciphertext, extracting `number_of_bits_to_extract` +/// bits starting from the bit at index `delta_log` (0-indexed) included, and going towards the +/// most significant bits. +/// +/// Output bits are ordered from the MSB to the LSB. Each one of them is output in a distinct LWE +/// ciphertext, containing the encryption of the bit scaled by q/2 (i.e., the most significant bit +/// in the plaintext representation). +/// +/// # Formal Definition +/// +/// This function takes as input an [`LWE ciphertext`] +/// (crate::core_crypto::specification::entities::LweCiphertextEntity) +/// $$\mathsf{ct\} = \mathsf{LWE}^n\_{\vec{s}}( \mathsf{m}) \subseteq \mathbb{Z}\_q^{(n+1)}$$ +/// which encrypts some message `m`. We extract bits $m\_i$ of this message into individual LWE +/// ciphertexts. Each of these ciphertexts contains an encryption of $m\_i \cdot q/2$, i.e. +/// $$\mathsf{ct\_i} = \mathsf{LWE}^n\_{\vec{s}}( \mathsf{m\_i} \cdot q/2 )$$. The number of +/// output LWE ciphertexts is determined by the `number_of_bits_to_extract` input parameter. +pub trait LweCiphertextDiscardingBitExtractEngine< + BootstrapKey, + KeyswitchKey, + InputCiphertext, + OutputCiphertextVector, +>: AbstractEngine where + BootstrapKey: LweBootstrapKeyEntity, + KeyswitchKey: LweKeyswitchKeyEntity, + InputCiphertext: LweCiphertextEntity, + OutputCiphertextVector: LweCiphertextVectorEntity, +{ + /// Extract bits of an LWE ciphertext. + fn discard_extract_bits_lwe_ciphertext( + &mut self, + output: &mut OutputCiphertextVector, + input: &InputCiphertext, + bsk: &BootstrapKey, + ksk: &KeyswitchKey, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ) -> Result<(), LweCiphertextDiscardingBitExtractError>; + + /// Unsafely extract bits of an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingBitExtractError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn discard_extract_bits_lwe_ciphertext_unchecked( + &mut self, + output: &mut OutputCiphertextVector, + input: &InputCiphertext, + bsk: &BootstrapKey, + ksk: &KeyswitchKey, + extracted_bits_count: ExtractedBitsCount, + delta_log: DeltaLog, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_bootstrap.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_bootstrap.rs new file mode 100644 index 000000000..6d6856872 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_bootstrap.rs @@ -0,0 +1,145 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; + +use crate::core_crypto::specification::entities::{ + GlweCiphertextEntity, LweBootstrapKeyEntity, LweCiphertextEntity, +}; + +engine_error! { + LweCiphertextDiscardingBootstrapError for LweCiphertextDiscardingBootstrapEngine @ + InputLweDimensionMismatch => "The input ciphertext and key LWE dimension must be the same.", + OutputLweDimensionMismatch => "The output ciphertext dimension and key size (dimension * \ + polynomial size) must be the same.", + AccumulatorPolynomialSizeMismatch => "The accumulator and key polynomial sizes must be the same.", + AccumulatorGlweDimensionMismatch => "The accumulator and key GLWE dimensions must be the same." +} + +impl LweCiphertextDiscardingBootstrapError { + /// Validates the inputs + pub fn perform_generic_checks( + output: &OutputCiphertext, + input: &InputCiphertext, + acc: &Accumulator, + bsk: &BootstrapKey, + ) -> Result<(), Self> + where + BootstrapKey: LweBootstrapKeyEntity, + Accumulator: GlweCiphertextEntity, + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, + { + if input.lwe_dimension() != bsk.input_lwe_dimension() { + return Err(Self::InputLweDimensionMismatch); + } + if acc.polynomial_size() != bsk.polynomial_size() { + return Err(Self::AccumulatorPolynomialSizeMismatch); + } + if acc.glwe_dimension() != bsk.glwe_dimension() { + return Err(Self::AccumulatorGlweDimensionMismatch); + } + if output.lwe_dimension() != bsk.output_lwe_dimension() { + return Err(Self::OutputLweDimensionMismatch); + } + + Ok(()) + } +} + +/// A trait for engines bootstrapping (discarding) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext with +/// the bootstrap of the `input` LWE ciphertext, using the `acc` accumulator as lookup-table, and +/// the `bsk` bootstrap key. +/// +/// # Formal Definition +/// +/// ## Programmable Bootstrapping +/// +/// This homomorphic procedure allows to both reduce the noise of a ciphertext and to evaluate a +/// Look-Up Table (LUT) on the encrypted plaintext at the same time, i.e., it transforms an input +/// [`LWE ciphertext`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +/// $\mathsf{ct}\_{\mathsf{in}} = \left( +/// \vec{a}\_{\mathsf{in}} , b\_{\mathsf{in}}\right) \in +/// \mathsf{LWE}^{n\_{\mathsf{in}}}\_{\vec{s}\_{\mathsf{in}}}( \mathsf{pt} ) \subseteq +/// \mathbb{Z}\_q^{(n\_{\mathsf{in}}+1)}$ into an output +/// [`LWE ciphertext`](`LweCiphertextEntity`) +/// $\mathsf{ct}\_{\mathsf{out}} = \left( \vec{a}\_{\mathsf{out}} , +/// b\_{\mathsf{out}}\right) \in \mathsf{LWE}^{n\_{\mathsf{out}}}\_{\vec{s}\_{\mathsf{out}}}( +/// \mathsf{LUT(pt)} )\subseteq \mathbb{Z}\_q^{(n\_{\mathsf{out}}+1)}$ where $n\_{\mathsf{in}} = +/// |\vec{s}\_{\mathsf{in}}|$ and $n\_{\mathsf{out}} = |\vec{s}\_{\mathsf{out}}|$, such that the +/// noise in this latter is set to a fixed (reduced) amount. It requires a +/// [`bootstrapping key`](`LweBootstrapKeyEntity`). +/// +/// The input ciphertext is encrypted under the +/// [`LWE secret key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +/// $\vec{s}\_{\mathsf{in}}$ and the +/// output ciphertext is encrypted under the +/// [`LWE secret key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +/// $\vec{s}\_{\mathsf{out}}$. +/// +/// $$\mathsf{ct}\_{\mathsf{in}} \in \mathsf{LWE}^{n\_{\mathsf{in}}}\_{\vec{s}\_{\mathsf{in}}}( +/// \mathsf{pt} ) ~~~~~~~~~~\mathsf{BSK}\_{\vec{s}\_{\mathsf{in}}\rightarrow +/// \vec{S}\_{\mathsf{out}}}$$ $$ \mathsf{PBS}\left(\mathsf{ct}\_{\mathsf{in}} , \mathsf{BSK} +/// \right) \rightarrow \mathsf{ct}\_{\mathsf{out}} \in +/// \mathsf{LWE}^{n\_{\mathsf{out}}}\_{\vec{s}\_{\mathsf{out}}} \left( \mathsf{pt} \right)$$ +/// +/// ## Algorithm +/// ###### inputs: +/// - $\mathsf{ct}\_{\mathsf{in}} = \left( \vec{a}\_{\mathsf{in}} , b\_{\mathsf{in}}\right) \in +/// \mathsf{LWE}^{n\_{\mathsf{in}}}\_{\vec{s}\_{\mathsf{in}}}( \mathsf{pt} )$: an [`LWE +/// ciphertext`](`LweCiphertextEntity`) with $\vec{a}\_{\mathsf{in}}=\left(a\_0, \cdots +/// a\_{n\_{\mathsf{in}}-1}\right)$ +/// - $\mathsf{BSK}\_{\vec{s}\_{\mathsf{in}}\rightarrow \vec{S}\_{\mathsf{out}}}$: a bootstrapping +/// key as defined above +/// - $\mathsf{LUT} \in \mathcal{R}\_q$: a LUT represented as a polynomial \_with redundancy\_ +/// +/// ###### outputs: +/// - $\mathsf{ct}\_{\mathsf{out}} \in \mathsf{LWE}^{n\_{\mathsf{out}}}\_{\vec{s}\_{\mathsf{out}}} +/// \left( \mathsf{LUT(pt)} \right)$: an [`LWE +/// ciphertext`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +/// +/// ###### algorithm: +/// 1. Compute $\tilde{a}\_i \in \mathbb{Z}\_{2N\_{\mathsf{out}}} \leftarrow \lfloor \frac{2 +/// N\_{\mathsf{out}} \cdot a\_i}{q} \rceil$, for $i= 0, 1, \ldots, n\_{\mathsf{in}-1}$ 2. Compute +/// $\tilde{b}\_\mathsf{in} \in \mathbb{Z}\_{2N\_{\mathsf{out}}} \leftarrow \lfloor \frac{2 +/// N\_{\mathsf{out}} \cdot b\_\mathsf{in}}{q} \rceil$ 3. Set $\mathsf{ACC} = (0, \ldots, 0, +/// \mathsf{LUT} \cdot X^{-\tilde{b}\_\mathsf{in}})$ 4. Compute $\mathsf{ACC} = +/// \mathsf{CMux}(\overline{\overline{\mathsf{CT}\_i}}, \mathsf{ACC} \cdot X^{\tilde{a}\_i}, +/// \mathsf{ACC})$, for $i= 0, 1, \ldots, n\_{\mathsf{in}-1}$ 5. Output $\mathsf{ct}\_{\mathsf{out}} +/// \leftarrow \mathsf{SampleExtract}(\mathsf{ACC})$ +pub trait LweCiphertextDiscardingBootstrapEngine< + BootstrapKey, + Accumulator, + InputCiphertext, + OutputCiphertext, +>: AbstractEngine where + BootstrapKey: LweBootstrapKeyEntity, + Accumulator: GlweCiphertextEntity, + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, +{ + /// Bootstrap an LWE ciphertext . + fn discard_bootstrap_lwe_ciphertext( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + acc: &Accumulator, + bsk: &BootstrapKey, + ) -> Result<(), LweCiphertextDiscardingBootstrapError>; + + /// Unsafely bootstrap an LWE ciphertext . + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingBootstrapError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn discard_bootstrap_lwe_ciphertext_unchecked( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + acc: &Accumulator, + bsk: &BootstrapKey, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_conversion.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_conversion.rs new file mode 100644 index 000000000..67c15b43d --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_conversion.rs @@ -0,0 +1,56 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextDiscardingConversionError for LweCiphertextDiscardingConversionEngine @ + LweDimensionMismatch => "All the ciphertext LWE dimensions must be the same." +} + +impl LweCiphertextDiscardingConversionError { + /// Validates the inputs + pub fn perform_generic_checks(output: &Output, input: &Input) -> Result<(), Self> + where + Input: LweCiphertextEntity, + Output: LweCiphertextEntity, + { + if input.lwe_dimension() != output.lwe_dimension() { + return Err(Self::LweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines converting (discarding) LWE ciphertexts . +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext with +/// the conversion of the `input` LWE ciphertext to a type with a different representation (for +/// instance from cpu to gpu memory). +/// +/// # Formal Definition +pub trait LweCiphertextDiscardingConversionEngine: AbstractEngine +where + Input: LweCiphertextEntity, + Output: LweCiphertextEntity, +{ + /// Converts a LWE ciphertext . + fn discard_convert_lwe_ciphertext( + &mut self, + output: &mut Output, + input: &Input, + ) -> Result<(), LweCiphertextDiscardingConversionError>; + + /// Unsafely converts a LWE ciphertext . + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingConversionError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn discard_convert_lwe_ciphertext_unchecked( + &mut self, + output: &mut Output, + input: &Input, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_encryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_encryption.rs new file mode 100644 index 000000000..80b5c4556 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_encryption.rs @@ -0,0 +1,70 @@ +use super::engine_error; + +use crate::core_crypto::prelude::Variance; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{ + LweCiphertextEntity, LweSecretKeyEntity, PlaintextEntity, +}; + +engine_error! { + LweCiphertextDiscardingEncryptionError for LweCiphertextDiscardingEncryptionEngine @ + LweDimensionMismatch => "The secret key and ciphertext LWE dimensions must be the same." +} + +impl LweCiphertextDiscardingEncryptionError { + /// Validates the inputs + pub fn perform_generic_checks( + key: &SecretKey, + output: &Ciphertext, + ) -> Result<(), Self> + where + SecretKey: LweSecretKeyEntity, + Ciphertext: LweCiphertextEntity, + { + if key.lwe_dimension() != output.lwe_dimension() { + return Err(Self::LweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines encrypting (discarding) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext with +/// the encryption of the `input` plaintext, under the `key` secret key. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::engines::LweCiphertextEncryptionEngine`) +pub trait LweCiphertextDiscardingEncryptionEngine: + AbstractEngine +where + SecretKey: LweSecretKeyEntity, + Plaintext: PlaintextEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Encrypts an LWE ciphertext. + fn discard_encrypt_lwe_ciphertext( + &mut self, + key: &SecretKey, + output: &mut Ciphertext, + input: &Plaintext, + noise: Variance, + ) -> Result<(), LweCiphertextDiscardingEncryptionError>; + + /// Unsafely encrypts an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingEncryptionError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn discard_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &SecretKey, + output: &mut Ciphertext, + input: &Plaintext, + noise: Variance, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_keyswitch.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_keyswitch.rs new file mode 100644 index 000000000..e6ba533c2 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_keyswitch.rs @@ -0,0 +1,123 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; + +use crate::core_crypto::specification::entities::{LweCiphertextEntity, LweKeyswitchKeyEntity}; + +engine_error! { + LweCiphertextDiscardingKeyswitchError for LweCiphertextDiscardingKeyswitchEngine @ + InputLweDimensionMismatch => "The input ciphertext LWE dimension and keyswitch key input LWE \ + dimensions must be the same.", + OutputLweDimensionMismatch => "The output ciphertext LWE dimension and keyswitch output LWE \ + dimensions must be the same." +} + +impl LweCiphertextDiscardingKeyswitchError { + /// Validates the inputs + pub fn perform_generic_checks( + output: &OutputCiphertext, + input: &InputCiphertext, + ksk: &KeyswitchKey, + ) -> Result<(), Self> + where + KeyswitchKey: LweKeyswitchKeyEntity, + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, + { + if input.lwe_dimension() != ksk.input_lwe_dimension() { + return Err(Self::InputLweDimensionMismatch); + } + if output.lwe_dimension() != ksk.output_lwe_dimension() { + return Err(Self::OutputLweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines keyswitching (discarding) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext with +/// the keyswitch of the `input` LWE ciphertext, using the `ksk` LWE keyswitch key. +/// +/// # Formal Definition +/// +/// ## LWE Keyswitch +/// +/// This homomorphic procedure transforms an input +/// [`LWE ciphertext`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +/// $\mathsf{ct}\_{\mathsf{in}} = +/// \left( \vec{a}\_{\mathsf{in}} , b\_{\mathsf{in}}\right) \in \mathsf{LWE}^{n\_{\mathsf{in}}}\_ +/// {\vec{s}\_{\mathsf{in}}}( \mathsf{pt} ) \subseteq \mathbb{Z}\_q^{(n\_{\mathsf{in}}+1)}$ into an +/// output [`LWE +/// ciphertext`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +/// $\mathsf{ct}\_{\mathsf{out}} = +/// \left( \vec{a}\_{\mathsf{out}} , b\_{\mathsf{out}}\right) \in +/// \mathsf{LWE}^{n\_{\mathsf{out}}}\_{\vec{s}\_{\mathsf{out}}}( \mathsf{pt} )\subseteq +/// \mathbb{Z}\_q^{(n\_{\mathsf{out}}+1)}$ where $n\_{\mathsf{in}} = |\vec{s}\_{\mathsf{in}}|$ and +/// $n\_{\mathsf{out}} = |\vec{s}\_{\mathsf{out}}|$. It requires a +/// [`key switching +/// key`](`crate::core_crypto::specification::entities::LweKeyswitchKeyEntity`). +/// The input ciphertext is encrypted under the +/// [`LWE secret key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +/// $\vec{s}\_{\mathsf{in}}$ and the output ciphertext is +/// encrypted under the [`LWE secret +/// key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) $\vec{s}\_{\ +/// mathsf{out}}$. +/// +/// $$\mathsf{ct}\_{\mathsf{in}} \in \mathsf{LWE}^{n\_{\mathsf{in}}}\_{\vec{s}\_{\mathsf{in}}}( +/// \mathsf{pt} ) ~~~~~~~~~~\mathsf{KSK}\_{\vec{s}\_{\mathsf{in}}\rightarrow +/// \vec{s}\_{\mathsf{out}}}$$ $$ \mathsf{keyswitch}\left(\mathsf{ct}\_{\mathsf{in}} , \mathsf{KSK} +/// \right) \rightarrow \mathsf{ct}\_{\mathsf{out}} \in +/// \mathsf{LWE}^{n\_{\mathsf{out}}}\_{\vec{s}\_{\mathsf{out}}} \left( \mathsf{pt} \right)$$ +/// +/// ## Algorithm +/// ###### inputs: +/// - $\mathsf{ct}\_{\mathsf{in}} = \left( \vec{a}\_{\mathsf{in}} , b\_{\mathsf{in}}\right) \in +/// \mathsf{LWE}^{n\_{\mathsf{in}}}\_{\vec{s}\_{\mathsf{in}}}( \mathsf{pt} )$: an [`LWE +/// ciphertext`](`LweCiphertextEntity`) with $\vec{a}\_{\mathsf{in}}=\left(a\_0, \cdots +/// a\_{n\_{\mathsf{in}}-1}\right)$ +/// - $\mathsf{KSK}\_{\vec{s}\_{\mathsf{in}}\rightarrow \vec{s}\_{\mathsf{out}}}$: a +/// [`key switching +/// key`](`crate::core_crypto::specification::entities::LweKeyswitchKeyEntity`) +/// +/// ###### outputs: +/// - $\mathsf{ct}\_{\mathsf{out}} \in \mathsf{LWE}^{n\_{\mathsf{out}}}\_{\vec{s}\_{\mathsf{out}}} +/// \left( \mathsf{pt} \right)$: an +/// [`LWE ciphertext`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +/// +/// ###### algorithm: +/// 1. set $\mathsf{ct}=\left( 0 , \cdots , 0 , b\_{\mathsf{in}} \right) \in +/// \mathbb{Z}\_q^{(n\_{\mathsf{out}}+1)}$ +/// 2. compute $\mathsf{ct}\_{\mathsf{out}} = \mathsf{ct} - +/// \sum\_{i=0}^{n\_{\mathsf{in}}-1} \mathsf{decompProduct}\left( a\_i , \overline{\mathsf{ct}\_i} +/// \right)$ +/// 3. output $\mathsf{ct}\_{\mathsf{out}}$ +pub trait LweCiphertextDiscardingKeyswitchEngine: + AbstractEngine +where + KeyswitchKey: LweKeyswitchKeyEntity, + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, +{ + /// Keyswitch an LWE ciphertext. + fn discard_keyswitch_lwe_ciphertext( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + ksk: &KeyswitchKey, + ) -> Result<(), LweCiphertextDiscardingKeyswitchError>; + + /// Unsafely keyswitch an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingKeyswitchError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn discard_keyswitch_lwe_ciphertext_unchecked( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + ksk: &KeyswitchKey, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_public_key_encryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_public_key_encryption.rs new file mode 100644 index 000000000..65249cb98 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_discarding_public_key_encryption.rs @@ -0,0 +1,65 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{ + LweCiphertextEntity, LwePublicKeyEntity, PlaintextEntity, +}; + +engine_error! { + LweCiphertextDiscardingPublicKeyEncryptionError for LweCiphertextDiscardingPublicKeyEncryptionEngine @ + LweDimensionMismatch => "The public key and ciphertext LWE dimensions must be the same." +} + +impl LweCiphertextDiscardingPublicKeyEncryptionError { + /// Validates the inputs + pub fn perform_generic_checks( + key: &PublicKey, + output: &Ciphertext, + ) -> Result<(), Self> + where + PublicKey: LwePublicKeyEntity, + Ciphertext: LweCiphertextEntity, + { + if key.lwe_dimension() != output.lwe_dimension() { + return Err(Self::LweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines encrypting (discarding) LWE ciphertexts with a public key. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` LWE ciphertext with +/// the encryption of the `input` plaintext, using the public `key`. The ciphertext can be decrypted +/// by the secret key used to generate the public key. +/// +/// # Formal Definition +pub trait LweCiphertextDiscardingPublicKeyEncryptionEngine: + AbstractEngine +where + PublicKey: LwePublicKeyEntity, + Plaintext: PlaintextEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Encrypts an LWE ciphertext using a public key. + fn discard_encrypt_lwe_ciphertext_with_public_key( + &mut self, + key: &PublicKey, + output: &mut Ciphertext, + input: &Plaintext, + ) -> Result<(), LweCiphertextDiscardingPublicKeyEncryptionError>; + + /// Unsafely encrypts an LWE ciphertext using a public key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextDiscardingPublicKeyEncryptionError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn discard_encrypt_lwe_ciphertext_with_public_key_unchecked( + &mut self, + key: &PublicKey, + output: &mut Ciphertext, + input: &Plaintext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_encryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_encryption.rs new file mode 100644 index 000000000..304126cfd --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_encryption.rs @@ -0,0 +1,63 @@ +use super::engine_error; + +use crate::core_crypto::prelude::Variance; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{ + LweCiphertextEntity, LweSecretKeyEntity, PlaintextEntity, +}; + +engine_error! { + LweCiphertextEncryptionError for LweCiphertextEncryptionEngine @ +} + +/// A trait for engines encrypting LWE ciphertexts. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates an LWE ciphertext containing the +/// encryption of the `input` plaintext under the `key` secret key. +/// +/// # Formal Definition +/// +/// ## LWE Encryption +/// ###### inputs: +/// - $\mathsf{pt}\in\mathbb{Z}\_q$: a plaintext +/// - $\vec{s}\in\mathbb{Z}\_q^n$: a secret key +/// - $\mathcal{D\_{\sigma^2,\mu}}$: a normal distribution of variance $\sigma^2$ and a mean $\mu$ +/// +/// ###### outputs: +/// - $\mathsf{ct} = \left( \vec{a} , b\right) \in \mathsf{LWE}^n\_{\vec{s}}( \mathsf{pt} )\subseteq +/// \mathbb{Z}\_q^{(n+1)}$: an LWE ciphertext +/// +/// ###### algorithm: +/// 1. uniformly sample a vector $\vec{a}\in\mathbb{Z}\_q^n$ +/// 2. sample an integer error term $e \hookleftarrow \mathcal{D\_{\sigma^2,\mu}}$ +/// 3. compute $b = \left\langle \vec{a} , \vec{s} \right\rangle + \mathsf{pt} + e \in\mathbb{Z}\_q$ +/// 4. output $\left( \vec{a} , b\right)$ +pub trait LweCiphertextEncryptionEngine: AbstractEngine +where + SecretKey: LweSecretKeyEntity, + Plaintext: PlaintextEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Encrypts an LWE ciphertext. + fn encrypt_lwe_ciphertext( + &mut self, + key: &SecretKey, + input: &Plaintext, + noise: Variance, + ) -> Result>; + + /// Unsafely encrypts an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextEncryptionError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &SecretKey, + input: &Plaintext, + noise: Variance, + ) -> Ciphertext; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_addition.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_addition.rs new file mode 100644 index 000000000..ed930418d --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_addition.rs @@ -0,0 +1,61 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextFusingAdditionError for LweCiphertextFusingAdditionEngine @ + LweDimensionMismatch => "The input and output LWE dimensions must be the same." +} + +impl LweCiphertextFusingAdditionError { + /// Validates the inputs + pub fn perform_generic_checks( + output: &OutputCiphertext, + input: &InputCiphertext, + ) -> Result<(), Self> + where + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, + { + if output.lwe_dimension() != input.lwe_dimension() { + return Err(Self::LweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines adding (fusing) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [fusing](super#operation-semantics) operation adds the `input` LWE ciphertext to the +/// `output` LWE ciphertext. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::engines::LweCiphertextDiscardingAdditionEngine`) +pub trait LweCiphertextFusingAdditionEngine: + AbstractEngine +where + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, +{ + /// Adds an LWE ciphertext to an other. + fn fuse_add_lwe_ciphertext( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + ) -> Result<(), LweCiphertextFusingAdditionError>; + + /// Unsafely add an LWE ciphertext to an other. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextFusingAdditionError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn fuse_add_lwe_ciphertext_unchecked( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_opposite.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_opposite.rs new file mode 100644 index 000000000..412f51602 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_opposite.rs @@ -0,0 +1,34 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextFusingOppositeError for LweCiphertextFusingOppositeEngine @ +} + +/// A trait for engines computing the opposite (fusing) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [fusing](super#operation-semantics) operation computes the opposite of the `input` LWE +/// ciphertext. +/// +/// # Formal Definition +pub trait LweCiphertextFusingOppositeEngine: AbstractEngine +where + Ciphertext: LweCiphertextEntity, +{ + /// Computes the opposite of an LWE ciphertext. + fn fuse_opp_lwe_ciphertext( + &mut self, + input: &mut Ciphertext, + ) -> Result<(), LweCiphertextFusingOppositeError>; + + /// Unsafely computes the opposite of an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextFusingOppositeError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn fuse_opp_lwe_ciphertext_unchecked(&mut self, input: &mut Ciphertext); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_subtraction.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_subtraction.rs new file mode 100644 index 000000000..be1deccc7 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_fusing_subtraction.rs @@ -0,0 +1,59 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextEntity; + +engine_error! { + LweCiphertextFusingSubtractionError for LweCiphertextFusingSubtractionEngine @ + LweDimensionMismatch => "The input and output LWE dimensions must be the same." +} + +impl LweCiphertextFusingSubtractionError { + /// Validates the inputs + pub fn perform_generic_checks( + output: &OutputCiphertext, + input: &InputCiphertext, + ) -> Result<(), Self> + where + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, + { + if output.lwe_dimension() != input.lwe_dimension() { + return Err(Self::LweDimensionMismatch); + } + Ok(()) + } +} + +/// A trait for engines subtracting (fusing) LWE ciphertexts. +/// +/// # Semantics +/// +/// This [fusing](super#operation-semantics) operation subtracts the `input` LWE ciphertext to the +/// `output` LWE ciphertext. +/// +/// # Formal Definition +pub trait LweCiphertextFusingSubtractionEngine: + AbstractEngine +where + InputCiphertext: LweCiphertextEntity, + OutputCiphertext: LweCiphertextEntity, +{ + /// Subtracts an LWE ciphertext to an other. + fn fuse_sub_lwe_ciphertext( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + ) -> Result<(), LweCiphertextFusingSubtractionError>; + + /// Unsafely subtracts an LWE ciphertext to another. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextFusingSubtractionError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn fuse_sub_lwe_ciphertext_unchecked( + &mut self, + output: &mut OutputCiphertext, + input: &InputCiphertext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_plaintext_fusing_addition.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_plaintext_fusing_addition.rs new file mode 100644 index 000000000..0427fef2f --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_plaintext_fusing_addition.rs @@ -0,0 +1,41 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{LweCiphertextEntity, PlaintextEntity}; + +engine_error! { + LweCiphertextPlaintextFusingAdditionError for LweCiphertextPlaintextFusingAdditionEngine @ +} + +/// A trait for engines adding (fusing) plaintexts to LWE ciphertexts. +/// +/// # Semantics +/// +/// This [fusing](super#operation-semantics) operation adds the `input` plaintext to the `output` +/// LWE ciphertext. +/// +/// # Formal Definition +pub trait LweCiphertextPlaintextFusingAdditionEngine: + AbstractEngine +where + Plaintext: PlaintextEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Add a plaintext to an LWE ciphertext. + fn fuse_add_lwe_ciphertext_plaintext( + &mut self, + output: &mut Ciphertext, + input: &Plaintext, + ) -> Result<(), LweCiphertextPlaintextFusingAdditionError>; + + /// Unsafely add a plaintext to an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextPlaintextFusingAdditionError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn fuse_add_lwe_ciphertext_plaintext_unchecked( + &mut self, + output: &mut Ciphertext, + input: &Plaintext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_trivial_encryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_trivial_encryption.rs new file mode 100644 index 000000000..8ddf1ab56 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_trivial_encryption.rs @@ -0,0 +1,47 @@ +use super::engine_error; +use crate::core_crypto::prelude::LweSize; + +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{LweCiphertextEntity, PlaintextEntity}; + +engine_error! { + LweCiphertextTrivialEncryptionError for LweCiphertextTrivialEncryptionEngine @ +} + +/// A trait for engines trivially encrypting LWE ciphertext. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates anLWE ciphertext containing the +/// trivial encryption of the `input` plaintext with the requested `lwe_size`. +/// +/// # Formal Definition +/// +/// A trivial encryption uses a zero mask and no noise. +/// It is absolutely not secure, as the body contains a direct copy of the plaintext. +/// However, it is useful for some FHE algorithms taking public information as input. For +/// example, a trivial GLWE encryption of a public lookup table is used in the bootstrap. +pub trait LweCiphertextTrivialEncryptionEngine: AbstractEngine +where + Plaintext: PlaintextEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Trivially encrypts an LWE ciphertext. + fn trivially_encrypt_lwe_ciphertext( + &mut self, + lwe_size: LweSize, + input: &Plaintext, + ) -> Result>; + + /// Unsafely creates the trivial LWE encryption of the plaintext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextTrivialEncryptionError ]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn trivially_encrypt_lwe_ciphertext_unchecked( + &mut self, + lwe_size: LweSize, + input: &Plaintext, + ) -> Ciphertext; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_consuming_retrieval.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_consuming_retrieval.rs new file mode 100644 index 000000000..9894c0c5d --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_consuming_retrieval.rs @@ -0,0 +1,39 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextVectorEntity; + +engine_error! { + LweCiphertextVectorConsumingRetrievalError for LweCiphertextVectorConsumingRetrievalEngine @ +} + +/// A trait for engines retrieving the content of the container from an LWE ciphertext +/// vector consuming it in the process. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation retrieves the content of the container from the +/// `input` LWE ciphertext vector consuming it in the process. +pub trait LweCiphertextVectorConsumingRetrievalEngine: + AbstractEngine +where + CiphertextVector: LweCiphertextVectorEntity, +{ + /// Retrieves the content of the container from an LWE ciphertext vector, consuming it in the + /// process. + fn consume_retrieve_lwe_ciphertext_vector( + &mut self, + ciphertext: CiphertextVector, + ) -> Result>; + + /// Unsafely retrieves the content of the container from an LWE ciphertext vector, consuming + /// it in the process. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextVectorConsumingRetrievalError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn consume_retrieve_lwe_ciphertext_vector_unchecked( + &mut self, + ciphertext: CiphertextVector, + ) -> Container; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_creation.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_creation.rs new file mode 100644 index 000000000..c63368272 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_creation.rs @@ -0,0 +1,51 @@ +use super::engine_error; +use crate::core_crypto::prelude::LweSize; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweCiphertextVectorEntity; + +engine_error! { + LweCiphertextVectorCreationError for LweCiphertextVectorCreationEngine @ + EmptyContainer => "The container used to create the LWE ciphertext vector is of length 0!" +} + +impl LweCiphertextVectorCreationError { + /// Validates the inputs + pub fn perform_generic_checks(container_length: usize) -> Result<(), Self> { + if container_length == 0 { + return Err(Self::EmptyContainer); + } + Ok(()) + } +} + +/// A trait for engines creating an LWE ciphertext vector from an arbitrary container. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation creates an LWE ciphertext vector from the +/// arbitrary `container`. By arbitrary here, we mean that `Container` can be any type that +/// allows to +/// instantiate an `LweCiphertextVectorEntity`. +pub trait LweCiphertextVectorCreationEngine: AbstractEngine +where + CiphertextVector: LweCiphertextVectorEntity, +{ + /// Creates an LWE ciphertext from an arbitrary container. + fn create_lwe_ciphertext_vector_from( + &mut self, + container: Container, + lwe_size: LweSize, + ) -> Result>; + + /// Unsafely creates an LWE ciphertext vector from an arbitrary container. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextVectorCreationError`]. For safety concerns _specific_ to an engine, + /// refer to the implementer safety section. + unsafe fn create_lwe_ciphertext_vector_from_unchecked( + &mut self, + container: Container, + lwe_size: LweSize, + ) -> CiphertextVector; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing.rs new file mode 100644 index 000000000..480413e1c --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing.rs @@ -0,0 +1,189 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{ + LweBootstrapKeyEntity, LweCiphertextVectorEntity, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, PlaintextVectorEntity, +}; +use crate::core_crypto::specification::parameters::{ + DecompositionBaseLog, DecompositionLevelCount, +}; + +engine_error! { + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError for + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingEngine @ + NullDecompositionBaseLog => "The circuit bootstrap decomposition base log must be greater \ + than zero.", + NullDecompositionLevelCount => "The circuit bootstrap decomposition level count must be \ + greater than zero.", + DecompositionTooLarge => "The decomposition precision (base log * level count) must not exceed \ + the precision of the ciphertext.", + KeysLweDimensionMismatch => "The bootstrap key output LWE dimension must be the same as the \ + input LWE dimension of the circuit bootstrap private functional \ + packing keyswitch keys.", + InputLweDimensionMismatch => "The input ciphertexts LWE dimension must be the same as the \ + bootstrap key input LWE dimension.", + OutputLweDimensionMismatch => "The output ciphertexts LWE dimension must be the same as the \ + `cbs_pfpksk` output GLWE dimension times its output polynomial \ + size.", + MalformedLookUpTables => "The input `luts` must have a size divisible by the circuit bootstrap \ + private functional packing keyswitch keys output polynomial size \ + times the number of output ciphertexts. This is required to get \ + small look-up tables of polynomials of the same size for each \ + output ciphertext.", + InvalidSmallLookUpTableSize => "The number of polynomials times the polynomial size in a small \ + look-up table must be equal to 2 to the power the number of \ + input ciphertexts encrypting bits." +} + +impl + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError +{ + /// Validates the inputs + #[allow(clippy::too_many_arguments)] + pub fn perform_generic_checks< + Input: LweCiphertextVectorEntity, + Output: LweCiphertextVectorEntity, + BootstrapKey: LweBootstrapKeyEntity, + LUTs: PlaintextVectorEntity, + CBSPFPKSK: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, + >( + input: &Input, + output: &Output, + bsk: &BootstrapKey, + luts: &LUTs, + cbs_decomposition_level_count: DecompositionLevelCount, + cbs_decomposition_base_log: DecompositionBaseLog, + cbs_pfpksk: &CBSPFPKSK, + ciphertext_modulus_log: usize, + ) -> Result<(), Self> { + if bsk.output_lwe_dimension() != cbs_pfpksk.input_lwe_dimension() { + return Err(Self::KeysLweDimensionMismatch); + } + if input.lwe_dimension() != bsk.input_lwe_dimension() { + return Err(Self::InputLweDimensionMismatch); + } + if output.lwe_dimension().0 + != cbs_pfpksk.output_glwe_dimension().0 * cbs_pfpksk.output_polynomial_size().0 + { + return Err(Self::OutputLweDimensionMismatch); + } + + let lut_polynomial_size = cbs_pfpksk.output_polynomial_size().0; + if luts.plaintext_count().0 % (lut_polynomial_size * output.lwe_ciphertext_count().0) != 0 { + return Err(Self::MalformedLookUpTables); + } + + let small_lut_size = luts.plaintext_count().0 / output.lwe_ciphertext_count().0; + if small_lut_size < lut_polynomial_size { + return Err(Self::InvalidSmallLookUpTableSize); + } + + if cbs_decomposition_level_count.0 == 0 { + return Err(Self::NullDecompositionBaseLog); + } + if cbs_decomposition_level_count.0 == 0 { + return Err(Self::NullDecompositionLevelCount); + } + if cbs_decomposition_base_log.0 * cbs_decomposition_level_count.0 > ciphertext_modulus_log { + return Err(Self::DecompositionTooLarge); + } + Ok(()) + } +} + +/// A trait for engines performing a (discarding) boolean circuit bootstrapping followed by a +/// vertical packing on LWE ciphertext vectors. The term "boolean" refers to the fact the input +/// ciphertexts encrypt a single bit of message. +/// +/// The provided "big" `luts` look-up table is expected to be divisible into the same number of +/// chunks of polynomials as there are ciphertexts in the `output` LweCiphertextVector. Each chunk +/// of polynomials is used as a look-up table to evaluate during the vertical packing operation to +/// fill an output ciphertext. +/// +/// Note that there should be enough polynomials provided in each chunk to perform the vertical +/// packing given the number of boolean input ciphertexts. The number of boolean input ciphertexts +/// is in fact a number of bits. For this example let's say we have 16 input ciphertexts +/// representing 16 bits and want to output 4 ciphertexts. The "big" `luts` will need to be +/// divisible into 4 chunks of equal size. If the polynomial size used is $1024 = 2^{10}$ then each +/// chunk must contain $2^6 = 64$ polynomials ($2^6 * 2^{10} = 2^{16}$) to match the amount of +/// values representable by the 16 input ciphertexts each encrypting a bit. The "big" `luts` then +/// has a layout looking as follows: +/// +/// ```text +/// small lut for 1st output ciphertext|...|small lut for 4th output ciphertext +/// |[polynomial 1] ... [polynomial 64]|...|[polynomial 1] ... [polynomial 64]| +/// ``` +/// +/// The polynomials in the above representation are not necessarily the same, this is just for +/// illustration purposes. +/// +/// It is also possible in the above example to have a single polynomial of size $2^{16} = 65 536$ +/// for each chunk if the polynomial size is supported for computation (which is not the case for 65 +/// 536 at the moment for implemented backends). Chunks containing a single polynomial of size +/// $2^{10} = 1024$ would work for example for 10 input ciphertexts as that polynomial size is +/// supported for computations. The "big" `luts` layout would then look as follows for that 10 bits +/// example (still with 4 output ciphertexts): +/// +/// ```text +/// small lut for 1st output ciphertext|...|small lut for 4th output ciphertext +/// |[ polynomial 1 ]|...|[ polynomial 1 ]| +/// ``` +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation first performs the circuit bootstrapping +/// on all boolean (i.e. containing only 1 bit of message) input LWE ciphertexts from the `input` +/// vector. It then fills the `output` LWE ciphertext vector with the result of the vertical packing +/// operation applied on the output of the circuit bootstrapping, using the provided look-up table. +/// +/// # Formal Definition +pub trait LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingEngine< + Input, + Output, + BootstrapKey, + LUTs, + CirctuiBootstrapFunctionalPackingKeyswitchKeys, +>: AbstractEngine where + Input: LweCiphertextVectorEntity, + Output: LweCiphertextVectorEntity, + BootstrapKey: LweBootstrapKeyEntity, + LUTs: PlaintextVectorEntity, + CirctuiBootstrapFunctionalPackingKeyswitchKeys: + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, +{ + /// Performs the circuit bootstrapping on all boolean input LWE ciphertexts followed by vertical + /// packing using the provided look-up table. + #[allow(clippy::too_many_arguments)] + fn discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector( + &mut self, + output: &mut Output, + input: &Input, + bsk: &BootstrapKey, + luts: &LUTs, + cbs_level_count: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + cbs_pfpksk: &CirctuiBootstrapFunctionalPackingKeyswitchKeys, + ) -> Result< + (), + LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError, + >; + + /// Unsafely performs the circuit bootstrapping on all boolean input LWE ciphertexts followed by + /// vertical packing using the provided look-up table. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextVectorDiscardingCircuitBootstrapBooleanVerticalPackingError`]. For safety + /// concerns _specific_ to an engine, refer to the implementer safety section. + #[allow(clippy::too_many_arguments)] + unsafe fn discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector_unchecked( + &mut self, + output: &mut Output, + input: &Input, + bsk: &BootstrapKey, + luts: &LUTs, + cbs_level_count: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + cbs_pfpksk: &CirctuiBootstrapFunctionalPackingKeyswitchKeys, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_zero_encryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_zero_encryption.rs new file mode 100644 index 000000000..2d57e1d54 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_vector_zero_encryption.rs @@ -0,0 +1,58 @@ +use super::engine_error; +use crate::core_crypto::prelude::{LweCiphertextCount, Variance}; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{LweCiphertextVectorEntity, LweSecretKeyEntity}; + +engine_error! { + LweCiphertextVectorZeroEncryptionError for LweCiphertextVectorZeroEncryptionEngine @ + NullCiphertextCount => "The ciphertext count must be greater than zero." +} + +impl LweCiphertextVectorZeroEncryptionError { + /// Validates the inputs + pub fn perform_generic_checks(count: LweCiphertextCount) -> Result<(), Self> { + if count.0 == 0 { + return Err(Self::NullCiphertextCount); + } + Ok(()) + } +} + +/// A trait for engines encrypting zero in LWE ciphertext vectors. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates an LWE ciphertext vector containing +/// encryptions of zeros, under the `key` secret key. +/// +/// # Formal Definition +/// +/// This generates a vector of [`LWE encryption`] +/// (`crate::core_crypto::specification::engines::LweCiphertextEncryptionEngine`) of zero. +pub trait LweCiphertextVectorZeroEncryptionEngine: + AbstractEngine +where + SecretKey: LweSecretKeyEntity, + CiphertextVector: LweCiphertextVectorEntity, +{ + /// Encrypts zeros in an LWE ciphertext vector. + fn zero_encrypt_lwe_ciphertext_vector( + &mut self, + key: &SecretKey, + noise: Variance, + count: LweCiphertextCount, + ) -> Result>; + + /// Unsafely encrypts zeros in an LWE ciphertext vector. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextVectorZeroEncryptionError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn zero_encrypt_lwe_ciphertext_vector_unchecked( + &mut self, + key: &SecretKey, + noise: Variance, + count: LweCiphertextCount, + ) -> CiphertextVector; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_zero_encryption.rs b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_zero_encryption.rs new file mode 100644 index 000000000..48d931520 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_ciphertext_zero_encryption.rs @@ -0,0 +1,45 @@ +use super::engine_error; + +use crate::core_crypto::prelude::Variance; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{LweCiphertextEntity, LweSecretKeyEntity}; + +engine_error! { + LweCiphertextZeroEncryptionError for LweCiphertextZeroEncryptionEngine @ +} + +/// A trait for engines encrypting zero in LWE ciphertexts. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates an LWE ciphertext containing an +/// encryption of zero, under the `key` secret key. +/// +/// # Formal Definition +/// +/// This generates a [`LWE encryption`] +/// (`crate::core_crypto::specification::engines::LweCiphertextEncryptionEngine`) of zero. +pub trait LweCiphertextZeroEncryptionEngine: AbstractEngine +where + SecretKey: LweSecretKeyEntity, + Ciphertext: LweCiphertextEntity, +{ + /// Encrypts zero into an LWE ciphertext. + fn zero_encrypt_lwe_ciphertext( + &mut self, + key: &SecretKey, + noise: Variance, + ) -> Result>; + + /// Safely encrypts zero into an LWE ciphertext. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCiphertextZeroEncryptionError`]. For safety concerns _specific_ to an engine, refer + /// to the implementer safety section. + unsafe fn zero_encrypt_lwe_ciphertext_unchecked( + &mut self, + key: &SecretKey, + noise: Variance, + ) -> Ciphertext; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs b/tfhe/src/core_crypto/specification/engines/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs new file mode 100644 index 000000000..09a70ec56 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation.rs @@ -0,0 +1,87 @@ +use super::engine_error; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweSecretKeyEntity, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, LweSecretKeyEntity, Variance, +}; +use crate::core_crypto::specification::engines::AbstractEngine; + +engine_error! { + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError for + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine @ + NullDecompositionBaseLog => "The key decomposition base log must be greater than zero.", + NullDecompositionLevelCount => "The key decomposition level count must be greater than zero.", + DecompositionTooLarge => "The decomposition precision (base log * level count) must not exceed \ + the precision of the ciphertext." +} + +impl + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError +{ + /// Validates the inputs + pub fn perform_generic_checks( + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + ciphertext_modulus_log: usize, + ) -> Result<(), Self> { + if decomposition_base_log.0 == 0 { + return Err(Self::NullDecompositionBaseLog); + } + + if decomposition_level_count.0 == 0 { + return Err(Self::NullDecompositionLevelCount); + } + + if decomposition_level_count.0 * decomposition_base_log.0 > ciphertext_modulus_log { + return Err(Self::DecompositionTooLarge); + } + + Ok(()) + } +} + +/// A trait for engines generating new LWE functional packing keyswitch keys used in a circuit +/// bootstrapping. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a new set of LWE private functional +/// packing keyswitch key required to perform a circuit bootstrapping. +/// +/// # Formal Definition +pub trait LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationEngine< + InputLweSecretKey, + OutputGlweSecretKey, + CBSFPKSK, +>: AbstractEngine where + InputLweSecretKey: LweSecretKeyEntity, + OutputGlweSecretKey: GlweSecretKeyEntity, + CBSFPKSK: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity, +{ + /// Generate a new LWE CBSFPKSK. + fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &mut self, + input_lwe_key: &InputLweSecretKey, + output_glwe_key: &OutputGlweSecretKey, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> Result< + CBSFPKSK, + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError, + >; + + /// Unsafely generate a new LWE CBSFPKSK. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysGenerationError`]. For safety + /// concerns _specific_ to an engine, refer to the implementer safety section. + unsafe fn generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_unchecked( + &mut self, + input_lwe_key: &InputLweSecretKey, + output_glwe_key: &OutputGlweSecretKey, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + noise: Variance, + ) -> CBSFPKSK; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_keyswitch_key_conversion.rs b/tfhe/src/core_crypto/specification/engines/lwe_keyswitch_key_conversion.rs new file mode 100644 index 000000000..aa228ee50 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_keyswitch_key_conversion.rs @@ -0,0 +1,36 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweKeyswitchKeyEntity; + +engine_error! { + LweKeyswitchKeyConversionError for LweKeyswitchKeyConversionEngine @ +} + +/// A trait for engines converting LWE keyswitch keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a LWE keyswitch key containing the +/// conversion of the `input` LWE keyswitch key to a type with a different representation (for +/// instance from cpu to gpu memory). +/// +/// # Formal Definition +pub trait LweKeyswitchKeyConversionEngine: AbstractEngine +where + Input: LweKeyswitchKeyEntity, + Output: LweKeyswitchKeyEntity, +{ + /// Converts a LWE keyswitch key. + fn convert_lwe_keyswitch_key( + &mut self, + input: &Input, + ) -> Result>; + + /// Unsafely converts a LWE keyswitch key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweKeyswitchKeyConversionError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn convert_lwe_keyswitch_key_unchecked(&mut self, input: &Input) -> Output; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_keyswitch_key_generation.rs b/tfhe/src/core_crypto/specification/engines/lwe_keyswitch_key_generation.rs new file mode 100644 index 000000000..231a2097f --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_keyswitch_key_generation.rs @@ -0,0 +1,79 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; + +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, Variance}; +use crate::core_crypto::specification::entities::{LweKeyswitchKeyEntity, LweSecretKeyEntity}; + +engine_error! { + LweKeyswitchKeyGenerationError for LweKeyswitchKeyGenerationEngine @ + NullDecompositionBaseLog => "The key decomposition base log must be greater than zero.", + NullDecompositionLevelCount => "The key decomposition level count must be greater than zero.", + DecompositionTooLarge => "The decomposition precision (base log * level count) must not exceed \ + the precision of the ciphertext." +} + +impl LweKeyswitchKeyGenerationError { + /// Validates the inputs + pub fn perform_generic_checks( + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + ciphertext_modulus_log: usize, + ) -> Result<(), Self> { + if decomposition_base_log.0 == 0 { + return Err(Self::NullDecompositionBaseLog); + } + + if decomposition_level_count.0 == 0 { + return Err(Self::NullDecompositionLevelCount); + } + + if decomposition_level_count.0 * decomposition_base_log.0 > ciphertext_modulus_log { + return Err(Self::DecompositionTooLarge); + } + + Ok(()) + } +} + +/// A trait for engines generating new LWE keyswitch keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a new LWE keyswitch key allowing to +/// switch from the `input_key` LWE secret key to the `output_key` LWE secret key. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::entities::LweKeyswitchKeyEntity`) +pub trait LweKeyswitchKeyGenerationEngine: + AbstractEngine +where + InputSecretKey: LweSecretKeyEntity, + OutputSecretKey: LweSecretKeyEntity, + KeyswitchKey: LweKeyswitchKeyEntity, +{ + /// Generates a new LWE keyswitch key. + fn generate_new_lwe_keyswitch_key( + &mut self, + input_key: &InputSecretKey, + output_key: &OutputSecretKey, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + noise: Variance, + ) -> Result>; + + /// Unsafely generates a new LWE keyswitch key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweKeyswitchKeyGenerationError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn generate_new_lwe_keyswitch_key_unchecked( + &mut self, + input_key: &InputSecretKey, + output_key: &OutputSecretKey, + decomposition_level_count: DecompositionLevelCount, + decomposition_base_log: DecompositionBaseLog, + noise: Variance, + ) -> KeyswitchKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_public_key_generation.rs b/tfhe/src/core_crypto/specification/engines/lwe_public_key_generation.rs new file mode 100644 index 000000000..b5800c9e5 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_public_key_generation.rs @@ -0,0 +1,57 @@ +use crate::core_crypto::prelude::{LwePublicKeyZeroEncryptionCount, Variance}; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::{LwePublicKeyEntity, LweSecretKeyEntity}; + +engine_error! { + LwePublicKeyGenerationError for LwePublicKeyGenerationEngine @ + NullPublicKeyZeroEncryptionCount => "The number of LWE encryptions of zero in the LWE public \ + key must be greater than zero." +} + +impl LwePublicKeyGenerationError { + /// Validates the inputs + pub fn perform_generic_checks( + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> Result<(), Self> { + if lwe_public_key_zero_encryption_count.0 == 0 { + return Err(Self::NullPublicKeyZeroEncryptionCount); + } + Ok(()) + } +} + +/// A trait for engines generating new LWE public keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a new LWE public key. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::entities::LwePublicKeyEntity`) +pub trait LwePublicKeyGenerationEngine: AbstractEngine +where + SecretKey: LweSecretKeyEntity, + PublicKey: LwePublicKeyEntity, +{ + /// Generates a new LWE public key. + fn generate_new_lwe_public_key( + &mut self, + lwe_secret_key: &SecretKey, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> Result>; + + /// Unsafely generates a new LWE public key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LwePublicKeyGenerationError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn generate_new_lwe_public_key_unchecked( + &mut self, + lwe_secret_key: &SecretKey, + noise: Variance, + lwe_public_key_zero_encryption_count: LwePublicKeyZeroEncryptionCount, + ) -> PublicKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/lwe_secret_key_generation.rs b/tfhe/src/core_crypto/specification/engines/lwe_secret_key_generation.rs new file mode 100644 index 000000000..e9790d05c --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/lwe_secret_key_generation.rs @@ -0,0 +1,49 @@ +use crate::core_crypto::prelude::LweDimension; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::LweSecretKeyEntity; + +engine_error! { + LweSecretKeyGenerationError for LweSecretKeyGenerationEngine @ + NullLweDimension => "The LWE dimension must be greater than zero." +} + +impl LweSecretKeyGenerationError { + /// Validates the inputs + pub fn perform_generic_checks(lwe_dimension: LweDimension) -> Result<(), Self> { + if lwe_dimension.0 == 0 { + return Err(Self::NullLweDimension); + } + Ok(()) + } +} + +/// A trait for engines generating new LWE secret keys. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a new LWE secret key. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +pub trait LweSecretKeyGenerationEngine: AbstractEngine +where + SecretKey: LweSecretKeyEntity, +{ + /// Generates a new LWE secret key. + fn generate_new_lwe_secret_key( + &mut self, + lwe_dimension: LweDimension, + ) -> Result>; + + /// Unsafely generates a new LWE secret key. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`LweSecretKeyGenerationError`]. For safety concerns _specific_ to an + /// engine, refer to the implementer safety section. + unsafe fn generate_new_lwe_secret_key_unchecked( + &mut self, + lwe_dimension: LweDimension, + ) -> SecretKey; +} diff --git a/tfhe/src/core_crypto/specification/engines/mod.rs b/tfhe/src/core_crypto/specification/engines/mod.rs new file mode 100644 index 000000000..b3c3eb360 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/mod.rs @@ -0,0 +1,235 @@ +//! A module containing specifications of FHE engines. +//! +//! In essence, __engines__ are types which can be used to perform operations on fhe entities. These +//! engines contain all the side-resources needed to execute the operations they declare. +//! An engine must implement at least the [`AbstractEngine`] super-trait, and can implement any +//! number of `*Engine` traits. +//! +//! Every fhe operation is defined by a `*Engine` operation trait which always expose two entry +//! points: +//! +//! + A safe entry point, returning a result, with an [operation-dedicated](#engine-errors) error. +//! When using this entry point, the user relies on the backend to check that the necessary +//! preconditions are verified by the inputs, at the cost of a small overhead. +//! + An unsafe entry point, returning the raw result if any. When using this entry point, it is the +//! user responsibility to ensure that the necessary preconditions are verified by the inputs. +//! Breaking one of those preconditions will result in either a panic, or an FHE UB. +//! +//! # Engine errors +//! +//! Implementing the [`AbstractEngine`] trait for a given type implies specifying an associated +//! [`EngineError`](`AbstractEngine::EngineError`) which should be able to represent all the +//! possible error cases specific to this engine. +//! +//! Each `*Engine` trait is associated with a specialized `*Error` type (for example +//! [`LweCiphertextDiscardingKeyswitchError`] is associated with +//! [`LweCiphertextDiscardingKeyswitchEngine`]), which contains: +//! +//! + Multiple __general__ error variants which can be potentially produced by any backend +//! (see the +//! [`LweCiphertextDiscardingKeyswitchError::InputLweDimensionMismatch`] variant for an example) +//! + One __specific__ variant which encapsulate the generic argument error `E` +//! (see the [`Engine`](`LweCiphertextDiscardingKeyswitchError::Engine`) variant for an example) +//! +//! When implementing a particular `*Engine` trait, this `E` argument will be forced to be the +//! [`EngineError`](`AbstractEngine::EngineError`) from the [`AbstractEngine`] super-trait, by the +//! signature of the operation entry point +//! (see [`LweCiphertextDiscardingKeyswitchEngine::discard_keyswitch_lwe_ciphertext`] for instance). +//! +//! This design makes it possible for each operation, to match the error exhaustively against both +//! general error variants, and backend-related error variants. +//! +//! # A word about Generation and Creation engines +//! +//! We have two families of engines to make entities: +//! - Generation engines which generate new entities with non trivial algorithms, e.g. a bootstrap +//! key generation +//! - Creation engines which wrap/re-interpret data to create entities from them without involving +//! non trivial algorithms, like creating a `Cleartext64` from a `u64` by simply wrapping the +//! value. +//! +//! # Operation semantics +//! +//! For each possible operation, we try to support the three following semantics: +//! +//! + __Pure operations__ take their inputs as arguments, allocate an object +//! holding the result, and return it (example: [`LweCiphertextEncryptionEngine`]). They usually +//! require more resources than other, because of the allocation. +//! + __Discarding operations__ take both their inputs and outputs as arguments +//! (example: [`LweCiphertextDiscardingAdditionEngine`]). In those operations, the data originally +//! available in the outputs is not used for the computation. They are usually the fastest ones. +//! + __Fusing operations__ take both their inputs and outputs as arguments +//! (example: [`LweCiphertextFusingAdditionEngine`]). In those operations though, the data +//! originally contained in the output is used for computation. + +// This makes it impossible for types outside this crate to implement operations. +pub(crate) mod sealed { + pub trait AbstractEngineSeal {} +} + +/// A top-level abstraction for engines. +/// +/// An `AbstractEngine` is nothing more than a type with an associated error type +/// [`EngineError`](`AbstractEngine::EngineError`) and a default constructor. +/// +/// The associated error type is expected to encode all the failure cases which can occur while +/// using an engine. +pub trait AbstractEngine: sealed::AbstractEngineSeal { + // # Why put the error type in an abstract super trait ? + // + // This error is supposed to be reduced to only engine related errors, and not ones related to + // the operations. For this reason, it is better for an engine to only have one error shared + // among all the operations. If a variant of this error can only be triggered for a single + // operation implemented by the engine, then it should probably be moved upstream, in the + // operation-dedicated error. + + /// The error associated to the engine. + type EngineError: std::error::Error; + + /// The constructor parameters type. + type Parameters; + + /// A constructor for the engine. + fn new(parameter: Self::Parameters) -> Result + where + Self: Sized; +} + +macro_rules! engine_error { + ($name:ident for $trait:ident @) => { + #[doc=concat!("An error used with the [`", stringify!($trait), "`] trait.")] + #[non_exhaustive] + #[derive(Debug, Clone, Eq, PartialEq)] + pub enum $name { + #[doc="_Specific_ error to the implementing engine."] + Engine(EngineError), + } + impl std::fmt::Display for $name{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Engine(error) => write!(f, "Error occurred in the engine: {}", error), + } + } + } + impl std::error::Error for $name{} + }; + ($name:ident for $trait:ident @ $($variants:ident => $messages:literal),*) => { + #[doc=concat!("An error used with the [`", stringify!($trait), "`] trait.")] + #[doc=""] + #[doc="This type provides a "] + #[doc=concat!("[`", stringify!($name), "::perform_generic_checks`] ")] + #[doc="function that does error checking for the general cases, returning an `Ok(())` "] + #[doc="if the inputs are valid, meaning that engine implementors would then only "] + #[doc="need to check for their own specific errors."] + #[doc="Otherwise an `Err(..)` with the proper error variant is returned."] + #[non_exhaustive] + #[derive(Debug, Clone, Eq, PartialEq)] + pub enum $name { + $( + #[doc="_Generic_ error: "] + #[doc=$messages] + $variants, + )* + #[doc="_Specific_ error to the implementing engine."] + Engine(EngineError), + } + impl std::fmt::Display for $name{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + $( + Self::$variants => write!(f, $messages), + )* + Self::Engine(error) => write!(f, "Error occurred in the engine: {}", error), + } + } + } + impl std::error::Error for $name{} + }; +} +pub(crate) use engine_error; + +mod cleartext_creation; +mod entity_deserialization; +mod entity_serialization; +mod glwe_ciphertext_consuming_retrieval; +mod glwe_ciphertext_conversion; +mod glwe_ciphertext_creation; +mod glwe_ciphertext_trivial_encryption; +mod glwe_secret_key_generation; +mod glwe_to_lwe_secret_key_transformation; +mod lwe_bootstrap_key_conversion; +mod lwe_bootstrap_key_generation; +mod lwe_ciphertext_cleartext_fusing_multiplication; +mod lwe_ciphertext_consuming_retrieval; +mod lwe_ciphertext_conversion; +mod lwe_ciphertext_creation; +mod lwe_ciphertext_decryption; +mod lwe_ciphertext_discarding_addition; +mod lwe_ciphertext_discarding_bit_extraction; +mod lwe_ciphertext_discarding_bootstrap; +mod lwe_ciphertext_discarding_conversion; +mod lwe_ciphertext_discarding_encryption; +mod lwe_ciphertext_discarding_keyswitch; +mod lwe_ciphertext_discarding_public_key_encryption; +mod lwe_ciphertext_encryption; +mod lwe_ciphertext_fusing_addition; +mod lwe_ciphertext_fusing_opposite; +mod lwe_ciphertext_fusing_subtraction; +mod lwe_ciphertext_plaintext_fusing_addition; +mod lwe_ciphertext_trivial_encryption; +mod lwe_ciphertext_vector_consuming_retrieval; +mod lwe_ciphertext_vector_creation; +mod lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing; +mod lwe_ciphertext_vector_zero_encryption; +mod lwe_ciphertext_zero_encryption; +mod lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation; +mod lwe_keyswitch_key_conversion; +mod lwe_keyswitch_key_generation; +mod lwe_public_key_generation; +mod lwe_secret_key_generation; +mod plaintext_creation; +mod plaintext_discarding_retrieval; +mod plaintext_vector_creation; + +pub use cleartext_creation::*; +pub use entity_deserialization::*; +pub use entity_serialization::*; +pub use glwe_ciphertext_consuming_retrieval::*; +pub use glwe_ciphertext_conversion::*; +pub use glwe_ciphertext_creation::*; +pub use glwe_ciphertext_trivial_encryption::*; +pub use glwe_secret_key_generation::*; +pub use glwe_to_lwe_secret_key_transformation::*; +pub use lwe_bootstrap_key_conversion::*; +pub use lwe_bootstrap_key_generation::*; +pub use lwe_ciphertext_cleartext_fusing_multiplication::*; +pub use lwe_ciphertext_consuming_retrieval::*; +pub use lwe_ciphertext_conversion::*; +pub use lwe_ciphertext_creation::*; +pub use lwe_ciphertext_decryption::*; +pub use lwe_ciphertext_discarding_addition::*; +pub use lwe_ciphertext_discarding_bit_extraction::*; +pub use lwe_ciphertext_discarding_bootstrap::*; +pub use lwe_ciphertext_discarding_conversion::*; +pub use lwe_ciphertext_discarding_encryption::*; +pub use lwe_ciphertext_discarding_keyswitch::*; +pub use lwe_ciphertext_discarding_public_key_encryption::*; +pub use lwe_ciphertext_encryption::*; +pub use lwe_ciphertext_fusing_addition::*; +pub use lwe_ciphertext_fusing_opposite::*; +pub use lwe_ciphertext_fusing_subtraction::*; +pub use lwe_ciphertext_plaintext_fusing_addition::*; +pub use lwe_ciphertext_trivial_encryption::*; +pub use lwe_ciphertext_vector_consuming_retrieval::*; +pub use lwe_ciphertext_vector_creation::*; +pub use lwe_ciphertext_vector_discarding_circuit_bootstrap_boolean_vertical_packing::*; +pub use lwe_ciphertext_vector_zero_encryption::*; +pub use lwe_ciphertext_zero_encryption::*; +pub use lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_generation::*; +pub use lwe_keyswitch_key_conversion::*; +pub use lwe_keyswitch_key_generation::*; +pub use lwe_public_key_generation::*; +pub use lwe_secret_key_generation::*; +pub use plaintext_creation::*; +pub use plaintext_discarding_retrieval::*; +pub use plaintext_vector_creation::*; diff --git a/tfhe/src/core_crypto/specification/engines/plaintext_creation.rs b/tfhe/src/core_crypto/specification/engines/plaintext_creation.rs new file mode 100644 index 000000000..0bfd3f8f1 --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/plaintext_creation.rs @@ -0,0 +1,36 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::PlaintextEntity; + +engine_error! { + PlaintextCreationError for PlaintextCreationEngine @ +} + +/// A trait for engines creating plaintexts from an arbitrary value. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a plaintext from the `value` +/// arbitrary value. By arbitrary here, we mean that `Value` can be any type that suits the backend +/// implementor (an integer, a struct wrapping integers, a struct wrapping foreign data or any other +/// thing). +/// +/// # Formal Definition +pub trait PlaintextCreationEngine: AbstractEngine +where + Plaintext: PlaintextEntity, +{ + /// Creates a plaintext from an arbitrary value. + fn create_plaintext_from( + &mut self, + value: &Value, + ) -> Result>; + + /// Unsafely creates a plaintext from an arbitrary value. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`PlaintextCreationError`]. For safety concerns _specific_ to an engine, refer to the + /// implementer safety section. + unsafe fn create_plaintext_from_unchecked(&mut self, value: &Value) -> Plaintext; +} diff --git a/tfhe/src/core_crypto/specification/engines/plaintext_discarding_retrieval.rs b/tfhe/src/core_crypto/specification/engines/plaintext_discarding_retrieval.rs new file mode 100644 index 000000000..68cc2f4ce --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/plaintext_discarding_retrieval.rs @@ -0,0 +1,41 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::PlaintextEntity; + +engine_error! { + PlaintextDiscardingRetrievalError for PlaintextDiscardingRetrievalEngine @ +} + +/// A trait for engines retrieving (discarding) arbitrary values from plaintexts. +/// +/// # Semantics +/// +/// This [discarding](super#operation-semantics) operation fills the `output` arbitrary value with +/// the retrieval of the `input` plaintext value. By arbitrary here, we mean that `Value` can be any +/// type that suits the backend implementor (an integer, a struct wrapping integers, a struct +/// wrapping foreign data or any other thing). +/// +/// # Formal Definition +pub trait PlaintextDiscardingRetrievalEngine: AbstractEngine +where + Plaintext: PlaintextEntity, +{ + /// Retrieves an arbitrary value from a plaintext inplace. + fn discard_retrieve_plaintext( + &mut self, + output: &mut Value, + input: &Plaintext, + ) -> Result<(), PlaintextDiscardingRetrievalError>; + + /// Unsafely retrieves an arbitrary value from a plaintext inplace. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`PlaintextDiscardingRetrievalError`]. For safety concerns _specific_ to an engine, refer + /// to the implementer safety section. + unsafe fn discard_retrieve_plaintext_unchecked( + &mut self, + output: &mut Value, + input: &Plaintext, + ); +} diff --git a/tfhe/src/core_crypto/specification/engines/plaintext_vector_creation.rs b/tfhe/src/core_crypto/specification/engines/plaintext_vector_creation.rs new file mode 100644 index 000000000..13e874d0f --- /dev/null +++ b/tfhe/src/core_crypto/specification/engines/plaintext_vector_creation.rs @@ -0,0 +1,50 @@ +use super::engine_error; +use crate::core_crypto::specification::engines::AbstractEngine; +use crate::core_crypto::specification::entities::PlaintextVectorEntity; + +engine_error! { + PlaintextVectorCreationError for PlaintextVectorCreationEngine @ + EmptyInput => "The input slice must not be empty." +} + +impl PlaintextVectorCreationError { + /// Validates the inputs + pub fn perform_generic_checks(values: &[Value]) -> Result<(), Self> { + if values.is_empty() { + return Err(Self::EmptyInput); + } + Ok(()) + } +} + +/// A trait for engines creating plaintext vectors from arbitrary values. +/// +/// # Semantics +/// +/// This [pure](super#operation-semantics) operation generates a plaintext vector from the `values` +/// slice of arbitrary values. By arbitrary here, we mean that `Value` can be any type that suits +/// the backend implementor (an integer, a struct wrapping integers, a struct wrapping foreign data +/// or any other thing). +/// +/// # Formal Definition +pub trait PlaintextVectorCreationEngine: AbstractEngine +where + PlaintextVector: PlaintextVectorEntity, +{ + /// Creates a plaintext vector from a slice of arbitrary values. + fn create_plaintext_vector_from( + &mut self, + values: &[Value], + ) -> Result>; + + /// Unsafely creates a plaintext vector from a slice of arbitrary values. + /// + /// # Safety + /// For the _general_ safety concerns regarding this operation, refer to the different variants + /// of [`PlaintextVectorCreationError`]. For safety concerns _specific_ to an engine, refer to + /// the implementer safety section. + unsafe fn create_plaintext_vector_from_unchecked( + &mut self, + values: &[Value], + ) -> PlaintextVector; +} diff --git a/tfhe/src/core_crypto/specification/entities/cleartext.rs b/tfhe/src/core_crypto/specification/entities/cleartext.rs new file mode 100644 index 000000000..9cef3c660 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/cleartext.rs @@ -0,0 +1,7 @@ +use crate::core_crypto::specification::entities::markers::CleartextKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying a cleartext entity. +/// +/// # Formal Definition +pub trait CleartextEntity: AbstractEntity {} diff --git a/tfhe/src/core_crypto/specification/entities/glwe_ciphertext.rs b/tfhe/src/core_crypto/specification/entities/glwe_ciphertext.rs new file mode 100644 index 000000000..859e4e7f6 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/glwe_ciphertext.rs @@ -0,0 +1,40 @@ +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; +use crate::core_crypto::specification::entities::markers::GlweCiphertextKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying a GLWE ciphertext. +/// +/// **Remark:** GLWE ciphertexts generalize LWE ciphertexts by definition, however in this library, +/// GLWE +/// ciphertext entities do not generalize LWE ciphertexts, i.e., polynomial size cannot be 1. +/// +/// # Formal Definition +/// +/// ## GLWE Ciphertext +/// +/// A GLWE ciphertext is an encryption of a polynomial plaintext. +/// It is secure under the hardness assumption called General Learning With Errors (GLWE). +/// It is a generalization of both +/// [`LWE ciphertexts`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +/// and RLWE ciphertexts. GLWE requires a cyclotomic ring. +/// We use the notation $\mathcal{R}\_q$ for the following cyclotomic ring: +/// $\mathbb{Z}\_q\[X\]/\left\langle X^N + 1\right\rangle$ where $N\in\mathbb{N}$ is a power of two. +/// +/// We call $q$ the ciphertext modulus and $N$ the ring dimension. +/// +/// We indicate a GLWE ciphertext of a plaintext $\mathsf{PT} \in\mathcal{R}\_q^{k+1}$ as the +/// following couple: $$\mathsf{CT} = \left( \vec{A}, B\right) = \left( A\_0, \ldots, A\_{k-1}, +/// B\right) \in \mathsf{GLWE}\_{\vec{S}} \left( \mathsf{PT} \right) \subseteq +/// \mathcal{R}\_q^{k+1}$$ +/// +/// ## Generalisation of LWE and RLWE +/// +/// When we set $k=1$ a GLWE ciphertext becomes an RLWE ciphertext. +/// When we set $N=1$ a GLWE ciphertext becomes an LWE ciphertext with $n=k$. +pub trait GlweCiphertextEntity: AbstractEntity { + /// Returns the GLWE dimension of the ciphertext. + fn glwe_dimension(&self) -> GlweDimension; + + /// Returns the polynomial size of the ciphertext. + fn polynomial_size(&self) -> PolynomialSize; +} diff --git a/tfhe/src/core_crypto/specification/entities/glwe_secret_key.rs b/tfhe/src/core_crypto/specification/entities/glwe_secret_key.rs new file mode 100644 index 000000000..993187fca --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/glwe_secret_key.rs @@ -0,0 +1,22 @@ +use crate::core_crypto::prelude::{GlweDimension, PolynomialSize}; +use crate::core_crypto::specification::entities::markers::GlweSecretKeyKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying a GLWE secret key. +/// +/// # Formal Definition +/// +/// ## GLWE Secret Key +/// +/// We consider a secret key: +/// $$\vec{S} =\left( S\_0, \ldots, S\_{k-1}\right) \in \mathcal{R}^{k}$$ +/// The $k$ polynomials composing $\vec{S}$ contain each $N$ integers coefficients that have been +/// sampled from some distribution which is either uniformly binary, uniformly ternary, gaussian or +/// even uniform. +pub trait GlweSecretKeyEntity: AbstractEntity { + /// Returns the GLWE dimension of the key. + fn glwe_dimension(&self) -> GlweDimension; + + /// Returns the polynomial size of the key. + fn polynomial_size(&self) -> PolynomialSize; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_bootstrap_key.rs b/tfhe/src/core_crypto/specification/entities/lwe_bootstrap_key.rs new file mode 100644 index 000000000..89648f82e --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_bootstrap_key.rs @@ -0,0 +1,60 @@ +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, +}; +use crate::core_crypto::specification::entities::markers::LweBootstrapKeyKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying an LWE bootstrap key. +/// +/// # Formal Definition +/// +/// ## Bootstrapping Key +/// A bootstrapping key is a vector of GGSW ciphertexts. It encrypts the coefficients of the +/// [`LWE secret key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +/// $\vec{s}\_{\ mathsf{in}}$ under the +/// [GLWE secret key](`crate::core_crypto::specification::entities::GlweSecretKeyEntity`) +/// $\vec{S}\_{\ mathsf{out}}$. +/// +/// $$\mathsf{BSK}\_{\vec{s}\_{\mathsf{in}}\rightarrow \vec{S}\_{\mathsf{out}}} = \left( +/// \overline{\overline{\mathsf{CT}\_0}}, \cdots , +/// \overline{\overline{\mathsf{CT}\_{n\_{\mathsf{in}}-1}}}\right) \subseteq +/// \mathbb{Z}\_q^{(n\_{\mathsf{out}}+1)\cdot n\_{\mathsf{in}}}$$ +/// +/// where $\vec{s}\_{\mathsf{in}} = \left( s\_0 , \cdots , s\_{\mathsf{in}-1} \right)$ and for all +/// $0\le i { + /// Returns the GLWE dimension of the key. + fn glwe_dimension(&self) -> GlweDimension; + + /// Returns the polynomial size of the key. + fn polynomial_size(&self) -> PolynomialSize; + + /// Returns the input LWE dimension of the key. + fn input_lwe_dimension(&self) -> LweDimension; + + /// Returns the output LWE dimension of the key. + fn output_lwe_dimension(&self) -> LweDimension { + LweDimension(self.glwe_dimension().0 * self.polynomial_size().0) + } + + /// Returns the number of decomposition levels of the key. + fn decomposition_base_log(&self) -> DecompositionBaseLog; + + /// Returns the logarithm of the base used in the key. + fn decomposition_level_count(&self) -> DecompositionLevelCount; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_ciphertext.rs b/tfhe/src/core_crypto/specification/entities/lwe_ciphertext.rs new file mode 100644 index 000000000..eb532fe19 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_ciphertext.rs @@ -0,0 +1,29 @@ +use crate::core_crypto::prelude::LweDimension; +use crate::core_crypto::specification::entities::markers::LweCiphertextKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying an LWE ciphertext. +/// +/// # Formal Definition +/// +/// ## LWE Ciphertext +/// +/// An LWE ciphertext is an encryption of a plaintext. +/// It is secure under the hardness assumption called Learning With Errors (LWE). +/// It is a specialization of +/// [`GLWE ciphertext`](`crate::core_crypto::specification::entities::GlweCiphertextEntity`). +/// +/// We indicate an LWE ciphertext of a plaintext $\mathsf{pt} \in\mathbb{Z}\_q$ as the following +/// couple: $$\mathsf{ct} = \left( \vec{a} , b\right) \in \mathsf{LWE}^n\_{\vec{s}}( \mathsf{pt} +/// )\subseteq \mathbb{Z}\_q^{(n+1)}$$ We call $q$ the ciphertext modulus and $n$ the LWE dimension. +/// +/// ## LWE dimension +/// It corresponds to the number of element in the LWE secret key. +/// In an LWE ciphertext, it is the length of the vector $\vec{a}$. +/// At [`encryption`](`crate::core_crypto::specification::engines::LweCiphertextEncryptionEngine`) +/// time, it is the number of uniformly random +/// integers generated. +pub trait LweCiphertextEntity: AbstractEntity { + /// Returns the LWE dimension of the ciphertext. + fn lwe_dimension(&self) -> LweDimension; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_ciphertext_vector.rs b/tfhe/src/core_crypto/specification/entities/lwe_ciphertext_vector.rs new file mode 100644 index 000000000..3c61ffe90 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_ciphertext_vector.rs @@ -0,0 +1,16 @@ +use crate::core_crypto::prelude::{LweCiphertextCount, LweDimension}; +use crate::core_crypto::specification::entities::markers::LweCiphertextVectorKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying an LWE ciphertext vector. +/// +/// # Formal Definition +/// +/// cf [`here`](`crate::core_crypto::specification::entities::LweCiphertextEntity`) +pub trait LweCiphertextVectorEntity: AbstractEntity { + /// Returns the LWE dimension of the ciphertexts. + fn lwe_dimension(&self) -> LweDimension; + + /// Returns the number of ciphertexts contained in the vector. + fn lwe_ciphertext_count(&self) -> LweCiphertextCount; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys.rs b/tfhe/src/core_crypto/specification/entities/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys.rs new file mode 100644 index 000000000..5d340dce6 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys.rs @@ -0,0 +1,32 @@ +use crate::core_crypto::prelude::markers::LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysKind; +use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, + GlweDimension, LweDimension, PolynomialSize, +}; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying a private functional packing keyswitch key vector used +/// for circuit bootstrapping. +/// +/// # Formal Definition +pub trait LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysEntity: + AbstractEntity +{ + /// Returns the input LWE dimension of the keys. + fn input_lwe_dimension(&self) -> LweDimension; + + /// Returns the output GLWE dimension of the keys. + fn output_glwe_dimension(&self) -> GlweDimension; + + /// Returns the output polynomial degree of the keys. + fn output_polynomial_size(&self) -> PolynomialSize; + + /// Returns the number of decomposition levels of the keys. + fn decomposition_level_count(&self) -> DecompositionLevelCount; + + /// Returns the logarithm of the base used in the keys. + fn decomposition_base_log(&self) -> DecompositionBaseLog; + + /// Returns the number of keys contained in the vector. + fn key_count(&self) -> FunctionalPackingKeyswitchKeyCount; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_keyswitch_key.rs b/tfhe/src/core_crypto/specification/entities/lwe_keyswitch_key.rs new file mode 100644 index 000000000..32fdf91fc --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_keyswitch_key.rs @@ -0,0 +1,37 @@ +use crate::core_crypto::prelude::{DecompositionBaseLog, DecompositionLevelCount, LweDimension}; +use crate::core_crypto::specification::entities::markers::LweKeyswitchKeyKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying an LWE keyswitch key. +/// +/// # Formal Definition +/// +/// ## Key Switching Key +/// +/// A key switching key is a vector of Lev ciphertexts. +/// It encrypts the coefficient of +/// the [`LWE secret key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +/// $\vec{s}\_{\mathsf{in}}$ under the +/// [`LWE secret key`](`crate::core_crypto::specification::entities::LweSecretKeyEntity`) +/// $\vec{s}\_{\mathsf{out}}$. +/// +/// $$\mathsf{KSK}\_{\vec{s}\_{\mathsf{in}}\rightarrow \vec{s}\_{\mathsf{out}}} = \left( +/// \overline{\mathsf{ct}\_0}, \cdots , \overline{\mathsf{ct}\_{n\_{\mathsf{in}}-1}}\right) +/// \subseteq \mathbb{Z}\_q^{(n\_{\mathsf{out}}+1)\cdot n\_{\mathsf{in}}}$$ +/// +/// where $\vec{s}\_{\mathsf{in}} = \left( s\_0 , \cdots , s\_{\mathsf{in}-1} \right)$ and for all +/// $0\le i { + /// Returns the input LWE dimension of the key. + fn input_lwe_dimension(&self) -> LweDimension; + + /// Returns the output lew dimension of the key. + fn output_lwe_dimension(&self) -> LweDimension; + + /// Returns the number of decomposition levels of the key. + fn decomposition_level_count(&self) -> DecompositionLevelCount; + + /// Returns the logarithm of the base used in the key. + fn decomposition_base_log(&self) -> DecompositionBaseLog; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_public_key.rs b/tfhe/src/core_crypto/specification/entities/lwe_public_key.rs new file mode 100644 index 000000000..58f06fd30 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_public_key.rs @@ -0,0 +1,20 @@ +use crate::core_crypto::prelude::{LweDimension, LwePublicKeyZeroEncryptionCount}; +use crate::core_crypto::specification::entities::markers::LwePublicKeyKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying an LWE public key. +/// +/// # Formal Definition +/// +/// ## LWE Public Key +/// +/// An LWE public key contains $m$ LWE encryptions of 0 under a secret key +/// $\vec{s}\in\mathbb{Z}\_q^n$ where $n$ is the LWE dimension of the ciphertexts contained in the +/// public key. +pub trait LwePublicKeyEntity: AbstractEntity { + /// Returns the LWE dimension of the key. + fn lwe_dimension(&self) -> LweDimension; + + /// Returns the number of LWE encryption of 0 in the key. + fn lwe_zero_encryption_count(&self) -> LwePublicKeyZeroEncryptionCount; +} diff --git a/tfhe/src/core_crypto/specification/entities/lwe_secret_key.rs b/tfhe/src/core_crypto/specification/entities/lwe_secret_key.rs new file mode 100644 index 000000000..396aed647 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/lwe_secret_key.rs @@ -0,0 +1,18 @@ +use crate::core_crypto::prelude::LweDimension; +use crate::core_crypto::specification::entities::markers::LweSecretKeyKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying an LWE secret key. +/// +/// # Formal Definition +/// +/// ## LWE Secret Key +/// +/// We consider a secret key: +/// $$\vec{s} \in \mathbb{Z}^n$$ +/// This vector contains $n$ integers that have been sampled for some distribution which is either +/// uniformly binary, uniformly ternary, gaussian or even uniform. +pub trait LweSecretKeyEntity: AbstractEntity { + /// Returns the LWE dimension of the key. + fn lwe_dimension(&self) -> LweDimension; +} diff --git a/tfhe/src/core_crypto/specification/entities/markers.rs b/tfhe/src/core_crypto/specification/entities/markers.rs new file mode 100644 index 000000000..45692746b --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/markers.rs @@ -0,0 +1,58 @@ +//! A module containing various marker traits used for entities. +use std::fmt::Debug; + +/// A trait implemented by marker types encoding the __kind__ of an FHE entity in +/// the type system. +/// +/// By _kind_ here, we mean the _what_, the abstract nature of an FHE entity. +/// +/// # Note +/// +/// [`EntityKindMarker`] types are only defined in the specification part of the library, and +/// can not be defined by a backend. +pub trait EntityKindMarker: seal::EntityKindMarkerSealed {} +macro_rules! entity_kind_marker { + (@ $name: ident => $doc: literal)=>{ + #[doc=$doc] + #[derive(Debug, Clone, Copy)] + pub struct $name{} + impl seal::EntityKindMarkerSealed for $name{} + impl EntityKindMarker for $name{} + }; + ($($name: ident => $doc: literal),+) =>{ + $( + entity_kind_marker!(@ $name => $doc); + )+ + } +} +entity_kind_marker! { + PlaintextKind + => "An empty type representing the plaintext kind in the type system.", + PlaintextVectorKind + => "An empty type representing the plaintext vector kind in the type system", + CleartextKind + => "An empty type representing the cleartext kind in the type system.", + LweCiphertextKind + => "An empty type representing the LWE ciphertext kind in the type system.", + LweCiphertextVectorKind + => "An empty type representing the LWE ciphertext vector kind in the type system.", + GlweCiphertextKind + => "An empty type representing the GLWE ciphertext kind in the type system.", + LwePublicKeyKind + => "An empty type representing the LWE public key kind in the type system.", + LweSecretKeyKind + => "An empty type representing the LWE secret key kind in the type system.", + GlweSecretKeyKind + => "An empty type representing the GLWE secret key kind in the type system.", + LweKeyswitchKeyKind + => "An empty type representing the LWE keyswitch key kind in the type system.", + LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeysKind + => "An empty type representing the private functional packing keyswitch key vector \ + used for a circuit bootstrap in the type system.", + LweBootstrapKeyKind + => "An empty type representing the LWE bootstrap key kind in the type system." +} + +pub(crate) mod seal { + pub trait EntityKindMarkerSealed {} +} diff --git a/tfhe/src/core_crypto/specification/entities/mod.rs b/tfhe/src/core_crypto/specification/entities/mod.rs new file mode 100644 index 000000000..80130920b --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/mod.rs @@ -0,0 +1,61 @@ +//! A module containing specifications of FHE entities. +//! +//! In practice, __Entities__ are types which implement: +//! +//! + The [`AbstractEntity`] super-trait. +//! + One of the `*Entity` traits. + +pub mod markers; + +use markers::*; +use std::fmt::Debug; + +/// A top-level abstraction for entities. +/// +/// An `AbstractEntity` type is nothing more than a type with an associated +/// [`Kind`](`AbstractEntity::Kind`) marker type (implementing the [`EntityKindMarker`] trait), +/// which encodes in the type system, the abstract nature of the object. +pub trait AbstractEntity: Debug { + // # Why associated types and not generic parameters ? + // + // With generic parameters you can have one type implement a variety of abstract entity. With + // associated types, a type can only implement one abstract entity. Hence, using generic + // parameters, would encourage broadly generic types representing various entities (say an + // array) while using associated types encourages narrowly defined types representing a single + // entity. We think it is preferable for the user if the backends expose narrowly defined + // types, as it makes the api cleaner and the signatures leaner. The downside is probably a bit + // more boilerplate though. + // + // Also, this prevents a single type to implement different downstream traits (a type being both + // a GGSW ciphertext vector and an LWE bootstrap key). Again, I think this is for the best, as + // it will help us design better backend-level apis. + + /// The _kind_ of the entity. + type Kind: EntityKindMarker; +} + +mod cleartext; +mod glwe_ciphertext; +mod glwe_secret_key; +mod lwe_bootstrap_key; +mod lwe_ciphertext; +mod lwe_ciphertext_vector; +mod lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys; +mod lwe_keyswitch_key; +mod lwe_public_key; +mod lwe_secret_key; +mod plaintext; +mod plaintext_vector; + +pub use cleartext::*; +pub use glwe_ciphertext::*; +pub use glwe_secret_key::*; +pub use lwe_bootstrap_key::*; +pub use lwe_ciphertext::*; +pub use lwe_ciphertext_vector::*; +pub use lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys::*; +pub use lwe_keyswitch_key::*; +pub use lwe_public_key::*; +pub use lwe_secret_key::*; +pub use plaintext::*; +pub use plaintext_vector::*; diff --git a/tfhe/src/core_crypto/specification/entities/plaintext.rs b/tfhe/src/core_crypto/specification/entities/plaintext.rs new file mode 100644 index 000000000..c60b57d5c --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/plaintext.rs @@ -0,0 +1,7 @@ +use crate::core_crypto::specification::entities::markers::PlaintextKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying a plaintext. +/// +/// # Formal Definition +pub trait PlaintextEntity: AbstractEntity {} diff --git a/tfhe/src/core_crypto/specification/entities/plaintext_vector.rs b/tfhe/src/core_crypto/specification/entities/plaintext_vector.rs new file mode 100644 index 000000000..e93d0b290 --- /dev/null +++ b/tfhe/src/core_crypto/specification/entities/plaintext_vector.rs @@ -0,0 +1,11 @@ +use crate::core_crypto::prelude::PlaintextCount; +use crate::core_crypto::specification::entities::markers::PlaintextVectorKind; +use crate::core_crypto::specification::entities::AbstractEntity; + +/// A trait implemented by types embodying a plaintext vector. +/// +/// # Formal Definition +pub trait PlaintextVectorEntity: AbstractEntity { + /// Returns the number of plaintext contained in the vector. + fn plaintext_count(&self) -> PlaintextCount; +} diff --git a/tfhe/src/core_crypto/specification/key_kinds.rs b/tfhe/src/core_crypto/specification/key_kinds.rs new file mode 100644 index 000000000..ed31f97fd --- /dev/null +++ b/tfhe/src/core_crypto/specification/key_kinds.rs @@ -0,0 +1,44 @@ +//! This module contains types to manage the different kinds of secret keys. +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// This type is a marker for keys using binary elements as scalar. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct BinaryKeyKind; +/// This type is a marker for keys using ternary elements as scalar. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct TernaryKeyKind; +/// This type is a marker for keys using normaly sampled elements as scalar. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct GaussianKeyKind; +/// This type is a marker for keys using uniformly sampled elements as scalar. +#[derive(Clone, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct UniformKeyKind; + +#[derive(Clone)] +/// This type is a marker for keys filled with zeros (used for testing) +pub struct ZeroKeyKind; + +/// Secret keys can be based on different kinds of scalar values (put aside the +/// data type eventually used to store it in memory). This trait is implemented by marker types, +/// which are used to specify in the type system, what kind of keys we are currently using. +pub trait KeyKind: seal::SealedKeyKind + Sync + Clone {} + +impl KeyKind for BinaryKeyKind {} +impl KeyKind for TernaryKeyKind {} +impl KeyKind for GaussianKeyKind {} +impl KeyKind for UniformKeyKind {} +impl KeyKind for ZeroKeyKind {} + +mod seal { + pub trait SealedKeyKind {} + impl SealedKeyKind for super::BinaryKeyKind {} + impl SealedKeyKind for super::TernaryKeyKind {} + impl SealedKeyKind for super::GaussianKeyKind {} + impl SealedKeyKind for super::UniformKeyKind {} + impl SealedKeyKind for super::ZeroKeyKind {} +} diff --git a/tfhe/src/core_crypto/specification/mod.rs b/tfhe/src/core_crypto/specification/mod.rs new file mode 100644 index 000000000..263aba310 --- /dev/null +++ b/tfhe/src/core_crypto/specification/mod.rs @@ -0,0 +1,29 @@ +//! A module containing the specification for the backends of the implemented FHE scheme. +//! +//! A backend is expected to provide access to two different families of objects: +//! +//! + __Entities__ which are FHE objects you can manipulate with the library (the data). +//! + __Engines__ which are types you can use to operate on entities (the operators). +//! +//! The specification contains traits for both entities and engines which are then implemented in +//! the backend modules. +//! +//! This module also contains common tools for the crate +//! +//! # Dispersion +//! This module contains the functions used to compute the variance, standard +//! deviation, etc. +//! +//! # Key kinds +//! This module contains types to manage the different kinds of secret keys. +//! +//! # Parameters +//! This module contains structures that wrap unsigned integer parameters like the ciphertext +//! dimension or the polynomial degree. + +pub mod engines; +pub mod entities; + +pub mod dispersion; +pub mod key_kinds; +pub mod parameters; diff --git a/tfhe/src/core_crypto/specification/parameters.rs b/tfhe/src/core_crypto/specification/parameters.rs new file mode 100644 index 000000000..51347cbbc --- /dev/null +++ b/tfhe/src/core_crypto/specification/parameters.rs @@ -0,0 +1,210 @@ +#![allow(deprecated)] +#[cfg(feature = "__commons_serialization")] +use serde::{Deserialize, Serialize}; + +/// The number plaintexts in a plaintext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct PlaintextCount(pub usize); + +/// The number encoder in an encoder list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct EncoderCount(pub usize); + +/// The number messages in a messages list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct CleartextCount(pub usize); + +/// The number of ciphertexts in a ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct CiphertextCount(pub usize); + +/// The number of ciphertexts in an lwe ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LweCiphertextCount(pub usize); + +/// The index of a ciphertext in an lwe ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LweCiphertextIndex(pub usize); + +/// The range of indices of multiple contiguous ciphertexts in an lwe ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LweCiphertextRange(pub usize, pub usize); + +impl LweCiphertextRange { + pub fn is_ordered(&self) -> bool { + self.1 <= self.0 + } +} + +/// The number of ciphertexts in a glwe ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct GlweCiphertextCount(pub usize); + +/// The number of ciphertexts in a gsw ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct GswCiphertextCount(pub usize); + +/// The number of ciphertexts in a ggsw ciphertext list. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct GgswCiphertextCount(pub usize); + +/// The number of scalars in an LWE ciphertext, i.e. the number of scalar in an LWE mask plus one. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LweSize(pub usize); + +impl LweSize { + /// Returns the associated [`LweDimension`]. + pub fn to_lwe_dimension(&self) -> LweDimension { + LweDimension(self.0 - 1) + } +} + +/// The number of scalar in an LWE mask, or the length of an LWE secret key. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LweDimension(pub usize); + +impl LweDimension { + /// Returns the associated [`LweSize`]. + pub fn to_lwe_size(&self) -> LweSize { + LweSize(self.0 + 1) + } +} + +/// The number of LWE encryptions of 0 in an LWE public key. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LwePublicKeyZeroEncryptionCount(pub usize); + +/// The number of polynomials in a GLWE ciphertext, i.e. the number of polynomials in a GLWE mask +/// plus one. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct GlweSize(pub usize); + +impl GlweSize { + /// Returns the associated [`GlweDimension`]. + pub fn to_glwe_dimension(&self) -> GlweDimension { + GlweDimension(self.0 - 1) + } +} + +/// The number of polynomials of an GLWE mask, or the size of an GLWE secret key. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct GlweDimension(pub usize); + +impl GlweDimension { + /// Returns the associated [`GlweSize`]. + pub fn to_glwe_size(&self) -> GlweSize { + GlweSize(self.0 + 1) + } +} + +/// The number of coefficients of a polynomial. +/// +/// Assuming a polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns $N$. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct PolynomialSize(pub usize); + +impl PolynomialSize { + /// Returns the associated [`PolynomialSizeLog`]. + pub fn log2(&self) -> PolynomialSizeLog { + PolynomialSizeLog((self.0 as f64).log2().ceil() as usize) + } +} + +/// The logarithm of the number of coefficients of a polynomial. +/// +/// Assuming a polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns $\log\_2(N)$. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct PolynomialSizeLog(pub usize); + +impl PolynomialSizeLog { + /// Returns the associated [`PolynomialSizeLog`]. + pub fn to_polynomial_size(&self) -> PolynomialSize { + PolynomialSize(1 << self.0) + } +} + +/// The number of polynomials in a polynomial list. +/// +/// Assuming a polynomial list, this return the number of polynomials. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct PolynomialCount(pub usize); + +/// The degree of a monomial. +/// +/// Assuming a monomial $aX^N$, this returns the $N$ value. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +#[deprecated(note = "MonomialDegree is not used anymore in the API. You should not use it.")] +pub struct MonomialDegree(pub usize); + +/// The index of a monomial in a polynomial. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct MonomialIndex(pub usize); + +/// The logarithm of the base used in a decomposition. +/// +/// When decomposing an integer over powers of the $2^B$ basis, this type represents the $B$ value. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct DecompositionBaseLog(pub usize); + +/// The number of levels used in a decomposition. +/// +/// When decomposing an integer over the $l$ largest powers of the basis, this type represents +/// the $l$ value. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct DecompositionLevelCount(pub usize); + +/// The logarithm of the number of LUT evaluated in a PBS. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct LutCountLog(pub usize); + +/// The number of MSB shifted in a Modulus Switch. +/// +/// When performing a Modulus Switch, this type represents the number of MSB that will be +/// discarded. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct ModulusSwitchOffset(pub usize); + +/// The base 2 logarithm of the scaling factor (generally written $\Delta$) used to store the +/// message in the MSB of ciphertexts. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct DeltaLog(pub usize); + +/// The number of bits to extract in a bit extraction. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct ExtractedBitsCount(pub usize); + +/// The number of functional packing keyswitch key in a functional packing keyswitch key list. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct FunctionalPackingKeyswitchKeyCount(pub usize); + +/// The number of bits used for the mask coefficients and the body of a ciphertext +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "__commons_serialization", derive(Serialize, Deserialize))] +pub struct CiphertextModulusLog(pub usize); diff --git a/tfhe/src/js_on_wasm_api/boolean.rs b/tfhe/src/js_on_wasm_api/boolean.rs new file mode 100644 index 000000000..9117c6f70 --- /dev/null +++ b/tfhe/src/js_on_wasm_api/boolean.rs @@ -0,0 +1,194 @@ +use bincode; +use wasm_bindgen::prelude::*; + +use super::js_wasm_seeder; + +use std::panic::set_hook; + +#[wasm_bindgen] +pub struct BooleanCiphertext(pub(crate) crate::boolean::ciphertext::Ciphertext); + +#[wasm_bindgen] +pub struct BooleanClientKey(pub(crate) crate::boolean::client_key::ClientKey); + +#[wasm_bindgen] +pub struct BooleanPublicKey(pub(crate) crate::boolean::public_key::PublicKey); + +#[wasm_bindgen] +pub struct Boolean {} + +#[wasm_bindgen] +pub struct BooleanParameters(pub(crate) crate::boolean::parameters::BooleanParameters); + +#[wasm_bindgen] +pub enum BooleanParameterSet { + Default, + TfheLib, +} + +impl TryFrom for BooleanParameterSet { + type Error = String; + + fn try_from(value: u32) -> Result { + match value { + 0 => Ok(BooleanParameterSet::Default), + 1 => Ok(BooleanParameterSet::TfheLib), + _ => Err(format!( + "Invalid value '{value}' for BooleansParametersSet, use \ + BooleanParameterSet constants" + )), + } + } +} + +#[wasm_bindgen] +impl Boolean { + #[wasm_bindgen] + pub fn get_boolean_parameters(parameter_choice: u32) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + let parameter_choice = BooleanParameterSet::try_from(parameter_choice) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str()))?; + + match parameter_choice { + BooleanParameterSet::Default => Ok(crate::boolean::parameters::DEFAULT_PARAMETERS), + BooleanParameterSet::TfheLib => Ok(crate::boolean::parameters::TFHE_LIB_PARAMETERS), + } + .map(BooleanParameters) + } + + #[wasm_bindgen] + pub fn new_boolean_parameters( + lwe_dimension: usize, + glwe_dimension: usize, + polynomial_size: usize, + lwe_modular_std_dev: f64, + glwe_modular_std_dev: f64, + pbs_base_log: usize, + pbs_level: usize, + ks_base_log: usize, + ks_level: usize, + ) -> BooleanParameters { + set_hook(Box::new(console_error_panic_hook::hook)); + use crate::core_crypto::prelude::*; + BooleanParameters(crate::boolean::parameters::BooleanParameters { + lwe_dimension: LweDimension(lwe_dimension), + glwe_dimension: GlweDimension(glwe_dimension), + polynomial_size: PolynomialSize(polynomial_size), + lwe_modular_std_dev: StandardDev(lwe_modular_std_dev), + glwe_modular_std_dev: StandardDev(glwe_modular_std_dev), + pbs_base_log: DecompositionBaseLog(pbs_base_log), + pbs_level: DecompositionLevelCount(pbs_level), + ks_base_log: DecompositionBaseLog(ks_base_log), + ks_level: DecompositionLevelCount(ks_level), + }) + } + + #[wasm_bindgen] + pub fn new_client_key_from_seed_and_parameters( + seed_high_bytes: u64, + seed_low_bytes: u64, + parameters: &BooleanParameters, + ) -> BooleanClientKey { + set_hook(Box::new(console_error_panic_hook::hook)); + let seed_high_bytes: u128 = seed_high_bytes.into(); + let seed_low_bytes: u128 = seed_low_bytes.into(); + let seed: u128 = (seed_high_bytes << 64) | seed_low_bytes; + + let constant_seeder = Box::new(js_wasm_seeder::ConstantSeeder::new( + crate::core_crypto::commons::math::random::Seed(seed), + )); + + let mut tmp_boolean_engine = + crate::boolean::engine::CpuBooleanEngine::new_from_seeder(constant_seeder); + + BooleanClientKey(tmp_boolean_engine.create_client_key(parameters.0.to_owned())) + } + + #[wasm_bindgen] + pub fn new_client_key(parameters: &BooleanParameters) -> BooleanClientKey { + set_hook(Box::new(console_error_panic_hook::hook)); + BooleanClientKey(crate::boolean::client_key::ClientKey::new(¶meters.0)) + } + + #[wasm_bindgen] + pub fn new_public_key(client_key: &BooleanClientKey) -> BooleanPublicKey { + set_hook(Box::new(console_error_panic_hook::hook)); + + BooleanPublicKey(crate::boolean::public_key::PublicKey::new(&client_key.0)) + } + + #[wasm_bindgen] + pub fn encrypt(client_key: &BooleanClientKey, message: bool) -> BooleanCiphertext { + set_hook(Box::new(console_error_panic_hook::hook)); + BooleanCiphertext(client_key.0.encrypt(message)) + } + + #[wasm_bindgen] + pub fn encrypt_with_public_key( + public_key: &BooleanPublicKey, + message: bool, + ) -> BooleanCiphertext { + set_hook(Box::new(console_error_panic_hook::hook)); + + BooleanCiphertext(public_key.0.encrypt(message)) + } + + #[wasm_bindgen] + pub fn trivial_encrypt(&mut self, message: bool) -> BooleanCiphertext { + set_hook(Box::new(console_error_panic_hook::hook)); + BooleanCiphertext(crate::boolean::ciphertext::Ciphertext::Trivial(message)) + } + + #[wasm_bindgen] + pub fn decrypt(client_key: &BooleanClientKey, ct: &BooleanCiphertext) -> bool { + set_hook(Box::new(console_error_panic_hook::hook)); + client_key.0.decrypt(&ct.0) + } + + #[wasm_bindgen] + pub fn serialize_boolean_ciphertext( + ciphertext: &BooleanCiphertext, + ) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&ciphertext.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_boolean_ciphertext(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(BooleanCiphertext) + } + + #[wasm_bindgen] + pub fn serialize_boolean_client_key(client_key: &BooleanClientKey) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&client_key.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_boolean_client_key(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(BooleanClientKey) + } + + #[wasm_bindgen] + pub fn serialize_boolean_public_key(public_key: &BooleanPublicKey) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&public_key.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_boolean_public_key(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(BooleanPublicKey) + } +} diff --git a/tfhe/src/js_on_wasm_api/mod.rs b/tfhe/src/js_on_wasm_api/mod.rs new file mode 100644 index 000000000..ea102e1a6 --- /dev/null +++ b/tfhe/src/js_on_wasm_api/mod.rs @@ -0,0 +1,39 @@ +#[cfg(feature = "shortint-client-js-wasm-api")] +pub mod shortint; +#[cfg(feature = "shortint-client-js-wasm-api")] +pub use shortint::*; + +#[cfg(feature = "boolean-client-js-wasm-api")] +pub mod boolean; +#[cfg(feature = "boolean-client-js-wasm-api")] +pub use boolean::*; + +pub(self) mod js_wasm_seeder { + use crate::core_crypto::commons::math::random::Seed; + use crate::core_crypto::prelude::Seeder; + + const SEED_BYTES_COUNT: usize = 16; + + pub struct ConstantSeeder { + seed: Seed, + } + + impl ConstantSeeder { + pub fn new(seed: Seed) -> Self { + Self { seed } + } + } + + impl Seeder for ConstantSeeder { + fn seed(&mut self) -> Seed { + self.seed + } + + fn is_available() -> bool + where + Self: Sized, + { + true + } + } +} diff --git a/tfhe/src/js_on_wasm_api/shortint.rs b/tfhe/src/js_on_wasm_api/shortint.rs new file mode 100644 index 000000000..6929438dd --- /dev/null +++ b/tfhe/src/js_on_wasm_api/shortint.rs @@ -0,0 +1,260 @@ +use bincode; +use wasm_bindgen::prelude::*; + +use super::js_wasm_seeder; + +use std::panic::set_hook; + +#[wasm_bindgen] +pub struct ShortintCiphertext(pub(crate) crate::shortint::ciphertext::Ciphertext); + +#[wasm_bindgen] +pub struct ShortintClientKey(pub(crate) crate::shortint::ClientKey); + +#[wasm_bindgen] +pub struct ShortintPublicKey(pub(crate) crate::shortint::PublicKey); + +#[wasm_bindgen] +pub struct ShortintServerKey(pub(crate) crate::shortint::ServerKey); + +#[wasm_bindgen] +pub struct Shortint {} + +#[wasm_bindgen] +pub struct ShortintParameters(pub(crate) crate::shortint::Parameters); + +#[wasm_bindgen] +impl Shortint { + #[wasm_bindgen] + pub fn get_shortint_parameters( + message_bits: usize, + carry_bits: usize, + ) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + match (message_bits, carry_bits) { + (1, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_0), + (1, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_1), + (2, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_0), + (1, 2) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_2), + (2, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_1), + (3, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_0), + (1, 3) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_3), + (2, 2) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2), + (3, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_1), + (4, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_0), + (1, 4) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_4), + (2, 3) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_3), + (3, 2) => Ok(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_2), + (4, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_1), + (5, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_0), + (1, 5) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_5), + (2, 4) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_4), + (3, 3) => Ok(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_3), + (4, 2) => Ok(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_2), + (5, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_1), + (6, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_6_CARRY_0), + (1, 6) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_6), + (2, 5) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_5), + (3, 4) => Ok(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_4), + (4, 3) => Ok(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_3), + (5, 2) => Ok(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_2), + (6, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_6_CARRY_1), + (7, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_7_CARRY_0), + (1, 7) => Ok(crate::shortint::parameters::PARAM_MESSAGE_1_CARRY_7), + (2, 6) => Ok(crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_6), + (3, 5) => Ok(crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_5), + (4, 4) => Ok(crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_4), + (5, 3) => Ok(crate::shortint::parameters::PARAM_MESSAGE_5_CARRY_3), + (6, 2) => Ok(crate::shortint::parameters::PARAM_MESSAGE_6_CARRY_2), + (7, 1) => Ok(crate::shortint::parameters::PARAM_MESSAGE_7_CARRY_1), + (8, 0) => Ok(crate::shortint::parameters::PARAM_MESSAGE_8_CARRY_0), + _ => Err(wasm_bindgen::JsError::new( + format!( + "No parameters for {message_bits} bits of message and {carry_bits} bits of carry" + ) + .as_str(), + )), + } + .map(ShortintParameters) + } + + #[wasm_bindgen] + pub fn new_shortint_parameters( + lwe_dimension: usize, + glwe_dimension: usize, + polynomial_size: usize, + lwe_modular_std_dev: f64, + glwe_modular_std_dev: f64, + pbs_base_log: usize, + pbs_level: usize, + ks_base_log: usize, + ks_level: usize, + pfks_level: usize, + pfks_base_log: usize, + pfks_modular_std_dev: f64, + cbs_level: usize, + cbs_base_log: usize, + message_modulus: usize, + carry_modulus: usize, + ) -> ShortintParameters { + set_hook(Box::new(console_error_panic_hook::hook)); + use crate::core_crypto::prelude::*; + ShortintParameters(crate::shortint::Parameters { + lwe_dimension: LweDimension(lwe_dimension), + glwe_dimension: GlweDimension(glwe_dimension), + polynomial_size: PolynomialSize(polynomial_size), + lwe_modular_std_dev: StandardDev(lwe_modular_std_dev), + glwe_modular_std_dev: StandardDev(glwe_modular_std_dev), + pbs_base_log: DecompositionBaseLog(pbs_base_log), + pbs_level: DecompositionLevelCount(pbs_level), + ks_base_log: DecompositionBaseLog(ks_base_log), + ks_level: DecompositionLevelCount(ks_level), + pfks_level: DecompositionLevelCount(pfks_level), + pfks_base_log: DecompositionBaseLog(pfks_base_log), + pfks_modular_std_dev: StandardDev(pfks_modular_std_dev), + cbs_level: DecompositionLevelCount(cbs_level), + cbs_base_log: DecompositionBaseLog(cbs_base_log), + message_modulus: crate::shortint::parameters::MessageModulus(message_modulus), + carry_modulus: crate::shortint::parameters::CarryModulus(carry_modulus), + }) + } + + #[wasm_bindgen] + pub fn new_client_key_from_seed_and_parameters( + seed_high_bytes: u64, + seed_low_bytes: u64, + parameters: &ShortintParameters, + ) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + let seed_high_bytes: u128 = seed_high_bytes.into(); + let seed_low_bytes: u128 = seed_low_bytes.into(); + let seed: u128 = (seed_high_bytes << 64) | seed_low_bytes; + + let constant_seeder = Box::new(js_wasm_seeder::ConstantSeeder::new( + crate::core_crypto::commons::math::random::Seed(seed), + )); + + let mut tmp_shortint_engine = + crate::shortint::engine::ShortintEngine::new_from_seeder(constant_seeder); + + tmp_shortint_engine + .new_client_key(parameters.0.to_owned()) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(ShortintClientKey) + } + + #[wasm_bindgen] + pub fn new_client_key(parameters: &ShortintParameters) -> ShortintClientKey { + set_hook(Box::new(console_error_panic_hook::hook)); + + ShortintClientKey(crate::shortint::client_key::ClientKey::new( + parameters.0.to_owned(), + )) + } + + #[wasm_bindgen] + pub fn new_public_key(client_key: &ShortintClientKey) -> ShortintPublicKey { + set_hook(Box::new(console_error_panic_hook::hook)); + + ShortintPublicKey(crate::shortint::public_key::PublicKey::new(&client_key.0)) + } + + #[wasm_bindgen] + pub fn new_server_key(client_key: &ShortintClientKey) -> ShortintServerKey { + set_hook(Box::new(console_error_panic_hook::hook)); + + ShortintServerKey(crate::shortint::server_key::ServerKey::new(&client_key.0)) + } + + #[wasm_bindgen] + pub fn encrypt(client_key: &ShortintClientKey, message: u64) -> ShortintCiphertext { + set_hook(Box::new(console_error_panic_hook::hook)); + + ShortintCiphertext(client_key.0.encrypt(message)) + } + + #[wasm_bindgen] + pub fn encrypt_with_public_key( + public_key: &ShortintPublicKey, + server_key: &ShortintServerKey, + message: u64, + ) -> ShortintCiphertext { + set_hook(Box::new(console_error_panic_hook::hook)); + + ShortintCiphertext(public_key.0.encrypt(&server_key.0, message)) + } + + #[wasm_bindgen] + pub fn decrypt(client_key: &ShortintClientKey, ct: &ShortintCiphertext) -> u64 { + set_hook(Box::new(console_error_panic_hook::hook)); + client_key.0.decrypt(&ct.0) + } + + #[wasm_bindgen] + pub fn serialize_shortint_ciphertext( + ciphertext: &ShortintCiphertext, + ) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&ciphertext.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_shortint_ciphertext(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(ShortintCiphertext) + } + + #[wasm_bindgen] + pub fn serialize_shortint_client_key( + client_key: &ShortintClientKey, + ) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&client_key.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_shortint_client_key(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(ShortintClientKey) + } + + #[wasm_bindgen] + pub fn serialize_shortint_public_key( + public_key: &ShortintPublicKey, + ) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&public_key.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_shortint_public_key(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(ShortintPublicKey) + } + + #[wasm_bindgen] + pub fn serialize_shortint_server_key( + server_key: &ShortintServerKey, + ) -> Result, JsError> { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::serialize(&server_key.0) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + } + + #[wasm_bindgen] + pub fn deserialize_shortint_server_key(buffer: &[u8]) -> Result { + set_hook(Box::new(console_error_panic_hook::hook)); + bincode::deserialize(buffer) + .map_err(|e| wasm_bindgen::JsError::new(format!("{:?}", e).as_str())) + .map(ShortintServerKey) + } +} diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs new file mode 100644 index 000000000..c0105b3d1 --- /dev/null +++ b/tfhe/src/lib.rs @@ -0,0 +1,29 @@ +#![cfg_attr(feature = "__wasm_api", allow(dead_code))] +#![cfg_attr( + feature = "backend_fft_nightly_avx512", + feature(stdsimd, avx512_target_feature) +)] + +#[cfg(feature = "__c_api")] +pub mod c_api; + +#[cfg(feature = "boolean")] +/// cbindgen:ignore +pub mod boolean; +/// cbindgen:ignore +pub mod core_crypto; +#[cfg(feature = "shortint")] +/// cbindgen:ignore +pub mod shortint; + +#[cfg(feature = "__wasm_api")] +/// cbindgen:ignore +pub mod js_on_wasm_api; +#[cfg(feature = "__wasm_api")] +pub use js_on_wasm_api::*; + +#[cfg(any(feature = "boolean", feature = "shortint"))] +pub(crate) mod seeders; + +#[cfg(all(doctest, feature = "shortint", feature = "boolean"))] +mod test_user_docs; diff --git a/tfhe/src/seeders.rs b/tfhe/src/seeders.rs new file mode 100644 index 000000000..334593994 --- /dev/null +++ b/tfhe/src/seeders.rs @@ -0,0 +1,86 @@ +use crate::core_crypto::commons::math::random::Seeder; +#[cfg(target_os = "macos")] +use concrete_csprng::seeders::AppleSecureEnclaveSeeder; +#[cfg(feature = "seeder_x86_64_rdseed")] +use concrete_csprng::seeders::RdseedSeeder; +#[cfg(feature = "seeder_unix")] +use concrete_csprng::seeders::UnixSeeder; + +#[cfg(feature = "__wasm_api")] +mod wasm_seeder { + use crate::core_crypto::commons::math::random::{Seed, Seeder}; + // This is used for web interfaces + use getrandom::getrandom; + + pub(super) struct WasmSeeder {} + + impl Seeder for WasmSeeder { + fn seed(&mut self) -> Seed { + let mut buffer = [0u8; 16]; + getrandom(&mut buffer).unwrap(); + + Seed(u128::from_le_bytes(buffer)) + } + + fn is_available() -> bool + where + Self: Sized, + { + true + } + } +} + +pub fn new_seeder() -> Box { + let mut seeder: Option> = None; + + let err_msg; + + #[cfg(not(feature = "__wasm_api"))] + { + #[cfg(feature = "seeder_x86_64_rdseed")] + { + if RdseedSeeder::is_available() { + seeder = Some(Box::new(RdseedSeeder)); + } + } + + // This Seeder is normally always available on macOS, so we enable it by default when on + // that platform + #[cfg(target_os = "macos")] + { + if seeder.is_none() && AppleSecureEnclaveSeeder::is_available() { + seeder = Some(Box::new(AppleSecureEnclaveSeeder)) + } + } + + #[cfg(feature = "seeder_unix")] + { + if seeder.is_none() && UnixSeeder::is_available() { + seeder = Some(Box::new(UnixSeeder::new(0))); + } + } + + #[cfg(not(feature = "__c_api"))] + { + err_msg = "Unable to instantiate a seeder, make sure to enable a seeder feature \ + like seeder_unix for example on unix platforms."; + } + + #[cfg(feature = "__c_api")] + { + err_msg = "No compatible seeder for current machine found."; + } + } + + #[cfg(feature = "__wasm_api")] + { + if seeder.is_none() && wasm_seeder::WasmSeeder::is_available() { + seeder = Some(Box::new(wasm_seeder::WasmSeeder {})) + } + + err_msg = "No compatible seeder found. Consider changing browser or dev environment"; + } + + seeder.expect(err_msg) +} diff --git a/tfhe/src/shortint/ciphertext/mod.rs b/tfhe/src/shortint/ciphertext/mod.rs new file mode 100644 index 000000000..793ba1872 --- /dev/null +++ b/tfhe/src/shortint/ciphertext/mod.rs @@ -0,0 +1,140 @@ +//! Module with the definition of a short-integer ciphertext. +use crate::core_crypto::prelude::*; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::cmp; +use std::fmt::Debug; + +/// This indicates the number of operations that has been done. +/// +/// For instances, computing and addition increases this number by 1, whereas a multiplication by +/// a constant $\lambda$ increases it by $\lambda$. +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct Degree(pub usize); + +impl Degree { + pub(crate) fn after_bitxor(&self, other: Degree) -> Degree { + let max = cmp::max(self.0, other.0); + let min = cmp::min(self.0, other.0); + let mut result = max; + + //Try every possibility to find the worst case + for i in 0..min + 1 { + if max ^ i > result { + result = max ^ i; + } + } + + Degree(result) + } + + pub(crate) fn after_bitor(&self, other: Degree) -> Degree { + let max = cmp::max(self.0, other.0); + let min = cmp::min(self.0, other.0); + let mut result = max; + + for i in 0..min + 1 { + if max | i > result { + result = max | i; + } + } + + Degree(result) + } + + pub(crate) fn after_bitand(&self, other: Degree) -> Degree { + Degree(cmp::min(self.0, other.0)) + } + + pub(crate) fn after_left_shift(&self, shift: u8, modulus: usize) -> Degree { + let mut result = 0; + + for i in 0..self.0 + 1 { + let tmp = (i << shift) % modulus; + if tmp > result { + result = tmp; + } + } + + Degree(result) + } + + #[allow(dead_code)] + pub(crate) fn after_pbs(&self, f: F) -> Degree + where + F: Fn(usize) -> usize, + { + let mut result = 0; + + for i in 0..self.0 + 1 { + let tmp = f(i); + if tmp > result { + result = tmp; + } + } + + Degree(result) + } +} + +/// A structure representing a short-integer ciphertext. +/// It is used to evaluate a short-integer circuits homomorphically. +/// Internally, it uses a LWE ciphertext. +#[derive(Clone)] +pub struct Ciphertext { + pub ct: LweCiphertext64, + pub degree: Degree, + pub message_modulus: MessageModulus, + pub carry_modulus: CarryModulus, +} + +#[derive(Serialize, Deserialize)] +struct SerializableCiphertext { + data: Vec, + pub degree: Degree, + pub message_modulus: MessageModulus, + pub carry_modulus: CarryModulus, +} + +impl Serialize for Ciphertext { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let data = ser_eng + .serialize(&self.ct) + .map_err(serde::ser::Error::custom)?; + + SerializableCiphertext { + data, + degree: self.degree, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Ciphertext { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = SerializableCiphertext::deserialize(deserializer)?; + + let mut de_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + let ct = de_eng + .deserialize(thing.data.as_slice()) + .map_err(serde::de::Error::custom)?; + + Ok(Self { + ct, + degree: thing.degree, + message_modulus: thing.message_modulus, + carry_modulus: thing.carry_modulus, + }) + } +} diff --git a/tfhe/src/shortint/client_key/mod.rs b/tfhe/src/shortint/client_key/mod.rs new file mode 100644 index 000000000..30803dfed --- /dev/null +++ b/tfhe/src/shortint/client_key/mod.rs @@ -0,0 +1,379 @@ +//! Module with the definition of the ClientKey. + +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Ciphertext; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::parameters::{MessageModulus, Parameters}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt::Debug; + +/// A structure containing the client key, which must be kept secret. +/// +/// In more details, it contains: +/// * `lwe_secret_key` - an LWE secret key, used to encrypt the inputs and decrypt the outputs. +/// This secret key is also used in the generation of bootstrapping and key switching keys. +/// * `glwe_secret_key` - a GLWE secret key, used to generate the bootstrapping keys and key +/// switching keys. +/// * `parameters` - the cryptographic parameter set. +#[derive(Clone, Debug, PartialEq)] +pub struct ClientKey { + /// The actual encryption / decryption key + pub(crate) lwe_secret_key: LweSecretKey64, + pub(crate) glwe_secret_key: GlweSecretKey64, + /// Key used as the output of the keyswitch operation + pub(crate) lwe_secret_key_after_ks: LweSecretKey64, + pub parameters: Parameters, +} + +impl ClientKey { + /// Generates a client key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::client_key::ClientKey; + /// use tfhe::shortint::parameters::Parameters; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(Parameters::default()); + /// ``` + pub fn new(parameters: Parameters) -> ClientKey { + ShortintEngine::with_thread_local_mut(|engine| engine.new_client_key(parameters).unwrap()) + } + + /// Encrypts a small integer message using the client key. + /// + /// The input message is reduced to the encrypted message space modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::ClientKey; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encryption of one message that is within the encrypted message modulus: + /// let msg = 3; + /// let ct = cks.encrypt(msg); + /// + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg, dec); + /// + /// // Encryption of one message that is outside the encrypted message modulus: + /// let msg = 5; + /// let ct = cks.encrypt(msg); + /// + /// let dec = cks.decrypt(&ct); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(msg % modulus, dec); + /// ``` + pub fn encrypt(&self, message: u64) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.encrypt(self, message).unwrap()) + } + + /// Encrypts a small integer message using the client key with a specific message modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::MessageModulus; + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 3; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt_with_message_modulus(msg, MessageModulus(6)); + /// + /// // Decryption: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_with_message_modulus( + &self, + message: u64, + message_modulus: MessageModulus, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .encrypt_with_message_modulus(self, message, message_modulus) + .unwrap() + }) + } + + /// Encrypts an integer without reducing the input message modulus the message space + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 7; + /// let ct = cks.unchecked_encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 1 | 1 1 | + /// + /// let dec = cks.decrypt_message_and_carry(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn unchecked_encrypt(&self, message: u64) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_encrypt(self, message).unwrap() + }) + } + + /// Decrypts a ciphertext encrypting an integer message and carries using the client key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 3; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt(msg); + /// + /// // Decryption: + /// let dec = cks.decrypt_message_and_carry(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt_message_and_carry(&self, ct: &Ciphertext) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine.decrypt_message_and_carry(self, ct).unwrap() + }) + } + + /// Decrypts a ciphertext encrypting a message using the client key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 3; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt(msg); + /// + /// // Decryption: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt(&self, ct: &Ciphertext) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| engine.decrypt(self, ct).unwrap()) + } + + /// Encrypts a small integer message using the client key without padding bit. + /// + /// The input message is reduced to the encrypted message space modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::ClientKey; + /// + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encryption of one message that is within the encrypted message modulus: + /// let msg = 6; + /// let ct = cks.encrypt_without_padding(msg); + /// + /// let dec = cks.decrypt_message_and_carry_without_padding(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_without_padding(&self, message: u64) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.encrypt_without_padding(self, message).unwrap() + }) + } + + /// Decrypts a ciphertext encrypting an integer message and carries using the client key, + /// where the ciphertext is assumed to not have any padding bit. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::ClientKey; + /// + /// // Generate the client key + /// let cks = ClientKey::new(PARAM_MESSAGE_1_CARRY_1); + /// + /// let msg = 3; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt_without_padding(msg); + /// + /// // Decryption: + /// let dec = cks.decrypt_message_and_carry_without_padding(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn decrypt_message_and_carry_without_padding(&self, ct: &Ciphertext) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .decrypt_message_and_carry_without_padding(self, ct) + .unwrap() + }) + } + + /// Decrypts a ciphertext encrypting an integer message using the client key, + /// where the ciphertext is assumed to not have any padding bit. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 7; + /// let modulus = 4; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt_without_padding(msg); + /// + /// // Decryption: + /// let dec = cks.decrypt_without_padding(&ct); + /// assert_eq!(msg % modulus, dec); + /// ``` + pub fn decrypt_without_padding(&self, ct: &Ciphertext) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine.decrypt_without_padding(self, ct).unwrap() + }) + } + + /// Encrypts a small integer message using the client key without padding bit with some modulus. + /// + /// The input message is reduced to the encrypted message space modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 2; + /// let modulus = 3; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt_native_crt(msg, modulus); + /// + /// // Decryption: + /// let dec = cks.decrypt_message_native_crt(&ct, modulus); + /// assert_eq!(msg, dec % modulus as u64); + /// ``` + pub fn encrypt_native_crt(&self, message: u64, message_modulus: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .encrypt_native_crt(self, message, message_modulus) + .unwrap() + }) + } + + /// Decrypts a ciphertext encrypting an integer message using the client key, + /// where the ciphertext is assumed to not have any padding bit and is related to some modulus. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters}; + /// + /// // Generate the client key + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let msg = 1; + /// let modulus = 3; + /// + /// // Encryption of one message: + /// let ct = cks.encrypt_native_crt(msg, modulus); + /// + /// // Decryption: + /// let dec = cks.decrypt_message_native_crt(&ct, modulus); + /// assert_eq!(msg, dec % modulus as u64); + /// ``` + pub fn decrypt_message_native_crt(&self, ct: &Ciphertext, message_modulus: u8) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .decrypt_message_native_crt(self, ct, message_modulus as u64) + .unwrap() + }) + } +} + +#[derive(Serialize, Deserialize)] +struct SerializableClientKey { + lwe_secret_key: Vec, + glwe_secret_key: Vec, + lwe_secret_key_after_ks: Vec, + parameters: Parameters, +} + +impl Serialize for ClientKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let lwe_secret_key = ser_eng + .serialize(&self.lwe_secret_key) + .map_err(serde::ser::Error::custom)?; + let glwe_secret_key = ser_eng + .serialize(&self.glwe_secret_key) + .map_err(serde::ser::Error::custom)?; + let lwe_secret_key_after_ks = ser_eng + .serialize(&self.lwe_secret_key_after_ks) + .map_err(serde::ser::Error::custom)?; + + SerializableClientKey { + lwe_secret_key, + glwe_secret_key, + lwe_secret_key_after_ks, + parameters: self.parameters, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ClientKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = + SerializableClientKey::deserialize(deserializer).map_err(serde::de::Error::custom)?; + let mut de_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + Ok(Self { + lwe_secret_key: de_eng + .deserialize(thing.lwe_secret_key.as_slice()) + .map_err(serde::de::Error::custom)?, + glwe_secret_key: de_eng + .deserialize(thing.glwe_secret_key.as_slice()) + .map_err(serde::de::Error::custom)?, + lwe_secret_key_after_ks: de_eng + .deserialize(thing.lwe_secret_key_after_ks.as_slice()) + .map_err(serde::de::Error::custom)?, + parameters: thing.parameters, + }) + } +} diff --git a/tfhe/src/shortint/engine/client_side.rs b/tfhe/src/shortint/engine/client_side.rs new file mode 100644 index 000000000..e48ca28f4 --- /dev/null +++ b/tfhe/src/shortint/engine/client_side.rs @@ -0,0 +1,259 @@ +//! All the `ShortintEngine` method related to client side (encrypt / decrypt) +use super::{EngineResult, ShortintEngine}; +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use crate::shortint::{Ciphertext, ClientKey, Parameters}; + +impl ShortintEngine { + pub fn new_client_key(&mut self, parameters: Parameters) -> EngineResult { + // generate the lwe secret key + let small_lwe_secret_key: LweSecretKey64 = self + .engine + .generate_new_lwe_secret_key(parameters.lwe_dimension)?; + + // generate the rlwe secret key + let glwe_secret_key: GlweSecretKey64 = self + .engine + .generate_new_glwe_secret_key(parameters.glwe_dimension, parameters.polynomial_size)?; + + let large_lwe_secret_key = self + .engine + .transform_glwe_secret_key_to_lwe_secret_key(glwe_secret_key.clone())?; + + // pack the keys in the client key set + Ok(ClientKey { + lwe_secret_key: large_lwe_secret_key, + glwe_secret_key, + lwe_secret_key_after_ks: small_lwe_secret_key, + parameters, + }) + } + + pub fn encrypt(&mut self, client_key: &ClientKey, message: u64) -> EngineResult { + self.encrypt_with_message_modulus( + client_key, + message, + client_key.parameters.message_modulus, + ) + } + + pub(crate) fn encrypt_with_message_modulus( + &mut self, + client_key: &ClientKey, + message: u64, + message_modulus: MessageModulus, + ) -> EngineResult { + //This ensures that the space message_modulus*carry_modulus < param.message_modulus * + // param.carry_modulus + let carry_modulus = (client_key.parameters.message_modulus.0 + * client_key.parameters.carry_modulus.0) + / message_modulus.0; + + //The delta is the one defined by the parameters + let delta = (1_u64 << 63) + / (client_key.parameters.message_modulus.0 * client_key.parameters.carry_modulus.0) + as u64; + + //The input is reduced modulus the message_modulus + let m = message % message_modulus.0 as u64; + + let shifted_message = m * delta; + // encode the message + let plain: Plaintext64 = self.engine.create_plaintext_from(&shifted_message)?; + + // convert into a variance + let var = Variance(client_key.parameters.lwe_modular_std_dev.get_variance()); + + // encryption + let ct = self + .engine + .encrypt_lwe_ciphertext(&client_key.lwe_secret_key, &plain, var)?; + + Ok(Ciphertext { + ct, + degree: Degree(message_modulus.0 - 1), + message_modulus, + carry_modulus: CarryModulus(carry_modulus), + }) + } + + pub(crate) fn unchecked_encrypt( + &mut self, + client_key: &ClientKey, + message: u64, + ) -> EngineResult { + let delta = (1_u64 << 63) + / (client_key.parameters.message_modulus.0 * client_key.parameters.carry_modulus.0) + as u64; + let shifted_message = message * delta; + // encode the message + let plain: Plaintext64 = self.engine.create_plaintext_from(&shifted_message)?; + + // convert into a variance + let var = Variance(client_key.parameters.lwe_modular_std_dev.get_variance()); + + // encryption + let ct = self + .engine + .encrypt_lwe_ciphertext(&client_key.lwe_secret_key, &plain, var)?; + Ok(Ciphertext { + ct, + degree: Degree( + client_key.parameters.message_modulus.0 * client_key.parameters.carry_modulus.0 - 1, + ), + message_modulus: client_key.parameters.message_modulus, + carry_modulus: client_key.parameters.carry_modulus, + }) + } + + pub(crate) fn decrypt_message_and_carry( + &mut self, + client_key: &ClientKey, + ct: &Ciphertext, + ) -> EngineResult { + // decryption + let decrypted = self + .engine + .decrypt_lwe_ciphertext(&client_key.lwe_secret_key, &ct.ct)?; + + let mut decrypted_u64: u64 = 0; + self.engine + .discard_retrieve_plaintext(&mut decrypted_u64, &decrypted)?; + + let delta = (1_u64 << 63) + / (client_key.parameters.message_modulus.0 * client_key.parameters.carry_modulus.0) + as u64; + + //The bit before the message + let rounding_bit = delta >> 1; + + //compute the rounding bit + let rounding = (decrypted_u64 & rounding_bit) << 1; + + Ok((decrypted_u64.wrapping_add(rounding)) / delta) + } + + pub fn decrypt(&mut self, client_key: &ClientKey, ct: &Ciphertext) -> EngineResult { + self.decrypt_message_and_carry(client_key, ct) + .map(|message_and_carry| message_and_carry % ct.message_modulus.0 as u64) + } + + pub(crate) fn encrypt_without_padding( + &mut self, + client_key: &ClientKey, + message: u64, + ) -> EngineResult { + //Multiply by 2 to reshift and exclude the padding bit + let delta = ((1_u64 << 63) + / (client_key.parameters.message_modulus.0 * client_key.parameters.carry_modulus.0) + as u64) + * 2; + + let shifted_message = message * delta; + // encode the message + let plain: Plaintext64 = self.engine.create_plaintext_from(&shifted_message)?; + + // convert into a variance + let var = Variance(client_key.parameters.lwe_modular_std_dev.get_variance()); + + // encryption + let ct = self + .engine + .encrypt_lwe_ciphertext(&client_key.lwe_secret_key, &plain, var)?; + + Ok(Ciphertext { + ct, + degree: Degree(client_key.parameters.message_modulus.0 - 1), + message_modulus: client_key.parameters.message_modulus, + carry_modulus: client_key.parameters.carry_modulus, + }) + } + + pub(crate) fn decrypt_message_and_carry_without_padding( + &mut self, + client_key: &ClientKey, + ct: &Ciphertext, + ) -> EngineResult { + // decryption + let decrypted = self + .engine + .decrypt_lwe_ciphertext(&client_key.lwe_secret_key, &ct.ct)?; + + let mut decrypted_u64: u64 = 0; + self.engine + .discard_retrieve_plaintext(&mut decrypted_u64, &decrypted)?; + + let delta = ((1_u64 << 63) + / (client_key.parameters.message_modulus.0 * client_key.parameters.carry_modulus.0) + as u64) + * 2; + + //The bit before the message + let rounding_bit = delta >> 1; + + //compute the rounding bit + let rounding = (decrypted_u64 & rounding_bit) << 1; + + Ok((decrypted_u64.wrapping_add(rounding)) / delta) + } + + pub(crate) fn decrypt_without_padding( + &mut self, + client_key: &ClientKey, + ct: &Ciphertext, + ) -> EngineResult { + self.decrypt_message_and_carry_without_padding(client_key, ct) + .map(|message_and_carry| message_and_carry % ct.message_modulus.0 as u64) + } + + pub(crate) fn encrypt_native_crt( + &mut self, + client_key: &ClientKey, + message: u64, + message_modulus: u8, + ) -> EngineResult { + let carry_modulus = 1; + let m = (message % message_modulus as u64) as u128; + let shifted_message = m * (1 << 64) / message_modulus as u128; + // encode the message + let plain: Plaintext64 = self + .engine + .create_plaintext_from(&(shifted_message as u64))?; + + // convert into a variance + let var = Variance(client_key.parameters.lwe_modular_std_dev.get_variance()); + + // encryption + let ct = self + .engine + .encrypt_lwe_ciphertext(&client_key.lwe_secret_key, &plain, var)?; + Ok(Ciphertext { + ct, + degree: Degree(message_modulus as usize - 1), + message_modulus: MessageModulus(message_modulus as usize), + carry_modulus: CarryModulus(carry_modulus), + }) + } + + pub(crate) fn decrypt_message_native_crt( + &mut self, + client_key: &ClientKey, + ct: &Ciphertext, + basis: u64, + ) -> EngineResult { + // decryption + let decrypted = self + .engine + .decrypt_lwe_ciphertext(&client_key.lwe_secret_key, &ct.ct)?; + + let mut decrypted_u64: u64 = 0; + self.engine + .discard_retrieve_plaintext(&mut decrypted_u64, &decrypted)?; + + let mut result = decrypted_u64 as u128 * basis as u128; + result = result.wrapping_add((result & 1 << 63) << 1) / (1 << 64); + + Ok(result as u64 % basis) + } +} diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs new file mode 100644 index 000000000..558c7e231 --- /dev/null +++ b/tfhe/src/shortint/engine/mod.rs @@ -0,0 +1,243 @@ +use crate::core_crypto::prelude::*; +use crate::seeders::new_seeder; +use crate::shortint::ServerKey; +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::fmt::Debug; + +mod client_side; +mod public_side; +mod server_side; +#[cfg(not(feature = "__wasm_api"))] +mod wopbs; + +use crate::core_crypto::backends::default::engines::ActivatedRandomGenerator; +use crate::core_crypto::commons::crypto::secret::generators::DeterministicSeeder; + +thread_local! { + static LOCAL_ENGINE: RefCell = RefCell::new(ShortintEngine::new()); +} + +/// Stores buffers associated to a ServerKey +pub struct Buffers { + pub(crate) accumulator: GlweCiphertext64, + pub(crate) buffer_lwe_after_ks: LweCiphertext64, +} + +/// This allows to store and retrieve the `Buffers` +/// corresponding to a `ServerKey` in a `BTreeMap` +#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] +struct KeyId { + accumulator_dim: GlweSize, + lwe_dim_after_pbs: usize, + glwe_size: GlweSize, + poly_size: PolynomialSize, +} + +impl ServerKey { + #[inline] + fn key_id(&self) -> KeyId { + KeyId { + accumulator_dim: self.bootstrapping_key.glwe_dimension().to_glwe_size(), + lwe_dim_after_pbs: self.bootstrapping_key.output_lwe_dimension().0, + glwe_size: self.bootstrapping_key.glwe_dimension().to_glwe_size(), + poly_size: self.bootstrapping_key.polynomial_size(), + } + } +} + +/// Simple wrapper around `std::error::Error` to be able to +/// forward all the possible `EngineError` type from [`core_cryto`](crate::core_crypto) +#[allow(dead_code)] +#[derive(Debug)] +pub struct EngineError { + error: Box, +} + +impl From for EngineError +where + T: std::error::Error + 'static, +{ + fn from(error: T) -> Self { + Self { + error: Box::new(error), + } + } +} + +pub(crate) type EngineResult = Result; + +/// ShortintEngine +/// +/// This 'engine' holds the necessary engines from [`core_crypto`](crate::core_crypto) +/// as well as the buffers that we want to keep around to save processing time. +/// +/// This structs actually implements the logics into its methods. +pub struct ShortintEngine { + pub(crate) engine: DefaultEngine, + pub(crate) fft_engine: FftEngine, + pub(crate) par_engine: DefaultParallelEngine, + buffers: BTreeMap, +} + +impl ShortintEngine { + /// Safely gives access to the `thead_local` shortint engine + /// to call one (or many) of its method. + #[inline] + pub fn with_thread_local_mut(func: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + LOCAL_ENGINE.with(|engine_cell| func(&mut engine_cell.borrow_mut())) + } + + /// Creates a new shortint engine + /// + /// Creating a `ShortintEngine` should not be needed, as each + /// rust thread gets its own `thread_local` engine created automatically, + /// see [ShortintEngine::with_thread_local_mut] + /// + /// + /// # Panics + /// + /// This will panic if the `CoreEngine` failed to create. + pub fn new() -> Self { + let root_seeder = new_seeder(); + + Self::new_from_seeder(root_seeder) + } + + pub fn new_from_seeder(mut root_seeder: Box) -> Self { + let mut deterministic_seeder = + DeterministicSeeder::::new(root_seeder.seed()); + + let default_engine_seeder = Box::new(DeterministicSeeder::::new( + deterministic_seeder.seed(), + )); + let default_parallel_engine_seeder = Box::new(DeterministicSeeder::< + ActivatedRandomGenerator, + >::new(deterministic_seeder.seed())); + + let engine = + DefaultEngine::new(default_engine_seeder).expect("Failed to create a DefaultEngine"); + let par_engine = DefaultParallelEngine::new(default_parallel_engine_seeder) + .expect("Failed to create a DefaultParallelEngine"); + let fft_engine = FftEngine::new(()).unwrap(); + Self { + engine, + fft_engine, + par_engine, + buffers: Default::default(), + } + } + + fn generate_accumulator_with_engine( + engine: &mut DefaultEngine, + server_key: &ServerKey, + f: F, + ) -> EngineResult + where + F: Fn(u64) -> u64, + { + // Modulus of the msg contained in the msg bits and operations buffer + let modulus_sup = server_key.message_modulus.0 * server_key.carry_modulus.0; + + // N/(p/2) = size of each block + let box_size = server_key.bootstrapping_key.polynomial_size().0 / modulus_sup; + + // Value of the shift we multiply our messages by + let delta = + (1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64; + + // Create the accumulator + let mut accumulator_u64 = vec![0_u64; server_key.bootstrapping_key.polynomial_size().0]; + + // This accumulator extracts the carry bits + for i in 0..modulus_sup { + let index = i * box_size; + accumulator_u64[index..index + box_size] + .iter_mut() + .for_each(|a| *a = f(i as u64) * delta); + } + + let half_box_size = box_size / 2; + + // Negate the first half_box_size coefficients + for a_i in accumulator_u64[0..half_box_size].iter_mut() { + *a_i = (*a_i).wrapping_neg(); + } + + // Rotate the accumulator + accumulator_u64.rotate_left(half_box_size); + + // Everywhere + let accumulator_plaintext = engine.create_plaintext_vector_from(&accumulator_u64)?; + + let accumulator = engine.trivially_encrypt_glwe_ciphertext( + server_key.bootstrapping_key.glwe_dimension().to_glwe_size(), + &accumulator_plaintext, + )?; + + Ok(accumulator) + } + + fn generate_accumulator_bivariate_with_engine( + engine: &mut DefaultEngine, + server_key: &ServerKey, + f: F, + ) -> EngineResult + where + F: Fn(u64, u64) -> u64, + { + let modulus = server_key.message_modulus.0 as u64; + let wrapped_f = |input: u64| -> u64 { + let lhs = (input / modulus) % modulus; + let rhs = input % modulus; + + f(lhs, rhs) + }; + ShortintEngine::generate_accumulator_with_engine(engine, server_key, wrapped_f) + } + + /// Returns the `Buffers` for the given `ServerKey` + /// + /// Takes care creating the buffers if they do not exists for the given key + /// + /// This also `&mut CoreEngine` to simply borrow checking for the caller + /// (since returned buffers are borrowed from `self`, using the `self.engine` + /// wouldn't be possible after calling `buffers_for_key`) + pub fn buffers_for_key( + &mut self, + server_key: &ServerKey, + ) -> (&mut Buffers, &mut DefaultEngine, &mut FftEngine) { + let key = server_key.key_id(); + // To make borrow checker happy + let engine = &mut self.engine; + let buffers_map = &mut self.buffers; + let buffers = buffers_map.entry(key).or_insert_with(|| { + let accumulator = Self::generate_accumulator_with_engine(engine, server_key, |n| { + n % server_key.message_modulus.0 as u64 + }) + .unwrap(); + + // Allocate the buffer for the output of the PBS + let zero_plaintext = engine.create_plaintext_from(&0_u64).unwrap(); + let buffer_lwe_after_pbs = engine + .trivially_encrypt_lwe_ciphertext( + server_key + .key_switching_key + .output_lwe_dimension() + .to_lwe_size(), + &zero_plaintext, + ) + .unwrap(); + + Buffers { + accumulator, + buffer_lwe_after_ks: buffer_lwe_after_pbs, + } + }); + + (buffers, engine, &mut self.fft_engine) + } +} diff --git a/tfhe/src/shortint/engine/public_side.rs b/tfhe/src/shortint/engine/public_side.rs new file mode 100644 index 000000000..8cc2daaa8 --- /dev/null +++ b/tfhe/src/shortint/engine/public_side.rs @@ -0,0 +1,201 @@ +//! All the `ShortintEngine` method related to client side (encrypt / decrypt) +use super::{EngineResult, ShortintEngine}; +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use crate::shortint::{Ciphertext, ClientKey, PublicKey, ServerKey}; + +// We have q = 64 so log2q = 6 +const LOG2_Q_64: usize = 6; + +impl ShortintEngine { + pub(crate) fn new_public_key(&mut self, client_key: &ClientKey) -> EngineResult { + let client_parameters = client_key.parameters; + + // Formula is (k*N + 1) * log2(q) + 128 + let zero_encryption_count = LwePublicKeyZeroEncryptionCount( + (client_parameters.polynomial_size.0 * client_parameters.glwe_dimension.0 + 1) + * LOG2_Q_64 + + 128, + ); + + Ok(PublicKey { + lwe_public_key: self.par_engine.generate_new_lwe_public_key( + &client_key.lwe_secret_key, + Variance(client_key.parameters.lwe_modular_std_dev.get_variance()), + zero_encryption_count, + )?, + parameters: client_key.parameters.to_owned(), + }) + } + + pub(crate) fn encrypt_with_public_key( + &mut self, + public_key: &PublicKey, + server_key: &ServerKey, + message: u64, + ) -> EngineResult { + let mut ciphertext = self.encrypt_with_message_modulus_and_public_key( + public_key, + message, + public_key.parameters.message_modulus, + )?; + + let acc = self.generate_accumulator(server_key, |x| x)?; + + self.programmable_bootstrap_keyswitch_assign(server_key, &mut ciphertext, &acc)?; + + Ok(ciphertext) + } + + pub(crate) fn encrypt_with_message_modulus_and_public_key( + &mut self, + public_key: &PublicKey, + message: u64, + message_modulus: MessageModulus, + ) -> EngineResult { + //This ensures that the space message_modulus*carry_modulus < param.message_modulus * + // param.carry_modulus + let carry_modulus = (public_key.parameters.message_modulus.0 + * public_key.parameters.carry_modulus.0) + / message_modulus.0; + + //The delta is the one defined by the parameters + let delta = (1_u64 << 63) + / (public_key.parameters.message_modulus.0 * public_key.parameters.carry_modulus.0) + as u64; + + //The input is reduced modulus the message_modulus + let m = message % message_modulus.0 as u64; + + let shifted_message = m * delta; + // encode the message + let plain: Plaintext64 = self.engine.create_plaintext_from(&shifted_message)?; + + // This allocates the required ct + let mut encrypted_ct = self.engine.trivially_encrypt_lwe_ciphertext( + public_key.lwe_public_key.lwe_dimension().to_lwe_size(), + &plain, + )?; + + // encryption + self.engine.discard_encrypt_lwe_ciphertext_with_public_key( + &public_key.lwe_public_key, + &mut encrypted_ct, + &plain, + )?; + + Ok(Ciphertext { + ct: encrypted_ct, + degree: Degree(message_modulus.0 - 1), + message_modulus, + carry_modulus: CarryModulus(carry_modulus), + }) + } + + pub(crate) fn encrypt_without_padding_with_public_key( + &mut self, + public_key: &PublicKey, + message: u64, + ) -> EngineResult { + //Multiply by 2 to reshift and exclude the padding bit + let delta = ((1_u64 << 63) + / (public_key.parameters.message_modulus.0 * public_key.parameters.carry_modulus.0) + as u64) + * 2; + + let shifted_message = message * delta; + // encode the message + let plain: Plaintext64 = self.engine.create_plaintext_from(&shifted_message)?; + + // This allocates the required ct + let mut encrypted_ct = self.engine.trivially_encrypt_lwe_ciphertext( + public_key.lwe_public_key.lwe_dimension().to_lwe_size(), + &plain, + )?; + + // encryption + self.engine.discard_encrypt_lwe_ciphertext_with_public_key( + &public_key.lwe_public_key, + &mut encrypted_ct, + &plain, + )?; + + Ok(Ciphertext { + ct: encrypted_ct, + degree: Degree(public_key.parameters.message_modulus.0 - 1), + message_modulus: public_key.parameters.message_modulus, + carry_modulus: public_key.parameters.carry_modulus, + }) + } + + pub(crate) fn encrypt_native_crt_with_public_key( + &mut self, + public_key: &PublicKey, + message: u64, + message_modulus: u8, + ) -> EngineResult { + let carry_modulus = 1; + let m = (message % message_modulus as u64) as u128; + let shifted_message = m * (1 << 64) / message_modulus as u128; + // encode the message + let plain: Plaintext64 = self + .engine + .create_plaintext_from(&(shifted_message as u64))?; + + // This allocates the required ct + let mut encrypted_ct = self.engine.trivially_encrypt_lwe_ciphertext( + public_key.lwe_public_key.lwe_dimension().to_lwe_size(), + &plain, + )?; + + // encryption + self.engine.discard_encrypt_lwe_ciphertext_with_public_key( + &public_key.lwe_public_key, + &mut encrypted_ct, + &plain, + )?; + + Ok(Ciphertext { + ct: encrypted_ct, + degree: Degree(message_modulus as usize - 1), + message_modulus: MessageModulus(message_modulus as usize), + carry_modulus: CarryModulus(carry_modulus), + }) + } + + pub(crate) fn unchecked_encrypt_with_public_key( + &mut self, + public_key: &PublicKey, + message: u64, + ) -> EngineResult { + let delta = (1_u64 << 63) + / (public_key.parameters.message_modulus.0 * public_key.parameters.carry_modulus.0) + as u64; + let shifted_message = message * delta; + // encode the message + let plain: Plaintext64 = self.engine.create_plaintext_from(&shifted_message)?; + + // This allocates the required ct + let mut encrypted_ct = self.engine.trivially_encrypt_lwe_ciphertext( + public_key.lwe_public_key.lwe_dimension().to_lwe_size(), + &plain, + )?; + + // encryption + self.engine.discard_encrypt_lwe_ciphertext_with_public_key( + &public_key.lwe_public_key, + &mut encrypted_ct, + &plain, + )?; + + Ok(Ciphertext { + ct: encrypted_ct, + degree: Degree( + public_key.parameters.message_modulus.0 * public_key.parameters.carry_modulus.0 - 1, + ), + message_modulus: public_key.parameters.message_modulus, + carry_modulus: public_key.parameters.carry_modulus, + }) + } +} diff --git a/tfhe/src/shortint/engine/server_side/add.rs b/tfhe/src/shortint/engine/server_side/add.rs new file mode 100644 index 000000000..5a5855c8b --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/add.rs @@ -0,0 +1,59 @@ +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_add( + &mut self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_add_assign(&mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn unchecked_add_assign( + &mut self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + self.engine + .fuse_add_lwe_ciphertext(&mut ct_left.ct, &ct_right.ct)?; + ct_left.degree = Degree(ct_left.degree.0 + ct_right.degree.0); + Ok(()) + } + + pub(crate) fn smart_add( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_add_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_add_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + //If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !server_key.is_add_possible(ct_left, ct_right) { + if ct_left.message_modulus.0 - 1 + ct_right.degree.0 <= server_key.max_degree.0 { + self.message_extract_assign(server_key, ct_left)?; + } else if ct_right.message_modulus.0 - 1 + ct_left.degree.0 <= server_key.max_degree.0 { + self.message_extract_assign(server_key, ct_right)?; + } else { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + } + self.unchecked_add_assign(ct_left, ct_right)?; + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/bitwise_op.rs b/tfhe/src/shortint/engine/server_side/bitwise_op.rs new file mode 100644 index 000000000..36f794a1d --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/bitwise_op.rs @@ -0,0 +1,154 @@ +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_bitand( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_bitand_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn unchecked_bitand_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (x / modulus) & (x % modulus) + })?; + ct_left.degree = ct_left.degree.after_bitand(ct_right.degree); + Ok(()) + } + + pub(crate) fn smart_bitand( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_bitand_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_bitand_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_bitand_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_bitxor( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_bitxor_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn unchecked_bitxor_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (x / modulus) ^ (x % modulus) + })?; + ct_left.degree = ct_left.degree.after_bitxor(ct_right.degree); + Ok(()) + } + + pub(crate) fn smart_bitxor( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_bitxor_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_bitxor_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_bitxor_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_bitor( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_bitor_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn unchecked_bitor_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (x / modulus) | (x % modulus) + })?; + ct_left.degree = ct_left.degree.after_bitor(ct_right.degree); + Ok(()) + } + + pub(crate) fn smart_bitor( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_bitor_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_bitor_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_bitor_assign(server_key, ct_left, ct_right)?; + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/comp_op.rs b/tfhe/src/shortint/engine/server_side/comp_op.rs new file mode 100644 index 000000000..eb26fb8d7 --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/comp_op.rs @@ -0,0 +1,463 @@ +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_greater( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_greater_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + fn unchecked_greater_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let modulus_msg = ct_left.message_modulus.0 as u64; + let large_mod = modulus * modulus_msg; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (((x % large_mod / modulus) % modulus_msg) > (x % modulus_msg)) as u64 + })?; + + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_greater( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_greater_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_greater_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + + self.unchecked_greater_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_greater_or_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_greater_or_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + fn unchecked_greater_or_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let modulus_msg = ct_left.message_modulus.0 as u64; + let large_mod = modulus * modulus_msg; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (((x % large_mod / modulus) % modulus_msg) >= (x % modulus_msg)) as u64 + })?; + + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_greater_or_equal( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_greater_or_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_greater_or_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_greater_or_equal_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_less( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_less_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + fn unchecked_less_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let modulus_msg = ct_left.message_modulus.0 as u64; + let large_mod = modulus * modulus_msg; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (((x % large_mod / modulus) % modulus_msg) < (x % modulus_msg)) as u64 + })?; + + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_less( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_less_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_less_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_less_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_less_or_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_less_or_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + fn unchecked_less_or_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let modulus_msg = ct_left.message_modulus.0 as u64; + let large_mod = modulus * modulus_msg; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + (((x % large_mod / modulus) % modulus_msg) <= (x % modulus_msg)) as u64 + })?; + + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_less_or_equal( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_less_or_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_less_or_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_less_or_equal_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + fn unchecked_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let modulus_msg = ct_left.message_modulus.0 as u64; + let large_mod = modulus * modulus_msg; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + ((((x % large_mod) / modulus) % modulus_msg) == (x % modulus_msg)) as u64 + })?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_equal( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_equal_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn smart_scalar_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_scalar_equal_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + fn smart_scalar_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let modulus = ct_left.message_modulus.0 as u64; + let acc = + self.generate_accumulator(server_key, |x| (x % modulus == scalar as u64) as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn unchecked_not_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_not_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + fn unchecked_not_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let modulus_msg = ct_left.message_modulus.0 as u64; + let large_mod = modulus * modulus_msg; + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + ((((x % large_mod) / modulus) % modulus_msg) != (x % modulus_msg)) as u64 + })?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_not_equal( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_not_equal_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_not_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_not_equal_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn smart_scalar_not_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_scalar_not_equal_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + fn smart_scalar_not_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let modulus = ct_left.message_modulus.0 as u64; + let acc = + self.generate_accumulator(server_key, |x| (x % modulus != scalar as u64) as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_scalar_greater_or_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_scalar_greater_or_equal_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + fn smart_scalar_greater_or_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let acc = self.generate_accumulator(server_key, |x| (x >= scalar as u64) as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_scalar_less_or_equal( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_scalar_less_or_equal_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + fn smart_scalar_less_or_equal_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let acc = self.generate_accumulator(server_key, |x| (x <= scalar as u64) as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_scalar_greater( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_scalar_greater_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + fn smart_scalar_greater_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let acc = self.generate_accumulator(server_key, |x| (x > scalar as u64) as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree.0 = 1; + Ok(()) + } + + pub(crate) fn smart_scalar_less( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_scalar_less_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + fn smart_scalar_less_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let acc = self.generate_accumulator(server_key, |x| (x < scalar as u64) as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree.0 = 1; + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/div_mod.rs b/tfhe/src/shortint/engine/server_side/div_mod.rs new file mode 100644 index 000000000..95598dd3d --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/div_mod.rs @@ -0,0 +1,130 @@ +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +// Specific division function returning 0 in case of a division by 0 +pub(crate) fn division(x: u64, modulus: u64) -> u64 { + if x % modulus == 0 { + 0 + } else { + (x / modulus) / (x % modulus) + } +} + +impl ShortintEngine { + pub(crate) fn unchecked_div( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_div_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn unchecked_div_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + + //In this case the degree of the result is equal to the degree of ct_left + self.unchecked_functional_bivariate_pbs_assign(server_key, ct_left, ct_right, |x| { + division(x, modulus) + })?; + Ok(()) + } + + pub(crate) fn smart_div( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_div_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_div_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + if ct_left.message_modulus.0 + ct_right.degree.0 <= server_key.max_degree.0 { + self.message_extract_assign(server_key, ct_left)?; + } else if ct_right.message_modulus.0 + (ct_left.degree.0 + 1) <= server_key.max_degree.0 + { + self.message_extract_assign(server_key, ct_right)?; + } else { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + } + self.unchecked_div_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + /// # Panics + /// + /// This function will panic if `scalar == 0` + pub(crate) fn unchecked_scalar_div( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut result = ct.clone(); + self.unchecked_scalar_div_assign(server_key, &mut result, scalar)?; + Ok(result) + } + + /// # Panics + /// + /// This function will panic if `scalar == 0` + pub(crate) fn unchecked_scalar_div_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + assert_ne!(scalar, 0); + //generate the accumulator for the multiplication + let acc = self.generate_accumulator(server_key, |x| x / (scalar as u64))?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + ct.degree = Degree(ct.degree.0 / scalar as usize); + Ok(()) + } + + pub(crate) fn unchecked_scalar_mod( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + modulus: u8, + ) -> EngineResult { + let mut result = ct.clone(); + self.unchecked_scalar_mod_assign(server_key, &mut result, modulus)?; + Ok(result) + } + + /// # Panics + /// + /// This function will panic if `modulus == 0` + pub(crate) fn unchecked_scalar_mod_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + modulus: u8, + ) -> EngineResult<()> { + assert_ne!(modulus, 0); + let acc = self.generate_accumulator(server_key, |x| x % modulus as u64)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + ct.degree = Degree(modulus as usize - 1); + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/mod.rs b/tfhe/src/shortint/engine/server_side/mod.rs new file mode 100644 index 000000000..685bdcf3f --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/mod.rs @@ -0,0 +1,412 @@ +use super::ShortintEngine; +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::EngineResult; +use crate::shortint::server_key::MaxDegree; +use crate::shortint::{Ciphertext, ClientKey, ServerKey}; +use std::cmp::min; + +mod add; +mod bitwise_op; +mod comp_op; +mod div_mod; +mod mul; +mod neg; +mod scalar_add; +mod scalar_mul; +mod scalar_sub; +mod shift; +mod sub; + +impl ShortintEngine { + pub(crate) fn new_server_key(&mut self, cks: &ClientKey) -> EngineResult { + // Plaintext Max Value + let max_value = cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0 - 1; + + // The maximum number of operations before we need to clean the carry buffer + let max = MaxDegree(max_value); + self.new_server_key_with_max_degree(cks, max) + } + + pub(crate) fn new_server_key_with_max_degree( + &mut self, + cks: &ClientKey, + max_degree: MaxDegree, + ) -> EngineResult { + // Convert into a variance for rlwe context + let var_rlwe = Variance(cks.parameters.glwe_modular_std_dev.get_variance()); + + let bootstrap_key: LweBootstrapKey64 = self.par_engine.generate_new_lwe_bootstrap_key( + &cks.lwe_secret_key_after_ks, + &cks.glwe_secret_key, + cks.parameters.pbs_base_log, + cks.parameters.pbs_level, + var_rlwe, + )?; + + // Creation of the bootstrapping key in the Fourier domain + + let fourier_bsk: FftFourierLweBootstrapKey64 = + self.fft_engine.convert_lwe_bootstrap_key(&bootstrap_key)?; + + // Convert into a variance for lwe context + let var_lwe = Variance(cks.parameters.lwe_modular_std_dev.get_variance()); + + // Creation of the key switching key + let ksk = self.engine.generate_new_lwe_keyswitch_key( + &cks.lwe_secret_key, + &cks.lwe_secret_key_after_ks, + cks.parameters.ks_level, + cks.parameters.ks_base_log, + var_lwe, + )?; + + // Pack the keys in the server key set: + Ok(ServerKey { + key_switching_key: ksk, + bootstrapping_key: fourier_bsk, + message_modulus: cks.parameters.message_modulus, + carry_modulus: cks.parameters.carry_modulus, + max_degree, + }) + } + + pub(crate) fn generate_accumulator( + &mut self, + server_key: &ServerKey, + f: F, + ) -> EngineResult + where + F: Fn(u64) -> u64, + { + Self::generate_accumulator_with_engine(&mut self.engine, server_key, f) + } + + pub(crate) fn keyswitch_bootstrap( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + ) -> EngineResult { + let mut ct_in = ct.clone(); + self.keyswitch_bootstrap_assign(server_key, &mut ct_in)?; + Ok(ct_in) + } + + pub(crate) fn keyswitch_bootstrap_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult<()> { + // Compute the programmable bootstrapping with fixed test polynomial + let (buffers, engine, fft_engine) = self.buffers_for_key(server_key); + + // Compute a keyswitch + engine.discard_keyswitch_lwe_ciphertext( + &mut buffers.buffer_lwe_after_ks, + &ct.ct, + &server_key.key_switching_key, + )?; + + // Compute a bootstrap + fft_engine.discard_bootstrap_lwe_ciphertext( + &mut ct.ct, + &buffers.buffer_lwe_after_ks, + &buffers.accumulator, + &server_key.bootstrapping_key, + )?; + Ok(()) + } + + pub(crate) fn programmable_bootstrap_keyswitch( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult { + let mut ct_res = ct.clone(); + self.programmable_bootstrap_keyswitch_assign(server_key, &mut ct_res, acc)?; + Ok(ct_res) + } + + pub(crate) fn programmable_bootstrap_keyswitch_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult<()> { + // Compute the programmable bootstrapping with fixed test polynomial + let (buffers, engine, fftw_engine) = self.buffers_for_key(server_key); + + // Compute a key switch + engine.discard_keyswitch_lwe_ciphertext( + &mut buffers.buffer_lwe_after_ks, + &ct.ct, + &server_key.key_switching_key, + )?; + + // Compute a bootstrap + fftw_engine.discard_bootstrap_lwe_ciphertext( + &mut ct.ct, + &buffers.buffer_lwe_after_ks, + acc, + &server_key.bootstrapping_key, + )?; + Ok(()) + } + + pub(crate) fn programmable_bootstrap_keyswitch_bivariate( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult { + let mut ct_res = ct_left.clone(); + self.programmable_bootstrap_keyswitch_bivariate_assign( + server_key, + &mut ct_res, + ct_right, + acc, + )?; + Ok(ct_res) + } + + pub(crate) fn programmable_bootstrap_keyswitch_bivariate_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + + // Message 1 is shifted to the carry bits + self.unchecked_scalar_mul_assign(ct_left, modulus as u8)?; + + // Message 2 is placed in the message bits + self.unchecked_add_assign(ct_left, ct_right)?; + + // Compute the PBS + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, acc)?; + + Ok(()) + } + + pub(crate) fn generate_accumulator_bivariate( + &mut self, + server_key: &ServerKey, + f: F, + ) -> EngineResult + where + F: Fn(u64, u64) -> u64, + { + Self::generate_accumulator_bivariate_with_engine(&mut self.engine, server_key, f) + } + + pub(crate) fn unchecked_functional_bivariate_pbs( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + f: F, + ) -> EngineResult + where + F: Fn(u64) -> u64, + { + let mut ct_res = ct_left.clone(); + self.unchecked_functional_bivariate_pbs_assign(server_key, &mut ct_res, ct_right, f)?; + Ok(ct_res) + } + + pub(crate) fn unchecked_functional_bivariate_pbs_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + f: F, + ) -> EngineResult<()> + where + F: Fn(u64) -> u64, + { + let modulus = (ct_right.degree.0 + 1) as u64; + + // Message 1 is shifted to the carry bits + self.unchecked_scalar_mul_assign(ct_left, modulus as u8)?; + + // Message 2 is placed in the message bits + self.unchecked_add_assign(ct_left, ct_right)?; + + // Generate the accumulator for the function + let acc = self.generate_accumulator(server_key, f)?; + + // Compute the PBS + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + Ok(()) + } + + // Those are currently not used in shortint, we therefore disable the warning when not compiling + // the C API + #[cfg_attr(not(feature = "__c_api"), allow(dead_code))] + pub(crate) fn smart_bivariate_pbs( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &mut Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult { + let mut ct_res = ct_left.clone(); + self.smart_bivariate_pbs_assign(server_key, &mut ct_res, ct_right, acc)?; + Ok(ct_res) + } + + #[cfg_attr(not(feature = "__c_api"), allow(dead_code))] + pub(crate) fn smart_bivariate_pbs_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult<()> { + if !server_key.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + + self.unchecked_bivariate_pbs_assign(server_key, ct_left, ct_right, acc) + } + + #[cfg_attr(not(feature = "__c_api"), allow(dead_code))] + pub(crate) fn unchecked_bivariate_pbs_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + acc: &GlweCiphertext64, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + + // Message 1 is shifted to the carry bits + self.unchecked_scalar_mul_assign(ct_left, modulus as u8)?; + + // Message 2 is placed in the message bits + self.unchecked_add_assign(ct_left, ct_right)?; + + // Compute the PBS + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, acc)?; + Ok(()) + } + + pub(crate) fn carry_extract_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult<()> { + let modulus = ct.message_modulus.0 as u64; + + let accumulator = self.generate_accumulator(server_key, |x| x / modulus)?; + + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &accumulator)?; + + // The degree of the carry + ct.degree = Degree(min(modulus - 1, ct.degree.0 as u64 / modulus) as usize); + Ok(()) + } + + pub(crate) fn carry_extract( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + ) -> EngineResult { + let mut result = ct.clone(); + self.carry_extract_assign(server_key, &mut result)?; + Ok(result) + } + + pub(crate) fn message_extract_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult<()> { + let modulus = ct.message_modulus.0 as u64; + + let acc = self.generate_accumulator(server_key, |x| x % modulus)?; + + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + + ct.degree = Degree(ct.message_modulus.0 - 1); + Ok(()) + } + + pub(crate) fn message_extract( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + ) -> EngineResult { + let mut result = ct.clone(); + self.message_extract_assign(server_key, &mut result)?; + Ok(result) + } + + // Impossible to call the assign function in this case + pub(crate) fn create_trivial( + &mut self, + server_key: &ServerKey, + value: u8, + ) -> EngineResult { + let lwe_size = server_key + .bootstrapping_key + .output_lwe_dimension() + .to_lwe_size(); + + let modular_value = value as usize % server_key.message_modulus.0; + + let delta = + (1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64; + + let shifted_value = (modular_value as u64) * delta; + + let plaintext = self.engine.create_plaintext_from(&shifted_value).unwrap(); + + let ct = self + .engine + .trivially_encrypt_lwe_ciphertext(lwe_size, &plaintext) + .unwrap(); + + let degree = Degree(modular_value); + + Ok(Ciphertext { + ct, + degree, + message_modulus: server_key.message_modulus, + carry_modulus: server_key.carry_modulus, + }) + } + + pub(crate) fn create_trivial_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + value: u8, + ) -> EngineResult<()> { + let lwe_size = server_key + .bootstrapping_key + .input_lwe_dimension() + .to_lwe_size(); + + let modular_value = value as usize % server_key.message_modulus.0; + + let delta = + (1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64; + + let shifted_value = (modular_value as u64) * delta; + + let plaintext = self.engine.create_plaintext_from(&shifted_value).unwrap(); + + ct.ct = self + .engine + .trivially_encrypt_lwe_ciphertext(lwe_size, &plaintext) + .unwrap(); + ct.degree = Degree(modular_value); + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/mul.rs b/tfhe/src/shortint/engine/server_side/mul.rs new file mode 100644 index 000000000..e2a634e7f --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/mul.rs @@ -0,0 +1,193 @@ +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_mul_lsb( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_mul_lsb_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn unchecked_mul_lsb_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + + //message 1 is shifted to the carry bits + self.unchecked_scalar_mul_assign(ct_left, modulus as u8)?; + + //message 2 is placed in the message bits + self.unchecked_add_assign(ct_left, ct_right)?; + + //Modulus of the msg in the msg bits + let res_modulus = ct_left.message_modulus.0 as u64; + + //generate the accumulator for the multiplication + let acc = self.generate_accumulator(server_key, |x| { + ((x / modulus) * (x % modulus)) % res_modulus + })?; + + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + ct_left.degree = Degree(ct_left.message_modulus.0 - 1); + Ok(()) + } + + pub(crate) fn unchecked_mul_msb( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_mul_msb_assign(server_key, &mut result, ct_right)?; + + Ok(result) + } + + pub(crate) fn unchecked_mul_msb_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + let modulus = (ct_right.degree.0 + 1) as u64; + let deg = (ct_left.degree.0 * ct_right.degree.0) / ct_right.message_modulus.0; + + // Message 1 is shifted to the carry bits + self.unchecked_scalar_mul_assign(ct_left, modulus as u8)?; + + // Message 2 is placed in the message bits + self.unchecked_add_assign(ct_left, ct_right)?; + + // Modulus of the msg in the msg bits + let res_modulus = server_key.message_modulus.0 as u64; + + // Generate the accumulator for the multiplication + let acc = self.generate_accumulator(server_key, |x| { + ((x / modulus) * (x % modulus)) / res_modulus + })?; + + self.programmable_bootstrap_keyswitch_assign(server_key, ct_left, &acc)?; + + ct_left.degree = Degree(deg); + Ok(()) + } + + pub(crate) fn unchecked_mul_lsb_small_carry_modulus( + &mut self, + server_key: &ServerKey, + ct1: &mut Ciphertext, + ct2: &mut Ciphertext, + ) -> EngineResult { + //ct1 + ct2 + let mut ct_tmp_left = self.unchecked_add(ct1, ct2)?; + + //ct1-ct2 + let (mut ct_tmp_right, z) = self.unchecked_sub_with_z(server_key, ct1, ct2)?; + + //Modulus of the msg in the msg bits + let modulus = ct1.message_modulus.0 as u64; + + let acc_add = self.generate_accumulator(server_key, |x| ((x * x) / 4) % modulus)?; + let acc_sub = + self.generate_accumulator(server_key, |x| (((x - z) * (x - z)) / 4) % modulus)?; + + self.programmable_bootstrap_keyswitch_assign(server_key, &mut ct_tmp_left, &acc_add)?; + self.programmable_bootstrap_keyswitch_assign(server_key, &mut ct_tmp_right, &acc_sub)?; + + //Last subtraction might fill one bit of carry + self.unchecked_sub(server_key, &ct_tmp_left, &ct_tmp_right) + } + + pub(crate) fn unchecked_mul_lsb_small_carry_modulus_assign( + &mut self, + server_key: &ServerKey, + ct1: &mut Ciphertext, + ct2: &mut Ciphertext, + ) -> EngineResult<()> { + *ct1 = self.unchecked_mul_lsb_small_carry_modulus(server_key, ct1, ct2)?; + Ok(()) + } + + pub(crate) fn smart_mul_lsb_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + //Choice of the multiplication algorithm depending on the parameters + if ct_left.message_modulus.0 > ct_left.carry_modulus.0 { + //If the ciphertext cannot be added together without exceeding the capacity of a + // ciphertext + if !server_key.is_mul_small_carry_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_mul_lsb_small_carry_modulus_assign(server_key, ct_left, ct_right)?; + } else { + //If the ciphertext cannot be added together without exceeding the capacity of a + // ciphertext + if !server_key.is_mul_possible(ct_left, ct_right) { + if server_key.message_modulus.0 * (ct_right.degree.0 + 1) + < (ct_right.carry_modulus.0 * ct_right.message_modulus.0 - 1) + { + self.message_extract_assign(server_key, ct_left)?; + } else if (server_key.message_modulus.0 + 1) + (ct_left.degree.0 + 1) + < (ct_right.carry_modulus.0 * ct_right.message_modulus.0 - 1) + { + self.message_extract_assign(server_key, ct_right)?; + } else { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + } + self.unchecked_mul_lsb_assign(server_key, ct_left, ct_right)?; + } + Ok(()) + } + + pub(crate) fn smart_mul_lsb( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_mul_lsb_assign(server_key, &mut result, ct_right)?; + Ok(result) + } + + pub(crate) fn smart_mul_msb_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + if !server_key.is_mul_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + self.unchecked_mul_msb_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn smart_mul_msb( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.smart_mul_msb_assign(server_key, &mut result, ct_right)?; + Ok(result) + } +} diff --git a/tfhe/src/shortint/engine/server_side/neg.rs b/tfhe/src/shortint/engine/server_side/neg.rs new file mode 100644 index 000000000..a399cafed --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/neg.rs @@ -0,0 +1,89 @@ +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_neg( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + ) -> EngineResult { + let mut result = ct.clone(); + self.unchecked_neg_assign(server_key, &mut result)?; + Ok(result) + } + + pub(crate) fn unchecked_neg_with_z( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + ) -> EngineResult<(Ciphertext, u64)> { + let mut result = ct.clone(); + let z = self.unchecked_neg_assign_with_z(server_key, &mut result)?; + Ok((result, z)) + } + + pub(crate) fn unchecked_neg_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult<()> { + let _z = self.unchecked_neg_assign_with_z(server_key, ct)?; + Ok(()) + } + + pub(crate) fn unchecked_neg_assign_with_z( + &mut self, + _server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult { + // z = ceil( degree / 2^p ) * 2^p + let msg_mod = ct.message_modulus.0; + let mut z = ((ct.degree.0 + msg_mod - 1) / msg_mod) as u64; + z *= msg_mod as u64; + + // Value of the shift we multiply our messages by + let delta = + (1_u64 << 63) / (_server_key.message_modulus.0 * _server_key.carry_modulus.0) as u64; + + //Scaling + 1 on the padding bit + let w = z * delta; + + // (0,Delta*z) - ct + self.engine.fuse_opp_lwe_ciphertext(&mut ct.ct)?; + + let clear_w = self.engine.create_plaintext_from(&w)?; + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut ct.ct, &clear_w)?; + + // Update the degree + ct.degree = Degree(z as usize); + + Ok(z) + } + + pub(crate) fn smart_neg( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if !server_key.is_neg_possible(ct) { + self.keyswitch_bootstrap_assign(server_key, ct)?; + } + self.unchecked_neg(server_key, ct) + } + + pub(crate) fn smart_neg_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + ) -> EngineResult<()> { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if !server_key.is_neg_possible(ct) { + self.keyswitch_bootstrap_assign(server_key, ct)?; + } + self.unchecked_neg_assign(server_key, ct) + } +} diff --git a/tfhe/src/shortint/engine/server_side/scalar_add.rs b/tfhe/src/shortint/engine/server_side/scalar_add.rs new file mode 100644 index 000000000..4b4a31b4f --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/scalar_add.rs @@ -0,0 +1,79 @@ +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_scalar_add( + &mut self, + ct: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut ct_result = ct.clone(); + self.unchecked_scalar_add_assign(&mut ct_result, scalar)?; + Ok(ct_result) + } + + pub(crate) fn unchecked_scalar_add_assign( + &mut self, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let delta = (1_u64 << 63) / (ct.message_modulus.0 * ct.carry_modulus.0) as u64; + let shift_plaintext = u64::from(scalar) * delta; + let plaintext_scalar = self.engine.create_plaintext_from(&shift_plaintext).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut ct.ct, &plaintext_scalar)?; + + ct.degree = Degree(ct.degree.0 + scalar as usize); + Ok(()) + } + + pub(crate) fn unchecked_scalar_add_assign_crt( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let delta = + (1_u64 << 63) / (server_key.message_modulus.0 * server_key.carry_modulus.0) as u64; + let shift_plaintext = u64::from(scalar) * delta; + let plaintext_scalar = self.engine.create_plaintext_from(&shift_plaintext).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut ct.ct, &plaintext_scalar)?; + + ct.degree = Degree(ct.degree.0 + scalar as usize); + Ok(()) + } + + pub(crate) fn smart_scalar_add( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut ct_result = ct.clone(); + self.smart_scalar_add_assign(server_key, &mut ct_result, scalar)?; + + Ok(ct_result) + } + + pub(crate) fn smart_scalar_add_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let modulus = server_key.message_modulus.0 as u64; + // Direct scalar computation is possible + if server_key.is_scalar_add_possible(ct, scalar) { + self.unchecked_scalar_add_assign(ct, scalar)?; + } else { + // If the scalar is too large, PBS is used to compute the scalar mul + let acc = self.generate_accumulator(server_key, |x| (scalar as u64 + x) % modulus)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + ct.degree = Degree(server_key.message_modulus.0 - 1); + } + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/scalar_mul.rs b/tfhe/src/shortint/engine/server_side/scalar_mul.rs new file mode 100644 index 000000000..94b69dffb --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/scalar_mul.rs @@ -0,0 +1,64 @@ +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_scalar_mul( + &mut self, + ct: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut ct_result = ct.clone(); + self.unchecked_scalar_mul_assign(&mut ct_result, scalar)?; + + Ok(ct_result) + } + + pub(crate) fn unchecked_scalar_mul_assign( + &mut self, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let scalar = u64::from(scalar); + let cleartext_scalar = self.engine.create_cleartext_from(&scalar).unwrap(); + self.engine + .fuse_mul_lwe_ciphertext_cleartext(&mut ct.ct, &cleartext_scalar)?; + + ct.degree = Degree(ct.degree.0 * scalar as usize); + Ok(()) + } + + pub(crate) fn smart_scalar_mul( + &mut self, + server_key: &ServerKey, + ctxt: &mut Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut ct_result = ctxt.clone(); + self.smart_scalar_mul_assign(server_key, &mut ct_result, scalar)?; + + Ok(ct_result) + } + + pub(crate) fn smart_scalar_mul_assign( + &mut self, + server_key: &ServerKey, + ctxt: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let modulus = server_key.message_modulus.0 as u64; + // Direct scalar computation is possible + if server_key.is_scalar_mul_possible(ctxt, scalar) { + self.unchecked_scalar_mul_assign(ctxt, scalar)?; + ctxt.degree = Degree(ctxt.degree.0 * scalar as usize); + } + // If the ciphertext cannot be multiplied without exceeding the degree max + else { + let acc = self.generate_accumulator(server_key, |x| (scalar as u64 * x) % modulus)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ctxt, &acc)?; + ctxt.degree = Degree(server_key.message_modulus.0 - 1); + } + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/scalar_sub.rs b/tfhe/src/shortint/engine/server_side/scalar_sub.rs new file mode 100644 index 000000000..633dc0fbb --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/scalar_sub.rs @@ -0,0 +1,64 @@ +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_scalar_sub( + &mut self, + ct: &Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut ct_result = ct.clone(); + self.unchecked_scalar_sub_assign(&mut ct_result, scalar)?; + Ok(ct_result) + } + + pub(crate) fn unchecked_scalar_sub_assign( + &mut self, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let neg_scalar = u64::from(scalar.wrapping_neg()) % ct.message_modulus.0 as u64; + let delta = (1_u64 << 63) / (ct.message_modulus.0 * ct.carry_modulus.0) as u64; + let shift_plaintext = neg_scalar * delta; + let plaintext_scalar = self.engine.create_plaintext_from(&shift_plaintext).unwrap(); + self.engine + .fuse_add_lwe_ciphertext_plaintext(&mut ct.ct, &plaintext_scalar)?; + + ct.degree = Degree(ct.degree.0 + neg_scalar as usize); + Ok(()) + } + + pub(crate) fn smart_scalar_sub( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult { + let mut ct_result = ct.clone(); + self.smart_scalar_sub_assign(server_key, &mut ct_result, scalar)?; + + Ok(ct_result) + } + + pub(crate) fn smart_scalar_sub_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + scalar: u8, + ) -> EngineResult<()> { + let modulus = server_key.message_modulus.0 as u64; + // Direct scalar computation is possible + if server_key.is_scalar_sub_possible(ct, scalar) { + self.unchecked_scalar_sub_assign(ct, scalar)?; + } else { + let scalar = u64::from(scalar); + // If the scalar is too large, PBS is used to compute the scalar mul + let acc = self.generate_accumulator(server_key, |x| (x - scalar) % modulus)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + ct.degree = Degree(server_key.message_modulus.0 - 1); + } + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/shift.rs b/tfhe/src/shortint/engine/server_side/shift.rs new file mode 100644 index 000000000..fe9093627 --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/shift.rs @@ -0,0 +1,77 @@ +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_scalar_right_shift( + &mut self, + server_key: &ServerKey, + ct: &Ciphertext, + shift: u8, + ) -> EngineResult { + let mut result = ct.clone(); + self.unchecked_scalar_right_shift_assign(server_key, &mut result, shift)?; + Ok(result) + } + + pub(crate) fn unchecked_scalar_right_shift_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + shift: u8, + ) -> EngineResult<()> { + let acc = self.generate_accumulator(server_key, |x| x >> shift)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + + ct.degree = Degree(ct.degree.0 >> shift); + Ok(()) + } + + pub(crate) fn unchecked_scalar_left_shift( + &mut self, + ct: &Ciphertext, + shift: u8, + ) -> EngineResult { + let mut result = ct.clone(); + self.unchecked_scalar_left_shift_assign(&mut result, shift)?; + Ok(result) + } + + pub(crate) fn unchecked_scalar_left_shift_assign( + &mut self, + ct: &mut Ciphertext, + shift: u8, + ) -> EngineResult<()> { + let scalar = 1_u8 << shift; + self.unchecked_scalar_mul_assign(ct, scalar)?; + Ok(()) + } + + pub(crate) fn smart_scalar_left_shift( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + shift: u8, + ) -> EngineResult { + let mut result = ct.clone(); + self.smart_scalar_left_shift_assign(server_key, &mut result, shift)?; + Ok(result) + } + + pub(crate) fn smart_scalar_left_shift_assign( + &mut self, + server_key: &ServerKey, + ct: &mut Ciphertext, + shift: u8, + ) -> EngineResult<()> { + if server_key.is_scalar_left_shift_possible(ct, shift) { + self.unchecked_scalar_left_shift_assign(ct, shift)?; + } else { + let modulus = server_key.message_modulus.0 as u64; + let acc = self.generate_accumulator(server_key, |x| (x << shift) % modulus)?; + self.programmable_bootstrap_keyswitch_assign(server_key, ct, &acc)?; + ct.degree = ct.degree.after_left_shift(shift, modulus as usize); + } + Ok(()) + } +} diff --git a/tfhe/src/shortint/engine/server_side/sub.rs b/tfhe/src/shortint/engine/server_side/sub.rs new file mode 100644 index 000000000..1c7197ad2 --- /dev/null +++ b/tfhe/src/shortint/engine/server_side/sub.rs @@ -0,0 +1,101 @@ +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::{Ciphertext, ServerKey}; + +impl ShortintEngine { + pub(crate) fn unchecked_sub( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let mut result = ct_left.clone(); + self.unchecked_sub_assign(server_key, &mut result, ct_right)?; + + Ok(result) + } + + pub(crate) fn unchecked_sub_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<()> { + self.unchecked_sub_assign_with_z(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn unchecked_sub_with_z( + &mut self, + server_key: &ServerKey, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult<(Ciphertext, u64)> { + let mut result = ct_left.clone(); + let z = self.unchecked_sub_assign_with_z(server_key, &mut result, ct_right)?; + + Ok((result, z)) + } + + pub(crate) fn unchecked_sub_assign_with_z( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> EngineResult { + let (neg_right, z) = self.unchecked_neg_with_z(server_key, ct_right)?; + + self.engine + .fuse_add_lwe_ciphertext(&mut ct_left.ct, &neg_right.ct)?; + + ct_left.degree = Degree(ct_left.degree.0 + z as usize); + + Ok(z) + } + + pub(crate) fn smart_sub( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult { + // If the ciphertext cannot be subtracted together without exceeding the degree max + if !server_key.is_sub_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_right)?; + self.message_extract_assign(server_key, ct_left)?; + } + self.unchecked_sub(server_key, ct_left, ct_right) + } + + pub(crate) fn smart_sub_assign( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<()> { + // If the ciphertext cannot be subtracted together without exceeding the degree max + if !server_key.is_sub_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_right)?; + self.message_extract_assign(server_key, ct_left)?; + } + + self.unchecked_sub_assign(server_key, ct_left, ct_right)?; + Ok(()) + } + + pub(crate) fn smart_sub_with_z( + &mut self, + server_key: &ServerKey, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> EngineResult<(Ciphertext, u64)> { + //If the ciphertext cannot be added together without exceeding the capacity of a ciphertext + if !server_key.is_sub_possible(ct_left, ct_right) { + self.message_extract_assign(server_key, ct_left)?; + self.message_extract_assign(server_key, ct_right)?; + } + + self.unchecked_sub_with_z(server_key, ct_left, ct_right) + } +} diff --git a/tfhe/src/shortint/engine/wopbs/mod.rs b/tfhe/src/shortint/engine/wopbs/mod.rs new file mode 100644 index 000000000..a70307375 --- /dev/null +++ b/tfhe/src/shortint/engine/wopbs/mod.rs @@ -0,0 +1,479 @@ +//! # WARNING: this module is experimental. +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::{EngineResult, ShortintEngine}; +use crate::shortint::wopbs::WopbsKey; +use crate::shortint::{Ciphertext, ClientKey, Parameters, ServerKey}; + +use crate::core_crypto::prelude::*; +use crate::shortint::server_key::MaxDegree; + +impl ShortintEngine { + // Creates a key when ONLY a wopbs is used. + pub(crate) fn new_wopbs_key_only_for_wopbs( + &mut self, + cks: &ClientKey, + sks: &ServerKey, + ) -> EngineResult { + let cbs_pfpksk = self + .engine + .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &cks.lwe_secret_key, + &cks.glwe_secret_key, + cks.parameters.pfks_base_log, + cks.parameters.pfks_level, + Variance(cks.parameters.pfks_modular_std_dev.get_variance()), + )?; + + let sks_cpy = sks.clone(); + + let wopbs_key = WopbsKey { + wopbs_server_key: sks_cpy.clone(), + cbs_pfpksk, + ksk_pbs_to_wopbs: sks.key_switching_key.clone(), + param: cks.parameters, + pbs_server_key: sks_cpy, + }; + Ok(wopbs_key) + } + + //Creates a new WoPBS key. + pub(crate) fn new_wopbs_key( + &mut self, + cks: &ClientKey, + sks: &ServerKey, + parameters: &Parameters, + ) -> EngineResult { + //Independent client key generation dedicated to the WoPBS + let small_lwe_secret_key: LweSecretKey64 = self + .engine + .generate_new_lwe_secret_key(parameters.lwe_dimension)?; + + let glwe_secret_key: GlweSecretKey64 = self + .engine + .generate_new_glwe_secret_key(parameters.glwe_dimension, parameters.polynomial_size)?; + + let large_lwe_secret_key = self + .engine + .transform_glwe_secret_key_to_lwe_secret_key(glwe_secret_key.clone())?; + + //BSK dedicated to the WoPBS + let var_rlwe = Variance(parameters.glwe_modular_std_dev.get_variance()); + + let bootstrap_key: LweBootstrapKey64 = self.par_engine.generate_new_lwe_bootstrap_key( + &small_lwe_secret_key, + &glwe_secret_key, + parameters.pbs_base_log, + parameters.pbs_level, + var_rlwe, + )?; + + // Creation of the bootstrapping key in the Fourier domain + let small_bsk: FftFourierLweBootstrapKey64 = + self.fft_engine.convert_lwe_bootstrap_key(&bootstrap_key)?; + + // Convert into a variance for lwe context + let var_lwe = Variance(parameters.lwe_modular_std_dev.get_variance()); + //KSK encryption_key -> small WoPBS key (used in the 1st KS in the extract bit) + let ksk_wopbs_large_to_wopbs_small = self.engine.generate_new_lwe_keyswitch_key( + &large_lwe_secret_key, + &small_lwe_secret_key, + parameters.ks_level, + parameters.ks_base_log, + var_lwe, + )?; + + //KSK to convert from input ciphertext key to the wopbs input one + //let var_lwe = Variance(cks.parameters.lwe_modular_std_dev.get_variance()); + let ksk_pbs_large_to_wopbs_large = self.engine.generate_new_lwe_keyswitch_key( + &cks.lwe_secret_key, + &large_lwe_secret_key, + cks.parameters.ks_level, + cks.parameters.ks_base_log, + var_lwe, + )?; + + //KSK large_wopbs_key -> small PBS key (used after the WoPBS computation to compute a + // classical PBS. This allows compatibility between PBS and WoPBS + let var_lwe_pbs = Variance(cks.parameters.lwe_modular_std_dev.get_variance()); + let ksk_wopbs_large_to_pbs_small = self.engine.generate_new_lwe_keyswitch_key( + &large_lwe_secret_key, + &cks.lwe_secret_key_after_ks, + cks.parameters.ks_level, + cks.parameters.ks_base_log, + var_lwe_pbs, + )?; + + let cbs_pfpksk = self + .engine + .generate_new_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys( + &large_lwe_secret_key, + &glwe_secret_key, + parameters.pfks_base_log, + parameters.pfks_level, + Variance(parameters.pfks_modular_std_dev.get_variance()), + )?; + + let wopbs_server_key = ServerKey { + key_switching_key: ksk_wopbs_large_to_wopbs_small, + bootstrapping_key: small_bsk, + message_modulus: parameters.message_modulus, + carry_modulus: parameters.carry_modulus, + max_degree: MaxDegree(parameters.message_modulus.0 * parameters.carry_modulus.0 - 1), + }; + + let pbs_server_key = ServerKey { + key_switching_key: ksk_wopbs_large_to_pbs_small, + bootstrapping_key: sks.bootstrapping_key.clone(), + message_modulus: cks.parameters.message_modulus, + carry_modulus: cks.parameters.carry_modulus, + max_degree: MaxDegree( + cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0 - 1, + ), + }; + + let wopbs_key = WopbsKey { + wopbs_server_key, + pbs_server_key, + cbs_pfpksk, + ksk_pbs_to_wopbs: ksk_pbs_large_to_wopbs_large, + param: *parameters, + }; + Ok(wopbs_key) + } + + pub(crate) fn extract_bits( + &mut self, + delta_log: DeltaLog, + lwe_in: &LweCiphertext64, + wopbs_key: &WopbsKey, + extracted_bit_count: ExtractedBitsCount, + ) -> EngineResult { + let server_key = &wopbs_key.wopbs_server_key; + + let lwe_size = server_key + .key_switching_key + .output_lwe_dimension() + .to_lwe_size(); + let mut output = self.engine.create_lwe_ciphertext_vector_from( + vec![0u64; lwe_size.0 * extracted_bit_count.0], + lwe_size, + )?; + + self.fft_engine.discard_extract_bits_lwe_ciphertext( + &mut output, + lwe_in, + &server_key.bootstrapping_key, + &server_key.key_switching_key, + extracted_bit_count, + DeltaLog(delta_log.0), + )?; + Ok(output) + } + + pub(crate) fn circuit_bootstrap_with_bits( + &mut self, + wopbs_key: &WopbsKey, + extracted_bits: &LweCiphertextVectorView64<'_>, + lut: &PlaintextVector64, + count: LweCiphertextCount, + ) -> EngineResult { + let sks = &wopbs_key.wopbs_server_key; + let mut output_cbs_vp_ct_container = + vec![0u64; sks.bootstrapping_key.output_lwe_dimension().to_lwe_size().0 * count.0]; + + let mut output_cbs_vp_ct = self.engine.create_lwe_ciphertext_vector_from( + output_cbs_vp_ct_container.as_mut_slice(), + sks.bootstrapping_key.output_lwe_dimension().to_lwe_size(), + )?; + + self.fft_engine + .discard_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_vector( + &mut output_cbs_vp_ct, + extracted_bits, + &sks.bootstrapping_key, + lut, + wopbs_key.param.cbs_level, + wopbs_key.param.cbs_base_log, + &wopbs_key.cbs_pfpksk, + )?; + + let output_vector = self.engine.create_lwe_ciphertext_vector_from( + output_cbs_vp_ct_container, + sks.bootstrapping_key.output_lwe_dimension().to_lwe_size(), + )?; + + Ok(output_vector) + } + + pub(crate) fn extract_bits_circuit_bootstrapping( + &mut self, + wopbs_key: &WopbsKey, + ct_in: &Ciphertext, + lut: &[u64], + delta_log: DeltaLog, + nb_bit_to_extract: ExtractedBitsCount, + ) -> EngineResult { + let extracted_bits = + self.extract_bits(delta_log, &ct_in.ct, wopbs_key, nb_bit_to_extract)?; + + let extracted_bit_size = extracted_bits.lwe_dimension().to_lwe_size(); + let data = self + .engine + .consume_retrieve_lwe_ciphertext_vector(extracted_bits)?; + let extrated_bits_view = self + .engine + .create_lwe_ciphertext_vector_from(data.as_slice(), extracted_bit_size)?; + + let plaintext_lut = self.engine.create_plaintext_vector_from(lut)?; + + let ciphertext = self.circuit_bootstrap_with_bits( + wopbs_key, + &extrated_bits_view, + &plaintext_lut, + LweCiphertextCount(1), + )?; + + let container = self + .engine + .consume_retrieve_lwe_ciphertext_vector(ciphertext)?; + let ct_out = self.engine.create_lwe_ciphertext_from(container)?; + + let sks = &wopbs_key.wopbs_server_key; + let ct_out = Ciphertext { + ct: ct_out, + degree: Degree(sks.message_modulus.0 - 1), + message_modulus: sks.message_modulus, + carry_modulus: sks.carry_modulus, + }; + + Ok(ct_out) + } + + pub(crate) fn programmable_bootstrapping_without_padding( + &mut self, + wopbs_key: &WopbsKey, + ct_in: &Ciphertext, + lut: &[u64], + ) -> EngineResult { + let sks = &wopbs_key.wopbs_server_key; + let delta = (1_usize << 63) / (sks.message_modulus.0 * sks.carry_modulus.0) * 2; + let delta_log = DeltaLog(f64::log2(delta as f64) as usize); + + let nb_bit_to_extract = + f64::log2((sks.message_modulus.0 * sks.carry_modulus.0) as f64) as usize; + + let ciphertext = self.extract_bits_circuit_bootstrapping( + wopbs_key, + ct_in, + lut, + delta_log, + ExtractedBitsCount(nb_bit_to_extract), + )?; + + Ok(ciphertext) + } + + pub(crate) fn keyswitch_to_wopbs_params( + &mut self, + sks: &ServerKey, + wopbs_key: &WopbsKey, + ct_in: &Ciphertext, + ) -> EngineResult { + // First PBS to remove the noise + let acc = self.generate_accumulator(sks, |x| x)?; + let ct_clean = self.programmable_bootstrap_keyswitch(sks, ct_in, &acc)?; + + // To make borrow checker happy + let engine = &mut self.engine; + let zero_plaintext = engine.create_plaintext_from(&0_u64).unwrap(); + let mut buffer_lwe_after_ks = engine + .trivially_encrypt_lwe_ciphertext( + wopbs_key + .ksk_pbs_to_wopbs + .output_lwe_dimension() + .to_lwe_size(), + &zero_plaintext, + ) + .unwrap(); + // Compute a key switch + engine.discard_keyswitch_lwe_ciphertext( + &mut buffer_lwe_after_ks, + &ct_clean.ct, + &wopbs_key.ksk_pbs_to_wopbs, + )?; + + Ok(Ciphertext { + ct: buffer_lwe_after_ks, + degree: ct_clean.degree, + message_modulus: ct_clean.message_modulus, + carry_modulus: ct_clean.carry_modulus, + }) + } + + pub(crate) fn keyswitch_to_pbs_params( + &mut self, + wopbs_key: &WopbsKey, + ct_in: &Ciphertext, + ) -> EngineResult { + // move to wopbs parameters to pbs parameters + //Keyswitch-PBS: + // 1. KS to go back to the original encryption key + // 2. PBS to remove the noise added by the previous KS + // + let acc = self.generate_accumulator(&wopbs_key.pbs_server_key, |x| x)?; + let (buffers, engine, fftw_engine) = self.buffers_for_key(&wopbs_key.pbs_server_key); + // Compute a key switch + engine.discard_keyswitch_lwe_ciphertext( + &mut buffers.buffer_lwe_after_ks, + &ct_in.ct, + &wopbs_key.pbs_server_key.key_switching_key, + )?; + + let out_lwe_size = wopbs_key + .pbs_server_key + .bootstrapping_key + .output_lwe_dimension() + .to_lwe_size(); + let mut ct_out = engine.create_lwe_ciphertext_from(vec![0; out_lwe_size.0])?; + + // Compute a bootstrap + fftw_engine.discard_bootstrap_lwe_ciphertext( + &mut ct_out, + &buffers.buffer_lwe_after_ks, + &acc, + &wopbs_key.pbs_server_key.bootstrapping_key, + )?; + Ok(Ciphertext { + ct: ct_out, + degree: ct_in.degree, + message_modulus: ct_in.message_modulus, + carry_modulus: ct_in.carry_modulus, + }) + } + + pub(crate) fn wopbs( + &mut self, + wopbs_key: &WopbsKey, + ct_in: &Ciphertext, + lut: &[u64], + ) -> EngineResult { + let tmp_sks = &wopbs_key.wopbs_server_key; + let delta = (1_usize << 63) / (tmp_sks.message_modulus.0 * tmp_sks.carry_modulus.0); + let delta_log = DeltaLog(f64::log2(delta as f64) as usize); + let nb_bit_to_extract = + f64::log2((tmp_sks.message_modulus.0 * tmp_sks.carry_modulus.0) as f64) as usize; + + let ct_out = self.extract_bits_circuit_bootstrapping( + wopbs_key, + ct_in, + lut, + delta_log, + ExtractedBitsCount(nb_bit_to_extract), + )?; + + Ok(ct_out) + } + + pub(crate) fn programmable_bootstrapping( + &mut self, + wopbs_key: &WopbsKey, + sks: &ServerKey, + ct_in: &Ciphertext, + lut: &[u64], + ) -> EngineResult { + let ct_wopbs = self.keyswitch_to_wopbs_params(sks, wopbs_key, ct_in)?; + let result_ct = self.wopbs(wopbs_key, &ct_wopbs, lut)?; + let ct_out = self.keyswitch_to_pbs_params(wopbs_key, &result_ct)?; + + Ok(ct_out) + } + + pub(crate) fn programmable_bootstrapping_native_crt( + &mut self, + wopbs_key: &WopbsKey, + ct_in: &mut Ciphertext, + lut: &[u64], + ) -> EngineResult { + let nb_bit_to_extract = + f64::log2((ct_in.message_modulus.0 * ct_in.carry_modulus.0) as f64).ceil() as usize; + let delta_log = DeltaLog(64 - nb_bit_to_extract); + + // trick ( ct - delta/2 + delta/2^4 ) + let lwe_size = ct_in.ct.lwe_dimension().to_lwe_size().0; + let mut cont = vec![0u64; lwe_size]; + cont[lwe_size - 1] = + (1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5)); + let tmp = self.engine.create_lwe_ciphertext_from(cont)?; + self.engine.fuse_sub_lwe_ciphertext(&mut ct_in.ct, &tmp)?; + + let ciphertext = self.extract_bits_circuit_bootstrapping( + wopbs_key, + ct_in, + lut, + delta_log, + ExtractedBitsCount(nb_bit_to_extract), + )?; + + Ok(ciphertext) + } + + /// Temporary wrapper. + /// + /// # Warning Experimental + pub fn circuit_bootstrapping_vertical_packing( + &mut self, + wopbs_key: &WopbsKey, + vec_lut: Vec>, + extracted_bits_blocks: Vec, + ) -> Vec { + let lwe_size = extracted_bits_blocks[0].lwe_dimension().to_lwe_size(); + + let mut all_datas = vec![]; + for lwe_vec in extracted_bits_blocks.into_iter() { + let data = self + .engine + .consume_retrieve_lwe_ciphertext_vector(lwe_vec) + .unwrap(); + + all_datas.extend_from_slice(data.as_slice()); + } + + let flatenned_extracted_bits_view = self + .engine + .create_lwe_ciphertext_vector_from(all_datas.as_slice(), lwe_size) + .unwrap(); + + let flattened_lut: Vec = vec_lut.iter().flatten().copied().collect(); + let plaintext_lut = self + .engine + .create_plaintext_vector_from(&flattened_lut) + .unwrap(); + let output = self + .circuit_bootstrap_with_bits( + wopbs_key, + &flatenned_extracted_bits_view, + &plaintext_lut, + LweCiphertextCount(vec_lut.len()), + ) + .unwrap(); + + assert_eq!(output.lwe_ciphertext_count().0, vec_lut.len()); + + let output_container = self + .engine + .consume_retrieve_lwe_ciphertext_vector(output) + .unwrap(); + let lwes: Result, Box> = output_container + .chunks_exact(output_container.len() / vec_lut.len()) + .map(|s| { + let lwe = self.engine.create_lwe_ciphertext_from(s.to_vec())?; + Ok(lwe) + }) + .collect(); + + let lwes = lwes.unwrap(); + + assert_eq!(lwes.len(), vec_lut.len()); + lwes + } +} diff --git a/tfhe/src/shortint/keycache.rs b/tfhe/src/shortint/keycache.rs new file mode 100644 index 000000000..65c3e28eb --- /dev/null +++ b/tfhe/src/shortint/keycache.rs @@ -0,0 +1,478 @@ +use crate::shortint::parameters::parameters_wopbs::*; +use crate::shortint::parameters::parameters_wopbs_message_carry::*; +use crate::shortint::parameters::parameters_wopbs_prime_moduli::*; +use crate::shortint::parameters::*; +use crate::shortint::wopbs::WopbsKey; +use crate::shortint::{ClientKey, ServerKey}; +use lazy_static::*; +use serde::{Deserialize, Serialize}; + +pub use utils::{ + FileStorage, KeyCache as TKeyCache, NamedParam, PersistentStorage, + SharedKey as GenericSharedKey, +}; + +#[macro_use] +pub mod utils { + use fs2::FileExt; + use once_cell::sync::OnceCell; + use serde::de::DeserializeOwned; + use serde::Serialize; + use std::fs::File; + use std::io::{BufReader, BufWriter}; + use std::ops::Deref; + use std::path::PathBuf; + use std::sync::{Arc, RwLock}; + + pub trait PersistentStorage { + fn load(&self, param: P) -> Option; + fn store(&self, param: P, key: &K); + } + + pub trait NamedParam { + fn name(&self) -> String; + } + + #[macro_export] + macro_rules! named_params_impl( + ( $thing:ident == ( $($const_param:ident),* $(,)? )) => { + named_params_impl!({ *$thing } == ( $($const_param),* )) + }; + + ( { $thing:expr } == ( $($const_param:ident),* $(,)? )) => { + $( + if $thing == $const_param { + return stringify!($const_param).to_string(); + } + )* + + panic!("Unnamed parameters"); + } + ); + + pub struct FileStorage { + prefix: String, + } + + impl FileStorage { + pub fn new(prefix: String) -> Self { + Self { prefix } + } + } + + impl PersistentStorage for FileStorage + where + P: NamedParam + DeserializeOwned + Serialize + PartialEq, + K: DeserializeOwned + Serialize, + { + fn load(&self, param: P) -> Option { + let mut path_buf = PathBuf::with_capacity(256); + path_buf.push(&self.prefix); + path_buf.push(param.name()); + path_buf.set_extension("bin"); + + if path_buf.exists() { + let file = File::open(&path_buf).unwrap(); + // Lock for reading + file.lock_shared().unwrap(); + let file_reader = BufReader::new(file); + bincode::deserialize_from::<_, (P, K)>(file_reader) + .ok() + .and_then(|(p, k)| if p == param { Some(k) } else { None }) + } else { + None + } + } + + fn store(&self, param: P, key: &K) { + let mut path_buf = PathBuf::with_capacity(256); + path_buf.push(&self.prefix); + std::fs::create_dir_all(&path_buf).unwrap(); + path_buf.push(param.name()); + path_buf.set_extension("bin"); + + let file = File::create(&path_buf).unwrap(); + // Lock for writing + file.lock_exclusive().unwrap(); + + let file_writer = BufWriter::new(file); + bincode::serialize_into(file_writer, &(param, key)).unwrap(); + } + } + + pub struct SharedKey { + inner: Arc>, + } + + impl Clone for SharedKey { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } + } + + impl Deref for SharedKey { + type Target = K; + + fn deref(&self) -> &Self::Target { + self.inner.get().unwrap() + } + } + + pub struct KeyCache { + // Where the keys will be stored persistently + // So they are not generated between each run + persistent_storage: S, + // Temporary memory storage to avoid querying the persistent storage each time + // the outer Arc makes it so that we don't clone the OnceCell contents when initializing it + memory_storage: RwLock)>>, + } + + impl KeyCache { + pub fn new(storage: S) -> Self { + Self { + persistent_storage: storage, + memory_storage: RwLock::new(vec![]), + } + } + } + + impl KeyCache + where + P: Copy + PartialEq + NamedParam, + S: PersistentStorage, + K: From

+ Clone, + { + pub fn get(&self, param: P) -> SharedKey { + self.with_key(param, |k| k.clone()) + } + + pub fn with_key(&self, param: P, f: F) -> R + where + F: FnOnce(&SharedKey) -> R, + { + let load_from_persistent_storage = || { + // we check if we can load the key from persistent storage + let persistent_storage = &self.persistent_storage; + let maybe_key = persistent_storage.load(param); + match maybe_key { + Some(key) => key, + None => { + let key = K::from(param); + persistent_storage.store(param, &key); + key + } + } + }; + + let try_load_from_memory_and_init = || { + // we only hold a read lock for a short duration to find the key + let memory_storage = self.memory_storage.read().unwrap(); + let maybe_shared_cell = memory_storage + .iter() + .find(|(p, _)| *p == param) + .map(|param_key| param_key.1.clone()); + drop(memory_storage); + + if let Some(shared_cell) = maybe_shared_cell { + shared_cell.inner.get_or_init(load_from_persistent_storage); + Ok(shared_cell) + } else { + Err(()) + } + }; + + match try_load_from_memory_and_init() { + Ok(result) => f(&result), + Err(()) => { + { + // we only hold a write lock for a short duration to push the lazily + // evaluated key without actually evaluating the key + let mut memory_storage = self.memory_storage.write().unwrap(); + if !memory_storage.iter().any(|(p, _)| *p == param) { + memory_storage.push(( + param, + SharedKey { + inner: Arc::new(OnceCell::new()), + }, + )); + } + } + f(&try_load_from_memory_and_init().ok().unwrap()) + } + } + } + } +} + +impl NamedParam for Parameters { + fn name(&self) -> String { + named_params_impl!( + self == ( + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_1_CARRY_2, + PARAM_MESSAGE_1_CARRY_3, + PARAM_MESSAGE_1_CARRY_4, + PARAM_MESSAGE_1_CARRY_5, + PARAM_MESSAGE_1_CARRY_6, + PARAM_MESSAGE_1_CARRY_7, + PARAM_MESSAGE_2_CARRY_1, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_2_CARRY_3, + PARAM_MESSAGE_2_CARRY_4, + PARAM_MESSAGE_2_CARRY_5, + PARAM_MESSAGE_2_CARRY_6, + PARAM_MESSAGE_3_CARRY_1, + PARAM_MESSAGE_3_CARRY_2, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_3_CARRY_4, + PARAM_MESSAGE_3_CARRY_5, + PARAM_MESSAGE_4_CARRY_1, + PARAM_MESSAGE_4_CARRY_2, + PARAM_MESSAGE_4_CARRY_3, + PARAM_MESSAGE_4_CARRY_4, + PARAM_MESSAGE_5_CARRY_1, + PARAM_MESSAGE_5_CARRY_2, + PARAM_MESSAGE_5_CARRY_3, + PARAM_MESSAGE_6_CARRY_1, + PARAM_MESSAGE_6_CARRY_2, + PARAM_MESSAGE_7_CARRY_1, + WOPBS_PARAM_MESSAGE_1_NORM2_2, + WOPBS_PARAM_MESSAGE_1_NORM2_4, + WOPBS_PARAM_MESSAGE_1_NORM2_6, + WOPBS_PARAM_MESSAGE_1_NORM2_8, + WOPBS_PARAM_MESSAGE_2_NORM2_2, + WOPBS_PARAM_MESSAGE_2_NORM2_4, + WOPBS_PARAM_MESSAGE_2_NORM2_6, + WOPBS_PARAM_MESSAGE_2_NORM2_8, + WOPBS_PARAM_MESSAGE_3_NORM2_2, + WOPBS_PARAM_MESSAGE_3_NORM2_4, + WOPBS_PARAM_MESSAGE_3_NORM2_6, + WOPBS_PARAM_MESSAGE_3_NORM2_8, + WOPBS_PARAM_MESSAGE_4_NORM2_2, + WOPBS_PARAM_MESSAGE_4_NORM2_4, + WOPBS_PARAM_MESSAGE_4_NORM2_6, + WOPBS_PARAM_MESSAGE_4_NORM2_8, + WOPBS_PARAM_MESSAGE_5_NORM2_2, + WOPBS_PARAM_MESSAGE_5_NORM2_4, + WOPBS_PARAM_MESSAGE_5_NORM2_6, + WOPBS_PARAM_MESSAGE_5_NORM2_8, + WOPBS_PARAM_MESSAGE_6_NORM2_2, + WOPBS_PARAM_MESSAGE_6_NORM2_4, + WOPBS_PARAM_MESSAGE_6_NORM2_6, + WOPBS_PARAM_MESSAGE_6_NORM2_8, + WOPBS_PARAM_MESSAGE_7_NORM2_2, + WOPBS_PARAM_MESSAGE_7_NORM2_4, + WOPBS_PARAM_MESSAGE_7_NORM2_6, + WOPBS_PARAM_MESSAGE_7_NORM2_8, + WOPBS_PARAM_MESSAGE_8_NORM2_2, + WOPBS_PARAM_MESSAGE_8_NORM2_4, + //WOPBS_PARAM_MESSAGE_8_NORM2_5, + WOPBS_PARAM_MESSAGE_8_NORM2_6, + WOPBS_PARAM_MESSAGE_1_CARRY_0, + WOPBS_PARAM_MESSAGE_1_CARRY_1, + WOPBS_PARAM_MESSAGE_1_CARRY_2, + WOPBS_PARAM_MESSAGE_1_CARRY_3, + WOPBS_PARAM_MESSAGE_1_CARRY_4, + WOPBS_PARAM_MESSAGE_1_CARRY_5, + WOPBS_PARAM_MESSAGE_1_CARRY_6, + WOPBS_PARAM_MESSAGE_1_CARRY_7, + WOPBS_PARAM_MESSAGE_2_CARRY_0, + WOPBS_PARAM_MESSAGE_2_CARRY_1, + WOPBS_PARAM_MESSAGE_2_CARRY_2, + WOPBS_PARAM_MESSAGE_2_CARRY_3, + WOPBS_PARAM_MESSAGE_2_CARRY_4, + WOPBS_PARAM_MESSAGE_2_CARRY_5, + WOPBS_PARAM_MESSAGE_2_CARRY_6, + WOPBS_PARAM_MESSAGE_3_CARRY_0, + WOPBS_PARAM_MESSAGE_3_CARRY_1, + WOPBS_PARAM_MESSAGE_3_CARRY_2, + WOPBS_PARAM_MESSAGE_3_CARRY_3, + WOPBS_PARAM_MESSAGE_3_CARRY_4, + WOPBS_PARAM_MESSAGE_3_CARRY_5, + WOPBS_PARAM_MESSAGE_4_CARRY_0, + WOPBS_PARAM_MESSAGE_4_CARRY_1, + WOPBS_PARAM_MESSAGE_4_CARRY_2, + WOPBS_PARAM_MESSAGE_4_CARRY_3, + WOPBS_PARAM_MESSAGE_4_CARRY_4, + WOPBS_PARAM_MESSAGE_5_CARRY_0, + WOPBS_PARAM_MESSAGE_5_CARRY_1, + WOPBS_PARAM_MESSAGE_5_CARRY_2, + WOPBS_PARAM_MESSAGE_5_CARRY_3, + WOPBS_PARAM_MESSAGE_6_CARRY_0, + WOPBS_PARAM_MESSAGE_6_CARRY_1, + WOPBS_PARAM_MESSAGE_6_CARRY_2, + WOPBS_PARAM_MESSAGE_7_CARRY_0, + WOPBS_PARAM_MESSAGE_7_CARRY_1, + WOPBS_PARAM_MESSAGE_8_CARRY_0, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_7, + PARAM_4_BITS_5_BLOCKS, + ) + ); + } +} + +impl From for (ClientKey, ServerKey) { + fn from(param: Parameters) -> Self { + let cks = ClientKey::new(param); + let sks = ServerKey::new(&cks); + (cks, sks) + } +} + +pub struct Keycache { + inner: TKeyCache, +} + +impl Default for Keycache { + fn default() -> Self { + Self { + inner: TKeyCache::new(FileStorage::new( + "../keys/shortint/client_server".to_string(), + )), + } + } +} + +pub struct SharedKey { + inner: GenericSharedKey<(ClientKey, ServerKey)>, +} + +pub struct SharedWopbsKey { + inner: GenericSharedKey<(ClientKey, ServerKey)>, + wopbs: GenericSharedKey, +} + +impl SharedKey { + pub fn client_key(&self) -> &ClientKey { + &self.inner.0 + } + pub fn server_key(&self) -> &ServerKey { + &self.inner.1 + } +} + +impl SharedWopbsKey { + pub fn client_key(&self) -> &ClientKey { + &self.inner.0 + } + pub fn server_key(&self) -> &ServerKey { + &self.inner.1 + } + pub fn wopbs_key(&self) -> &WopbsKey { + &self.wopbs + } +} + +impl Keycache { + pub fn get_from_param(&self, param: Parameters) -> SharedKey { + SharedKey { + inner: self.inner.get(param), + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct WopbsParamPair(pub Parameters, pub Parameters); + +impl From<(Parameters, Parameters)> for WopbsParamPair { + fn from(tuple: (Parameters, Parameters)) -> Self { + Self(tuple.0, tuple.1) + } +} + +impl From for WopbsKey { + fn from(params: WopbsParamPair) -> Self { + // use with_key to avoid doing a temporary cloning + KEY_CACHE.inner.with_key(params.0, |keys| { + WopbsKey::new_wopbs_key(&keys.0, &keys.1, ¶ms.1) + }) + } +} + +impl NamedParam for WopbsParamPair { + fn name(&self) -> String { + self.1.name() + } +} + +/// The KeyCache struct for shortint. +/// +/// You should not create an instance yourself, +/// but rather use the global variable defined: [KEY_CACHE_WOPBS] +pub struct KeycacheWopbsV0 { + inner: TKeyCache, +} + +impl Default for KeycacheWopbsV0 { + fn default() -> Self { + Self { + inner: TKeyCache::new(FileStorage::new("../keys/shortint/wopbs_v0".to_string())), + } + } +} + +impl KeycacheWopbsV0 { + pub fn get_from_param>(&self, params: T) -> SharedWopbsKey { + let params = params.into(); + let key = KEY_CACHE.get_from_param(params.0); + let wk = self.inner.get(params); + SharedWopbsKey { + inner: key.inner, + wopbs: wk, + } + } +} + +lazy_static! { + pub static ref KEY_CACHE: Keycache = Default::default(); + pub static ref KEY_CACHE_WOPBS: KeycacheWopbsV0 = Default::default(); +} diff --git a/tfhe/src/shortint/mod.rs b/tfhe/src/shortint/mod.rs new file mode 100755 index 000000000..a6241f795 --- /dev/null +++ b/tfhe/src/shortint/mod.rs @@ -0,0 +1,87 @@ +#![allow(clippy::excessive_precision)] +//! Welcome to the tfhe.rs `shortint` module documentation! +//! +//! # Description +//! +//! This library makes it possible to execute modular operations over encrypted short integer. +//! +//! It allows to execute an integer circuit on an untrusted server because both circuit inputs and +//! outputs are kept private. +//! +//! Data are encrypted on the client side, before being sent to the server. +//! On the server side every computation is performed on ciphertexts. +//! +//! The server however, has to know the integer circuit to be evaluated. +//! At the end of the computation, the server returns the encryption of the result to the user. +//! +//! # Keys +//! +//! This crates exposes two type of keys: +//! * The [`ClientKey`](crate::shortint::client_key::ClientKey) is used to encrypt and decrypt and +//! has to be kept secret; +//! * The [`ServerKey`](crate::shortint::server_key::ServerKey) is used to perform homomorphic +//! operations on the server side and it is meant to be published (the client sends it to the +//! server). +//! +//! +//! # Quick Example +//! +//! The following piece of code shows how to generate keys and run a small integer circuit +//! homomorphically. +//! +//! ```rust +//! use tfhe::shortint::{gen_keys, Parameters}; +//! +//! // We generate a set of client/server keys, using the default parameters: +//! let (mut client_key, mut server_key) = gen_keys(Parameters::default()); +//! +//! let msg1 = 1; +//! let msg2 = 0; +//! +//! // We use the client key to encrypt two messages: +//! let ct_1 = client_key.encrypt(msg1); +//! let ct_2 = client_key.encrypt(msg2); +//! +//! // We use the server public key to execute an integer circuit: +//! let ct_3 = server_key.unchecked_add(&ct_1, &ct_2); +//! +//! // We use the client key to decrypt the output of the circuit: +//! let output = client_key.decrypt(&ct_3); +//! assert_eq!(output, 1); +//! ``` +pub mod ciphertext; +pub mod client_key; +pub mod engine; +#[cfg(any(test, doctest, feature = "internal-keycache"))] +pub mod keycache; +pub mod parameters; +pub mod prelude; +pub mod public_key; +pub mod server_key; +#[cfg(not(feature = "__wasm_api"))] +pub mod wopbs; + +pub use ciphertext::Ciphertext; +pub use client_key::ClientKey; +pub use parameters::Parameters; +pub use public_key::PublicKey; +pub use server_key::{CheckError, ServerKey}; + +/// Generate a couple of client and server keys. +/// +/// # Example +/// +/// Generating a pair of [ClientKey] and [ServerKey] using the default parameters. +/// +/// ```rust +/// use tfhe::shortint::gen_keys; +/// +/// // generate the client key and the server key: +/// let (cks, sks) = gen_keys(Default::default()); +/// ``` +pub fn gen_keys(parameters_set: Parameters) -> (ClientKey, ServerKey) { + let cks = ClientKey::new(parameters_set); + let sks = ServerKey::new(&cks); + + (cks, sks) +} diff --git a/tfhe/src/shortint/parameters/mod.rs b/tfhe/src/shortint/parameters/mod.rs new file mode 100644 index 000000000..4ad76101d --- /dev/null +++ b/tfhe/src/shortint/parameters/mod.rs @@ -0,0 +1,849 @@ +//! Module with the definition of parameters for short-integers. +//! +//! This module provides the structure containing the cryptographic parameters required for the +//! homomorphic evaluation of integer circuits as well as a list of secure cryptographic parameter +//! sets. + +pub use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweDimension, + LweDimension, PolynomialSize, StandardDev, +}; +use serde::{Deserialize, Serialize}; + +pub mod parameters_wopbs; +pub mod parameters_wopbs_message_carry; +pub(crate) mod parameters_wopbs_prime_moduli; + +/// The number of bits on which the message will be encoded. +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct MessageModulus(pub usize); + +/// The number of bits on which the carry will be encoded. +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct CarryModulus(pub usize); + +/// A structure defining the set of cryptographic parameters for homomorphic integer circuit +/// evaluation. +#[derive(Serialize, Copy, Clone, Deserialize, Debug, PartialEq)] +pub struct Parameters { + pub lwe_dimension: LweDimension, + pub glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub lwe_modular_std_dev: StandardDev, + pub glwe_modular_std_dev: StandardDev, + pub pbs_base_log: DecompositionBaseLog, + pub pbs_level: DecompositionLevelCount, + pub ks_base_log: DecompositionBaseLog, + pub ks_level: DecompositionLevelCount, + pub pfks_level: DecompositionLevelCount, + pub pfks_base_log: DecompositionBaseLog, + pub pfks_modular_std_dev: StandardDev, + pub cbs_level: DecompositionLevelCount, + pub cbs_base_log: DecompositionBaseLog, + pub message_modulus: MessageModulus, + pub carry_modulus: CarryModulus, +} + +impl Parameters { + /// Constructs a new set of parameters for integer circuit evaluation. + /// + /// # Safety + /// + /// This function is unsafe, as failing to fix the parameters properly would yield incorrect + /// and unsecure computation. Unless you are a cryptographer who really knows the impact of each + /// of those parameters, you __must__ stick with the provided parameters. + #[allow(clippy::too_many_arguments)] + pub unsafe fn new( + lwe_dimension: LweDimension, + glwe_dimension: GlweDimension, + polynomial_size: PolynomialSize, + lwe_modular_std_dev: StandardDev, + glwe_modular_std_dev: StandardDev, + pbs_base_log: DecompositionBaseLog, + pbs_level: DecompositionLevelCount, + ks_base_log: DecompositionBaseLog, + ks_level: DecompositionLevelCount, + pfks_modular_std_dev: StandardDev, + pfks_base_log: DecompositionBaseLog, + pfks_level: DecompositionLevelCount, + cbs_level: DecompositionLevelCount, + cbs_base_log: DecompositionBaseLog, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + ) -> Parameters { + Parameters { + lwe_dimension, + glwe_dimension, + polynomial_size, + lwe_modular_std_dev, + glwe_modular_std_dev, + pbs_base_log, + pbs_level, + ks_level, + ks_base_log, + pfks_level, + pfks_base_log, + pfks_modular_std_dev, + cbs_level, + cbs_base_log, + message_modulus, + carry_modulus, + } + } +} + +impl Default for Parameters { + fn default() -> Self { + DEFAULT_PARAMETERS + } +} + +/// Vector containing all parameter sets +pub const ALL_PARAMETER_VEC: [Parameters; 28] = WITH_CARRY_PARAMETERS_VEC; + +/// Vector containing all parameter sets where the carry space is strictly greater than one +pub const WITH_CARRY_PARAMETERS_VEC: [Parameters; 28] = [ + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_1_CARRY_2, + PARAM_MESSAGE_1_CARRY_3, + PARAM_MESSAGE_1_CARRY_4, + PARAM_MESSAGE_1_CARRY_5, + PARAM_MESSAGE_1_CARRY_6, + PARAM_MESSAGE_1_CARRY_7, + PARAM_MESSAGE_2_CARRY_1, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_2_CARRY_3, + PARAM_MESSAGE_2_CARRY_4, + PARAM_MESSAGE_2_CARRY_5, + PARAM_MESSAGE_2_CARRY_6, + PARAM_MESSAGE_3_CARRY_1, + PARAM_MESSAGE_3_CARRY_2, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_3_CARRY_4, + PARAM_MESSAGE_3_CARRY_5, + PARAM_MESSAGE_4_CARRY_1, + PARAM_MESSAGE_4_CARRY_2, + PARAM_MESSAGE_4_CARRY_3, + PARAM_MESSAGE_4_CARRY_4, + PARAM_MESSAGE_5_CARRY_1, + PARAM_MESSAGE_5_CARRY_2, + PARAM_MESSAGE_5_CARRY_3, + PARAM_MESSAGE_6_CARRY_1, + PARAM_MESSAGE_6_CARRY_2, + PARAM_MESSAGE_7_CARRY_1, +]; + +/// Vector containing all parameter sets where the carry space is strictly greater than one +pub const BIVARIATE_PBS_COMPLIANT_PARAMETER_SET_VEC: [Parameters; 16] = [ + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_1_CARRY_2, + PARAM_MESSAGE_1_CARRY_3, + PARAM_MESSAGE_1_CARRY_4, + PARAM_MESSAGE_1_CARRY_5, + PARAM_MESSAGE_1_CARRY_6, + PARAM_MESSAGE_1_CARRY_7, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_2_CARRY_3, + PARAM_MESSAGE_2_CARRY_4, + PARAM_MESSAGE_2_CARRY_5, + PARAM_MESSAGE_2_CARRY_6, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_3_CARRY_4, + PARAM_MESSAGE_3_CARRY_5, + PARAM_MESSAGE_4_CARRY_4, +]; + +/// Default parameter set +pub const DEFAULT_PARAMETERS: Parameters = PARAM_MESSAGE_2_CARRY_2; + +/// Nomenclature: PARAM_MESSAGE_X_CARRY_Y: the message (respectively carry) modulus is +/// encoded over X (reps. Y) bits, i.e., message_modulus = 2^{X} (resp. carry_modulus = 2^{Y}). +/// All parameter sets guarantee 128-bits of security and an error probability smaller than +/// 2^{-40} for a PBS. +pub const PARAM_MESSAGE_1_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(678), + glwe_dimension: GlweDimension(5), + polynomial_size: PolynomialSize(256), + lwe_modular_std_dev: StandardDev(0.000022810107419132102), + glwe_modular_std_dev: StandardDev(0.00000000037411618952047216), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(2), + ks_base_log: DecompositionBaseLog(5), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000037411618952047216), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(684), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00002043784477291318), + glwe_modular_std_dev: StandardDev(0.0000000000034525330484572114), + pbs_base_log: DecompositionBaseLog(18), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(3), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(18), + pfks_modular_std_dev: StandardDev(0.0000000000034525330484572114), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_2_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(656), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.000034119201269311964), + glwe_modular_std_dev: StandardDev(0.00000004053919869756513), + pbs_base_log: DecompositionBaseLog(8), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000037411618952047216), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(742), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000007069849454709433), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(3), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(18), + pfks_modular_std_dev: StandardDev(0.0000000000034525330484572114), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_2_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(742), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000007069849454709433), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(3), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(18), + pfks_modular_std_dev: StandardDev(0.0000000000034525330484572114), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_3_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(742), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000007069849454709433), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(3), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(18), + pfks_modular_std_dev: StandardDev(0.0000000000034525330484572114), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(745), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.000006692125069956277), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(8), +}; +pub const PARAM_MESSAGE_2_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(742), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.000007069849454709433), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_3_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(742), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.000007069849454709433), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_4_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(742), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.000007069849454709433), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(807), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(4096), + lwe_modular_std_dev: StandardDev(0.0000021515145918907506), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(16), +}; +pub const PARAM_MESSAGE_2_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(856), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(4096), + lwe_modular_std_dev: StandardDev(0.0000008775214009854235), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(22), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(8), +}; +pub const PARAM_MESSAGE_3_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(812), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(4096), + lwe_modular_std_dev: StandardDev(0.0000019633637461248447), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(22), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_4_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(808), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(4096), + lwe_modular_std_dev: StandardDev(0.0000021124945159091033), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(22), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_5_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(807), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(4096), + lwe_modular_std_dev: StandardDev(0.0000021515145918907506), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(22), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(864), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(8192), + lwe_modular_std_dev: StandardDev(0.000000757998020150446), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(32), +}; +pub const PARAM_MESSAGE_2_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(864), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(8192), + lwe_modular_std_dev: StandardDev(0.000000757998020150446), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(16), +}; +pub const PARAM_MESSAGE_3_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(864), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(8192), + lwe_modular_std_dev: StandardDev(0.000000757998020150446), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(8), +}; +pub const PARAM_MESSAGE_4_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(864), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(8192), + lwe_modular_std_dev: StandardDev(0.000000757998020150446), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_5_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(875), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(8192), + lwe_modular_std_dev: StandardDev(0.0000006197725091905067), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(22), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_6_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(915), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(8192), + lwe_modular_std_dev: StandardDev(0.00000029804653749339636), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(22), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_6: Parameters = Parameters { + lwe_dimension: LweDimension(930), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000022649232786295453), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(64), +}; +pub const PARAM_MESSAGE_2_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(934), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000021050318566634375), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(32), +}; +pub const PARAM_MESSAGE_3_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(930), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000022649232786295453), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(16), +}; +pub const PARAM_MESSAGE_4_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(930), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000022649232786295453), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(8), +}; +pub const PARAM_MESSAGE_5_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(930), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000022649232786295453), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_6_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(930), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000022649232786295453), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_7_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(930), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(16384), + lwe_modular_std_dev: StandardDev(0.00000022649232786295453), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const PARAM_MESSAGE_1_CARRY_7: Parameters = Parameters { + lwe_dimension: LweDimension(1004), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.00000005845871624688967), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(128), +}; +pub const PARAM_MESSAGE_2_CARRY_6: Parameters = Parameters { + lwe_dimension: LweDimension(987), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.00000007979529246348835), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(64), +}; +pub const PARAM_MESSAGE_3_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(985), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.00000008277032914509569), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(32), +}; +pub const PARAM_MESSAGE_4_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(996), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.00000006767666038309478), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(3), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(16), +}; +pub const PARAM_MESSAGE_5_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(1020), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.000000043618425315728666), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(8), +}; +pub const PARAM_MESSAGE_6_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(1018), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.000000045244666805696514), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(4), +}; +pub const PARAM_MESSAGE_7_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(1017), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.0000000460803851108693), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(2), +}; +pub const PARAM_MESSAGE_8_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(1017), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(32768), + lwe_modular_std_dev: StandardDev(0.0000000460803851108693), + glwe_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(4), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.0000000000000000002168404344971009), + cbs_level: DecompositionLevelCount(0), + cbs_base_log: DecompositionBaseLog(0), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; + +/// Return a parameter set from a message and carry moduli. +/// +/// # Example +/// +/// ```rust +/// use tfhe::shortint::parameters::{ +/// get_parameters_from_message_and_carry, PARAM_MESSAGE_3_CARRY_1, +/// }; +/// let message_space = 7; +/// let carry_space = 2; +/// let param = get_parameters_from_message_and_carry(message_space, carry_space); +/// assert_eq!(param, PARAM_MESSAGE_3_CARRY_1); +/// ``` +pub fn get_parameters_from_message_and_carry(msg_space: usize, carry_space: usize) -> Parameters { + let mut out = Parameters::default(); + let mut flag: bool = false; + let mut rescaled_message_space = f64::ceil(f64::log2(msg_space as f64)) as usize; + rescaled_message_space = 1 << rescaled_message_space; + let mut rescaled_carry_space = f64::ceil(f64::log2(carry_space as f64)) as usize; + rescaled_carry_space = 1 << rescaled_carry_space; + + for param in ALL_PARAMETER_VEC { + if param.message_modulus.0 == rescaled_message_space + && param.carry_modulus.0 == rescaled_carry_space + { + out = param; + flag = true; + break; + } + } + if !flag { + println!( + "### WARNING: NO PARAMETERS FOUND for msg_space = {} and carry_space = {} ### ", + rescaled_message_space, rescaled_carry_space + ); + } + out +} diff --git a/tfhe/src/shortint/parameters/parameters_wopbs.rs b/tfhe/src/shortint/parameters/parameters_wopbs.rs new file mode 100644 index 000000000..664fcfe9d --- /dev/null +++ b/tfhe/src/shortint/parameters/parameters_wopbs.rs @@ -0,0 +1,618 @@ +pub use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweDimension, + LweDimension, PolynomialSize, StandardDev, +}; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use crate::shortint::Parameters; + +pub const ALL_PARAMETER_VEC_WOPBS_NORM2: [Parameters; 31] = [ + WOPBS_PARAM_MESSAGE_1_NORM2_2, + WOPBS_PARAM_MESSAGE_1_NORM2_4, + WOPBS_PARAM_MESSAGE_1_NORM2_6, + WOPBS_PARAM_MESSAGE_1_NORM2_8, + WOPBS_PARAM_MESSAGE_2_NORM2_2, + WOPBS_PARAM_MESSAGE_2_NORM2_4, + WOPBS_PARAM_MESSAGE_2_NORM2_6, + WOPBS_PARAM_MESSAGE_2_NORM2_8, + WOPBS_PARAM_MESSAGE_3_NORM2_2, + WOPBS_PARAM_MESSAGE_3_NORM2_4, + WOPBS_PARAM_MESSAGE_3_NORM2_6, + WOPBS_PARAM_MESSAGE_3_NORM2_8, + WOPBS_PARAM_MESSAGE_4_NORM2_2, + WOPBS_PARAM_MESSAGE_4_NORM2_4, + WOPBS_PARAM_MESSAGE_4_NORM2_6, + WOPBS_PARAM_MESSAGE_4_NORM2_8, + WOPBS_PARAM_MESSAGE_5_NORM2_2, + WOPBS_PARAM_MESSAGE_5_NORM2_4, + WOPBS_PARAM_MESSAGE_5_NORM2_6, + WOPBS_PARAM_MESSAGE_5_NORM2_8, + WOPBS_PARAM_MESSAGE_6_NORM2_2, + WOPBS_PARAM_MESSAGE_6_NORM2_4, + WOPBS_PARAM_MESSAGE_6_NORM2_6, + WOPBS_PARAM_MESSAGE_6_NORM2_8, + WOPBS_PARAM_MESSAGE_7_NORM2_2, + WOPBS_PARAM_MESSAGE_7_NORM2_4, + WOPBS_PARAM_MESSAGE_7_NORM2_6, + WOPBS_PARAM_MESSAGE_7_NORM2_8, + WOPBS_PARAM_MESSAGE_8_NORM2_2, + WOPBS_PARAM_MESSAGE_8_NORM2_4, + WOPBS_PARAM_MESSAGE_8_NORM2_6, +]; + +pub const WOPBS_PARAM_MESSAGE_1_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(512), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.0003472352121441949901), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(24), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(24), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(2), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_1_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(502), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00041688866384199045524), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_1_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(499), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.0004403915565001254653), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_1_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(500), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00043241360644590172285), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_2_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(488), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.0005384866525630595423), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_2_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(488), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.0005384866525630595423), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_2_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_2_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_3_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(488), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.0005384866525630595423), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(3), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_3_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_3_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(494), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00048254425233109359873), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_3_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(494), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00048254425233109359873), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_4_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(486), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00055853990682276860028), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_4_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_4_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_4_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_5_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_5_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_5_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_5_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(5), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_6_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_6_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_6_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_6_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_7_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_7_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_7_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_7_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_8_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_8_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_8_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; + +pub const PARAM_4_BITS_5_BLOCKS: Parameters = Parameters { + lwe_dimension: LweDimension(667), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.0000000004168323308734758), + glwe_modular_std_dev: StandardDev(0.00000000000000000000000000000004905643852600863), + pbs_base_log: DecompositionBaseLog(7), + pbs_level: DecompositionLevelCount(6), + ks_base_log: DecompositionBaseLog(1), + ks_level: DecompositionLevelCount(14), + pfks_level: DecompositionLevelCount(6), + pfks_base_log: DecompositionBaseLog(7), + pfks_modular_std_dev: StandardDev(0.00000000000000000000000000000004905643852600863), + cbs_level: DecompositionLevelCount(7), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; diff --git a/tfhe/src/shortint/parameters/parameters_wopbs_message_carry.rs b/tfhe/src/shortint/parameters/parameters_wopbs_message_carry.rs new file mode 100644 index 000000000..80dc4840e --- /dev/null +++ b/tfhe/src/shortint/parameters/parameters_wopbs_message_carry.rs @@ -0,0 +1,967 @@ +pub use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweDimension, + LweDimension, PolynomialSize, StandardDev, +}; +use crate::shortint::parameters::parameters_wopbs::*; +use crate::shortint::parameters::parameters_wopbs_prime_moduli::*; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use crate::shortint::Parameters; + +pub const ALL_PARAMETER_VEC_WOPBS: [Parameters; 116] = [ + WOPBS_PARAM_MESSAGE_1_CARRY_0, + WOPBS_PARAM_MESSAGE_1_CARRY_1, + WOPBS_PARAM_MESSAGE_1_CARRY_2, + WOPBS_PARAM_MESSAGE_1_CARRY_3, + WOPBS_PARAM_MESSAGE_1_CARRY_4, + WOPBS_PARAM_MESSAGE_1_CARRY_5, + WOPBS_PARAM_MESSAGE_1_CARRY_6, + WOPBS_PARAM_MESSAGE_1_CARRY_7, + WOPBS_PARAM_MESSAGE_2_CARRY_0, + WOPBS_PARAM_MESSAGE_2_CARRY_1, + WOPBS_PARAM_MESSAGE_2_CARRY_2, + WOPBS_PARAM_MESSAGE_2_CARRY_3, + WOPBS_PARAM_MESSAGE_2_CARRY_4, + WOPBS_PARAM_MESSAGE_2_CARRY_5, + WOPBS_PARAM_MESSAGE_2_CARRY_6, + WOPBS_PARAM_MESSAGE_3_CARRY_0, + WOPBS_PARAM_MESSAGE_3_CARRY_1, + WOPBS_PARAM_MESSAGE_3_CARRY_2, + WOPBS_PARAM_MESSAGE_3_CARRY_3, + WOPBS_PARAM_MESSAGE_3_CARRY_4, + WOPBS_PARAM_MESSAGE_3_CARRY_5, + WOPBS_PARAM_MESSAGE_4_CARRY_0, + WOPBS_PARAM_MESSAGE_4_CARRY_1, + WOPBS_PARAM_MESSAGE_4_CARRY_2, + WOPBS_PARAM_MESSAGE_4_CARRY_3, + WOPBS_PARAM_MESSAGE_4_CARRY_4, + WOPBS_PARAM_MESSAGE_5_CARRY_0, + WOPBS_PARAM_MESSAGE_5_CARRY_1, + WOPBS_PARAM_MESSAGE_5_CARRY_2, + WOPBS_PARAM_MESSAGE_5_CARRY_3, + WOPBS_PARAM_MESSAGE_6_CARRY_0, + WOPBS_PARAM_MESSAGE_6_CARRY_1, + WOPBS_PARAM_MESSAGE_6_CARRY_2, + WOPBS_PARAM_MESSAGE_7_CARRY_0, + WOPBS_PARAM_MESSAGE_7_CARRY_1, + WOPBS_PARAM_MESSAGE_8_CARRY_0, + WOPBS_PARAM_MESSAGE_1_NORM2_2, + WOPBS_PARAM_MESSAGE_1_NORM2_4, + WOPBS_PARAM_MESSAGE_1_NORM2_6, + WOPBS_PARAM_MESSAGE_1_NORM2_8, + WOPBS_PARAM_MESSAGE_2_NORM2_2, + WOPBS_PARAM_MESSAGE_2_NORM2_4, + WOPBS_PARAM_MESSAGE_2_NORM2_6, + WOPBS_PARAM_MESSAGE_2_NORM2_8, + WOPBS_PARAM_MESSAGE_3_NORM2_2, + WOPBS_PARAM_MESSAGE_3_NORM2_4, + WOPBS_PARAM_MESSAGE_3_NORM2_6, + WOPBS_PARAM_MESSAGE_3_NORM2_8, + WOPBS_PARAM_MESSAGE_4_NORM2_2, + WOPBS_PARAM_MESSAGE_4_NORM2_4, + WOPBS_PARAM_MESSAGE_4_NORM2_6, + WOPBS_PARAM_MESSAGE_4_NORM2_8, + WOPBS_PARAM_MESSAGE_5_NORM2_2, + WOPBS_PARAM_MESSAGE_5_NORM2_4, + WOPBS_PARAM_MESSAGE_5_NORM2_6, + WOPBS_PARAM_MESSAGE_5_NORM2_8, + WOPBS_PARAM_MESSAGE_6_NORM2_2, + WOPBS_PARAM_MESSAGE_6_NORM2_4, + WOPBS_PARAM_MESSAGE_6_NORM2_6, + WOPBS_PARAM_MESSAGE_6_NORM2_8, + WOPBS_PARAM_MESSAGE_7_NORM2_2, + WOPBS_PARAM_MESSAGE_7_NORM2_4, + WOPBS_PARAM_MESSAGE_7_NORM2_6, + WOPBS_PARAM_MESSAGE_7_NORM2_8, + WOPBS_PARAM_MESSAGE_8_NORM2_2, + WOPBS_PARAM_MESSAGE_8_NORM2_4, + WOPBS_PARAM_MESSAGE_8_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_8, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_2, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_3, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_4, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_5, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_6, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_7, + WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_8, +]; + +pub const WOPBS_PARAM_MESSAGE_1_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(32), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_6: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(64), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_7: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(8), + pbs_level: DecompositionLevelCount(5), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(5), + pfks_base_log: DecompositionBaseLog(8), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(128), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(16), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_8: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(7), + pbs_level: DecompositionLevelCount(6), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(6), + pfks_base_log: DecompositionBaseLog(7), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(7), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(256), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(8), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(498), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00044851669823869648209), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(24), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(24), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(2), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(653), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00003604499526942373), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(2), + ks_base_log: DecompositionBaseLog(5), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(3), + cbs_base_log: DecompositionBaseLog(5), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_1_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(487), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00054842163045222410337), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(2), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(16), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(32), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_6: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(64), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_7: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(8), + pbs_level: DecompositionLevelCount(5), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(5), + pfks_base_log: DecompositionBaseLog(8), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(128), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(8), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(500), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00043241360644590172285), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(23), + pbs_level: DecompositionLevelCount(1), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(1), + pfks_base_log: DecompositionBaseLog(23), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(2), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(769), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.0000043131554647504185), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(2), + ks_base_log: DecompositionBaseLog(6), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(3), + cbs_base_log: DecompositionBaseLog(5), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_2_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(487), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00054842163045222410337), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(16), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(32), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(873), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.0000006428797112843789), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(1), + ks_base_log: DecompositionBaseLog(10), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(3), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(8), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_6: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(64), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(486), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00055853990682276860028), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_3_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(487), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00054842163045222410337), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(2), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_4_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(953), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.0000001486733969411098), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(1), + ks_base_log: DecompositionBaseLog(11), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(16), +}; +pub const WOPBS_PARAM_MESSAGE_4_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(8), +}; +pub const WOPBS_PARAM_MESSAGE_4_CARRY_5: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(32), +}; +pub const WOPBS_PARAM_MESSAGE_4_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_4_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(486), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00055853990682276860028), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_4_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(486), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00055853990682276860028), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_5_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(8), +}; +pub const WOPBS_PARAM_MESSAGE_5_CARRY_4: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(16), +}; +pub const WOPBS_PARAM_MESSAGE_5_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_5_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_5_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(486), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00055853990682276860028), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_6_CARRY_3: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(8), +}; +pub const WOPBS_PARAM_MESSAGE_6_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_6_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_6_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(486), + glwe_dimension: GlweDimension(3), + polynomial_size: PolynomialSize(512), + lwe_modular_std_dev: StandardDev(0.00055853990682276860028), + glwe_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.000000000002573000821792597679153983627), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_7_CARRY_2: Parameters = Parameters { + lwe_dimension: LweDimension(481), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00061200133780220371345), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(9), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(4), +}; +pub const WOPBS_PARAM_MESSAGE_7_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_7_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_8_CARRY_1: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(11), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(11), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(2), +}; +pub const WOPBS_PARAM_MESSAGE_8_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(497), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00045679174732062467505), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(4), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PARAM_MESSAGE_9_CARRY_0: Parameters = Parameters { + lwe_dimension: LweDimension(493), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00049144710341316649172), + glwe_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(5), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000022148688116005568513645324585951), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(512), + carry_modulus: CarryModulus(1), +}; + +pub fn get_parameters_from_message_and_carry_wopbs( + msg_space: usize, + carry_space: usize, +) -> Parameters { + let mut out = Parameters::default(); + let mut flag: bool = false; + let mut rescaled_message_space = f64::ceil(f64::log2(msg_space as f64)) as usize; + rescaled_message_space = 1 << rescaled_message_space; + let mut rescaled_carry_space = f64::ceil(f64::log2(carry_space as f64)) as usize; + rescaled_carry_space = 1 << rescaled_carry_space; + + for param in ALL_PARAMETER_VEC_WOPBS { + if param.message_modulus.0 == rescaled_message_space + && param.carry_modulus.0 == rescaled_carry_space + { + out = param; + flag = true; + break; + } + } + if !flag { + println!( + "### WARNING: NO PARAMETERS FOUND for msg_space = {} and carry_space = {} ### ", + rescaled_message_space, rescaled_carry_space + ); + } + out +} diff --git a/tfhe/src/shortint/parameters/parameters_wopbs_prime_moduli.rs b/tfhe/src/shortint/parameters/parameters_wopbs_prime_moduli.rs new file mode 100644 index 000000000..71b8d9292 --- /dev/null +++ b/tfhe/src/shortint/parameters/parameters_wopbs_prime_moduli.rs @@ -0,0 +1,889 @@ +pub use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, DispersionParameter, GlweDimension, + LweDimension, PolynomialSize, StandardDev, +}; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use crate::shortint::Parameters; + +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(693), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.0000173339182921315917918), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(757), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000005372539047440715995675), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(695), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001671088050446407327190), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(705), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001391593132168288907584), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(710), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001269897734067647866200), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(4), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(697), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001611023673517825963297), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(16), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(16), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(728), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00000913465281899372298196), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(690), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001831229863526819043776), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(699), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001553118206991877872242), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(694), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000017019548679502491437), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(730), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00000880632348297507352018), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(15), + ks_base_log: DecompositionBaseLog(1), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_3_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(706), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001366355065014387319960), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(8), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(702), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001470138983326210590285), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(696), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000016407810365194741608), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(713), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001202050272339788291268), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(716), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001137827730902298847640), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(745), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00000669212506995627734883), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_4_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(692), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001765409465411734898801), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(16), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(702), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001470138983326210590285), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(696), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000016407810365194741608), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(713), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001202050272339788291268), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(716), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001137827730902298847640), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(745), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00000669212506995627734883), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_5_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(692), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001765409465411734898801), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(32), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(702), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001470138983326210590285), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(696), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000016407810365194741608), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(713), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001202050272339788291268), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(716), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001137827730902298847640), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(745), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00000669212506995627734883), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_6_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(692), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001765409465411734898801), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(64), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(702), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001470138983326210590285), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(696), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000016407810365194741608), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(713), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001202050272339788291268), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(716), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001137827730902298847640), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(745), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00000669212506995627734883), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_7_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(692), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001765409465411734898801), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(128), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_2: Parameters = Parameters { + lwe_dimension: LweDimension(702), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001470138983326210590285), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(15), + pbs_level: DecompositionLevelCount(2), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(2), + pfks_base_log: DecompositionBaseLog(15), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_3: Parameters = Parameters { + lwe_dimension: LweDimension(689), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001865054674846586206642), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_4: Parameters = Parameters { + lwe_dimension: LweDimension(696), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.000016407810365194741608), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_5: Parameters = Parameters { + lwe_dimension: LweDimension(713), + glwe_dimension: GlweDimension(1), + polynomial_size: PolynomialSize(2048), + lwe_modular_std_dev: StandardDev(0.00001202050272339788291268), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(6), + cbs_base_log: DecompositionBaseLog(3), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_6: Parameters = Parameters { + lwe_dimension: LweDimension(716), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001137827730902298847640), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(12), + pbs_level: DecompositionLevelCount(3), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(3), + pfks_base_log: DecompositionBaseLog(12), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_7: Parameters = Parameters { + lwe_dimension: LweDimension(745), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00000669212506995627734883), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(5), + cbs_base_log: DecompositionBaseLog(4), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; +pub const WOPBS_PRIME_PARAM_MESSAGE_8_NORM2_8: Parameters = Parameters { + lwe_dimension: LweDimension(692), + glwe_dimension: GlweDimension(2), + polynomial_size: PolynomialSize(1024), + lwe_modular_std_dev: StandardDev(0.00001765409465411734898801), + glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + pbs_base_log: DecompositionBaseLog(9), + pbs_level: DecompositionLevelCount(4), + ks_level: DecompositionLevelCount(7), + ks_base_log: DecompositionBaseLog(2), + pfks_level: DecompositionLevelCount(4), + pfks_base_log: DecompositionBaseLog(9), + pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432531092229224715860), + cbs_level: DecompositionLevelCount(4), + cbs_base_log: DecompositionBaseLog(6), + message_modulus: MessageModulus(256), + carry_modulus: CarryModulus(1), +}; diff --git a/tfhe/src/shortint/prelude.rs b/tfhe/src/shortint/prelude.rs new file mode 100644 index 000000000..367f842d7 --- /dev/null +++ b/tfhe/src/shortint/prelude.rs @@ -0,0 +1,16 @@ +pub use super::client_key::ClientKey; +pub use super::gen_keys; +pub use super::parameters::{ + CarryModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, + MessageModulus, Parameters, PolynomialSize, StandardDev, DEFAULT_PARAMETERS, + PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_1_CARRY_2, PARAM_MESSAGE_1_CARRY_3, + PARAM_MESSAGE_1_CARRY_4, PARAM_MESSAGE_1_CARRY_5, PARAM_MESSAGE_1_CARRY_6, + PARAM_MESSAGE_1_CARRY_7, PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_2_CARRY_3, + PARAM_MESSAGE_2_CARRY_4, PARAM_MESSAGE_2_CARRY_5, PARAM_MESSAGE_2_CARRY_6, + PARAM_MESSAGE_3_CARRY_3, PARAM_MESSAGE_3_CARRY_4, PARAM_MESSAGE_3_CARRY_5, + PARAM_MESSAGE_4_CARRY_4, +}; +pub use super::public_key::PublicKey; + +pub use super::ciphertext::Ciphertext; +pub use super::server_key::ServerKey; diff --git a/tfhe/src/shortint/public_key/mod.rs b/tfhe/src/shortint/public_key/mod.rs new file mode 100644 index 000000000..957da8c6a --- /dev/null +++ b/tfhe/src/shortint/public_key/mod.rs @@ -0,0 +1,244 @@ +//! Module with the definition of the PublicKey. +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Ciphertext; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::parameters::{MessageModulus, Parameters}; +use crate::shortint::{ClientKey, ServerKey}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt::Debug; + +/// A structure containing a public key. +#[derive(Clone, Debug, PartialEq)] +pub struct PublicKey { + pub(crate) lwe_public_key: LwePublicKey64, + pub parameters: Parameters, +} + +impl PublicKey { + /// Generates a public key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::client_key::ClientKey; + /// use tfhe::shortint::parameters::Parameters; + /// use tfhe::shortint::public_key::PublicKey; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let pk = PublicKey::new(&cks); + /// ``` + pub fn new(client_key: &ClientKey) -> PublicKey { + ShortintEngine::with_thread_local_mut(|engine| engine.new_public_key(client_key).unwrap()) + } + + /// Encrypts a small integer message using the client key. + /// + /// The input message is reduced to the encrypted message space modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::{ClientKey, PublicKey, ServerKey}; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// + /// let sks = ServerKey::new(&cks); + /// + /// let pk = PublicKey::new(&cks); + /// + /// // Encryption of one message that is within the encrypted message modulus: + /// let msg = 3; + /// let ct = pk.encrypt(&sks, msg); + /// + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg, dec); + /// + /// // Encryption of one message that is outside the encrypted message modulus: + /// let msg = 5; + /// let ct = pk.encrypt(&sks, msg); + /// + /// let dec = cks.decrypt(&ct); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(msg % modulus, dec); + /// ``` + pub fn encrypt(&self, server_key: &ServerKey, message: u64) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .encrypt_with_public_key(self, server_key, message) + .unwrap() + }) + } + + /// Encrypts a small integer message using the client key with a specific message modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::MessageModulus; + /// use tfhe::shortint::{ClientKey, Parameters, PublicKey}; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let pk = PublicKey::new(&cks); + /// + /// let msg = 3; + /// + /// // Encryption of one message: + /// let ct = pk.encrypt_with_message_modulus(msg, MessageModulus(6)); + /// + /// // Decryption: + /// let dec = cks.decrypt(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_with_message_modulus( + &self, + message: u64, + message_modulus: MessageModulus, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .encrypt_with_message_modulus_and_public_key(self, message, message_modulus) + .unwrap() + }) + } + + /// Encrypts an integer without reducing the input message modulus the message space + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters, PublicKey}; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let pk = PublicKey::new(&cks); + /// + /// let msg = 7; + /// let ct = pk.unchecked_encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 1 | 1 1 | + /// + /// let dec = cks.decrypt_message_and_carry(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn unchecked_encrypt(&self, message: u64) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_encrypt_with_public_key(self, message) + .unwrap() + }) + } + + /// Encrypts a small integer message using the client key without padding bit. + /// + /// The input message is reduced to the encrypted message space modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::{ClientKey, PublicKey}; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2); + /// // DISCLAIMER: Note that this parameter is not guaranteed to be secure + /// let pk = PublicKey::new(&cks); + /// + /// // Encryption of one message that is within the encrypted message modulus: + /// let msg = 6; + /// let ct = pk.encrypt_without_padding(msg); + /// + /// let dec = cks.decrypt_message_and_carry_without_padding(&ct); + /// assert_eq!(msg, dec); + /// ``` + pub fn encrypt_without_padding(&self, message: u64) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .encrypt_without_padding_with_public_key(self, message) + .unwrap() + }) + } + + /// Encrypts a small integer message using the client key without padding bit with some modulus. + /// + /// The input message is reduced to the encrypted message space modulus + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{ClientKey, Parameters, PublicKey}; + /// + /// // Generate the client key: + /// let cks = ClientKey::new(Parameters::default()); + /// + /// let pk = PublicKey::new(&cks); + /// + /// let msg = 2; + /// let modulus = 3; + /// + /// // Encryption of one message: + /// let ct = pk.encrypt_native_crt(msg, modulus); + /// + /// // Decryption: + /// let dec = cks.decrypt_message_native_crt(&ct, modulus); + /// assert_eq!(msg, dec % modulus as u64); + /// ``` + pub fn encrypt_native_crt(&self, message: u64, message_modulus: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .encrypt_native_crt_with_public_key(self, message, message_modulus) + .unwrap() + }) + } +} + +#[derive(Serialize, Deserialize)] +struct SerializablePublicKey { + lwe_public_key: Vec, + parameters: Parameters, +} + +impl Serialize for PublicKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let lwe_public_key = ser_eng + .serialize(&self.lwe_public_key) + .map_err(serde::ser::Error::custom)?; + + SerializablePublicKey { + lwe_public_key, + parameters: self.parameters, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for PublicKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = + SerializablePublicKey::deserialize(deserializer).map_err(serde::de::Error::custom)?; + let mut de_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + Ok(Self { + lwe_public_key: de_eng + .deserialize(thing.lwe_public_key.as_slice()) + .map_err(serde::de::Error::custom)?, + parameters: thing.parameters, + }) + } +} diff --git a/tfhe/src/shortint/server_key/add.rs b/tfhe/src/shortint/server_key/add.rs new file mode 100644 index 000000000..baab867f7 --- /dev/null +++ b/tfhe/src/shortint/server_key/add.rs @@ -0,0 +1,256 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// This function computes the addition without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg1 = 1; + /// let msg2 = 2; + /// let ct1 = cks.encrypt(msg1); + /// let ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.unchecked_add(&ct1, &ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg1 + msg2, res); + /// ``` + pub fn unchecked_add(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_add(ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// The result is _stored_ in the `ct_left` ciphertext. + /// + /// This function computes the addition without checking if it exceeds the capacity of the + /// ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// let mut ct_left = cks.encrypt(msg); + /// let ct_right = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// sks.unchecked_add_assign(&mut ct_left, &ct_right); + /// + /// // Decrypt: + /// let two = cks.decrypt(&ct_left); + /// assert_eq!(msg + msg, two); + /// ``` + pub fn unchecked_add_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_add_assign(ct_left, ct_right).unwrap() + }) + } + + /// Verifies if ct_left and ct_right can be added together. + /// + /// This checks that the sum of their degree is + /// smaller than the maximum degree. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg); + /// let ct_right = cks.encrypt(msg); + /// + /// // Check if we can perform an addition + /// let can_be_added = sks.is_add_possible(&ct_left, &ct_right); + /// + /// assert_eq!(can_be_added, true); + /// ``` + pub fn is_add_possible(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> bool { + let final_operation_count = ct_left.degree.0 + ct_right.degree.0; + final_operation_count <= self.max_degree.0 + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.checked_add(&ct1, &ct2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, msg + msg); + /// ``` + pub fn checked_add( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_add_possible(ct_left, ct_right) { + let ct_result = self.unchecked_add(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct_left = cks.encrypt(msg); + /// let ct_right = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let res = sks.checked_add_assign(&mut ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct_left); + /// assert_eq!(clear_res, msg + msg); + /// ``` + pub fn checked_add_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> Result<(), CheckError> { + if self.is_add_possible(ct_left, ct_right) { + self.unchecked_add_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an addition: + /// let ct_res = sks.smart_add(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let two = cks.decrypt(&ct_res); + /// assert_eq!(msg + msg, two); + /// ``` + pub fn smart_add(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_add(self, ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically an addition between two ciphertexts + /// + /// The result is stored in the `ct_left` cipher text. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let msg1 = 15; + /// let msg2 = 3; + /// + /// let mut ct1 = cks.unchecked_encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an addition: + /// sks.smart_add_assign(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let two = cks.decrypt(&ct1); + /// + /// // 15 + 3 mod 4 -> 3 + 3 mod 4 -> 2 mod 4 + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!((msg2 + msg1) % modulus, two); + /// ``` + pub fn smart_add_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_add_assign(self, ct_left, ct_right).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/bitwise_op.rs b/tfhe/src/shortint/server_key/bitwise_op.rs new file mode 100644 index 000000000..d1e5a8ee2 --- /dev/null +++ b/tfhe/src/shortint/server_key/bitwise_op.rs @@ -0,0 +1,651 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::CheckError::CarryFull; +use crate::shortint::{CheckError, Ciphertext}; + +impl ServerKey { + /// Compute bitwise AND between two ciphertexts without checks. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// let (cks, sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_1 = 2; + /// let clear_2 = 1; + /// + /// let ct_1 = cks.encrypt(clear_1); + /// let ct_2 = cks.encrypt(clear_2); + /// + /// let ct_res = sks.unchecked_bitand(&ct_1, &ct_2); + /// + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(clear_1 & clear_2, res); + /// ``` + pub fn unchecked_bitand(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_bitand(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute bitwise AND between two ciphertexts without checks. + /// + /// The result is assigned in the `ct_left` ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// let (cks, sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_1 = 1; + /// let clear_2 = 2; + /// + /// let mut ct_left = cks.encrypt(clear_1); + /// let ct_right = cks.encrypt(clear_2); + /// + /// sks.unchecked_bitand_assign(&mut ct_left, &ct_right); + /// + /// let res = cks.decrypt(&ct_left); + /// assert_eq!(clear_1 & clear_2, res); + /// ``` + pub fn unchecked_bitand_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_bitand_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Compute bitwise AND between two ciphertexts without checks. + /// + /// If the operation can be performed, the result is returned a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an AND: + /// let ct_res = sks.checked_bitand(&ct1, &ct2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, msg & msg); + /// ``` + pub fn checked_bitand( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + let ct_result = self.unchecked_bitand(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Compute bitwise AND between two ciphertexts without checks. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct_left = cks.encrypt(msg); + /// let ct_right = cks.encrypt(msg); + /// + /// // Compute homomorphically an AND: + /// let res = sks.checked_bitand_assign(&mut ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct_left); + /// assert_eq!(clear_res, msg & msg); + /// ``` + pub fn checked_bitand_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> Result<(), CheckError> { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.unchecked_bitand_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an AND between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an AND: + /// let ct_res = sks.smart_bitand(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg & msg, res); + /// ``` + pub fn smart_bitand(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_bitand(self, ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically an AND between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// The result is stored in the `ct_left` cipher text. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let modulus = 4; + /// // Encrypt two messages: + /// let msg1 = 15; + /// let msg2 = 3; + /// + /// let mut ct1 = cks.unchecked_encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an AND: + /// sks.smart_bitand_assign(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct1); + /// + /// assert_eq!((msg2 & msg1) % modulus, res); + /// ``` + pub fn smart_bitand_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_bitand_assign(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute bitwise XOR between two ciphertexts without checks. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// let (cks, sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_1 = 1; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(clear_1); + /// let ct_right = cks.encrypt(clear_2); + /// + /// let ct_res = sks.unchecked_bitxor(&ct_left, &ct_right); + /// + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(clear_1 ^ clear_2, res); + /// ``` + pub fn unchecked_bitxor(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_bitxor(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute bitwise XOR between two ciphertexts without checks. + /// + /// The result is assigned in the `ct_left` ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// let (cks, sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_1 = 2; + /// let clear_2 = 0; + /// + /// // Encrypt two messages + /// let mut ct_left = cks.encrypt(clear_1); + /// let mut ct_right = cks.encrypt(clear_2); + /// + /// sks.smart_bitxor(&mut ct_left, &mut ct_right); + /// + /// let res = cks.decrypt(&ct_left); + /// assert_eq!(clear_1 ^ clear_2, res); + /// ``` + pub fn unchecked_bitxor_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_bitxor_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Compute bitwise XOR between two ciphertexts without checks. + /// + /// If the operation can be performed, the result is returned a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically a xor: + /// let ct_res = sks.checked_bitxor(&ct1, &ct2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, msg ^ msg); + /// ``` + pub fn checked_bitxor( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + let ct_result = self.unchecked_bitxor(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Compute bitwise XOR between two ciphertexts without checks. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct_left = cks.encrypt(msg); + /// let ct_right = cks.encrypt(msg); + /// + /// // Compute homomorphically a xor: + /// let res = sks.checked_bitxor_assign(&mut ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct_left); + /// assert_eq!(clear_res, msg ^ msg); + /// ``` + pub fn checked_bitxor_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> Result<(), CheckError> { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.unchecked_bitxor_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an XOR between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically a XOR: + /// let ct_res = sks.smart_bitxor(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg ^ msg, res); + /// ``` + pub fn smart_bitxor(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_bitxor(self, ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically a XOR between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// The result is stored in the `ct_left` cipher text. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let modulus = 4; + /// // Encrypt two messages: + /// let msg1 = 15; + /// let msg2 = 3; + /// + /// let mut ct1 = cks.unchecked_encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically a XOR: + /// sks.smart_bitxor_assign(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct1); + /// + /// assert_eq!((msg2 ^ msg1) % modulus, res); + /// ``` + pub fn smart_bitxor_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_bitxor_assign(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute bitwise OR between two ciphertexts. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_left = 1; + /// let clear_right = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(clear_left); + /// let ct_right = cks.encrypt(clear_right); + /// + /// let ct_res = sks.unchecked_bitor(&ct_left, &ct_right); + /// + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(clear_left | clear_right, res); + /// ``` + pub fn unchecked_bitor(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_bitor(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute bitwise OR between two ciphertexts. + /// + /// The result is assigned in the `ct_left` ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_left = 2; + /// let clear_right = 1; + /// + /// // Encrypt two messages + /// let mut ct_left = cks.encrypt(clear_left); + /// let ct_right = cks.encrypt(clear_right); + /// + /// sks.unchecked_bitor_assign(&mut ct_left, &ct_right); + /// + /// let res = cks.decrypt(&ct_left); + /// assert_eq!(clear_left | clear_right, res); + /// ``` + pub fn unchecked_bitor_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_bitor_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Compute bitwise OR between two ciphertexts without checks. + /// + /// If the operation can be performed, the result is returned a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically a or: + /// let ct_res = sks.checked_bitor(&ct1, &ct2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, msg | msg); + /// ``` + pub fn checked_bitor( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + let ct_result = self.unchecked_bitor(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Compute bitwise OR between two ciphertexts without checks. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct_left = cks.encrypt(msg); + /// let ct_right = cks.encrypt(msg); + /// + /// // Compute homomorphically an or: + /// let res = sks.checked_bitor_assign(&mut ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct_left); + /// assert_eq!(clear_res, msg | msg); + /// ``` + pub fn checked_bitor_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> Result<(), CheckError> { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + self.unchecked_bitor_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an OR between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_bitor(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(msg | msg, res); + /// ``` + pub fn smart_bitor(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_bitor(self, ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically an OR between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// The result is stored in the `ct_left` cipher text. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let modulus = 4; + /// // Encrypt two messages: + /// let msg1 = 15; + /// let msg2 = 3; + /// + /// let mut ct1 = cks.unchecked_encrypt(msg1); + /// let mut ct2 = cks.encrypt(msg2); + /// + /// // Compute homomorphically an OR: + /// sks.smart_bitor_assign(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct1); + /// + /// assert_eq!((msg2 | msg1) % modulus, res); + /// ``` + pub fn smart_bitor_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_bitor_assign(self, ct_left, ct_right).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/comp_op.rs b/tfhe/src/shortint/server_key/comp_op.rs new file mode 100644 index 000000000..7e47d7880 --- /dev/null +++ b/tfhe/src/shortint/server_key/comp_op.rs @@ -0,0 +1,822 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +// # Note: +// _assign comparison operation are not made public (if they exists) as we don't think there are +// uses for them. For instance: adding has an assign variants because you can do "+" and "+=" +// however, comparisons like equality do not have that, "==" does not have and "===", +// ">=" is greater of equal, not greater_assign. + +impl ServerKey { + /// Implements the "greater" (`>`) operator between two ciphertexts without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let ct_res = sks.unchecked_greater(&ct_left, &ct_right); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg_1 > msg_2) as u64, res); + /// ``` + pub fn unchecked_greater(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_greater(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "greater" (`>`) operator between two ciphertexts with checks. + /// + /// If the operation can be performed, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let res = sks.checked_greater(&ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// let res = res.unwrap(); + /// + /// let clear_res = cks.decrypt(&res); + /// assert_eq!((msg_1 > msg_2) as u64, clear_res); + /// ``` + pub fn checked_greater( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_greater(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a `>` between two ciphertexts encrypting integer values. + /// + /// This checks that the operation is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_greater(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg > msg) as u64, res); + /// ``` + pub fn smart_greater(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_greater(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "greater or equal" (`>=`) operator between two ciphertexts without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let ct_res = sks.unchecked_greater_or_equal(&ct_left, &ct_right); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg_1 >= msg_2) as u64, res); + /// ``` + pub fn unchecked_greater_or_equal( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_greater_or_equal(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Computes homomorphically a `>=` between two ciphertexts encrypting integer values. + /// + /// This checks that the operation is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_greater_or_equal(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg >= msg) as u64, res); + /// ``` + pub fn smart_greater_or_equal( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_greater_or_equal(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Implements the "greater or equal" (`>=`) operator between two ciphertexts with checks. + /// + /// If the operation can be performed, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let res = sks.checked_greater(&ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// let res = res.unwrap(); + /// + /// let clear_res = cks.decrypt(&res); + /// assert_eq!((msg_1 >= msg_2) as u64, clear_res); + /// ``` + pub fn checked_greater_or_equal( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_greater_or_equal(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Implements the "less" (`<`) operator between two ciphertexts without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// // Do the comparison + /// let ct_res = sks.unchecked_less(&ct_left, &ct_right); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg_1 < msg_2) as u64, res); + /// ``` + pub fn unchecked_less(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_less(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "less" (`<`) operator between two ciphertexts with checks. + /// + /// If the operation can be performed, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let res = sks.checked_less(&ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// let res = res.unwrap(); + /// + /// let clear_res = cks.decrypt(&res); + /// assert_eq!((msg_1 < msg_2) as u64, clear_res); + /// ``` + pub fn checked_less( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_less(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a `<` between two ciphertexts encrypting integer values. + /// + /// This checks that the operation is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_less(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg < msg) as u64, res); + /// ``` + pub fn smart_less(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_less(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "less or equal" (`<=`) between two ciphertexts operator without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let ct_res = sks.unchecked_less_or_equal(&ct_left, &ct_right); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg_1 <= msg_2) as u64, res); + /// ``` + pub fn unchecked_less_or_equal( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_less_or_equal(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Implements the "less or equal" (`<=`) operator between two ciphertexts with checks. + /// + /// If the operation can be performed, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let res = sks.checked_less_or_equal(&ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// let res = res.unwrap(); + /// + /// let clear_res = cks.decrypt(&res); + /// assert_eq!((msg_1 <= msg_2) as u64, clear_res); + /// ``` + pub fn checked_less_or_equal( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_less(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a `<=` between two ciphertexts encrypting integer values. + /// + /// This checks that the operation is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_less_or_equal(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg <= msg) as u64, res); + /// ``` + pub fn smart_less_or_equal( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_less_or_equal(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "equal" operator (`==`) between two ciphertexts without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let msg_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let ct_res = sks.unchecked_equal(&ct_left, &ct_right); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, 1); + /// ``` + pub fn unchecked_equal(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_equal(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "equal" (`==`) operator between two ciphertexts with checks. + /// + /// If the operation can be performed, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let res = sks.checked_equal(&ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// let res = res.unwrap(); + /// + /// let clear_res = cks.decrypt(&res); + /// assert_eq!((msg_1 == msg_2) as u64, clear_res); + /// ``` + pub fn checked_equal( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_equal(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a `==` between two ciphertexts encrypting integer values. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_equal(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg == msg) as u64, res); + /// ``` + pub fn smart_equal(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_equal(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "not equal" operator (`!=`) between two ciphertexts without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let ct_res = sks.unchecked_not_equal(&ct_left, &ct_right); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, 1); + /// ``` + pub fn unchecked_not_equal(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_not_equal(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "not equal" (`!=`) operator between two ciphertexts with checks. + /// + /// If the operation can be performed, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 1; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_left = cks.encrypt(msg_1); + /// let ct_right = cks.encrypt(msg_2); + /// + /// let res = sks.checked_not_equal(&ct_left, &ct_right); + /// + /// assert!(res.is_ok()); + /// let res = res.unwrap(); + /// + /// let clear_res = cks.decrypt(&res); + /// assert_eq!((msg_1 != msg_2) as u64, clear_res); + /// ``` + pub fn checked_not_equal( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_functional_bivariate_pbs_possible(ct_left, ct_right) { + Ok(self.unchecked_not_equal(ct_left, ct_right)) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a `!=` between two ciphertexts encrypting integer values. + /// + /// This checks that the operation is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Encrypt two messages: + /// let mut ct1 = cks.encrypt(msg); + /// let mut ct2 = cks.encrypt(msg); + /// + /// // Compute homomorphically an OR: + /// let ct_res = sks.smart_not_equal(&mut ct1, &mut ct2); + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((msg != msg) as u64, res); + /// ``` + pub fn smart_not_equal( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_not_equal(self, ct_left, ct_right).unwrap() + }) + } + + /// Implements the "equal" operator (`==`) between a ciphertext and a scalar without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let scalar = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// + /// let ct_res = sks.smart_scalar_equal(&ct_left, scalar); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (msg_1 == scalar as u64) as u64); + /// ``` + pub fn smart_scalar_equal(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_equal(self, ct_left, scalar).unwrap() + }) + } + + /// Implements the "not equal" operator (`!=`) between a ciphertext and a scalar without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let scalar = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// + /// let ct_res = sks.smart_scalar_not_equal(&ct_left, scalar); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (msg_1 != scalar as u64) as u64); + /// ``` + pub fn smart_scalar_not_equal(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_scalar_not_equal(self, ct_left, scalar) + .unwrap() + }) + } + + /// Implements the "greater or equal" operator (`>=`) between a ciphertext and a scalar without + /// checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let scalar = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// + /// let ct_res = sks.smart_scalar_greater_or_equal(&ct_left, scalar); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (msg_1 >= scalar as u64) as u64); + /// ``` + pub fn smart_scalar_greater_or_equal(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_scalar_greater_or_equal(self, ct_left, scalar) + .unwrap() + }) + } + + /// Implements the "less or equal" operator (`<=`) between a ciphertext and a scalar without + /// checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let scalar = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// + /// let ct_res = sks.smart_scalar_less_or_equal(&ct_left, scalar); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (msg_1 <= scalar as u64) as u64); + /// ``` + pub fn smart_scalar_less_or_equal(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_scalar_less_or_equal(self, ct_left, scalar) + .unwrap() + }) + } + + /// Implements the "greater" operator (`>`) between a ciphertext and a scalar without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let scalar = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// + /// let ct_res = sks.smart_scalar_greater(&ct_left, scalar); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (msg_1 > scalar as u64) as u64); + /// ``` + pub fn smart_scalar_greater(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_greater(self, ct_left, scalar).unwrap() + }) + } + + /// Implements the "less" operator (`<`) between a ciphertext and a scalar without checks. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let scalar = 2; + /// + /// // Encrypt two messages + /// let ct_left = cks.encrypt(msg_1); + /// + /// let ct_res = sks.smart_scalar_less(&ct_left, scalar); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (msg_1 < scalar as u64) as u64); + /// ``` + pub fn smart_scalar_less(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_less(self, ct_left, scalar).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/div_mod.rs b/tfhe/src/shortint/server_key/div_mod.rs new file mode 100644 index 000000000..2067d635d --- /dev/null +++ b/tfhe/src/shortint/server_key/div_mod.rs @@ -0,0 +1,238 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Compute a division between two ciphertexts without checks. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Warning + /// + /// /!\ A division by zero returns 0! + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 1; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let ct_1 = cks.encrypt(clear_1); + /// let ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_div(&ct_1, &ct_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(clear_1 / clear_2, res); + /// ``` + pub fn unchecked_div(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_div(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute a division between two ciphertexts without checks. + /// + /// The result is _assigned_ in `ct_left`. + /// + /// # Warning + /// + /// /!\ A division by zero returns 0! + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 1; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let mut ct_1 = cks.encrypt(clear_1); + /// let ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_div_assign(&mut ct_1, &ct_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_1); + /// assert_eq!(clear_1 / clear_2, res); + /// ``` + pub fn unchecked_div_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_div_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Compute a division between two ciphertexts. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Warning + /// + /// /!\ A division by zero returns 0! + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 1; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let mut ct_1 = cks.encrypt(clear_1); + /// let mut ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.smart_div(&mut ct_1, &mut ct_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(clear_1 / clear_2, res); + /// ``` + pub fn smart_div(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_div(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute a division between two ciphertexts without checks. + /// + /// The result is _assigned_ in `ct_left`. + /// + /// # Warning + /// + /// /!\ A division by zero returns 0! + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 3; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let mut ct_1 = cks.encrypt(clear_1); + /// let mut ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_div_assign(&mut ct_1, &ct_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_1); + /// assert_eq!(clear_1 / clear_2, res); + /// ``` + pub fn smart_div_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_div_assign(self, ct_left, ct_right).unwrap() + }) + } + + /// Compute a division of a ciphertext by a scalar without checks. + /// + /// # Panics + /// + /// This function will panic if `scalar == 0`. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear_1 = 3; + /// let clear_2 = 2; + /// + /// // Encrypt one message + /// let mut ct_1 = cks.encrypt(clear_1); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_scalar_div(&mut ct_1, clear_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(clear_1 / (clear_2 as u64), res); + /// ``` + pub fn unchecked_scalar_div(&self, ct_left: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_div(self, ct_left, scalar).unwrap() + }) + } + + pub fn unchecked_scalar_div_assign(&self, ct_left: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_scalar_div_assign(self, ct_left, scalar) + .unwrap() + }) + } + + /// Computes homomorphically a modular reduction without checks. + /// + /// # Panics + /// + /// This function will panic if `modulus == 0`. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// let modulus: u8 = 2; + /// // Compute homomorphically an addition: + /// let ct_res = sks.unchecked_scalar_mod(&mut ct, modulus); + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// assert_eq!(1, dec); + /// ``` + pub fn unchecked_scalar_mod(&self, ct_left: &Ciphertext, modulus: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_mod(self, ct_left, modulus).unwrap() + }) + } + + pub fn unchecked_scalar_mod_assign(&self, ct_left: &mut Ciphertext, modulus: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_scalar_mod_assign(self, ct_left, modulus) + .unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs new file mode 100644 index 000000000..4af3287d4 --- /dev/null +++ b/tfhe/src/shortint/server_key/mod.rs @@ -0,0 +1,598 @@ +//! Module with the definition of the ServerKey. +//! +//! This module implements the generation of the server public key, together with all the +//! available homomorphic integer operations. +mod add; +mod bitwise_op; +mod comp_op; +mod div_mod; +mod mul; +mod neg; +mod scalar_add; +mod scalar_mul; +mod scalar_sub; +mod shift; +mod sub; + +#[cfg(test)] +mod tests; + +use crate::core_crypto::prelude::*; +use crate::shortint::ciphertext::Ciphertext; +use crate::shortint::client_key::ClientKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::parameters::{CarryModulus, MessageModulus}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::fmt::{Debug, Display, Formatter}; + +/// Maximum value that the degree can reach. +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct MaxDegree(pub usize); + +/// Error returned when the carry buffer is full. +#[derive(Debug)] +pub enum CheckError { + CarryFull, +} + +impl Display for CheckError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + CheckError::CarryFull => { + write!(f, "The carry buffer is full") + } + } + } +} + +impl std::error::Error for CheckError {} + +/// A structure containing the server public key. +/// +/// The server key is generated by the client and is meant to be published: the client +/// sends it to the server so it can compute homomorphic circuits. +#[derive(Clone, Debug, PartialEq)] +pub struct ServerKey { + pub key_switching_key: LweKeyswitchKey64, + pub bootstrapping_key: FftFourierLweBootstrapKey64, + // Size of the message buffer + pub message_modulus: MessageModulus, + // Size of the carry buffer + pub carry_modulus: CarryModulus, + // Maximum number of operations that can be done before emptying the operation buffer + pub max_degree: MaxDegree, +} + +impl ServerKey { + /// Generates a server key. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::{gen_keys, ServerKey}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Generate the server key: + /// let sks = ServerKey::new(&cks); + /// ``` + pub fn new(cks: &ClientKey) -> ServerKey { + ShortintEngine::with_thread_local_mut(|engine| engine.new_server_key(cks).unwrap()) + } + + /// Generates a server key with a chosen maximum degree + pub fn new_with_max_degree(cks: &ClientKey, max_degree: MaxDegree) -> ServerKey { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .new_server_key_with_max_degree(cks, max_degree) + .unwrap() + }) + } + + /// Constructs the accumulator given a function as input. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 3; + /// + /// let ct = cks.encrypt(msg); + /// + /// // Generate the accumulator for the function f: x -> x^2 mod 2^2 + /// let f = |x| x ^ 2 % 4; + /// + /// let acc = sks.generate_accumulator(f); + /// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc); + /// + /// let dec = cks.decrypt(&ct_res); + /// // 3^2 mod 4 = 1 + /// assert_eq!(dec, f(msg)); + /// ``` + pub fn generate_accumulator(&self, f: F) -> GlweCiphertext64 + where + F: Fn(u64) -> u64, + { + ShortintEngine::with_thread_local_mut(|engine| { + engine.generate_accumulator(self, f).unwrap() + }) + } + + /// Constructs the accumulator for a given bivariate function as input. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 3; + /// + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(0); + /// // Generate the accumulator for the function f: x -> x^2 mod 2^2 + /// let f = |x, y| (x + y) ^ 2 % 4; + /// + /// let acc = sks.generate_accumulator_bivariate(f); + /// let ct_res = sks.keyswitch_programmable_bootstrap_bivariate(&ct1, &ct2, &acc); + /// + /// let dec = cks.decrypt(&ct_res); + /// // 3^2 mod 4 = 1 + /// assert_eq!(dec, f(msg, 0)); + /// ``` + pub fn generate_accumulator_bivariate(&self, f: F) -> GlweCiphertext64 + where + F: Fn(u64, u64) -> u64, + { + ShortintEngine::with_thread_local_mut(|engine| { + engine.generate_accumulator_bivariate(self, f).unwrap() + }) + } + + /// Computes a keyswitch and a bootstrap, returning a new ciphertext with empty + /// carry bits. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let mut ct1 = cks.encrypt(3); + /// // | ct1 | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 1 | + /// let mut ct2 = cks.encrypt(2); + /// // | ct2 | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// let ct_res = sks.smart_add(&mut ct1, &mut ct2); + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 1 | 0 1 | + /// + /// // Get the carry + /// let ct_carry = sks.carry_extract(&ct_res); + /// let carry = cks.decrypt(&ct_carry); + /// assert_eq!(carry, 1); + /// + /// let ct_res = sks.keyswitch_bootstrap(&ct_res); + /// + /// let ct_carry = sks.carry_extract(&ct_res); + /// let carry = cks.decrypt(&ct_carry); + /// assert_eq!(carry, 0); + /// + /// let clear = cks.decrypt(&ct_res); + /// + /// assert_eq!(clear, (3 + 2) % 4); + /// ``` + pub fn keyswitch_bootstrap(&self, ct_in: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.keyswitch_bootstrap(self, ct_in).unwrap() + }) + } + + pub fn keyswitch_bootstrap_assign(&self, ct_in: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.keyswitch_bootstrap_assign(self, ct_in).unwrap() + }) + } + + /// Computes a keyswitch and programmable bootstrap. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg: u64 = 3; + /// let ct1 = cks.encrypt(msg); + /// let ct2 = cks.encrypt(msg); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// + /// // Generate the accumulator for the function f: x -> x^3 mod 2^2 + /// let acc = sks.generate_accumulator_bivariate(|x, y| x * y * x % modulus); + /// let ct_res = sks.keyswitch_programmable_bootstrap_bivariate(&ct1, &ct2, &acc); + /// + /// let dec = cks.decrypt(&ct_res); + /// // 3^3 mod 4 = 3 + /// assert_eq!(dec, (msg * msg * msg) % modulus); + /// ``` + pub fn keyswitch_programmable_bootstrap_bivariate( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + acc: &GlweCiphertext64, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrap_keyswitch_bivariate(self, ct_left, ct_right, acc) + .unwrap() + }) + } + + pub fn keyswitch_programmable_bootstrap_bivariate_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + acc: &GlweCiphertext64, + ) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrap_keyswitch_bivariate_assign(self, ct_left, ct_right, acc) + .unwrap() + }) + } + + /// Computes a keyswitch and programmable bootstrap. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg: u64 = 3; + /// let ct = cks.encrypt(msg); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// + /// // Generate the accumulator for the function f: x -> x^3 mod 2^2 + /// let acc = sks.generate_accumulator(|x| x * x * x % modulus); + /// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc); + /// + /// let dec = cks.decrypt(&ct_res); + /// // 3^3 mod 4 = 3 + /// assert_eq!(dec, (msg * msg * msg) % modulus); + /// ``` + pub fn keyswitch_programmable_bootstrap( + &self, + ct_in: &Ciphertext, + acc: &GlweCiphertext64, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrap_keyswitch(self, ct_in, acc) + .unwrap() + }) + } + + pub fn keyswitch_programmable_bootstrap_assign( + &self, + ct_in: &mut Ciphertext, + acc: &GlweCiphertext64, + ) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrap_keyswitch_assign(self, ct_in, acc) + .unwrap() + }) + } + + /// Generic programmable bootstrap where messages are concatenated + /// into one ciphertext to compute bivariate functions. + /// This is used to apply many binary operations (comparisons, multiplications, division). + pub fn unchecked_functional_bivariate_pbs( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + f: F, + ) -> Ciphertext + where + F: Fn(u64) -> u64, + { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_functional_bivariate_pbs(self, ct_left, ct_right, f) + .unwrap() + }) + } + + pub fn unchecked_functional_bivariate_pbs_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + f: F, + ) where + F: Fn(u64) -> u64, + { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_functional_bivariate_pbs_assign(self, ct_left, ct_right, f) + .unwrap() + }) + } + + /// Verifies if a bivariate functional pbs can be applied on ct_left and ct_right. + pub fn is_functional_bivariate_pbs_possible(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> bool { + //product of the degree + let final_degree = ct1.degree.0 * (ct2.degree.0 + 1) + ct2.degree.0; + final_degree < ct1.carry_modulus.0 * ct1.message_modulus.0 + } + + /// Replace the input encrypted message by the value of its carry buffer. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 9; + /// + /// // Encrypt a message + /// let mut ct = cks.unchecked_encrypt(clear); + /// + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 0 | 0 1 | + /// + /// // Compute homomorphically carry extraction + /// sks.carry_extract_assign(&mut ct); + /// + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Decrypt: + /// let res = cks.decrypt_message_and_carry(&ct); + /// assert_eq!(2, res); + /// ``` + pub fn carry_extract_assign(&self, ct: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.carry_extract_assign(self, ct).unwrap() + }) + } + + /// Extracts a new ciphertext encrypting the input carry buffer. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 9; + /// + /// // Encrypt a message + /// let ct = cks.unchecked_encrypt(clear); + /// + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 0 | 0 1 | + /// + /// // Compute homomorphically carry extraction + /// let ct_res = sks.carry_extract(&ct); + /// + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(2, res); + /// ``` + pub fn carry_extract(&self, ct: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.carry_extract(self, ct).unwrap()) + } + + /// Clears the carry buffer of the input ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let clear = 9; + /// + /// // Encrypt a message + /// let mut ct = cks.unchecked_encrypt(clear); + /// + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 0 | 0 1 | + /// + /// // Compute homomorphically the message extraction + /// sks.message_extract_assign(&mut ct); + /// + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 1 | + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct); + /// assert_eq!(1, res); + /// ``` + pub fn message_extract_assign(&self, ct: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.message_extract_assign(self, ct).unwrap() + }) + } + + /// Extracts a new ciphertext containing only the message i.e., with a cleared carry buffer. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + /// + /// let clear = 9; + /// + /// // Encrypt a message + /// let ct = cks.unchecked_encrypt(clear); + /// + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 0 | 0 1 | + /// + /// // Compute homomorphically the message extraction + /// let ct_res = sks.message_extract(&ct); + /// + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 1 | + /// + /// // Decrypt: + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(1, res); + /// ``` + pub fn message_extract(&self, ct: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.message_extract(self, ct).unwrap()) + } + + /// Computes a trivial shortint from a given value. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1; + /// + /// // Trivial encryption + /// let ct1 = sks.create_trivial(msg); + /// + /// let ct_res = cks.decrypt(&ct1); + /// assert_eq!(1, ct_res); + /// ``` + pub fn create_trivial(&self, value: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.create_trivial(self, value).unwrap()) + } + + pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.create_trivial_assign(self, ct, value).unwrap() + }) + } +} + +#[derive(Serialize, Deserialize)] +pub(super) struct SerializableServerKey { + pub key_switching_key: Vec, + pub bootstrapping_key: Vec, + // Size of the message buffer + pub message_modulus: MessageModulus, + // Size of the carry buffer + pub carry_modulus: CarryModulus, + // Maximum number of operations that can be done before emptying the operation buffer + pub max_degree: MaxDegree, +} + +impl Serialize for ServerKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + let mut fft_ser_eng = FftSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let key_switching_key = ser_eng + .serialize(&self.key_switching_key) + .map_err(serde::ser::Error::custom)?; + let bootstrapping_key = fft_ser_eng + .serialize(&self.bootstrapping_key) + .map_err(serde::ser::Error::custom)?; + + SerializableServerKey { + key_switching_key, + bootstrapping_key, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + max_degree: self.max_degree, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for ServerKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = + SerializableServerKey::deserialize(deserializer).map_err(serde::de::Error::custom)?; + let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + let mut fft_ser_eng = FftSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + Ok(Self { + key_switching_key: ser_eng + .deserialize(thing.key_switching_key.as_slice()) + .map_err(serde::de::Error::custom)?, + bootstrapping_key: fft_ser_eng + .deserialize(thing.bootstrapping_key.as_slice()) + .map_err(serde::de::Error::custom)?, + message_modulus: thing.message_modulus, + carry_modulus: thing.carry_modulus, + max_degree: thing.max_degree, + }) + } +} diff --git a/tfhe/src/shortint/server_key/mul.rs b/tfhe/src/shortint/server_key/mul.rs new file mode 100644 index 000000000..5de52d6f2 --- /dev/null +++ b/tfhe/src/shortint/server_key/mul.rs @@ -0,0 +1,575 @@ +use super::ServerKey; +use crate::shortint::ciphertext::Degree; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Multiplies two ciphertexts together without checks. + /// + /// Returns the "least significant bits" of the multiplication, i.e., the result modulus the + /// message_modulus. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// + /// // Generate the client key and the server key + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + /// + /// let clear_1 = 1; + /// let clear_2 = 1; + /// + /// // Encrypt two messages + /// let ct_1 = cks.encrypt(clear_1); + /// let ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_mul_lsb(&ct_1, &ct_2); + /// // 2*3 == 6 == 01_10 (base 2) + /// // Only the message part is returned (lsb) so `ct_res` is: + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!((clear_1 * clear_2) % modulus, res); + /// ``` + pub fn unchecked_mul_lsb(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_mul_lsb(self, ct_left, ct_right).unwrap() + }) + } + + /// Multiplies two ciphertexts together without checks. + /// + /// Returns the "least significant bits" of the multiplication, i.e., the result modulus the + /// message_modulus. + /// + /// The result is _assigned_ in the first ciphertext + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// // Generate the client key and the server key + /// let (mut cks, mut sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_1 = 3; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let mut ct_1 = cks.encrypt(clear_1); + /// let ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// sks.unchecked_mul_lsb_assign(&mut ct_1, &ct_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_1); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!((clear_1 * clear_2) % modulus, res); + /// ``` + pub fn unchecked_mul_lsb_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_mul_lsb_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Multiplies two ciphertexts together without checks. + /// + /// Returns the "most significant bits" of the multiplication, i.e., the part in the carry + /// buffer. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::DEFAULT_PARAMETERS; + /// + /// // Generate the client key and the server key + /// let (mut cks, mut sks) = gen_keys(DEFAULT_PARAMETERS); + /// + /// let clear_1 = 3; + /// let clear_2 = 2; + /// + /// // Encrypt two messages + /// let mut ct_1 = cks.encrypt(clear_1); + /// let mut ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_mul_msb(&ct_1, &ct_2); + /// // 2*3 == 6 == 01_10 (base 2) + /// // however the ciphertext will contain only the carry buffer + /// // as the message, the ct_res is actually: + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 1 | + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!((clear_1 * clear_2) / modulus, res); + /// ``` + pub fn unchecked_mul_msb(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_mul_msb(self, ct_left, ct_right).unwrap() + }) + } + + pub fn unchecked_mul_msb_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_mul_msb_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Verifies if two ciphertexts can be multiplied together. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(Parameters::default()); + /// + /// let msg = 2; + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(msg); + /// let ct_2 = cks.encrypt(msg); + /// + /// // Check if we can perform a multiplication + /// let res = sks.is_mul_possible(&ct_1, &ct_2); + /// + /// assert_eq!(true, res); + /// ``` + pub fn is_mul_possible(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> bool { + self.is_functional_bivariate_pbs_possible(ct1, ct2) + } + + /// Multiplies two ciphertexts together with checks. + /// + /// Returns the "least significant bits" of the multiplication, i.e., the result modulus the + /// message_modulus. + /// + /// If the operation can be performed, a _new_ ciphertext with the result is returned. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(2); + /// let ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a multiplication: + /// let ct_res = sks.checked_mul_lsb(&ct_1, &ct_2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt_message_and_carry(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res % modulus, 2); + /// ``` + pub fn checked_mul_lsb( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_mul_possible(ct_left, ct_right) { + let ct_result = self.unchecked_mul_lsb(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Multiplies two ciphertexts together with checks. + /// + /// Returns the "least significant bits" of the multiplication, i.e., the result modulus the + /// message_modulus. + /// + /// If the operation can be performed, the result is assigned to the first ciphertext given + /// as a parameter. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt two messages: + /// let mut ct_1 = cks.encrypt(2); + /// let ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a multiplication: + /// let ct_res = sks.checked_mul_lsb_assign(&mut ct_1, &ct_2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let clear_res = cks.decrypt_message_and_carry(&ct_1); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res % modulus, 2); + /// ``` + pub fn checked_mul_lsb_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> Result<(), CheckError> { + if self.is_mul_possible(ct_left, ct_right) { + self.unchecked_mul_lsb_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Multiplies two ciphertexts together without checks. + /// + /// Returns the "most significant bits" of the multiplication, i.e., the part in the carry + /// buffer. + /// + /// If the operation can be performed, a _new_ ciphertext with the result is returned. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let msg_2 = 2; + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(msg_1); + /// let ct_2 = cks.encrypt(msg_2); + /// + /// // Compute homomorphically a multiplication: + /// let ct_res = sks.checked_mul_msb(&ct_1, &ct_2); + /// assert!(ct_res.is_ok()); + /// + /// // 2*2 == 4 == 01_00 (base 2) + /// // however the ciphertext will contain only the carry buffer + /// // as the message, the ct_res is actually: + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 1 | + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!( + /// clear_res, + /// (msg_1 * msg_2) / cks.parameters.message_modulus.0 as u64 + /// ); + /// ``` + pub fn checked_mul_msb( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + if self.is_mul_possible(ct_left, ct_right) { + let ct_result = self.unchecked_mul_msb(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Multiply two ciphertexts together using one bit of carry only. + /// + /// The algorithm uses the (.)^2/4 trick. + /// For more information: page 4, §Computing a multiplication in + /// Chillotti, I., Joye, M., Ligier, D., Orfila, J. B., & Tap, S. (2020, December). + /// CONCRETE: Concrete operates on ciphertexts rapidly by extending TfhE. + /// In WAHC 2020–8th Workshop on Encrypted Computing & Applied Homomorphic Cryptography (Vol. + /// 15). + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + /// + /// let clear_1 = 1; + /// let clear_2 = 1; + /// + /// // Encrypt two messages + /// let mut ct_1 = cks.encrypt(clear_1); + /// let mut ct_2 = cks.encrypt(clear_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.unchecked_mul_lsb_small_carry(&mut ct_1, &mut ct_2); + /// + /// // Decrypt + /// let res = cks.decrypt(&ct_res); + /// assert_eq!((clear_2 * clear_1), res); + /// ``` + pub fn unchecked_mul_lsb_small_carry( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_mul_lsb_small_carry_modulus(self, ct_left, ct_right) + .unwrap() + }) + } + + pub fn unchecked_mul_lsb_small_carry_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_mul_lsb_small_carry_modulus_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Verifies if two ciphertexts can be multiplied together in the case where the carry + /// modulus is smaller than the message modulus. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_1; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_2_CARRY_1); + /// + /// let msg = 2; + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(msg); + /// let ct_2 = cks.encrypt(msg); + /// + /// // Check if we can perform a multiplication + /// let mut res = sks.is_mul_small_carry_possible(&ct_1, &ct_2); + /// + /// assert_eq!(true, res); + /// + /// //Encryption with a full carry buffer + /// let large_msg = 7; + /// let ct_3 = cks.unchecked_encrypt(large_msg); + /// + /// // Check if we can perform a multiplication + /// res = sks.is_mul_small_carry_possible(&ct_1, &ct_3); + /// + /// assert_eq!(false, res); + /// ``` + pub fn is_mul_small_carry_possible(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> bool { + // Check if an addition is possible + let b1 = self.is_add_possible(ct_left, ct_right); + let b2 = self.is_sub_possible(ct_left, ct_right); + b1 & b2 + } + + /// Computes homomorphically a multiplication between two ciphertexts encrypting integer values. + /// + /// The operation is done using a small carry buffer. + /// + /// If the operation can be performed, a _new_ ciphertext with the result of the + /// multiplication is returned. Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg_1 = 2; + /// let msg_2 = 3; + /// + /// // Encrypt two messages: + /// let mut ct_1 = cks.encrypt(msg_1); + /// let mut ct_2 = cks.encrypt(msg_2); + /// + /// // Compute homomorphically a multiplication + /// let ct_res = sks.checked_mul_lsb_with_small_carry(&mut ct_1, &mut ct_2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res % modulus, (msg_1 * msg_2) % modulus); + /// ``` + pub fn checked_mul_lsb_with_small_carry( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> Result { + if self.is_mul_small_carry_possible(ct_left, ct_right) { + let mut ct_result = self.unchecked_mul_lsb_small_carry(ct_left, ct_right); + ct_result.degree = Degree(ct_left.degree.0 * 2); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Multiplies two ciphertexts. + /// + /// Returns the "least significant bits" of the multiplication, i.e., the result modulus the + /// message_modulus. + /// + /// The result is _assigned_ in the first ciphertext + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_1; + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_1); + /// + /// // Encrypt two messages: + /// let msg1 = 5; + /// let msg2 = 3; + /// + /// let mut ct_1 = cks.unchecked_encrypt(msg1); + /// let mut ct_2 = cks.unchecked_encrypt(msg2); + /// + /// // Compute homomorphically a multiplication + /// sks.smart_mul_lsb_assign(&mut ct_1, &mut ct_2); + /// + /// let res = cks.decrypt(&ct_1); + /// let modulus = sks.message_modulus.0 as u64; + /// assert_eq!(res % modulus, (msg1 * msg2) % modulus); + /// ``` + pub fn smart_mul_lsb_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_mul_lsb_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + pub fn smart_mul_msb_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_mul_msb_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Multiply two ciphertexts together + /// + /// Returns the "least significant bits" of the multiplication, i.e., the result modulus the + /// message_modulus. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt two messages: + /// let msg1 = 12; + /// let msg2 = 13; + /// + /// let mut ct_left = cks.unchecked_encrypt(msg1); + /// // | ct_left | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 1 | 0 0 | + /// let mut ct_right = cks.unchecked_encrypt(msg2); + /// // | ct_right | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 1 | 0 1 | + /// + /// // Compute homomorphically a multiplication: + /// let ct_res = sks.smart_mul_lsb(&mut ct_left, &mut ct_right); + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 0 | + /// + /// let res = cks.decrypt(&ct_res); + /// let modulus = sks.message_modulus.0; + /// assert_eq!(res, (msg1 * msg2) % modulus as u64); + /// ``` + pub fn smart_mul_lsb(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_mul_lsb(self, ct_left, ct_right).unwrap() + }) + } + + /// Multiply two ciphertexts together + /// + /// Returns the "most significant bits" of the multiplication, i.e., the part in the carry + /// buffer. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt two messages: + /// let msg1 = 12; + /// let msg2 = 12; + /// + /// let mut ct_1 = cks.unchecked_encrypt(msg1); + /// let mut ct_2 = cks.unchecked_encrypt(msg2); + /// + /// // Compute homomorphically a multiplication: + /// let ct_res = sks.smart_mul_msb(&mut ct_1, &mut ct_2); + /// + /// let res = cks.decrypt(&ct_res); + /// let modulus = sks.carry_modulus.0; + /// assert_eq!(res, (msg1 * msg2) % modulus as u64); + /// ``` + pub fn smart_mul_msb(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_mul_msb(self, ct_left, ct_right).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/neg.rs b/tfhe/src/shortint/server_key/neg.rs new file mode 100644 index 000000000..660a4dc90 --- /dev/null +++ b/tfhe/src/shortint/server_key/neg.rs @@ -0,0 +1,241 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Homomorphically negates a message without checks. + /// + /// Negation here means the opposite value in the modulo set. + /// + /// This function computes the opposite of a message without checking if it exceeds the + /// capacity of the ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(Parameters::default()); + /// + /// let msg = 1; + /// + /// // Encrypt a message + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// let mut ct_res = sks.unchecked_neg(&ct); + /// + /// // Decrypt + /// let three = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(modulus - msg, three); + /// ``` + pub fn unchecked_neg(&self, ct: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.unchecked_neg(self, ct).unwrap()) + } + + pub fn unchecked_neg_with_z(&self, ct: &Ciphertext) -> (Ciphertext, u64) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_neg_with_z(self, ct).unwrap() + }) + } + + /// Homomorphically negates a message inplace without checks. + /// + /// Negation here means the opposite value in the modulo set. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt a message + /// let msg = 3; + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// sks.unchecked_neg_assign(&mut ct); + /// + /// // Decrypt + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(modulus - msg, cks.decrypt(&ct)); + /// ``` + pub fn unchecked_neg_assign(&self, ct: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_neg_assign(self, ct).unwrap() + }) + } + + pub fn unchecked_neg_assign_with_z(&self, ct: &mut Ciphertext) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_neg_assign_with_z(self, ct).unwrap() + }) + } + + /// Verifies if a ciphertext can be negated. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt a message + /// let msg = 2; + /// let ct = cks.encrypt(msg); + /// + /// // Check if we can perform a negation + /// let can_be_negated = sks.is_neg_possible(&ct); + /// + /// assert_eq!(can_be_negated, true); + /// ``` + pub fn is_neg_possible(&self, ct: &Ciphertext) -> bool { + // z = ceil( degree / 2^p ) x 2^p + let msg_mod = self.message_modulus.0; + let mut z = (ct.degree.0 + msg_mod - 1) / msg_mod; + z = z.wrapping_mul(msg_mod); + + // counter = z / (2^p-1) + let counter = z / (self.message_modulus.0 - 1); + + counter <= self.max_degree.0 + } + + /// Computes homomorphically a negation of a ciphertext. + /// + /// If the operation can be performed, the result is returned a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt a message + /// let msg = 1; + /// let ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation: + /// let ct_res = sks.checked_neg(&ct); + /// + /// assert!(ct_res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct_res.unwrap()); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res, modulus - msg); + /// ``` + pub fn checked_neg(&self, ct: &Ciphertext) -> Result { + // If the ciphertext cannot be negated without exceeding the capacity of a ciphertext + if self.is_neg_possible(ct) { + let ct_result = self.unchecked_neg(ct); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a negation of a ciphertext. + /// + /// If the operation is possible, the result is stored _in_ the input ciphertext. + /// Otherwise [CheckError::CarryFull] is returned and the ciphertext is not . + /// + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt a message: + /// let msg = 1; + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically the negation: + /// let res = sks.checked_neg_assign(&mut ct); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res, modulus - msg); + /// ``` + pub fn checked_neg_assign(&self, ct: &mut Ciphertext) -> Result<(), CheckError> { + if self.is_neg_possible(ct) { + self.unchecked_neg_assign(ct); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a negation of a ciphertext. + /// + /// This checks that the negation is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt two messages: + /// let msg = 3; + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// let ct_res = sks.smart_neg(&mut ct); + /// + /// // Decrypt + /// let clear_res = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res, modulus - msg); + /// ``` + pub fn smart_neg(&self, ct: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.smart_neg(self, ct).unwrap()) + } + + /// Computes homomorphically a negation of a ciphertext. + /// + /// This checks that the addition is possible. In the case where the carry buffers are full, + /// then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// // Encrypt two messages: + /// let msg = 3; + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a negation + /// sks.smart_neg_assign(&mut ct); + /// + /// // Decrypt + /// let clear_res = cks.decrypt(&ct); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res, modulus - msg); + /// ``` + pub fn smart_neg_assign(&self, ct: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| engine.smart_neg_assign(self, ct).unwrap()) + } +} diff --git a/tfhe/src/shortint/server_key/scalar_add.rs b/tfhe/src/shortint/server_key/scalar_add.rs new file mode 100644 index 000000000..011a309b5 --- /dev/null +++ b/tfhe/src/shortint/server_key/scalar_add.rs @@ -0,0 +1,247 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Computes homomorphically an addition between a ciphertext and a scalar. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// This function does _not_ check whether the capacity of the ciphertext is exceeded. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar addition: + /// let ct_res = sks.unchecked_scalar_add(&ct, 2); + /// + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(3, clear); + /// ``` + pub fn unchecked_scalar_add(&self, ct: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_add(ct, scalar).unwrap() + }) + } + + /// Computes homomorphically an addition between a ciphertext and a scalar. + /// + /// The result it stored in the given ciphertext. + /// + /// This function does not check whether the capacity of the ciphertext is exceeded. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let mut ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar addition: + /// sks.unchecked_scalar_add_assign(&mut ct, 2); + /// + /// let clear = cks.decrypt(&ct); + /// assert_eq!(3, clear); + /// ``` + pub fn unchecked_scalar_add_assign(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_add_assign(ct, scalar).unwrap() + }) + } + + pub fn unchecked_scalar_add_assign_crt(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_scalar_add_assign_crt(self, ct, scalar) + .unwrap() + }) + } + + /// Verifies if a scalar can be added to the ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let ct = cks.encrypt(2); + /// + /// // Verification if the scalar addition can be computed: + /// let can_be_computed = sks.is_scalar_add_possible(&ct, 3); + /// + /// assert_eq!(can_be_computed, true); + /// ``` + pub fn is_scalar_add_possible(&self, ct: &Ciphertext, scalar: u8) -> bool { + let final_degree = scalar as usize + ct.degree.0; + + final_degree <= self.max_degree.0 + } + + /// Computes homomorphically an addition between a ciphertext and a scalar. + /// + /// If the operation is possible, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt a message: + /// let ct = cks.encrypt(1); + /// + /// // Compute homomorphically a addition multiplication: + /// let ct_res = sks.checked_scalar_add(&ct, 2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, 3); + /// ``` + pub fn checked_scalar_add( + &self, + ct: &Ciphertext, + scalar: u8, + ) -> Result { + //If the ciphertext cannot be multiplied without exceeding the max degree + if self.is_scalar_add_possible(ct, scalar) { + let ct_result = self.unchecked_scalar_add(ct, scalar); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between a ciphertext and a scalar. + /// + /// If the operation is possible, the result is stored _in_ the input ciphertext. + /// Otherwise [CheckError::CarryFull] is returned and the ciphertext is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt a message: + /// let mut ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar addition: + /// let res = sks.checked_scalar_add_assign(&mut ct, 2); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct); + /// assert_eq!(clear_res, 3); + /// ``` + pub fn checked_scalar_add_assign( + &self, + ct: &mut Ciphertext, + scalar: u8, + ) -> Result<(), CheckError> { + if self.is_scalar_add_possible(ct, scalar) { + self.unchecked_scalar_add_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically an addition between a ciphertext and a scalar. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// This checks that the scalar addition is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1_u64; + /// let scalar = 9_u8; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_scalar_add(&mut ct, scalar); + /// + /// // The input ciphertext content is not changed + /// assert_eq!(cks.decrypt(&ct), msg); + /// + /// // Our result is what we expect + /// let clear = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(2, clear % modulus); + /// ``` + pub fn smart_scalar_add(&self, ct: &mut Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_add(self, ct, scalar).unwrap() + }) + } + + /// Computes homomorphically an addition of a ciphertext by a scalar. + /// + /// The result is _stored_ in the `ct` ciphertext. + /// + /// This checks that the scalar addition is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1_u64; + /// let scalar = 5_u8; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.smart_scalar_add_assign(&mut ct, scalar); + /// + /// // Our result is what we expect + /// let clear = cks.decrypt_message_and_carry(&ct); + /// assert_eq!(6, clear); + /// ``` + pub fn smart_scalar_add_assign(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_add_assign(self, ct, scalar).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/scalar_mul.rs b/tfhe/src/shortint/server_key/scalar_mul.rs new file mode 100644 index 000000000..23fe420fe --- /dev/null +++ b/tfhe/src/shortint/server_key/scalar_mul.rs @@ -0,0 +1,245 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Computes homomorphically a multiplication of a ciphertext by a scalar. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// The operation is modulo the the precision bits to the power of two. + /// + /// This function does _not_ check whether the capacity of the ciphertext is exceeded. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.unchecked_scalar_mul(&ct, 3); + /// + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(3, clear); + /// ``` + pub fn unchecked_scalar_mul(&self, ct: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_mul(ct, scalar).unwrap() + }) + } + + /// Computes homomorphically a multiplication of a ciphertext by a scalar. + /// + /// The result it stored in the given ciphertext. + /// + /// The operation is modulo the the precision bits to the power of two. + /// + /// This function does not check whether the capacity of the ciphertext is exceeded. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let mut ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.unchecked_scalar_mul_assign(&mut ct, 3); + /// + /// let clear = cks.decrypt(&ct); + /// assert_eq!(3, clear); + /// ``` + pub fn unchecked_scalar_mul_assign(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_mul_assign(ct, scalar).unwrap() + }) + } + + /// Verifies if the ciphertext can be multiplied by a scalar. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let ct = cks.encrypt(2); + /// + /// // Verification if the scalar multiplication can be computed: + /// let can_be_computed = sks.is_scalar_mul_possible(&ct, 3); + /// + /// assert_eq!(can_be_computed, true); + /// ``` + pub fn is_scalar_mul_possible(&self, ct: &Ciphertext, scalar: u8) -> bool { + //scalar * ct.counter + let final_degree = scalar as usize * ct.degree.0; + + final_degree <= self.max_degree.0 + } + + /// Computes homomorphically a multiplication of a ciphertext by a scalar. + /// + /// If the operation is possible, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// The operation is modulo the precision bits to the power of two. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt a message: + /// let ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.checked_scalar_mul(&ct, 3); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, 3); + /// ``` + pub fn checked_scalar_mul( + &self, + ct: &Ciphertext, + scalar: u8, + ) -> Result { + //If the ciphertext cannot be multiplied without exceeding the degree max + if self.is_scalar_mul_possible(ct, scalar) { + let ct_result = self.unchecked_scalar_mul(ct, scalar); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication of a ciphertext by a scalar. + /// + /// If the operation is possible, the result is stored _in_ the input ciphertext. + /// Otherwise [CheckError::CarryFull] is returned and the ciphertext is not . + /// + /// The operation is modulo the precision bits to the power of two. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt a message: + /// let mut ct = cks.encrypt(1); + /// + /// // Compute homomorphically a scalar multiplication: + /// let res = sks.checked_scalar_mul_assign(&mut ct, 3); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct); + /// assert_eq!(clear_res, 3); + /// ``` + pub fn checked_scalar_mul_assign( + &self, + ct: &mut Ciphertext, + scalar: u8, + ) -> Result<(), CheckError> { + if self.is_scalar_mul_possible(ct, scalar) { + self.unchecked_scalar_mul_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a multiplication of a ciphertext by a scalar. + /// + /// This checks that the multiplication is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1_u64; + /// let scalar = 3_u8; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_scalar_mul(&mut ct, scalar); + /// + /// // The input ciphertext content is not changed + /// assert_eq!(cks.decrypt(&ct), msg); + /// + /// // Our result is what we expect + /// let clear = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(3, clear % modulus); + /// ``` + pub fn smart_scalar_mul(&self, ct: &mut Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_mul(self, ct, scalar).unwrap() + }) + } + + /// Computes homomorphically a multiplication of a ciphertext by a scalar. + /// + /// This checks that the multiplication is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 1_u64; + /// let scalar = 3_u8; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.smart_scalar_mul_assign(&mut ct, scalar); + /// + /// // Our result is what we expect + /// let clear = cks.decrypt(&ct); + /// assert_eq!(3, clear); + /// ``` + pub fn smart_scalar_mul_assign(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_mul_assign(self, ct, scalar).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/scalar_sub.rs b/tfhe/src/shortint/server_key/scalar_sub.rs new file mode 100644 index 000000000..0e278604a --- /dev/null +++ b/tfhe/src/shortint/server_key/scalar_sub.rs @@ -0,0 +1,239 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// This function does _not_ check whether the capacity of the ciphertext is exceeded. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let ct = cks.encrypt(5); + /// + /// // Compute homomorphically a scalar subtraction: + /// let ct_res = sks.unchecked_scalar_sub(&ct, 6); + /// + /// // 5 - 6 mod 4 = 3 mod 4 + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!(3, clear); + /// ``` + pub fn unchecked_scalar_sub(&self, ct: &Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_sub(ct, scalar).unwrap() + }) + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// The result it stored in the given ciphertext. + /// + /// This function does not check whether the capacity of the ciphertext is exceeded. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let mut ct = cks.encrypt(5); + /// + /// // Compute homomorphically a scalar subtraction: + /// sks.unchecked_scalar_sub_assign(&mut ct, 2); + /// + /// let clear = cks.decrypt(&ct); + /// assert_eq!(3, clear); + /// ``` + pub fn unchecked_scalar_sub_assign(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_sub_assign(ct, scalar).unwrap() + }) + } + + /// Verifies if a scalar can be subtracted to the ciphertext. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let ct = cks.encrypt(5); + /// + /// // Verification if the scalar subtraction can be computed: + /// let can_be_computed = sks.is_scalar_sub_possible(&ct, 3); + /// + /// assert_eq!(can_be_computed, true); + /// ``` + pub fn is_scalar_sub_possible(&self, ct: &Ciphertext, scalar: u8) -> bool { + let neg_scalar = u64::from(scalar.wrapping_neg()) % self.message_modulus.0 as u64; + let final_degree = neg_scalar as usize + ct.degree.0; + final_degree <= self.max_degree.0 + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation is possible, the result is returned in a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt a message: + /// let ct = cks.encrypt(5); + /// + /// // Compute homomorphically a subtraction multiplication: + /// let ct_res = sks.checked_scalar_sub(&ct, 2); + /// + /// assert!(ct_res.is_ok()); + /// + /// let ct_res = ct_res.unwrap(); + /// let clear_res = cks.decrypt(&ct_res); + /// assert_eq!(clear_res, 3); + /// ``` + pub fn checked_scalar_sub( + &self, + ct: &Ciphertext, + scalar: u8, + ) -> Result { + //If the scalar subtraction cannot be done without exceeding the max degree + if self.is_scalar_sub_possible(ct, scalar) { + let ct_result = self.unchecked_scalar_sub(ct, scalar); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// If the operation is possible, the result is stored _in_ the input ciphertext. + /// Otherwise [CheckError::CarryFull] is returned and the ciphertext is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt a message: + /// let mut ct = cks.encrypt(5); + /// + /// // Compute homomorphically a scalar subtraction: + /// let res = sks.checked_scalar_sub_assign(&mut ct, 2); + /// + /// assert!(res.is_ok()); + /// + /// let clear_res = cks.decrypt(&ct); + /// assert_eq!(clear_res, 3); + /// ``` + pub fn checked_scalar_sub_assign( + &self, + ct: &mut Ciphertext, + scalar: u8, + ) -> Result<(), CheckError> { + if self.is_scalar_sub_possible(ct, scalar) { + self.unchecked_scalar_sub_assign(ct, scalar); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// This checks that the scalar subtraction is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 3; + /// let scalar = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// let ct_res = sks.smart_scalar_sub(&mut ct, scalar); + /// + /// // The input ciphertext content is not changed + /// assert_eq!(cks.decrypt(&ct), msg); + /// + /// // Our result is what we expect + /// let clear = cks.decrypt(&ct_res); + /// + /// assert_eq!(msg - scalar as u64, clear); + /// ``` + pub fn smart_scalar_sub(&self, ct: &mut Ciphertext, scalar: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_sub(self, ct, scalar).unwrap() + }) + } + + /// Computes homomorphically a subtraction of a ciphertext by a scalar. + /// + /// The result is _stored_ in the `ct` ciphertext. + /// + /// This checks that the scalar subtraction is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 5; + /// let scalar = 3; + /// + /// let mut ct = cks.encrypt(msg); + /// + /// // Compute homomorphically a scalar multiplication: + /// sks.smart_scalar_sub_assign(&mut ct, scalar); + /// + /// // Our result is what we expect + /// let clear = cks.decrypt(&ct); + /// assert_eq!(msg - scalar as u64, clear); + /// ``` + pub fn smart_scalar_sub_assign(&self, ct: &mut Ciphertext, scalar: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_sub_assign(self, ct, scalar).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/shift.rs b/tfhe/src/shortint/server_key/shift.rs new file mode 100644 index 000000000..1c2751766 --- /dev/null +++ b/tfhe/src/shortint/server_key/shift.rs @@ -0,0 +1,306 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Computes homomorphically a right shift of the bits without checks. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// let ct = cks.encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Compute homomorphically a right shift + /// let shift: u8 = 1; + /// let ct_res = sks.unchecked_scalar_right_shift(&ct, shift); + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 1 | + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(msg >> shift, dec); + /// ``` + pub fn unchecked_scalar_right_shift(&self, ct: &Ciphertext, shift: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_scalar_right_shift(self, ct, shift) + .unwrap() + }) + } + + /// Computes homomorphically a right shift of the bits without checks. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// let mut ct = cks.encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Compute homomorphically a right shift + /// let shift: u8 = 1; + /// sks.unchecked_scalar_right_shift_assign(&mut ct, shift); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 0 1 | + /// + /// // Decrypt: + /// let dec = cks.decrypt(&ct); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(msg >> shift, dec); + /// ``` + pub fn unchecked_scalar_right_shift_assign(&self, ct: &mut Ciphertext, shift: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_scalar_right_shift_assign(self, ct, shift) + .unwrap() + }) + } + + /// Computes homomorphically a left shift of the bits without checks. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// + /// let ct = cks.encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Compute homomorphically a left shift + /// let shift: u8 = 1; + /// let ct_res = sks.unchecked_scalar_left_shift(&ct, shift); + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 1 | 0 0 | + /// + /// // Decrypt: + /// let msg_and_carry = cks.decrypt_message_and_carry(&ct_res); + /// let msg_only = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// + /// assert_eq!(msg << shift, msg_and_carry); + /// assert_eq!((msg << shift) % modulus, msg_only); + /// ``` + pub fn unchecked_scalar_left_shift(&self, ct: &Ciphertext, shift: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_scalar_left_shift(ct, shift).unwrap() + }) + } + + /// Computes homomorphically a left shift of the bits without checks + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// let mut ct = cks.encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Compute homomorphically a left shift + /// let shift: u8 = 1; + /// sks.unchecked_scalar_left_shift_assign(&mut ct, shift); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 1 | 0 0 | + /// + /// // Decrypt: + /// let msg_and_carry = cks.decrypt_message_and_carry(&ct); + /// let msg_only = cks.decrypt(&ct); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// + /// assert_eq!(msg << shift, msg_and_carry); + /// assert_eq!((msg << shift) % modulus, msg_only); + /// ``` + pub fn unchecked_scalar_left_shift_assign(&self, ct: &mut Ciphertext, shift: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_scalar_left_shift_assign(ct, shift) + .unwrap() + }) + } + + /// Checks if the left shift operation can be applied. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::{gen_keys, Parameters}; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(Parameters::default()); + /// + /// let msg = 2; + /// let shift = 5; + /// let ct1 = cks.encrypt(msg); + /// + /// // Check if we can perform an addition + /// let res = sks.is_scalar_left_shift_possible(&ct1, shift); + /// + /// assert_eq!(false, res); + /// ``` + pub fn is_scalar_left_shift_possible(&self, ct1: &Ciphertext, shift: u8) -> bool { + let final_operation_count = ct1.degree.0 << shift as usize; + final_operation_count <= self.max_degree.0 + } + + /// Computes homomorphically a left shift of the bits. + /// + /// If the operation can be performed, a new ciphertext with the result is returned. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// + /// let ct1 = cks.encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// // Shifting 3 times is not ok, as it exceeds the carry buffer + /// let ct_res = sks.checked_scalar_left_shift(&ct1, 3); + /// assert!(ct_res.is_err()); + /// + /// // Shifting 2 times is ok + /// let shift = 2; + /// let ct_res = sks.checked_scalar_left_shift(&ct1, shift); + /// assert!(ct_res.is_ok()); + /// let ct_res = ct_res.unwrap(); + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 1 0 | 0 0 | + /// + /// // Decrypt: + /// let msg_and_carry = cks.decrypt_message_and_carry(&ct_res); + /// let msg_only = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// + /// assert_eq!(msg << shift, msg_and_carry); + /// assert_eq!((msg << shift) % modulus, msg_only); + /// ``` + pub fn checked_scalar_left_shift( + &self, + ct: &Ciphertext, + shift: u8, + ) -> Result { + if self.is_scalar_left_shift_possible(ct, shift) { + let ct_result = self.unchecked_scalar_left_shift(ct, shift); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + pub fn checked_scalar_left_shift_assign( + &self, + ct: &mut Ciphertext, + shift: u8, + ) -> Result<(), CheckError> { + if self.is_scalar_left_shift_possible(ct, shift) { + self.unchecked_scalar_left_shift_assign(ct, shift); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a left shift of the bits + /// + /// This checks that the operation is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// let mut ct = cks.encrypt(msg); + /// // | ct | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 0 | 1 0 | + /// + /// let shift: u8 = 1; + /// let ct_res = sks.smart_scalar_left_shift(&mut ct, shift); + /// // | ct_res | + /// // | carry | message | + /// // |-------|---------| + /// // | 0 1 | 0 0 | + /// + /// // Decrypt: + /// let msg_and_carry = cks.decrypt_message_and_carry(&ct_res); + /// let msg_only = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// + /// assert_eq!(msg << shift, msg_and_carry); + /// assert_eq!((msg << shift) % modulus, msg_only); + /// ``` + pub fn smart_scalar_left_shift(&self, ct: &mut Ciphertext, shift: u8) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_scalar_left_shift(self, ct, shift).unwrap() + }) + } + + pub fn smart_scalar_left_shift_assign(&self, ct: &mut Ciphertext, shift: u8) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .smart_scalar_left_shift_assign(self, ct, shift) + .unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/sub.rs b/tfhe/src/shortint/server_key/sub.rs new file mode 100644 index 000000000..aaab8420b --- /dev/null +++ b/tfhe/src/shortint/server_key/sub.rs @@ -0,0 +1,303 @@ +use super::ServerKey; +use crate::shortint::engine::ShortintEngine; +use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::CheckError::CarryFull; +use crate::shortint::Ciphertext; + +impl ServerKey { + /// Homomorphically subtracts ct_right to ct_left. + /// + /// The result is returned in a _new_ ciphertext. + /// + /// This function computes the subtraction without checking + /// if it exceeds the capacity of the ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(2); + /// let ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a subtraction: + /// let ct_res = sks.unchecked_sub(&ct_1, &ct_2); + /// + /// // Decrypt: + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(cks.decrypt(&ct_res), 2 - 1); + /// ``` + pub fn unchecked_sub(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.unchecked_sub(self, ct_left, ct_right).unwrap() + }) + } + + /// Homomorphically subtracts ct_right to ct_left. + /// + /// The result is assigned in the `ct_left` ciphertext. + /// + /// This function computes the subtraction without checking + /// if it exceeds the capacity of the ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let mut ct_1 = cks.encrypt(2); + /// let ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a subtraction: + /// sks.unchecked_sub_assign(&mut ct_1, &ct_2); + /// + /// // Decrypt: + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(cks.decrypt(&ct_1) % modulus, 1); + /// ``` + pub fn unchecked_sub_assign(&self, ct_left: &mut Ciphertext, ct_right: &Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_sub_assign(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Verifies if ct_right can be subtracted to ct_left. + /// + /// # Example + /// + ///```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// let msg = 2; + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(msg); + /// let ct_2 = cks.encrypt(msg); + /// + /// // Check if we can perform an subtraction + /// let can_be_subtracted = sks.is_sub_possible(&ct_1, &ct_2); + /// + /// assert_eq!(true, can_be_subtracted); + /// ``` + pub fn is_sub_possible(&self, ct_left: &Ciphertext, ct_right: &Ciphertext) -> bool { + // z = ceil( degree / 2^p ) x 2^p + let msg_mod = self.message_modulus.0; + let mut z = (ct_right.degree.0 + msg_mod - 1) / msg_mod; + z = z.wrapping_mul(msg_mod); + + let final_operation_count = ct_left.degree.0 + z; + + final_operation_count <= self.max_degree.0 + } + + /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. + /// + /// If the operation can be performed, the result is returned a _new_ ciphertext. + /// Otherwise [CheckError::CarryFull] is returned. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let ct_1 = cks.encrypt(3); + /// let ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a subtraction: + /// let ct_res = sks.checked_sub(&ct_1, &ct_2); + /// + /// assert!(ct_res.is_ok()); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// let clear_res = cks.decrypt(&ct_res.unwrap()); + /// assert_eq!(clear_res % modulus, 2); + /// ``` + pub fn checked_sub( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> Result { + // If the ciphertexts cannot be subtracted without exceeding the degree max + if self.is_sub_possible(ct_left, ct_right) { + let ct_result = self.unchecked_sub(ct_left, ct_right); + Ok(ct_result) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction between two ciphertexts. + /// + /// If the operation can be performed, the result is stored in the `ct_left` ciphertext. + /// Otherwise [CheckError::CarryFull] is returned, and `ct_left` is not modified. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let mut ct_1 = cks.encrypt(3); + /// let ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a subtraction: + /// let res = sks.checked_sub_assign(&mut ct_1, &ct_2); + /// + /// assert!(res.is_ok()); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// let clear_res = cks.decrypt(&ct_1); + /// assert_eq!(clear_res % modulus, 2); + /// ``` + pub fn checked_sub_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> Result<(), CheckError> { + // If the ciphertexts cannot be subtracted without exceeding the degree max + if self.is_sub_possible(ct_left, ct_right) { + self.unchecked_sub_assign(ct_left, ct_right); + Ok(()) + } else { + Err(CarryFull) + } + } + + /// Computes homomorphically a subtraction between two ciphertexts. + /// + /// This checks that the subtraction is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let mut ct_1 = cks.encrypt(3); + /// let mut ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a subtraction: + /// let ct_res = sks.smart_sub(&mut ct_1, &mut ct_2); + /// + /// let clear_res = cks.decrypt(&ct_res); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(clear_res % modulus, 2); + /// ``` + pub fn smart_sub(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_sub(self, ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically a subtraction between two ciphertexts. + /// + /// This checks that the subtraction is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// + /// // Encrypt two messages: + /// let mut ct_1 = cks.encrypt(3); + /// let mut ct_2 = cks.encrypt(1); + /// + /// // Compute homomorphically a subtraction: + /// sks.smart_sub_assign(&mut ct_1, &mut ct_2); + /// let modulus = cks.parameters.message_modulus.0 as u64; + /// assert_eq!(cks.decrypt(&ct_1) % modulus, 2); + /// ``` + pub fn smart_sub_assign(&self, ct_left: &mut Ciphertext, ct_right: &mut Ciphertext) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_sub_assign(self, ct_left, ct_right).unwrap() + }) + } + + /// Computes homomorphically a subtraction between two ciphertexts without checks, and returns + /// a correcting term. + /// + /// This checks that the subtraction is possible. In the case where the carry buffers are + /// full, then it is automatically cleared to allow the operation. + /// + /// # Warning + /// + /// This is an advanced functionality, needed for internal requirements. + pub fn unchecked_sub_with_correcting_term( + &self, + ct_left: &Ciphertext, + ct_right: &Ciphertext, + ) -> (Ciphertext, u64) { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_sub_with_z(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Computes homomorphically a subtraction between two ciphertexts without checks, and returns + /// a correcting term. + /// + /// # Warning + /// + /// This is an advanced functionality, needed for internal requirements. + pub fn unchecked_sub_with_correcting_term_assign( + &self, + ct_left: &mut Ciphertext, + ct_right: &Ciphertext, + ) -> u64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .unchecked_sub_assign_with_z(self, ct_left, ct_right) + .unwrap() + }) + } + + /// Computes homomorphically a subtraction between two ciphertexts without checks, and returns + /// a correcting term. + /// + /// # Warning + /// + /// This is an advanced functionality, needed for internal requirements. + pub fn smart_sub_with_correcting_term( + &self, + ct_left: &mut Ciphertext, + ct_right: &mut Ciphertext, + ) -> (Ciphertext, u64) { + ShortintEngine::with_thread_local_mut(|engine| { + engine.smart_sub_with_z(self, ct_left, ct_right).unwrap() + }) + } +} diff --git a/tfhe/src/shortint/server_key/tests.rs b/tfhe/src/shortint/server_key/tests.rs new file mode 100644 index 000000000..5f2dc3cb9 --- /dev/null +++ b/tfhe/src/shortint/server_key/tests.rs @@ -0,0 +1,1820 @@ +use crate::shortint::keycache::KEY_CACHE; +use crate::shortint::parameters::*; +use paste::paste; +use rand::Rng; + +/// Number of assert in randomized tests +const NB_TEST: usize = 30; + +// Macro to generate tests for all parameter sets +macro_rules! create_parametrized_test{ + ($name:ident { $($param:ident),* }) => { + paste! { + $( + #[test] + fn []() { + $name($param) + } + )* + } + }; + ($name:ident)=> { + create_parametrized_test!($name + { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_1_CARRY_2, + PARAM_MESSAGE_1_CARRY_3, + PARAM_MESSAGE_1_CARRY_4, + PARAM_MESSAGE_1_CARRY_5, + PARAM_MESSAGE_1_CARRY_6, + PARAM_MESSAGE_1_CARRY_7, + PARAM_MESSAGE_2_CARRY_1, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_2_CARRY_3, + PARAM_MESSAGE_2_CARRY_4, + PARAM_MESSAGE_2_CARRY_5, + PARAM_MESSAGE_2_CARRY_6, + PARAM_MESSAGE_3_CARRY_1, + PARAM_MESSAGE_3_CARRY_2, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_3_CARRY_4, + PARAM_MESSAGE_3_CARRY_5, + PARAM_MESSAGE_4_CARRY_1, + PARAM_MESSAGE_4_CARRY_2, + PARAM_MESSAGE_4_CARRY_3, + PARAM_MESSAGE_4_CARRY_4, + PARAM_MESSAGE_5_CARRY_1, + PARAM_MESSAGE_5_CARRY_2, + PARAM_MESSAGE_5_CARRY_3, + PARAM_MESSAGE_6_CARRY_1, + PARAM_MESSAGE_6_CARRY_2, + PARAM_MESSAGE_7_CARRY_1 + }); + }; +} + +//Macro to generate tests for parameters sets compatible with the bivariate pbs +macro_rules! create_parametrized_test_bivariate_pbs_compliant{ + ($name:ident { $($param:ident),* }) => { + paste! { + $( + #[test] + fn []() { + $name($param) + } + )* + } + }; + ($name:ident)=> { + create_parametrized_test!($name + { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_1_CARRY_2, + PARAM_MESSAGE_1_CARRY_3, + PARAM_MESSAGE_1_CARRY_4, + PARAM_MESSAGE_1_CARRY_5, + PARAM_MESSAGE_1_CARRY_6, + PARAM_MESSAGE_1_CARRY_7, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_2_CARRY_3, + PARAM_MESSAGE_2_CARRY_4, + PARAM_MESSAGE_2_CARRY_5, + PARAM_MESSAGE_2_CARRY_6, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_3_CARRY_4, + PARAM_MESSAGE_3_CARRY_5, + PARAM_MESSAGE_4_CARRY_4 + }); + }; +} + +//These functions are compatible with all parameter sets. +create_parametrized_test!(shortint_encrypt_decrypt); +create_parametrized_test!(shortint_encrypt_with_message_modulus_decrypt); +create_parametrized_test!(shortint_encrypt_decrypt_without_padding); +create_parametrized_test!(shortint_keyswitch_bootstrap); +create_parametrized_test!(shortint_keyswitch_programmable_bootstrap); +create_parametrized_test!(shortint_carry_extract); +create_parametrized_test!(shortint_message_extract); +create_parametrized_test!(shortint_generate_accumulator); +create_parametrized_test!(shortint_unchecked_add); +create_parametrized_test!(shortint_smart_add); +create_parametrized_test!(shortint_smart_mul_lsb); +create_parametrized_test!(shortint_unchecked_neg); +create_parametrized_test!(shortint_smart_neg); +create_parametrized_test!(shortint_unchecked_scalar_add); +create_parametrized_test!(shortint_smart_scalar_add); +create_parametrized_test!(shortint_unchecked_scalar_sub); +create_parametrized_test!(shortint_smart_scalar_sub); +create_parametrized_test!(shortint_unchecked_scalar_mul); +create_parametrized_test!(shortint_smart_scalar_mul); +create_parametrized_test!(shortint_unchecked_right_shift); +create_parametrized_test!(shortint_unchecked_left_shift); +create_parametrized_test!(shortint_unchecked_sub); +create_parametrized_test!(shortint_smart_sub); +create_parametrized_test!(shortint_mul_small_carry); + +//These functions are compatible with some parameter sets where the carry modulus is larger than +// the message modulus. +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_bitand); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_bitor); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_bitxor); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_greater); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_greater_or_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_less); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_less_or_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_bitand); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_bitor); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_bitxor); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_greater); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_greater_or_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_less); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_less_or_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_scalar_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_scalar_less); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_scalar_less_or_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_scalar_greater); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_scalar_greater_or_equal); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_div); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_scalar_div); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_mod); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_mul_lsb); +create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_mul_msb); +create_parametrized_test_bivariate_pbs_compliant!(shortint_smart_mul_msb); +create_parametrized_test_bivariate_pbs_compliant!( + shortint_keyswitch_bivariate_programmable_bootstrap +); +create_parametrized_test_bivariate_pbs_compliant!( + shortint_encrypt_with_message_modulus_smart_add_and_mul +); + +/// test encryption and decryption with the LWE client key +fn shortint_encrypt_decrypt(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let cks = keys.client_key(); + + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let ct = cks.encrypt(clear); + + // decryption of ct_zero + let dec = cks.decrypt(&ct); + + // assert + assert_eq!(clear, dec); + } +} + +/// test encryption and decryption with the LWE client key +fn shortint_encrypt_with_message_modulus_decrypt(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let cks = keys.client_key(); + + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST { + let mut modulus = rng.gen::() % cks.parameters.message_modulus.0 as u64; + while modulus == 0 { + modulus = rng.gen::() % cks.parameters.message_modulus.0 as u64; + } + + let clear = rng.gen::() % modulus; + + let ct = cks.encrypt_with_message_modulus(clear, MessageModulus(modulus as usize)); + + // decryption of ct_zero + let dec = cks.decrypt(&ct); + + // assert + assert_eq!(clear, dec); + } +} + +fn shortint_encrypt_decrypt_without_padding(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let cks = keys.client_key(); + + let mut rng = rand::thread_rng(); + + // We assume that the modulus is the largest possible without padding bit + let modulus = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + let ct = cks.encrypt_without_padding(clear); + + // decryption of ct_zero + let dec = cks.decrypt_message_and_carry_without_padding(&ct); + + // assert + assert_eq!(clear, dec); + } +} + +fn shortint_keyswitch_bootstrap(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mut failures = 0; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // keyswitch and bootstrap + let ct_res = sks.keyswitch_bootstrap(&ctxt_0); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + if clear_0 != dec_res { + failures += 1; + } + // assert + // assert_eq!(clear_0, dec_res); + } + + println!("fail_rate = {}/{}", failures, 100); + assert_eq!(0, failures); +} + +fn shortint_keyswitch_programmable_bootstrap(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + //define the accumulator as identity + let acc = sks.generate_accumulator(|n| n % modulus); + // add the two ciphertexts + let ct_res = sks.keyswitch_programmable_bootstrap(&ctxt_0, &acc); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0, dec_res); + } +} + +fn shortint_keyswitch_bivariate_programmable_bootstrap(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + let ctxt_1 = cks.encrypt(clear_1); + //define the accumulator as identity + let acc = sks.generate_accumulator_bivariate(|x, y| x * 2 * y % modulus); + // add the two ciphertexts + let ct_res = sks.keyswitch_programmable_bootstrap_bivariate(&ctxt_0, &ctxt_1, &acc); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((2 * clear_0 * clear_1) % modulus, dec_res); + } +} + +/// test extraction of a carry +fn shortint_carry_extract(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let full_modulus = + cks.parameters.message_modulus.0 as u64 + cks.parameters.carry_modulus.0 as u64; + let msg_modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + // shift to the carry bits + let clear = rng.gen::() % full_modulus; + + // unchecked encryption of the message to have a larger message encrypted. + let ctxt = cks.unchecked_encrypt(clear); + + // extract the carry + let ct_carry = sks.carry_extract(&ctxt); + + // decryption of message and carry + let dec = cks.decrypt_message_and_carry(&ct_carry); + + // assert + println!( + "msg = {}, modulus = {}, msg/modulus = {}", + clear, + msg_modulus, + clear / msg_modulus + ); + assert_eq!(clear / msg_modulus, dec); + } +} + +/// test extraction of a message +fn shortint_message_extract(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus_sup = (param.message_modulus.0 * param.carry_modulus.0) as u64; + + let modulus = param.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus_sup; + + // encryption of an integer + let ctxt = cks.unchecked_encrypt(clear); + + // message extraction + let ct_msg = sks.message_extract(&ctxt); + + // decryption of ct_msg + let dec = cks.decrypt(&ct_msg); + + // assert + assert_eq!(clear % modulus, dec); + } +} + +/// test multiplication with the LWE server key +fn shortint_generate_accumulator(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + let double = |x| 2 * x; + let acc = sks.generate_accumulator(double); + + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % modulus; + + // encryption of an integer + let ct = cks.encrypt(clear); + + let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear * 2) % modulus, dec_res); + } +} + +/// test addition with the LWE server key +fn shortint_unchecked_add(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_add(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + println!( + "The parameters set is CARRY_{}_MESSAGE_{}", + cks.parameters.carry_modulus.0, cks.parameters.message_modulus.0 + ); + assert_eq!((clear_0 + clear_1) % modulus, dec_res); + } +} + +/// test addition with the LWE server key +fn shortint_smart_add(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..10 { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_add(&mut ctxt_0, &mut ctxt_1); + let mut clear = clear_0 + clear_1; + + //add multiple times to raise the degree and test the smart operation + for _ in 0..40 { + ct_res = sks.smart_add(&mut ct_res, &mut ctxt_0); + clear += clear_0; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear % modulus, dec_res); + } + } +} + +/// test bitwise 'and' with the LWE server key +fn shortint_unchecked_bitand(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_bitand(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 & clear_1, dec_res); + } +} + +/// test bitwise 'or' with the LWE server key +fn shortint_unchecked_bitor(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_bitor(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 | clear_1, dec_res); + } +} + +/// test bitwise 'xor' with the LWE server key +fn shortint_unchecked_bitxor(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_bitxor(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 ^ clear_1, dec_res); + } +} + +/// test bitwise 'and' with the LWE server key +fn shortint_smart_bitand(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 *= scalar as u64; + clear_1 *= scalar as u64; + + // add the two ciphertexts + let ct_res = sks.smart_bitand(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 & clear_1) % modulus, dec_res); + } +} + +/// test bitwise 'or' with the LWE server key +fn shortint_smart_bitor(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 *= scalar as u64; + clear_1 *= scalar as u64; + + // add the two ciphertexts + let ct_res = sks.smart_bitor(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 | clear_1) % modulus, dec_res); + } +} + +/// test bitwise 'xor' with the LWE server key +fn shortint_smart_bitxor(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 *= scalar as u64; + clear_1 *= scalar as u64; + + // add the two ciphertexts + let ct_res = sks.smart_bitxor(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 ^ clear_1) % modulus, dec_res); + } +} + +/// test '>' with the LWE server key +fn shortint_unchecked_greater(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_greater(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 > clear_1) as u64, dec_res); + } +} + +/// test '>' with the LWE server key +fn shortint_smart_greater(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.smart_greater(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 > clear_1) as u64, dec_res); + } +} + +/// test '>=' with the LWE server key +fn shortint_unchecked_greater_or_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_greater_or_equal(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 >= clear_1) as u64, dec_res); + } +} + +/// test '>=' with the LWE server key +fn shortint_smart_greater_or_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 = (clear_0 * scalar as u64) % modulus; + clear_1 = (clear_1 * scalar as u64) % modulus; + + // add the two ciphertexts + let ct_res = sks.smart_greater_or_equal(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 >= clear_1) as u64, dec_res); + } +} + +/// test '<' with the LWE server key +fn shortint_unchecked_less(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_less(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 < clear_1) as u64, dec_res); + } +} + +/// test '<' with the LWE server key +fn shortint_smart_less(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 = (clear_0 * scalar as u64) % modulus; + clear_1 = (clear_1 * scalar as u64) % modulus; + + // add the two ciphertexts + let ct_res = sks.smart_less(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 < clear_1) as u64, dec_res); + } +} + +/// test '<=' with the LWE server key +fn shortint_unchecked_less_or_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_less_or_equal(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 <= clear_1) as u64, dec_res); + } +} + +/// test '<=' with the LWE server key +fn shortint_smart_less_or_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 *= scalar as u64; + clear_1 *= scalar as u64; + + // add the two ciphertexts + let ct_res = sks.smart_less_or_equal(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(((clear_0 % modulus) <= (clear_1 % modulus)) as u64, dec_res); + } +} + +fn shortint_unchecked_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_equal(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 == clear_1) as u64, dec_res); + } +} + +/// test '==' with the LWE server key +fn shortint_smart_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + let mod_scalar = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let mut clear_0 = rng.gen::() % modulus; + let mut clear_1 = rng.gen::() % modulus; + let scalar = rng.gen::() % mod_scalar; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + sks.unchecked_scalar_mul_assign(&mut ctxt_0, scalar); + sks.unchecked_scalar_mul_assign(&mut ctxt_1, scalar); + + clear_0 *= scalar as u64; + clear_1 *= scalar as u64; + + // add the two ciphertexts + let ct_res = sks.smart_equal(&mut ctxt_0, &mut ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(((clear_0 % modulus) == (clear_1 % modulus)) as u64, dec_res); + } +} + +/// test '==' with the LWE server key +fn shortint_smart_scalar_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let msg_modulus = cks.parameters.message_modulus.0 as u64; + let modulus = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % msg_modulus; + + let scalar = (rng.gen::() % modulus as u16) as u8; + + // encryption of an integer + let ctxt = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.smart_scalar_equal(&ctxt, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear == scalar as u64) as u64, dec_res); + } +} + +/// test '<' with the LWE server key +fn shortint_smart_scalar_less(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let msg_modulus = cks.parameters.message_modulus.0 as u64; + let modulus = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % msg_modulus; + + let scalar = (rng.gen::() % modulus as u16) as u8; + + // encryption of an integer + let ctxt = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.smart_scalar_less(&ctxt, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear < scalar as u64) as u64, dec_res); + } +} + +/// test '<=' with the LWE server key +fn shortint_smart_scalar_less_or_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let msg_modulus = cks.parameters.message_modulus.0 as u64; + let modulus = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % msg_modulus; + + let scalar = (rng.gen::() % modulus as u16) as u8; + + // encryption of an integer + let ctxt = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.smart_scalar_less_or_equal(&ctxt, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear <= scalar as u64) as u64, dec_res); + } +} + +/// test '>' with the LWE server key +fn shortint_smart_scalar_greater(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let msg_modulus = cks.parameters.message_modulus.0 as u64; + let modulus = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % msg_modulus; + + let scalar = (rng.gen::() % modulus as u16) as u8; + + // encryption of an integer + let ctxt = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.smart_scalar_greater(&ctxt, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear > scalar as u64) as u64, dec_res); + } +} + +/// test '>' with the LWE server key +fn shortint_smart_scalar_greater_or_equal(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let msg_modulus = cks.parameters.message_modulus.0 as u64; + let modulus = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) as u64; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % msg_modulus; + + let scalar = (rng.gen::() % modulus as u16) as u8; + + // encryption of an integer + let ctxt = cks.encrypt(clear); + + // add the two ciphertexts + let ct_res = sks.smart_scalar_greater_or_equal(&ctxt, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear >= scalar as u64) as u64, dec_res); + } +} + +/// test division with the LWE server key +fn shortint_unchecked_div(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = (rng.gen::() % (modulus - 1)) + 1; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_div(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 / clear_1, dec_res); + } +} + +/// test scalar division with the LWE server key +fn shortint_unchecked_scalar_div(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = (rng.gen::() % (modulus - 1)) + 1; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_div(&ctxt_0, clear_1 as u8); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 / clear_1, dec_res); + } +} + +/// test modulus with the LWE server key +fn shortint_unchecked_mod(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = (rng.gen::() % (modulus - 1)) + 1; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_mod(&ctxt_0, clear_1 as u8); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 % clear_1, dec_res); + } +} + +/// test LSB multiplication with the LWE server key +fn shortint_unchecked_mul_lsb(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_mul_lsb(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 * clear_1) % modulus, dec_res); + } +} + +/// test MSB multiplication with the LWE server key +fn shortint_unchecked_mul_msb(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let ct_res = sks.unchecked_mul_msb(&ctxt_0, &ctxt_1); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 * clear_1) / modulus, dec_res); + } +} + +/// test LSB multiplication with the LWE server key +fn shortint_smart_mul_lsb(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..10 { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_mul_lsb(&mut ctxt_0, &mut ctxt_1); + + let mut clear = clear_0 * clear_1; + + //add multiple times to raise the degree + for _ in 0..30 { + ct_res = sks.smart_mul_lsb(&mut ct_res, &mut ctxt_0); + clear = (clear * clear_0) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear, dec_res); + } + } +} + +/// test MSB multiplication with the LWE server key +fn shortint_smart_mul_msb(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..10 { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_1 = cks.encrypt(clear_1); + + // add the two ciphertexts + let mut ct_res = sks.smart_mul_msb(&mut ctxt_0, &mut ctxt_1); + + let mut clear = (clear_0 * clear_1) / modulus; + + // let dec_res = cks.decrypt(&ct_res); + // println!("clear_0 = {}, clear_1 = {}, dec = {}, clear = {}", clear_0, clear_1, dec_res, + // clear); + + //add multiple times to raise the degree + for _ in 0..30 { + ct_res = sks.smart_mul_msb(&mut ct_res, &mut ctxt_0); + clear = (clear * clear_0) / modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear % modulus, dec_res); + } + } +} + +/// test unchecked negation +fn shortint_unchecked_neg(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + // Define the cleartexts + let clear = rng.gen::() % modulus; + + // Encrypt the integers + let ctxt = cks.encrypt(clear); + + // Negates the ctxt + let ct_tmp = sks.unchecked_neg(&ctxt); + + // Decrypt the result + let dec = cks.decrypt(&ct_tmp); + + // Check the correctness + let clear_result = clear.wrapping_neg() % modulus; + + assert_eq!(clear_result, dec); + } +} + +/// test smart negation +fn shortint_smart_neg(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..10 { + let clear1 = rng.gen::() % modulus; + + let mut ct1 = cks.encrypt(clear1); + + let mut ct_res = sks.smart_neg(&mut ct1); + + let mut clear_result = clear1.wrapping_neg() % modulus; + + for _ in 0..30 { + // scalar multiplication + ct_res = sks.smart_neg(&mut ct_res); + + clear_result = clear_result.wrapping_neg() % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_result, dec_res); + } + } +} + +/// test scalar add +fn shortint_unchecked_scalar_add(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let message_modulus = param.message_modulus.0 as u8; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % message_modulus as u8; + + let scalar = rng.gen::() % message_modulus as u8; + + // encryption of an integer + let ct = cks.encrypt(clear as u64); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_add(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear + scalar) % message_modulus, dec_res as u8); + } +} + +/// test smart scalar add +fn shortint_smart_scalar_add(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u8; + + for _ in 0..10 { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0 as u64); + + // add the two ciphertexts + let mut ct_res = sks.smart_scalar_add(&mut ctxt_0, clear_1); + + let mut clear = (clear_0 + clear_1) % modulus; + + //add multiple times to raise the degree + for _ in 0..30 { + ct_res = sks.smart_scalar_add(&mut ct_res, clear_1); + clear = (clear + clear_1) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + assert_eq!(clear, dec_res as u8); + } + } +} + +/// test unchecked scalar sub +fn shortint_unchecked_scalar_sub(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let message_modulus = param.message_modulus.0 as u8; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % message_modulus; + + let scalar = rng.gen::() % message_modulus; + + // encryption of an integer + let ct = cks.encrypt(clear as u64); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_sub(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear - scalar) % message_modulus, dec_res as u8); + } +} + +fn shortint_smart_scalar_sub(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u8; + + for _ in 0..10 { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_0 = cks.encrypt(clear_0 as u64); + + // add the two ciphertexts + let mut ct_res = sks.smart_scalar_sub(&mut ctxt_0, clear_1); + + let mut clear = (clear_0 - clear_1) % modulus; + + // let dec_res = cks.decrypt(&ct_res); + // println!("clear_0 = {}, clear_1 = {}, dec = {}, clear = {}", clear_0, clear_1, dec_res, + // clear); + + //add multiple times to raise the degree + for _ in 0..30 { + ct_res = sks.smart_scalar_sub(&mut ct_res, clear_1); + clear = (clear - clear_1) % modulus; + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // println!("clear_1 = {}, dec = {}, clear = {}", clear_1, dec_res, clear); + // assert + assert_eq!(clear, dec_res as u8); + } + } +} + +/// test scalar multiplication with the LWE server key +fn shortint_unchecked_scalar_mul(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + + let message_modulus = param.message_modulus.0 as u8; + let carry_modulus = param.carry_modulus.0 as u8; + + for _ in 0..NB_TEST { + let clear = rng.gen::() % message_modulus; + + let scalar = rng.gen::() % carry_modulus; + + // encryption of an integer + let ct = cks.encrypt(clear as u64); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_mul(&ct, scalar); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear * scalar) % message_modulus, dec_res as u8); + } +} + +/// test smart scalar multiplication with the LWE server key +fn shortint_smart_scalar_mul(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u8; + + let scalar_modulus = cks.parameters.carry_modulus.0 as u8; + + for _ in 0..10 { + let clear = rng.gen::() % modulus; + + let scalar = rng.gen::() % scalar_modulus; + + // encryption of an integer + let mut ct = cks.encrypt(clear as u64); + + let mut ct_res = sks.smart_scalar_mul(&mut ct, scalar); + + let mut clear_res = clear * scalar; + for _ in 0..10 { + // scalar multiplication + ct_res = sks.smart_scalar_mul(&mut ct_res, scalar); + clear_res = (clear_res * scalar) % modulus; + } + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_res, dec_res as u8); + } +} + +/// test unchecked '>>' operation +fn shortint_unchecked_right_shift(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let shift = rng.gen::() % 2; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_right_shift(&ctxt_0, shift as u8); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_0 >> shift, dec_res); + } +} + +/// test '<<' operation +fn shortint_unchecked_left_shift(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..NB_TEST { + let clear_0 = rng.gen::() % modulus; + let shift = rng.gen::() % 2; + + // encryption of an integer + let ctxt_0 = cks.encrypt(clear_0); + + // add the two ciphertexts + let ct_res = sks.unchecked_scalar_left_shift(&ctxt_0, shift as u8); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 << shift) % modulus, dec_res); + } +} + +/// test unchecked subtraction +fn shortint_unchecked_sub(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + for _ in 0..NB_TEST { + // Define the cleartexts + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + // Encrypt the integers + let ctxt_1 = cks.encrypt(clear1); + let ctxt_2 = cks.encrypt(clear2); + + // Add the ciphertext 1 and 2 + let ct_tmp = sks.unchecked_sub(&ctxt_1, &ctxt_2); + + // Decrypt the result + let dec = cks.decrypt(&ct_tmp); + + // Check the correctness + let clear_result = (clear1 - clear2) % modulus; + assert_eq!(clear_result, dec % modulus); + } +} + +fn shortint_smart_sub(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..10 { + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + let mut ct1 = cks.encrypt(clear1); + let mut ct2 = cks.encrypt(clear2); + + let mut ct_res = sks.smart_sub(&mut ct1, &mut ct2); + + let mut clear_res = (clear1 - clear2) % modulus; + for _ in 0..10 { + // scalar multiplication + ct_res = sks.smart_sub(&mut ct_res, &mut ct2); + clear_res = (clear_res - clear2) % modulus; + } + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!(clear_res, dec_res); + } +} + +/// test multiplication +fn shortint_mul_small_carry(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + //RNG + let mut rng = rand::thread_rng(); + + let modulus = cks.parameters.message_modulus.0 as u64; + + for _ in 0..50 { + let clear_0 = rng.gen::() % modulus; + + let clear_1 = rng.gen::() % modulus; + + // encryption of an integer + let mut ctxt_zero = cks.encrypt(clear_0); + + // encryption of an integer + let mut ctxt_one = cks.encrypt(clear_1); + + // multiply together the two ciphertexts + let ct_res = sks.unchecked_mul_lsb_small_carry(&mut ctxt_zero, &mut ctxt_one); + + // decryption of ct_res + let dec_res = cks.decrypt(&ct_res); + + // assert + assert_eq!((clear_0 * clear_1) % modulus, dec_res % modulus); + } +} + +/// test encryption and decryption with the LWE client key +fn shortint_encrypt_with_message_modulus_smart_add_and_mul(param: Parameters) { + let keys = KEY_CACHE.get_from_param(param); + let (cks, sks) = (keys.client_key(), keys.server_key()); + + let mut rng = rand::thread_rng(); + let full_mod = (cks.parameters.message_modulus.0 * cks.parameters.carry_modulus.0) / 3; + + for _ in 0..NB_TEST { + let mut modulus = rng.gen::() % full_mod as u64; + while modulus == 0 { + modulus = rng.gen::() % full_mod as u64; + } + + let clear1 = rng.gen::() % modulus; + let clear2 = rng.gen::() % modulus; + + let mut ct1 = cks.encrypt_with_message_modulus(clear1, MessageModulus(modulus as usize)); + let mut ct2 = cks.encrypt_with_message_modulus(clear2, MessageModulus(modulus as usize)); + + println!( + "MUL SMALL CARRY:: clear1 = {}, clear2 = {}, mod = {}", + clear1, clear2, modulus + ); + let ct_res = sks.unchecked_mul_lsb_small_carry(&mut ct1, &mut ct2); + assert_eq!( + (clear1 * clear2) % modulus, + cks.decrypt_message_and_carry(&ct_res) % modulus + ); + + println!( + "ADD:: clear1 = {}, clear2 = {}, mod = {}", + clear1, clear2, modulus + ); + let ct_res = sks.unchecked_add(&ct1, &ct2); + assert_eq!((clear1 + clear2), cks.decrypt_message_and_carry(&ct_res)); + } +} diff --git a/tfhe/src/shortint/wopbs/mod.rs b/tfhe/src/shortint/wopbs/mod.rs new file mode 100644 index 000000000..2c2c11486 --- /dev/null +++ b/tfhe/src/shortint/wopbs/mod.rs @@ -0,0 +1,443 @@ +//! Module with the definition of the WopbsKey (WithOut padding PBS Key). +//! +//! This module implements the generation of another server public key, which allows to compute +//! an alternative version of the programmable bootstrapping. This does not require the use of a +//! bit of padding. +//! +//! In the case where a padding bit is defined, keys are generated so that there a compatible for +//! both uses. + +use crate::shortint::engine::ShortintEngine; +use crate::shortint::{Ciphertext, ClientKey, Parameters, ServerKey}; + +use crate::core_crypto::prelude::*; + +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +#[cfg(test)] +mod test; + +// Struct for WoPBS based on the private functional packing keyswitch. +#[derive(Clone, Debug)] +pub struct WopbsKey { + //Key for the private functional keyswitch + pub wopbs_server_key: ServerKey, + pub pbs_server_key: ServerKey, + pub cbs_pfpksk: LweCircuitBootstrapPrivateFunctionalPackingKeyswitchKeys64, + pub ksk_pbs_to_wopbs: LweKeyswitchKey64, + pub param: Parameters, +} + +impl WopbsKey { + /// Generates the server key required to compute a WoPBS from the client and the server keys. + /// + /// #Warning + /// Only when the classical PBS is not used in the circuit + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::wopbs::*; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(WOPBS_PARAM_MESSAGE_1_CARRY_1); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// ``` + pub fn new_wopbs_key_only_for_wopbs(cks: &ClientKey, sks: &ServerKey) -> WopbsKey { + ShortintEngine::with_thread_local_mut(|engine| { + engine.new_wopbs_key_only_for_wopbs(cks, sks).unwrap() + }) + } + + /// Generates the server key required to compute a WoPBS from the client and the server keys. + /// # Example + /// + /// ```rust + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_1; + /// use tfhe::shortint::wopbs::*; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1); + /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_1_CARRY_1); + /// ``` + pub fn new_wopbs_key(cks: &ClientKey, sks: &ServerKey, parameters: &Parameters) -> WopbsKey { + ShortintEngine::with_thread_local_mut(|engine| { + engine.new_wopbs_key(cks, sks, parameters).unwrap() + }) + } + + /// Generates the Look-Up Table homomorphically using the WoPBS approach. + /// + /// # Warning: this assumes one bit of padding. + /// + /// # Example + /// + /// ```rust + /// use rand::Rng; + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::wopbs::*; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let message_modulus = WOPBS_PARAM_MESSAGE_2_CARRY_2.message_modulus.0 as u64; + /// let m = 2; + /// let mut ct = cks.encrypt(m); + /// let lut = wopbs_key.generate_lut(&ct, |x| x * x % message_modulus); + /// let ct_res = wopbs_key.programmable_bootstrapping(&mut sks, &mut ct, &lut); + /// let res = cks.decrypt(&ct_res); + /// assert_eq!(res, (m * m) % message_modulus); + /// ``` + pub fn generate_lut(&self, ct: &Ciphertext, f: F) -> Vec + where + F: Fn(u64) -> u64, + { + // The function is applied only on the message modulus bits + let basis = ct.message_modulus.0 * ct.carry_modulus.0; + let delta = 64 - f64::log2((basis) as f64).ceil() as u64 - 1; + let poly_size = self.wopbs_server_key.bootstrapping_key.polynomial_size().0; + let mut vec_lut = vec![0; poly_size]; + for (i, value) in vec_lut.iter_mut().enumerate().take(basis) { + *value = f((i % ct.message_modulus.0) as u64) << delta; + } + vec_lut + } + + /// Generates the Look-Up Table homomorphically using the WoPBS approach. + /// + /// # Warning: this assumes no bit of padding. + /// + /// # Example + /// + /// ```rust + /// use rand::Rng; + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::wopbs::WopbsKey; + /// + /// // Generate the client key and the server key: + /// let (mut cks, mut sks) = gen_keys(WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let message_modulus = WOPBS_PARAM_MESSAGE_2_CARRY_2.message_modulus.0 as u64; + /// let m = 2; + /// let ct = cks.encrypt_without_padding(m); + /// let lut = wopbs_key.generate_lut(&ct, |x| x * x % message_modulus); + /// let ct_res = wopbs_key.programmable_bootstrapping_without_padding(&ct, &lut); + /// let res = cks.decrypt_without_padding(&ct_res); + /// assert_eq!(res, (m * m) % message_modulus); + /// ``` + pub fn generate_lut_without_padding(&self, ct: &Ciphertext, f: F) -> Vec + where + F: Fn(u64) -> u64, + { + // The function is applied only on the message modulus bits + let basis = ct.message_modulus.0 * ct.carry_modulus.0; + let delta = 64 - f64::log2((basis) as f64).ceil() as u64; + let poly_size = self.wopbs_server_key.bootstrapping_key.polynomial_size().0; + let mut vec_lut = vec![0; poly_size]; + for (i, value) in vec_lut.iter_mut().enumerate().take(basis) { + *value = f((i % ct.message_modulus.0) as u64) << delta; + } + vec_lut + } + + /// Generates the Look-Up Table homomorphically using the WoPBS approach. + /// + /// + /// # Example + /// + /// ```rust + /// use rand::Rng; + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_3_NORM2_2; + /// use tfhe::shortint::wopbs::WopbsKey; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(WOPBS_PARAM_MESSAGE_3_NORM2_2); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let message_modulus = 5; + /// let m = 2; + /// let mut ct = cks.encrypt_native_crt(m, message_modulus); + /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x * x % message_modulus as u64); + /// let ct_res = wopbs_key.programmable_bootstrapping_native_crt(&mut ct, &lut); + /// let res = cks.decrypt_message_native_crt(&ct_res, message_modulus); + /// assert_eq!(res, (m * m) % message_modulus as u64); + /// ``` + pub fn generate_lut_native_crt(&self, ct: &Ciphertext, f: F) -> Vec + where + F: Fn(u64) -> u64, + { + // The function is applied only on the message modulus bits + let basis = ct.message_modulus.0 * ct.carry_modulus.0; + let nb_bit = f64::log2((basis) as f64).ceil() as u64; + let poly_size = self.wopbs_server_key.bootstrapping_key.polynomial_size().0; + let mut vec_lut = vec![0; poly_size]; + for i in 0..basis { + let index_lut = (((i as u64 % basis as u64) << nb_bit) / basis as u64) as usize; + vec_lut[index_lut] = + (((f(i as u64) % basis as u64) as u128 * (1 << 64)) / basis as u128) as u64; + } + vec_lut + } + + /// Applies the Look-Up Table homomorphically using the WoPBS approach. + /// + /// #Warning: this assumes one bit of padding. + /// + /// # Example + /// + /// ```rust + /// use rand::Rng; + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::wopbs::*; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2); + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let mut rng = rand::thread_rng(); + /// let message_modulus = WOPBS_PARAM_MESSAGE_2_CARRY_2.message_modulus.0; + /// let ct = cks.encrypt(rng.gen::() % message_modulus as u64); + /// let lut = vec![(1_u64 << 59); wopbs_key.param.polynomial_size.0]; + /// let ct_res = wopbs_key.programmable_bootstrapping(&sks, &ct, &lut); + /// let res = cks.decrypt_message_and_carry(&ct_res); + /// assert_eq!(res, 1); + /// ``` + pub fn programmable_bootstrapping( + &self, + sks: &ServerKey, + ct_in: &Ciphertext, + lut: &[u64], + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrapping(self, sks, ct_in, lut) + .unwrap() + }) + } + + /// Applies the Look-Up Table homomorphically using the WoPBS approach. + /// + /// #Warning: this assumes one bit of padding. + /// #Warning: to use in a WoPBS context ONLY (i.e., non compliant with classical PBS) + /// + /// # Example + /// + /// ```rust + /// use rand::Rng; + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2; + /// use tfhe::shortint::wopbs::*; + /// + /// // Generate the client key and the server key: + /// let (cks, sks) = gen_keys(WOPBS_PARAM_MESSAGE_2_CARRY_2); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let mut rng = rand::thread_rng(); + /// let message_modulus = WOPBS_PARAM_MESSAGE_2_CARRY_2.message_modulus.0; + /// let ct = cks.encrypt(rng.gen::() % message_modulus as u64); + /// let lut = vec![(1_u64 << 59); wopbs_key.param.polynomial_size.0]; + /// let ct_res = wopbs_key.wopbs(&ct, &lut); + /// let res = cks.decrypt_message_and_carry(&ct_res); + /// assert_eq!(res, 1); + /// ``` + pub fn wopbs(&self, ct_in: &Ciphertext, lut: &[u64]) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.wopbs(self, ct_in, lut).unwrap()) + } + + /// Applies the Look-Up Table homomorphically using the WoPBS approach. + /// + /// # Example + /// + /// ```rust + /// use rand::Rng; + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_1_NORM2_2; + /// use tfhe::shortint::wopbs::*; + /// + /// let (cks, sks) = gen_keys(WOPBS_PARAM_MESSAGE_1_NORM2_2); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let mut rng = rand::thread_rng(); + /// let ct = cks.encrypt_without_padding(rng.gen::() % 2); + /// let lut = vec![(1_u64 << 63); wopbs_key.param.polynomial_size.0]; + /// let ct_res = wopbs_key.programmable_bootstrapping_without_padding(&ct, &lut); + /// let res = cks.decrypt_message_and_carry_without_padding(&ct_res); + /// assert_eq!(res, 1); + /// ``` + pub fn programmable_bootstrapping_without_padding( + &self, + ct_in: &Ciphertext, + lut: &[u64], + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrapping_without_padding(self, ct_in, lut) + .unwrap() + }) + } + + /// Applies the Look-Up Table homomorphically using the WoPBS approach. + /// + /// # Example + /// + /// ```rust + /// use tfhe::shortint::ciphertext::Ciphertext; + /// use tfhe::shortint::gen_keys; + /// use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_3_NORM2_2; + /// use tfhe::shortint::wopbs::*; + /// + /// let (cks, sks) = gen_keys(WOPBS_PARAM_MESSAGE_3_NORM2_2); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let msg = 2; + /// let modulus = 5; + /// let mut ct = cks.encrypt_native_crt(msg, modulus); + /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x); + /// let ct_res = wopbs_key.programmable_bootstrapping_native_crt(&mut ct, &lut); + /// let res = cks.decrypt_message_native_crt(&ct_res, modulus); + /// assert_eq!(res, msg); + /// ``` + pub fn programmable_bootstrapping_native_crt( + &self, + ct_in: &mut Ciphertext, + lut: &[u64], + ) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .programmable_bootstrapping_native_crt(self, ct_in, lut) + .unwrap() + }) + } + + /// Extracts the given number of bits from a ciphertext. + /// + /// # Warning Experimental + pub fn extract_bits( + &self, + delta_log: DeltaLog, + ciphertext: &Ciphertext, + num_bits_to_extract: usize, + ) -> LweCiphertextVector64 { + ShortintEngine::with_thread_local_mut(|engine| { + engine + .extract_bits( + delta_log, + &ciphertext.ct, + self, + ExtractedBitsCount(num_bits_to_extract), + ) + .unwrap() + }) + } + + /// Temporary wrapper. + /// + /// # Warning Experimental + pub fn circuit_bootstrapping_vertical_packing( + &self, + vec_lut: Vec>, + extracted_bits_blocks: Vec, + ) -> Vec { + ShortintEngine::with_thread_local_mut(|engine| { + engine.circuit_bootstrapping_vertical_packing(self, vec_lut, extracted_bits_blocks) + }) + } + + pub fn keyswitch_to_wopbs_params(&self, sks: &ServerKey, ct_in: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| { + engine.keyswitch_to_wopbs_params(sks, self, ct_in) + }) + .unwrap() + } + + pub fn keyswitch_to_pbs_params(&self, ct_in: &Ciphertext) -> Ciphertext { + ShortintEngine::with_thread_local_mut(|engine| engine.keyswitch_to_pbs_params(self, ct_in)) + .unwrap() + } +} + +#[derive(Serialize)] +struct SerializableWopbsKey<'a> { + wopbs_server_key: &'a ServerKey, + pbs_server_key: &'a ServerKey, + cbs_pfpksk: Vec, + ksk_pbs_to_wopbs: Vec, + param: Parameters, +} + +#[derive(Deserialize)] +struct DeserializableWopbsKey { + wopbs_server_key: ServerKey, + pbs_server_key: ServerKey, + cbs_pfpksk: Vec, + ksk_pbs_to_wopbs: Vec, + param: Parameters, +} + +impl Serialize for WopbsKey { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut default_ser_eng = + DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?; + + let cbs_pfpksk = default_ser_eng + .serialize(&self.cbs_pfpksk) + .map_err(serde::ser::Error::custom)?; + + let ksk_pbs_to_wopbs = default_ser_eng + .serialize(&self.ksk_pbs_to_wopbs) + .map_err(serde::ser::Error::custom)?; + + SerializableWopbsKey { + wopbs_server_key: &self.wopbs_server_key, + pbs_server_key: &self.pbs_server_key, + cbs_pfpksk, + ksk_pbs_to_wopbs, + param: self.param, + } + .serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for WopbsKey { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let thing = + DeserializableWopbsKey::deserialize(deserializer).map_err(serde::de::Error::custom)?; + + let mut default_ser_eng = + DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?; + + let cbs_pfpksk = default_ser_eng + .deserialize(thing.cbs_pfpksk.as_slice()) + .map_err(serde::de::Error::custom)?; + + let ksk_pbs_to_wopbs = default_ser_eng + .deserialize(thing.ksk_pbs_to_wopbs.as_slice()) + .map_err(serde::de::Error::custom)?; + + Ok(Self { + wopbs_server_key: thing.wopbs_server_key, + pbs_server_key: thing.pbs_server_key, + cbs_pfpksk, + ksk_pbs_to_wopbs, + param: thing.param, + }) + } +} diff --git a/tfhe/src/shortint/wopbs/test.rs b/tfhe/src/shortint/wopbs/test.rs new file mode 100644 index 000000000..3d33d3664 --- /dev/null +++ b/tfhe/src/shortint/wopbs/test.rs @@ -0,0 +1,106 @@ +use crate::shortint::keycache::KEY_CACHE_WOPBS; +use crate::shortint::parameters::parameters_wopbs_message_carry::*; +use crate::shortint::parameters::{ + MessageModulus, PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_4_CARRY_4, +}; +use crate::shortint::wopbs::WopbsKey; +use crate::shortint::{gen_keys, Parameters}; +use paste::paste; +use rand::Rng; + +const NB_TEST: usize = 1; + +macro_rules! create_parametrized_test{ + ($name:ident { $( ($sks_param:ident, $wopbs_param:ident) ),* }) => { + paste! { + $( + #[test] + fn []() { + $name(($sks_param, $wopbs_param)) + } + )* + } + }; + ($name:ident)=> { + create_parametrized_test!($name + { + (PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1), + (PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2), + (PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3), + (PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4) + }); + }; +} + +create_parametrized_test!(generate_lut); +create_parametrized_test!(generate_lut_modulus); +create_parametrized_test!(generate_lut_modulus_not_power_of_two); + +fn generate_lut(params: (Parameters, Parameters)) { + let keys = KEY_CACHE_WOPBS.get_from_param(params); + let (cks, sks, wopbs_key) = (keys.client_key(), keys.server_key(), keys.wopbs_key()); + let mut rng = rand::thread_rng(); + + let mut tmp = 0; + for _ in 0..NB_TEST { + let message_modulus = params.0.message_modulus.0; + let m = rng.gen::() % message_modulus; + let ct = cks.encrypt(m as u64); + let lut = wopbs_key.generate_lut(&ct, |x| x % message_modulus as u64); + let ct_res = wopbs_key.programmable_bootstrapping(sks, &ct, &lut); + + let res = cks.decrypt(&ct_res); + if res != (m % message_modulus) as u64 { + tmp += 1; + } + } + if 0 != tmp { + println!("______"); + println!("failure rate {:?}/{:?}", tmp, NB_TEST); + println!("______"); + } + assert_eq!(0, tmp); +} + +fn generate_lut_modulus(params: (Parameters, Parameters)) { + let keys = KEY_CACHE_WOPBS.get_from_param(params); + let (cks, sks, wopbs_key) = (keys.client_key(), keys.server_key(), keys.wopbs_key()); + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST { + let message_modulus = MessageModulus(params.0.message_modulus.0 - 1); + let m = rng.gen::() % message_modulus.0; + + let ct = cks.encrypt_with_message_modulus(m as u64, message_modulus); + + let ct = wopbs_key.keyswitch_to_wopbs_params(sks, &ct); + let lut = wopbs_key.generate_lut(&ct, |x| (x * x) % message_modulus.0 as u64); + let ct_res = wopbs_key.wopbs(&ct, &lut); + let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res); + + let res = cks.decrypt(&ct_res); + assert_eq!(res as usize, (m * m) % message_modulus.0); + } +} + +fn generate_lut_modulus_not_power_of_two(params: (Parameters, Parameters)) { + let (cks, sks) = gen_keys(params.1); + let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + // let keys = KEY_CACHE_WOPBS.get_from_param((WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_2, + // WOPBS_PRIME_PARAM_MESSAGE_2_NORM2_2)); let (cks, _, wopbs_key) = (keys.client_key(), + // keys.server_key(), keys.wopbs_key()); + let mut rng = rand::thread_rng(); + + for _ in 0..NB_TEST { + let message_modulus = MessageModulus(params.0.message_modulus.0 - 1); + + let m = rng.gen::() % message_modulus.0; + let mut ct = cks.encrypt_native_crt(m as u64, message_modulus.0 as u8); + let lut = wopbs_key.generate_lut_native_crt(&ct, |x| (x * x) % message_modulus.0 as u64); + + let ct_res = wopbs_key.programmable_bootstrapping_native_crt(&mut ct, &lut); + let res = cks.decrypt_message_native_crt(&ct_res, message_modulus.0 as u8); + assert_eq!(res as usize, (m * m) % message_modulus.0); + } +} diff --git a/tfhe/src/test_user_docs.rs b/tfhe/src/test_user_docs.rs new file mode 100644 index 000000000..f674e405e --- /dev/null +++ b/tfhe/src/test_user_docs.rs @@ -0,0 +1,23 @@ +use doc_comment::doctest; + +// Getting started +doctest!("../docs/getting_started/quick_start.md", quick_start); +doctest!("../docs/getting_started/operations.md", operations); + +// Booleans +doctest!("../docs/Booleans/parameters.md", booleans_parameters); +doctest!("../docs/Booleans/operations.md", booleans_operations); +doctest!("../docs/Booleans/serialization.md", booleans_serialization); +doctest!("../docs/Booleans/tutorial.md", booleans_tutorial); + +// Shortint +doctest!("../docs/shortint/parameters.md", shortint_parameters); +doctest!("../docs/shortint/serialization.md", shortint_serialization); +doctest!("../docs/shortint/tutorial.md", shortint_tutorial); +doctest!("../docs/shortint/operations.md", shortint_operations); + +// doctest!("../docs/tutorials/serialization.md", serialization_tuto); +// doctest!( +// "../docs/tutorials/circuit_evaluation.md", +// circuit_evaluation +// ); diff --git a/toolchain.txt b/toolchain.txt new file mode 100644 index 000000000..762dae8c2 --- /dev/null +++ b/toolchain.txt @@ -0,0 +1 @@ +nightly-2022-11-03 \ No newline at end of file