mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 15:48:20 -05:00
Compare commits
52 Commits
multibit-d
...
chore/para
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00fc67e84c | ||
|
|
a5906bb7cb | ||
|
|
90b7494acd | ||
|
|
3508019cd2 | ||
|
|
200c8a177a | ||
|
|
2f6c1cf0b5 | ||
|
|
b96027f417 | ||
|
|
90c850ca0d | ||
|
|
c8d3008a8d | ||
|
|
08c264f193 | ||
|
|
4ae202d8a4 | ||
|
|
7eb8601540 | ||
|
|
8a1691c536 | ||
|
|
d1cb55ba24 | ||
|
|
2b9a49db87 | ||
|
|
62ddb24f00 | ||
|
|
c6ae463b41 | ||
|
|
4947eefad4 | ||
|
|
71209e3927 | ||
|
|
2a66ea3d16 | ||
|
|
d4ff1f5595 | ||
|
|
8ae92a960d | ||
|
|
b042c2f7d6 | ||
|
|
e307da5c7f | ||
|
|
3d5b88d608 | ||
|
|
4fbf0691c5 | ||
|
|
5d277e85b9 | ||
|
|
778eea30e9 | ||
|
|
63247fa227 | ||
|
|
799291a1f0 | ||
|
|
509fe7a63e | ||
|
|
4eac45f0c6 | ||
|
|
ddb3451087 | ||
|
|
e66a329e33 | ||
|
|
d79b1d9b19 | ||
|
|
b501cc078a | ||
|
|
800878d89e | ||
|
|
20d0e81bae | ||
|
|
d3dbf4ecc9 | ||
|
|
c20ca07cd3 | ||
|
|
9f6c7e9139 | ||
|
|
3c8d6a6f8b | ||
|
|
1c837fa6f0 | ||
|
|
1ec7e4762a | ||
|
|
20fb697d57 | ||
|
|
0429d56cf3 | ||
|
|
509bf3e284 | ||
|
|
b2fc1d5266 | ||
|
|
62d94dbee8 | ||
|
|
fbe911d7db | ||
|
|
ba72faf828 | ||
|
|
c387b9340f |
27
.github/workflows/aws_tfhe_integer_tests.yml
vendored
27
.github/workflows/aws_tfhe_integer_tests.yml
vendored
@@ -25,26 +25,35 @@ on:
|
||||
request_id:
|
||||
description: 'Slab request ID'
|
||||
type: string
|
||||
matrix_item:
|
||||
description: 'Build matrix item'
|
||||
fork_repo:
|
||||
description: 'Name of forked repo as user/repo'
|
||||
type: string
|
||||
fork_git_sha:
|
||||
description: 'Git SHA to checkout from fork'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
integer-tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }}
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ${{ github.event.inputs.runner_name }}
|
||||
runs-on: ${{ inputs.runner_name }}
|
||||
steps:
|
||||
# Step used for log purpose.
|
||||
- name: Instance configuration used
|
||||
run: |
|
||||
echo "ID: ${{ github.event.inputs.instance_id }}"
|
||||
echo "AMI: ${{ github.event.inputs.instance_image_id }}"
|
||||
echo "Type: ${{ github.event.inputs.instance_type }}"
|
||||
echo "Request ID: ${{ github.event.inputs.request_id }}"
|
||||
echo "ID: ${{ inputs.instance_id }}"
|
||||
echo "AMI: ${{ inputs.instance_image_id }}"
|
||||
echo "Type: ${{ inputs.instance_type }}"
|
||||
echo "Request ID: ${{ inputs.request_id }}"
|
||||
echo "Fork repo: ${{ inputs.fork_repo }}"
|
||||
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
|
||||
|
||||
- uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
- name: Checkout tfhe-rs
|
||||
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
with:
|
||||
repository: ${{ inputs.fork_repo }}
|
||||
ref: ${{ inputs.fork_git_sha }}
|
||||
|
||||
- name: Set up home
|
||||
run: |
|
||||
|
||||
90
.github/workflows/aws_tfhe_multi_bit_tests.yml
vendored
Normal file
90
.github/workflows/aws_tfhe_multi_bit_tests.yml
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
name: AWS Multi Bit Tests on CPU
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
RUSTFLAGS: "-C target-cpu=native"
|
||||
|
||||
on:
|
||||
# Allows you to run this workflow manually from the Actions tab as an alternative.
|
||||
workflow_dispatch:
|
||||
# All the inputs are provided by Slab
|
||||
inputs:
|
||||
instance_id:
|
||||
description: "AWS instance ID"
|
||||
type: string
|
||||
instance_image_id:
|
||||
description: "AWS instance AMI ID"
|
||||
type: string
|
||||
instance_type:
|
||||
description: "AWS instance product type"
|
||||
type: string
|
||||
runner_name:
|
||||
description: "Action runner name"
|
||||
type: string
|
||||
request_id:
|
||||
description: 'Slab request ID'
|
||||
type: string
|
||||
fork_repo:
|
||||
description: 'Name of forked repo as user/repo'
|
||||
type: string
|
||||
fork_git_sha:
|
||||
description: 'Git SHA to checkout from fork'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
shortint-tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ${{ inputs.runner_name }}
|
||||
steps:
|
||||
# Step used for log purpose.
|
||||
- name: Instance configuration used
|
||||
run: |
|
||||
echo "ID: ${{ inputs.instance_id }}"
|
||||
echo "AMI: ${{ inputs.instance_image_id }}"
|
||||
echo "Type: ${{ inputs.instance_type }}"
|
||||
echo "Request ID: ${{ inputs.request_id }}"
|
||||
echo "Fork repo: ${{ inputs.fork_repo }}"
|
||||
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
|
||||
|
||||
- name: Checkout tfhe-rs
|
||||
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
with:
|
||||
repository: ${{ inputs.fork_repo }}
|
||||
ref: ${{ inputs.fork_git_sha }}
|
||||
|
||||
- name: Set up home
|
||||
run: |
|
||||
echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Install latest stable
|
||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af
|
||||
with:
|
||||
toolchain: stable
|
||||
default: true
|
||||
|
||||
- name: Gen Keys if required
|
||||
run: |
|
||||
MULTI_BIT_ONLY=TRUE make gen_key_cache
|
||||
|
||||
- name: Run shortint multi-bit tests
|
||||
run: |
|
||||
make test_shortint_multi_bit_ci
|
||||
|
||||
- name: Run integer multi-bit tests
|
||||
run: |
|
||||
make test_integer_multi_bit_ci
|
||||
|
||||
- name: Slack Notification
|
||||
if: ${{ always() }}
|
||||
continue-on-error: true
|
||||
uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7
|
||||
env:
|
||||
SLACK_COLOR: ${{ job.status }}
|
||||
SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }}
|
||||
SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png
|
||||
SLACK_MESSAGE: "Shortint tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})"
|
||||
SLACK_USERNAME: ${{ secrets.BOT_USERNAME }}
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
31
.github/workflows/aws_tfhe_tests.yml
vendored
31
.github/workflows/aws_tfhe_tests.yml
vendored
@@ -25,26 +25,35 @@ on:
|
||||
request_id:
|
||||
description: 'Slab request ID'
|
||||
type: string
|
||||
matrix_item:
|
||||
description: 'Build matrix item'
|
||||
fork_repo:
|
||||
description: 'Name of forked repo as user/repo'
|
||||
type: string
|
||||
fork_git_sha:
|
||||
description: 'Git SHA to checkout from fork'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
shortint-tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }}
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ${{ github.event.inputs.runner_name }}
|
||||
runs-on: ${{ inputs.runner_name }}
|
||||
steps:
|
||||
# Step used for log purpose.
|
||||
- name: Instance configuration used
|
||||
run: |
|
||||
echo "ID: ${{ github.event.inputs.instance_id }}"
|
||||
echo "AMI: ${{ github.event.inputs.instance_image_id }}"
|
||||
echo "Type: ${{ github.event.inputs.instance_type }}"
|
||||
echo "Request ID: ${{ github.event.inputs.request_id }}"
|
||||
echo "ID: ${{ inputs.instance_id }}"
|
||||
echo "AMI: ${{ inputs.instance_image_id }}"
|
||||
echo "Type: ${{ inputs.instance_type }}"
|
||||
echo "Request ID: ${{ inputs.request_id }}"
|
||||
echo "Fork repo: ${{ inputs.fork_repo }}"
|
||||
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
|
||||
|
||||
- uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
- name: Checkout tfhe-rs
|
||||
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
with:
|
||||
repository: ${{ inputs.fork_repo }}
|
||||
ref: ${{ inputs.fork_git_sha }}
|
||||
|
||||
- name: Set up home
|
||||
run: |
|
||||
@@ -72,10 +81,6 @@ jobs:
|
||||
run: |
|
||||
make test_user_doc
|
||||
|
||||
- name: Run js on wasm API tests
|
||||
run: |
|
||||
make test_nodejs_wasm_api_in_docker
|
||||
|
||||
- name: Gen Keys if required
|
||||
run: |
|
||||
make gen_key_cache
|
||||
|
||||
87
.github/workflows/aws_tfhe_wasm_tests.yml
vendored
Normal file
87
.github/workflows/aws_tfhe_wasm_tests.yml
vendored
Normal file
@@ -0,0 +1,87 @@
|
||||
name: AWS WASM Tests on CPU
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||
RUSTFLAGS: "-C target-cpu=native"
|
||||
|
||||
on:
|
||||
# Allows you to run this workflow manually from the Actions tab as an alternative.
|
||||
workflow_dispatch:
|
||||
# All the inputs are provided by Slab
|
||||
inputs:
|
||||
instance_id:
|
||||
description: "AWS instance ID"
|
||||
type: string
|
||||
instance_image_id:
|
||||
description: "AWS instance AMI ID"
|
||||
type: string
|
||||
instance_type:
|
||||
description: "AWS instance product type"
|
||||
type: string
|
||||
runner_name:
|
||||
description: "Action runner name"
|
||||
type: string
|
||||
request_id:
|
||||
description: 'Slab request ID'
|
||||
type: string
|
||||
fork_repo:
|
||||
description: 'Name of forked repo as user/repo'
|
||||
type: string
|
||||
fork_git_sha:
|
||||
description: 'Git SHA to checkout from fork'
|
||||
type: string
|
||||
|
||||
jobs:
|
||||
wasm-tests:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ${{ inputs.runner_name }}
|
||||
steps:
|
||||
# Step used for log purpose.
|
||||
- name: Instance configuration used
|
||||
run: |
|
||||
echo "ID: ${{ inputs.instance_id }}"
|
||||
echo "AMI: ${{ inputs.instance_image_id }}"
|
||||
echo "Type: ${{ inputs.instance_type }}"
|
||||
echo "Request ID: ${{ inputs.request_id }}"
|
||||
echo "Fork repo: ${{ inputs.fork_repo }}"
|
||||
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
|
||||
|
||||
- name: Checkout tfhe-rs
|
||||
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
|
||||
with:
|
||||
repository: ${{ inputs.fork_repo }}
|
||||
ref: ${{ inputs.fork_git_sha }}
|
||||
|
||||
- name: Set up home
|
||||
run: |
|
||||
echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Install latest stable
|
||||
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af
|
||||
with:
|
||||
toolchain: stable
|
||||
default: true
|
||||
|
||||
- name: Run js on wasm API tests
|
||||
run: |
|
||||
make test_nodejs_wasm_api_in_docker
|
||||
|
||||
- name: Run parallel wasm tests
|
||||
run: |
|
||||
make install_node
|
||||
make ci_test_web_js_api_parallel
|
||||
|
||||
- name: Slack Notification
|
||||
if: ${{ always() }}
|
||||
continue-on-error: true
|
||||
uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7
|
||||
env:
|
||||
SLACK_COLOR: ${{ job.status }}
|
||||
SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }}
|
||||
SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png
|
||||
SLACK_MESSAGE: "WASM tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})"
|
||||
SLACK_USERNAME: ${{ secrets.BOT_USERNAME }}
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
4
.github/workflows/make_release.yml
vendored
4
.github/workflows/make_release.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
make build_web_js_api
|
||||
|
||||
- name: Publish web package
|
||||
uses: JS-DevTools/npm-publish@541aa6b21b4a1e9990c95a92c21adc16b35e9551
|
||||
uses: JS-DevTools/npm-publish@a25b4180b728b0279fca97d4e5bccf391685aead
|
||||
with:
|
||||
token: ${{ secrets.NPM_TOKEN }}
|
||||
package: tfhe/pkg/package.json
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
sed -i 's/"tfhe"/"node-tfhe"/g' tfhe/pkg/package.json
|
||||
|
||||
- name: Publish Node package
|
||||
uses: JS-DevTools/npm-publish@541aa6b21b4a1e9990c95a92c21adc16b35e9551
|
||||
uses: JS-DevTools/npm-publish@a25b4180b728b0279fca97d4e5bccf391685aead
|
||||
with:
|
||||
token: ${{ secrets.NPM_TOKEN }}
|
||||
package: tfhe/pkg/package.json
|
||||
|
||||
@@ -16,3 +16,5 @@ jobs:
|
||||
message: |
|
||||
@slab-ci cpu_test
|
||||
@slab-ci cpu_integer_test
|
||||
@slab-ci cpu_multi_bit_test
|
||||
@slab-ci cpu_wasm_test
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,3 +12,4 @@ target/
|
||||
|
||||
# Some of our bench outputs
|
||||
/tfhe/benchmarks_parameters
|
||||
/tfhe/shortint_key_sizes.csv
|
||||
|
||||
131
CODE_OF_CONDUCT.md
Normal file
131
CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
- Demonstrating empathy and kindness toward other people
|
||||
- Being respectful of differing opinions, viewpoints, and experiences
|
||||
- Giving and gracefully accepting constructive feedback
|
||||
- Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
- Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
- The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
- Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
- Public or private harassment
|
||||
- Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
- Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting us anonymously through [this form](https://forms.gle/569j3cZqGRFgrR3u9).
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder][mozilla coc].
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq][faq]. Translations are available at
|
||||
[https://www.contributor-covenant.org/translations][translations].
|
||||
|
||||
[faq]: https://www.contributor-covenant.org/faq
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[mozilla coc]: https://github.com/mozilla/diversity
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
119
Makefile
119
Makefile
@@ -21,6 +21,10 @@ else
|
||||
AVX512_FEATURE=
|
||||
endif
|
||||
|
||||
# Variables used only for regex_engine example
|
||||
REGEX_STRING?=''
|
||||
REGEX_PATTERN?=''
|
||||
|
||||
.PHONY: rs_check_toolchain # Echo the rust toolchain used for checks
|
||||
rs_check_toolchain:
|
||||
@echo $(RS_CHECK_TOOLCHAIN)
|
||||
@@ -58,6 +62,13 @@ install_wasm_pack: install_rs_build_toolchain
|
||||
cargo $(CARGO_RS_BUILD_TOOLCHAIN) install wasm-pack || \
|
||||
( echo "Unable to install cargo wasm-pack, unknown error." && exit 1 )
|
||||
|
||||
.PHONY: install_node # Install last version of NodeJS via nvm
|
||||
install_node:
|
||||
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh | $(SHELL)
|
||||
source ~/.bashrc
|
||||
$(SHELL) -i -c 'nvm install node' || \
|
||||
( echo "Unable to install node, unknown error." && exit 1 )
|
||||
|
||||
.PHONY: fmt # Format rust code
|
||||
fmt: install_rs_check_toolchain
|
||||
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt
|
||||
@@ -105,10 +116,10 @@ clippy_c_api: install_rs_check_toolchain
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint and the js wasm API
|
||||
.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint, integer and the js wasm API
|
||||
clippy_js_wasm_api: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api \
|
||||
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_tasks # Run clippy lints on helper tasks crate.
|
||||
@@ -131,9 +142,13 @@ clippy_fast: clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_ta
|
||||
|
||||
.PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests
|
||||
gen_key_cache: install_rs_build_toolchain
|
||||
if [[ "$${MULTI_BIT_ONLY}" == TRUE ]]; then \
|
||||
multi_bit_flag="--multi-bit-only"; \
|
||||
fi && \
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
|
||||
--example generates_test_keys \
|
||||
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe
|
||||
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe -- \
|
||||
$${multi_bit_flag:+"$${multi_bit_flag}"}
|
||||
|
||||
.PHONY: build_core # Build core_crypto without experimental features
|
||||
build_core: install_rs_build_toolchain install_rs_check_toolchain
|
||||
@@ -184,14 +199,23 @@ build_web_js_api: install_rs_build_toolchain install_wasm_pack
|
||||
cd tfhe && \
|
||||
RUSTFLAGS="$(WASM_RUSTFLAGS)" rustup run "$(RS_BUILD_TOOLCHAIN)" \
|
||||
wasm-pack build --release --target=web \
|
||||
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api
|
||||
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api
|
||||
|
||||
.PHONY: build_web_js_api_parallel # Build the js API targeting the web browser with parallelism support
|
||||
build_web_js_api_parallel: install_rs_check_toolchain install_wasm_pack
|
||||
cd tfhe && \
|
||||
rustup component add rust-src --toolchain $(RS_CHECK_TOOLCHAIN) && \
|
||||
RUSTFLAGS="$(WASM_RUSTFLAGS) -C target-feature=+atomics,+bulk-memory,+mutable-globals" rustup run $(RS_CHECK_TOOLCHAIN) \
|
||||
wasm-pack build --release --target=web \
|
||||
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,parallel-wasm-api \
|
||||
-Z build-std=panic_abort,std
|
||||
|
||||
.PHONY: build_node_js_api # Build the js API targeting nodejs
|
||||
build_node_js_api: install_rs_build_toolchain install_wasm_pack
|
||||
cd tfhe && \
|
||||
RUSTFLAGS="$(WASM_RUSTFLAGS)" rustup run "$(RS_BUILD_TOOLCHAIN)" \
|
||||
wasm-pack build --release --target=nodejs \
|
||||
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api
|
||||
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api
|
||||
|
||||
.PHONY: test_core_crypto # Run the tests of the core_crypto module including experimental ones
|
||||
test_core_crypto: install_rs_build_toolchain install_rs_check_toolchain
|
||||
@@ -208,13 +232,24 @@ test_boolean: install_rs_build_toolchain
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean -p tfhe -- boolean::
|
||||
|
||||
.PHONY: test_c_api # Run the tests for the C API
|
||||
test_c_api: build_c_api
|
||||
test_c_api: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api \
|
||||
-p tfhe \
|
||||
c_api
|
||||
|
||||
"$(MAKE)" build_c_api
|
||||
./scripts/c_api_tests.sh
|
||||
|
||||
.PHONY: test_shortint_ci # Run the tests for shortint ci
|
||||
test_shortint_ci: install_rs_build_toolchain install_cargo_nextest
|
||||
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
|
||||
./scripts/shortint-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN)
|
||||
./scripts/shortint-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN)
|
||||
|
||||
.PHONY: test_shortint_multi_bit_ci # Run the tests for shortint ci running only multibit tests
|
||||
test_shortint_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest
|
||||
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
|
||||
./scripts/shortint-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) --multi-bit
|
||||
|
||||
.PHONY: test_shortint # Run all the tests for shortint
|
||||
test_shortint: install_rs_build_toolchain
|
||||
@@ -224,7 +259,12 @@ test_shortint: install_rs_build_toolchain
|
||||
.PHONY: test_integer_ci # Run the tests for integer ci
|
||||
test_integer_ci: install_rs_build_toolchain install_cargo_nextest
|
||||
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
|
||||
./scripts/integer-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN)
|
||||
./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN)
|
||||
|
||||
.PHONY: test_integer_multi_bit_ci # Run the tests for integer ci running only multibit tests
|
||||
test_integer_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest
|
||||
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
|
||||
./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) --multi-bit
|
||||
|
||||
.PHONY: test_integer # Run all the tests for integer
|
||||
test_integer: install_rs_build_toolchain
|
||||
@@ -242,12 +282,27 @@ test_user_doc: install_rs_build_toolchain
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p tfhe \
|
||||
-- test_user_docs::
|
||||
|
||||
.PHONY: test_regex_engine # Run tests for regex_engine example
|
||||
test_regex_engine: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--example regex_engine \
|
||||
--features=$(TARGET_ARCH_FEATURE),integer
|
||||
|
||||
.PHONY: test_sha256_bool # Run tests for sha256_bool example
|
||||
test_sha256_bool: install_rs_build_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
|
||||
--example sha256_bool \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean
|
||||
|
||||
.PHONY: doc # Build rust doc
|
||||
doc: install_rs_check_toolchain
|
||||
RUSTDOCFLAGS="--html-in-header katex-header.html -Dwarnings" \
|
||||
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer --no-deps
|
||||
|
||||
.PHONY: docs # Build rust doc alias for doc
|
||||
docs: doc
|
||||
|
||||
.PHONY: format_doc_latex # Format the documentation latex equations to avoid broken rendering.
|
||||
format_doc_latex:
|
||||
cargo xtask format_latex_doc
|
||||
@@ -258,13 +313,15 @@ format_doc_latex:
|
||||
@printf "\n===============================\n"
|
||||
|
||||
.PHONY: check_compile_tests # Build tests in debug without running them
|
||||
check_compile_tests: build_c_api
|
||||
check_compile_tests:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
|
||||
--features=$(TARGET_ARCH_FEATURE),experimental,boolean,shortint,integer,internal-keycache \
|
||||
-p tfhe
|
||||
@if [[ "$(OS)" == "Linux" || "$(OS)" == "Darwin" ]]; then \
|
||||
./scripts/c_api_tests.sh --build-only; \
|
||||
fi
|
||||
|
||||
@if [[ "$(OS)" == "Linux" || "$(OS)" == "Darwin" ]]; then \
|
||||
"$(MAKE)" build_c_api; \
|
||||
./scripts/c_api_tests.sh --build-only; \
|
||||
fi
|
||||
|
||||
.PHONY: build_nodejs_test_docker # Build a docker image with tools to run nodejs tests for wasm API
|
||||
build_nodejs_test_docker:
|
||||
@@ -286,6 +343,24 @@ test_nodejs_wasm_api_in_docker: build_nodejs_test_docker
|
||||
test_nodejs_wasm_api: build_node_js_api
|
||||
cd tfhe && node --test js_on_wasm_tests
|
||||
|
||||
.PHONY: test_web_js_api_parallel # Run tests for the web wasm api
|
||||
test_web_js_api_parallel: build_web_js_api_parallel
|
||||
$(MAKE) -C tfhe/web_wasm_parallel_tests test
|
||||
|
||||
.PHONY: ci_test_web_js_api_parallel # Run tests for the web wasm api
|
||||
ci_test_web_js_api_parallel: build_web_js_api_parallel
|
||||
# Auto-retry since WASM tests can be flaky
|
||||
@for i in 1 2 3 ; do \
|
||||
source ~/.nvm/nvm.sh && \
|
||||
nvm use node && \
|
||||
$(MAKE) -C tfhe/web_wasm_parallel_tests test-ci | tee web_js_tests_output; \
|
||||
if grep -q -i "timeout" web_js_tests_output; then \
|
||||
echo "Timeout occurred starting attempt #${i}"; \
|
||||
else \
|
||||
break; \
|
||||
fi; \
|
||||
done
|
||||
|
||||
.PHONY: no_tfhe_typo # Check we did not invert the h and f in tfhe
|
||||
no_tfhe_typo:
|
||||
@./scripts/no_tfhe_typo.sh
|
||||
@@ -326,6 +401,26 @@ measure_boolean_key_sizes: install_rs_check_toolchain
|
||||
--example boolean_key_sizes \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,internal-keycache
|
||||
|
||||
.PHONY: regex_engine # Run regex_engine example
|
||||
regex_engine: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
|
||||
--example regex_engine \
|
||||
--features=$(TARGET_ARCH_FEATURE),integer \
|
||||
-- $(REGEX_STRING) $(REGEX_PATTERN)
|
||||
|
||||
.PHONY: dark_market # Run dark market example
|
||||
dark_market: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
|
||||
--example dark_market \
|
||||
--features=$(TARGET_ARCH_FEATURE),integer,internal-keycache \
|
||||
-- fhe-modified fhe-parallel plain fhe
|
||||
|
||||
.PHONY: sha256_bool # Run sha256_bool example
|
||||
sha256_bool: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
|
||||
--example sha256_bool \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean
|
||||
|
||||
.PHONY: pcc # pcc stands for pre commit checks
|
||||
pcc: no_tfhe_typo check_fmt doc clippy_all check_compile_tests
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ use tfhe::boolean::prelude::*;
|
||||
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let (mut client_key, mut server_key) = gen_keys();
|
||||
let (client_key, server_key) = gen_keys();
|
||||
|
||||
// We use the client secret key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(true);
|
||||
@@ -132,7 +132,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We create keys to create 16 bits integers
|
||||
// using 8 blocks of 2 bits
|
||||
let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, 8);
|
||||
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, 8);
|
||||
|
||||
let clear_a = 2382u16;
|
||||
let clear_b = 29374u16;
|
||||
|
||||
15
ci/slab.toml
15
ci/slab.toml
@@ -3,6 +3,11 @@ region = "eu-west-3"
|
||||
image_id = "ami-04deffe45b5b236fd"
|
||||
instance_type = "m6i.32xlarge"
|
||||
|
||||
[profile.cpu-small]
|
||||
region = "eu-west-3"
|
||||
image_id = "ami-04deffe45b5b236fd"
|
||||
instance_type = "m6i.4xlarge"
|
||||
|
||||
[profile.bench]
|
||||
region = "eu-west-3"
|
||||
image_id = "ami-04deffe45b5b236fd"
|
||||
@@ -18,6 +23,16 @@ workflow = "aws_tfhe_integer_tests.yml"
|
||||
profile = "cpu-big"
|
||||
check_run_name = "CPU Integer AWS Tests"
|
||||
|
||||
[command.cpu_multi_bit_test]
|
||||
workflow = "aws_tfhe_multi_bit_tests.yml"
|
||||
profile = "cpu-big"
|
||||
check_run_name = "CPU AWS Multi Bit Tests"
|
||||
|
||||
[command.cpu_wasm_test]
|
||||
workflow = "aws_tfhe_wasm_tests.yml"
|
||||
profile = "cpu-small"
|
||||
check_run_name = "CPU AWS WASM Tests"
|
||||
|
||||
[command.integer_bench]
|
||||
workflow = "integer_benchmark.yml"
|
||||
profile = "bench"
|
||||
|
||||
@@ -2,6 +2,49 @@
|
||||
|
||||
set -e
|
||||
|
||||
function usage() {
|
||||
echo "$0: shortint test runner"
|
||||
echo
|
||||
echo "--help Print this message"
|
||||
echo "--rust-toolchain The toolchain to run the tests with default: stable"
|
||||
echo "--multi-bit Run multi-bit tests only: default off"
|
||||
echo
|
||||
}
|
||||
|
||||
RUST_TOOLCHAIN="+stable"
|
||||
multi_bit=""
|
||||
not_multi_bit="_multi_bit"
|
||||
|
||||
while [ -n "$1" ]
|
||||
do
|
||||
case "$1" in
|
||||
"--help" | "-h" )
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
|
||||
"--rust-toolchain" )
|
||||
shift
|
||||
RUST_TOOLCHAIN="$1"
|
||||
;;
|
||||
|
||||
"--multi-bit" )
|
||||
multi_bit="_multi_bit"
|
||||
not_multi_bit=""
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Unknown param : $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then
|
||||
RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}"
|
||||
fi
|
||||
|
||||
CURR_DIR="$(dirname "$0")"
|
||||
ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")"
|
||||
|
||||
@@ -29,14 +72,15 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
|
||||
# mul_crt_4_4 is extremely flaky (~80% failure)
|
||||
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
|
||||
# test_integer_smart_mul_param_message_4_carry_4 is too slow
|
||||
filter_expression=''\
|
||||
'test(/^integer::.*$/)'\
|
||||
'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\
|
||||
'and not test(~mul_crt_param_message_4_carry_4)'\
|
||||
'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\
|
||||
'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)'
|
||||
filter_expression="""\
|
||||
test(/^integer::.*${multi_bit}/) \
|
||||
${not_multi_bit:+"and not test(~${not_multi_bit})"} \
|
||||
and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/) \
|
||||
and not test(~mul_crt_param_message_4_carry_4) \
|
||||
and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/) \
|
||||
and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)"""
|
||||
|
||||
cargo ${1:+"${1}"} nextest run \
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
--tests \
|
||||
--release \
|
||||
--package tfhe \
|
||||
@@ -45,39 +89,46 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
|
||||
--test-threads "${n_threads}" \
|
||||
-E "$filter_expression"
|
||||
|
||||
cargo ${1:+"${1}"} test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache \
|
||||
--doc \
|
||||
integer::
|
||||
if [[ "${multi_bit}" == "" ]]; then
|
||||
cargo "${RUST_TOOLCHAIN}" test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache \
|
||||
--doc \
|
||||
integer::
|
||||
fi
|
||||
else
|
||||
# block pbs are too slow for high params
|
||||
# mul_crt_4_4 is extremely flaky (~80% failure)
|
||||
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
|
||||
# test_integer_smart_mul_param_message_4_carry_4 is too slow
|
||||
filter_expression=''\
|
||||
'test(/^integer::.*$/)'\
|
||||
'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\
|
||||
'and not test(~mul_crt_param_message_4_carry_4)'\
|
||||
'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\
|
||||
'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)'
|
||||
filter_expression="""\
|
||||
test(/^integer::.*${multi_bit}/) \
|
||||
${not_multi_bit:+"and not test(~${not_multi_bit})"} \
|
||||
and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/) \
|
||||
and not test(~mul_crt_param_message_4_carry_4) \
|
||||
and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/) \
|
||||
and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)"""
|
||||
|
||||
cargo ${1:+"${1}"} nextest run \
|
||||
num_cpu_threads="$(${nproc_bin})"
|
||||
num_threads=$((num_cpu_threads * 2 / 3))
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
--tests \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--profile ci \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache \
|
||||
--test-threads "$(${nproc_bin})" \
|
||||
--test-threads $num_threads \
|
||||
-E "$filter_expression"
|
||||
|
||||
cargo ${1:+"${1}"} test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache \
|
||||
--doc \
|
||||
integer:: -- --test-threads="$(${nproc_bin})"
|
||||
if [[ "${multi_bit}" == "" ]]; then
|
||||
cargo "${RUST_TOOLCHAIN}" test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache \
|
||||
--doc \
|
||||
integer:: -- --test-threads="$(${nproc_bin})"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Test ran in $SECONDS seconds"
|
||||
|
||||
@@ -2,6 +2,47 @@
|
||||
|
||||
set -e
|
||||
|
||||
function usage() {
|
||||
echo "$0: shortint test runner"
|
||||
echo
|
||||
echo "--help Print this message"
|
||||
echo "--rust-toolchain The toolchain to run the tests with default: stable"
|
||||
echo "--multi-bit Run multi-bit tests only: default off"
|
||||
echo
|
||||
}
|
||||
|
||||
RUST_TOOLCHAIN="+stable"
|
||||
multi_bit=""
|
||||
|
||||
while [ -n "$1" ]
|
||||
do
|
||||
case "$1" in
|
||||
"--help" | "-h" )
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
|
||||
"--rust-toolchain" )
|
||||
shift
|
||||
RUST_TOOLCHAIN="$1"
|
||||
;;
|
||||
|
||||
"--multi-bit" )
|
||||
multi_bit="_multi_bit"
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Unknown param : $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then
|
||||
RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}"
|
||||
fi
|
||||
|
||||
CURR_DIR="$(dirname "$0")"
|
||||
ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")"
|
||||
|
||||
@@ -31,25 +72,25 @@ else
|
||||
fi
|
||||
|
||||
if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
|
||||
filter_expression_small_params=''\
|
||||
'('\
|
||||
' test(/^shortint::.*_param_message_1_carry_1$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_2$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_3$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_4$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_5$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_6$/)'\
|
||||
'or test(/^shortint::.*_param_message_2_carry_1$/)'\
|
||||
'or test(/^shortint::.*_param_message_2_carry_2$/)'\
|
||||
'or test(/^shortint::.*_param_message_2_carry_3$/)'\
|
||||
'or test(/^shortint::.*_param_message_3_carry_1$/)'\
|
||||
'or test(/^shortint::.*_param_message_3_carry_2$/)'\
|
||||
'or test(/^shortint::.*_param_message_3_carry_3$/)'\
|
||||
')'\
|
||||
'and not test(~smart_add_and_mul)' # This test is too slow
|
||||
filter_expression_small_params="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_1_carry_1/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_2/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_3/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_4/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_5/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_6/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_1/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_1/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_2/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_3/) \
|
||||
) \
|
||||
and not test(~smart_add_and_mul)""" # This test is too slow
|
||||
|
||||
# Run tests only no examples or benches with small params and more threads
|
||||
cargo ${1:+"${1}"} nextest run \
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
--tests \
|
||||
--release \
|
||||
--package tfhe \
|
||||
@@ -58,14 +99,14 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
|
||||
--test-threads "${n_threads_small}" \
|
||||
-E "${filter_expression_small_params}"
|
||||
|
||||
filter_expression_big_params=''\
|
||||
'('\
|
||||
' test(/^shortint::.*_param_message_4_carry_4$/)'\
|
||||
')'\
|
||||
'and not test(~smart_add_and_mul)'
|
||||
filter_expression_big_params="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_4_carry_4/) \
|
||||
) \
|
||||
and not test(~smart_add_and_mul)"""
|
||||
|
||||
# Run tests only no examples or benches with big params and less threads
|
||||
cargo ${1:+"${1}"} nextest run \
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
--tests \
|
||||
--release \
|
||||
--package tfhe \
|
||||
@@ -74,33 +115,35 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
|
||||
--test-threads "${n_threads_big}" \
|
||||
-E "${filter_expression_big_params}"
|
||||
|
||||
cargo ${1:+"${1}"} test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",shortint,internal-keycache \
|
||||
--doc \
|
||||
shortint::
|
||||
if [[ "${multi_bit}" == "" ]]; then
|
||||
cargo "${RUST_TOOLCHAIN}" test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",shortint,internal-keycache \
|
||||
--doc \
|
||||
shortint::
|
||||
fi
|
||||
else
|
||||
filter_expression=''\
|
||||
'('\
|
||||
' test(/^shortint::.*_param_message_1_carry_1$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_2$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_3$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_4$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_5$/)'\
|
||||
'or test(/^shortint::.*_param_message_1_carry_6$/)'\
|
||||
'or test(/^shortint::.*_param_message_2_carry_1$/)'\
|
||||
'or test(/^shortint::.*_param_message_2_carry_2$/)'\
|
||||
'or test(/^shortint::.*_param_message_2_carry_3$/)'\
|
||||
'or test(/^shortint::.*_param_message_3_carry_1$/)'\
|
||||
'or test(/^shortint::.*_param_message_3_carry_2$/)'\
|
||||
'or test(/^shortint::.*_param_message_3_carry_3$/)'\
|
||||
'or test(/^shortint::.*_param_message_4_carry_4$/)'\
|
||||
')'\
|
||||
'and not test(~smart_add_and_mul)' # This test is too slow
|
||||
filter_expression="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_1_carry_1/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_2/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_3/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_4/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_5/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_6/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_1/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_1/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_2/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_3/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_4_carry_4/) \
|
||||
)\
|
||||
and not test(~smart_add_and_mul)""" # This test is too slow
|
||||
|
||||
# Run tests only no examples or benches with small params and more threads
|
||||
cargo ${1:+"${1}"} nextest run \
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
--tests \
|
||||
--release \
|
||||
--package tfhe \
|
||||
@@ -109,12 +152,14 @@ else
|
||||
--test-threads "$(${nproc_bin})" \
|
||||
-E "${filter_expression}"
|
||||
|
||||
cargo ${1:+"${1}"} test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",shortint,internal-keycache \
|
||||
--doc \
|
||||
shortint:: -- --test-threads="$(${nproc_bin})"
|
||||
if [[ "${multi_bit}" == "" ]]; then
|
||||
cargo "${RUST_TOOLCHAIN}" test \
|
||||
--release \
|
||||
--package tfhe \
|
||||
--features="${ARCH_FEATURE}",shortint,internal-keycache \
|
||||
--doc \
|
||||
shortint:: -- --test-threads="$(${nproc_bin})"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Test ran in $SECONDS seconds"
|
||||
|
||||
@@ -18,17 +18,23 @@ rust-version = "1.67"
|
||||
[dev-dependencies]
|
||||
rand = "0.8.5"
|
||||
rand_distr = "0.4.3"
|
||||
kolmogorov_smirnov = "1.1.0"
|
||||
paste = "1.0.7"
|
||||
lazy_static = { version = "1.4.0" }
|
||||
criterion = "0.4.0"
|
||||
doc-comment = "0.3.3"
|
||||
serde_json = "1.0.94"
|
||||
clap = "4.2.7"
|
||||
# Used in user documentation
|
||||
bincode = "1.3.3"
|
||||
fs2 = { version = "0.4.3" }
|
||||
itertools = "0.10.5"
|
||||
num_cpus = "1.15"
|
||||
# For erf and normality test
|
||||
libm = "0.2.6"
|
||||
test-case = "*"
|
||||
combine = "*"
|
||||
env_logger = "*"
|
||||
log = "*"
|
||||
|
||||
[build-dependencies]
|
||||
cbindgen = { version = "0.24.3", optional = true }
|
||||
@@ -53,9 +59,10 @@ fs2 = { version = "0.4.3", optional = true }
|
||||
itertools = "0.10.5"
|
||||
|
||||
# wasm deps
|
||||
wasm-bindgen = { version = "0.2.63", features = [
|
||||
wasm-bindgen = { version = "0.2.86", features = [
|
||||
"serde-serialize",
|
||||
], optional = true }
|
||||
wasm-bindgen-rayon = { version = "1.0", optional = true }
|
||||
js-sys = { version = "0.3", optional = true }
|
||||
console_error_panic_hook = { version = "0.1.7", optional = true }
|
||||
serde-wasm-bindgen = { version = "0.4", optional = true }
|
||||
@@ -76,7 +83,12 @@ experimental-force_fft_algo_dif4 = []
|
||||
__c_api = ["cbindgen", "bincode"]
|
||||
boolean-c-api = ["boolean", "__c_api"]
|
||||
shortint-c-api = ["shortint", "__c_api"]
|
||||
high-level-c-api = ["boolean", "shortint", "integer", "__c_api"]
|
||||
high-level-c-api = [
|
||||
"boolean-c-api",
|
||||
"shortint-c-api",
|
||||
"integer",
|
||||
"__c_api"
|
||||
]
|
||||
|
||||
__wasm_api = [
|
||||
"wasm-bindgen",
|
||||
@@ -89,6 +101,9 @@ __wasm_api = [
|
||||
]
|
||||
boolean-client-js-wasm-api = ["boolean", "__wasm_api"]
|
||||
shortint-client-js-wasm-api = ["shortint", "__wasm_api"]
|
||||
integer-client-js-wasm-api = ["integer", "__wasm_api"]
|
||||
high-level-client-js-wasm-api = ["boolean", "shortint", "integer", "__wasm_api"]
|
||||
parallel-wasm-api = ["wasm-bindgen-rayon"]
|
||||
|
||||
nightly-avx512 = ["concrete-fft/nightly", "pulp/nightly"]
|
||||
|
||||
@@ -182,13 +197,29 @@ required-features = ["shortint", "internal-keycache"]
|
||||
name = "boolean_key_sizes"
|
||||
required-features = ["boolean", "internal-keycache"]
|
||||
|
||||
[[example]]
|
||||
name = "dark_market"
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[example]]
|
||||
name = "shortint_key_sizes"
|
||||
required-features = ["shortint", "internal-keycache"]
|
||||
|
||||
[[example]]
|
||||
name = "integer_compact_pk_ct_sizes"
|
||||
required-features = ["integer", "internal-keycache"]
|
||||
|
||||
[[example]]
|
||||
name = "micro_bench_and"
|
||||
required-features = ["boolean"]
|
||||
|
||||
[[example]]
|
||||
name = "regex_engine"
|
||||
required-features = ["integer"]
|
||||
|
||||
[[example]]
|
||||
name = "sha256_bool"
|
||||
required-features = ["boolean"]
|
||||
|
||||
[lib]
|
||||
crate-type = ["lib", "staticlib", "cdylib"]
|
||||
|
||||
@@ -7,9 +7,9 @@ use tfhe::boolean::parameters::{BooleanParameters, DEFAULT_PARAMETERS, TFHE_LIB_
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::*;
|
||||
use tfhe::shortint::Parameters;
|
||||
use tfhe::shortint::ClassicPBSParameters;
|
||||
|
||||
const SHORTINT_BENCH_PARAMS: [Parameters; 15] = [
|
||||
const SHORTINT_BENCH_PARAMS: [ClassicPBSParameters; 15] = [
|
||||
PARAM_MESSAGE_1_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_1,
|
||||
PARAM_MESSAGE_2_CARRY_0,
|
||||
@@ -125,7 +125,7 @@ fn multi_bit_benchmark_parameters<Scalar: Numeric>(
|
||||
(
|
||||
CryptoParametersRecord {
|
||||
lwe_dimension: Some(LweDimension(888)),
|
||||
lwe_modular_std_dev: Some(StandardDev(0.000002226459789930014)),
|
||||
lwe_modular_std_dev: Some(StandardDev(0.0000006125031601933181)),
|
||||
pbs_base_log: Some(DecompositionBaseLog(21)),
|
||||
pbs_level: Some(DecompositionLevelCount(1)),
|
||||
glwe_dimension: Some(GlweDimension(1)),
|
||||
@@ -400,7 +400,7 @@ fn multi_bit_pbs<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Syn
|
||||
&mut out_pbs_ct,
|
||||
&accumulator.as_view(),
|
||||
&multi_bit_bsk,
|
||||
ThreadCount(num_cpus::get()),
|
||||
ThreadCount(10),
|
||||
);
|
||||
black_box(&mut out_pbs_ct);
|
||||
})
|
||||
|
||||
@@ -23,13 +23,13 @@ use tfhe::shortint::parameters::{
|
||||
/// in radix decomposition
|
||||
struct ParamsAndNumBlocksIter {
|
||||
params_and_bit_sizes:
|
||||
itertools::Product<IntoIter<tfhe::shortint::Parameters, 1>, IntoIter<usize, 7>>,
|
||||
itertools::Product<IntoIter<tfhe::shortint::ClassicPBSParameters, 1>, IntoIter<usize, 7>>,
|
||||
}
|
||||
|
||||
impl Default for ParamsAndNumBlocksIter {
|
||||
fn default() -> Self {
|
||||
// FIXME One set of parameter is tested since we want to benchmark only quickest operations.
|
||||
const PARAMS: [tfhe::shortint::Parameters; 1] = [
|
||||
const PARAMS: [tfhe::shortint::ClassicPBSParameters; 1] = [
|
||||
PARAM_MESSAGE_2_CARRY_2,
|
||||
// PARAM_MESSAGE_3_CARRY_3,
|
||||
// PARAM_MESSAGE_4_CARRY_4,
|
||||
@@ -42,7 +42,7 @@ impl Default for ParamsAndNumBlocksIter {
|
||||
}
|
||||
}
|
||||
impl Iterator for ParamsAndNumBlocksIter {
|
||||
type Item = (tfhe::shortint::Parameters, usize, usize);
|
||||
type Item = (tfhe::shortint::ClassicPBSParameters, usize, usize);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let (param, bit_size) = self.params_and_bit_sizes.next()?;
|
||||
@@ -664,11 +664,29 @@ define_server_key_bench_fn!(
|
||||
define_server_key_bench_default_fn!(method_name: max_parallelized, display_name: max);
|
||||
define_server_key_bench_default_fn!(method_name: min_parallelized, display_name: min);
|
||||
define_server_key_bench_default_fn!(method_name: eq_parallelized, display_name: equal);
|
||||
define_server_key_bench_default_fn!(method_name: ne_parallelized, display_name: not_equal);
|
||||
define_server_key_bench_default_fn!(method_name: lt_parallelized, display_name: less_than);
|
||||
define_server_key_bench_default_fn!(method_name: le_parallelized, display_name: less_or_equal);
|
||||
define_server_key_bench_default_fn!(method_name: gt_parallelized, display_name: greater_than);
|
||||
define_server_key_bench_default_fn!(method_name: ge_parallelized, display_name: greater_or_equal);
|
||||
|
||||
define_server_key_bench_default_fn!(
|
||||
method_name: left_shift_parallelized,
|
||||
display_name: left_shift
|
||||
);
|
||||
define_server_key_bench_default_fn!(
|
||||
method_name: right_shift_parallelized,
|
||||
display_name: right_shift
|
||||
);
|
||||
define_server_key_bench_default_fn!(
|
||||
method_name: rotate_left_parallelized,
|
||||
display_name: rotate_left
|
||||
);
|
||||
define_server_key_bench_default_fn!(
|
||||
method_name: rotate_right_parallelized,
|
||||
display_name: rotate_right
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
smart_arithmetic_operation,
|
||||
smart_neg,
|
||||
@@ -793,10 +811,15 @@ criterion_group!(
|
||||
min_parallelized,
|
||||
max_parallelized,
|
||||
eq_parallelized,
|
||||
ne_parallelized,
|
||||
lt_parallelized,
|
||||
le_parallelized,
|
||||
gt_parallelized,
|
||||
ge_parallelized,
|
||||
left_shift_parallelized,
|
||||
right_shift_parallelized,
|
||||
rotate_left_parallelized,
|
||||
rotate_right_parallelized,
|
||||
scalar_add_parallelized,
|
||||
scalar_sub_parallelized,
|
||||
scalar_mul_parallelized,
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::utilities::{write_to_json, OperatorType};
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::*;
|
||||
use tfhe::shortint::{CiphertextBig, Parameters, ServerKey};
|
||||
use tfhe::shortint::{CiphertextBig, ClassicPBSParameters, ServerKey, ShortintParameterSet};
|
||||
|
||||
use rand::Rng;
|
||||
use tfhe::shortint::keycache::KEY_CACHE;
|
||||
@@ -13,14 +13,14 @@ use tfhe::shortint::keycache::KEY_CACHE;
|
||||
use tfhe::shortint::keycache::KEY_CACHE_WOPBS;
|
||||
use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_4_NORM2_6;
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS: [Parameters; 4] = [
|
||||
const SERVER_KEY_BENCH_PARAMS: [ClassicPBSParameters; 4] = [
|
||||
PARAM_MESSAGE_1_CARRY_1,
|
||||
PARAM_MESSAGE_2_CARRY_2,
|
||||
PARAM_MESSAGE_3_CARRY_3,
|
||||
PARAM_MESSAGE_4_CARRY_4,
|
||||
];
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS_EXTENDED: [Parameters; 15] = [
|
||||
const SERVER_KEY_BENCH_PARAMS_EXTENDED: [ClassicPBSParameters; 15] = [
|
||||
PARAM_MESSAGE_1_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_1,
|
||||
PARAM_MESSAGE_2_CARRY_0,
|
||||
@@ -43,7 +43,7 @@ fn bench_server_key_unary_function<F>(
|
||||
bench_name: &str,
|
||||
display_name: &str,
|
||||
unary_op: F,
|
||||
params: &[Parameters],
|
||||
params: &[ClassicPBSParameters],
|
||||
) where
|
||||
F: Fn(&ServerKey, &mut CiphertextBig),
|
||||
{
|
||||
@@ -55,7 +55,7 @@ fn bench_server_key_unary_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus.0 as u64;
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
|
||||
let clear_text = rng.gen::<u64>() % modulus;
|
||||
|
||||
@@ -87,7 +87,7 @@ fn bench_server_key_binary_function<F>(
|
||||
bench_name: &str,
|
||||
display_name: &str,
|
||||
binary_op: F,
|
||||
params: &[Parameters],
|
||||
params: &[ClassicPBSParameters],
|
||||
) where
|
||||
F: Fn(&ServerKey, &mut CiphertextBig, &mut CiphertextBig),
|
||||
{
|
||||
@@ -99,7 +99,7 @@ fn bench_server_key_binary_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus.0 as u64;
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
@@ -133,7 +133,7 @@ fn bench_server_key_binary_scalar_function<F>(
|
||||
bench_name: &str,
|
||||
display_name: &str,
|
||||
binary_op: F,
|
||||
params: &[Parameters],
|
||||
params: &[ClassicPBSParameters],
|
||||
) where
|
||||
F: Fn(&ServerKey, &mut CiphertextBig, u8),
|
||||
{
|
||||
@@ -145,7 +145,7 @@ fn bench_server_key_binary_scalar_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus.0 as u64;
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
@@ -178,7 +178,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
|
||||
bench_name: &str,
|
||||
display_name: &str,
|
||||
binary_op: F,
|
||||
params: &[Parameters],
|
||||
params: &[ClassicPBSParameters],
|
||||
) where
|
||||
F: Fn(&ServerKey, &mut CiphertextBig, u8),
|
||||
{
|
||||
@@ -190,7 +190,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus.0 as u64;
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
assert_ne!(modulus, 1);
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
@@ -231,7 +231,7 @@ fn carry_extract(c: &mut Criterion) {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus.0 as u64;
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
|
||||
@@ -267,7 +267,7 @@ fn programmable_bootstrapping(c: &mut Criterion) {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = cks.parameters.message_modulus.0 as u64;
|
||||
let modulus = cks.parameters.message_modulus().0 as u64;
|
||||
|
||||
let acc = sks.generate_accumulator(|x| x);
|
||||
|
||||
@@ -297,12 +297,15 @@ fn programmable_bootstrapping(c: &mut Criterion) {
|
||||
bench_group.finish();
|
||||
}
|
||||
|
||||
// TODO: remove?
|
||||
fn _bench_wopbs_param_message_8_norm2_5(c: &mut Criterion) {
|
||||
let mut bench_group = c.benchmark_group("programmable_bootstrap");
|
||||
|
||||
let param = WOPBS_PARAM_MESSAGE_4_NORM2_6;
|
||||
let param_set: ShortintParameterSet = param.try_into().unwrap();
|
||||
let pbs_params = param_set.pbs_parameters().unwrap();
|
||||
|
||||
let keys = KEY_CACHE_WOPBS.get_from_param((param, param));
|
||||
let keys = KEY_CACHE_WOPBS.get_from_param((pbs_params, param));
|
||||
let (cks, _, wopbs_key) = (keys.client_key(), keys.server_key(), keys.wopbs_key());
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::path::PathBuf;
|
||||
use tfhe::boolean::parameters::BooleanParameters;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
#[cfg(feature = "shortint")]
|
||||
use tfhe::shortint::Parameters;
|
||||
use tfhe::shortint::ClassicPBSParameters;
|
||||
|
||||
#[derive(Clone, Copy, Default, Serialize)]
|
||||
pub struct CryptoParametersRecord {
|
||||
@@ -52,8 +52,8 @@ impl From<BooleanParameters> for CryptoParametersRecord {
|
||||
}
|
||||
|
||||
#[cfg(feature = "shortint")]
|
||||
impl From<Parameters> for CryptoParametersRecord {
|
||||
fn from(params: Parameters) -> Self {
|
||||
impl From<ClassicPBSParameters> for CryptoParametersRecord {
|
||||
fn from(params: ClassicPBSParameters) -> Self {
|
||||
CryptoParametersRecord {
|
||||
lwe_dimension: Some(params.lwe_dimension),
|
||||
glwe_dimension: Some(params.glwe_dimension),
|
||||
@@ -64,11 +64,11 @@ impl From<Parameters> for CryptoParametersRecord {
|
||||
pbs_level: Some(params.pbs_level),
|
||||
ks_base_log: Some(params.ks_base_log),
|
||||
ks_level: Some(params.ks_level),
|
||||
pfks_level: Some(params.pfks_level),
|
||||
pfks_base_log: Some(params.pfks_base_log),
|
||||
pfks_modular_std_dev: Some(params.pfks_modular_std_dev),
|
||||
cbs_level: Some(params.cbs_level),
|
||||
cbs_base_log: Some(params.cbs_base_log),
|
||||
pfks_level: None,
|
||||
pfks_base_log: None,
|
||||
pfks_modular_std_dev: None,
|
||||
cbs_level: None,
|
||||
cbs_base_log: None,
|
||||
message_modulus: Some(params.message_modulus.0),
|
||||
carry_modulus: Some(params.carry_modulus.0),
|
||||
}
|
||||
|
||||
@@ -9,22 +9,24 @@ int uint128_client_key(const ClientKey *client_key) {
|
||||
FheUint128 *lhs = NULL;
|
||||
FheUint128 *rhs = NULL;
|
||||
FheUint128 *result = NULL;
|
||||
U128 lhs_clear = {10, 20};
|
||||
U128 rhs_clear = {1, 2};
|
||||
U128 result_clear = {0};
|
||||
|
||||
ok = fhe_uint128_try_encrypt_with_client_key_u128(10, 20, client_key, &lhs);
|
||||
ok = fhe_uint128_try_encrypt_with_client_key_u128(lhs_clear, client_key, &lhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint128_try_encrypt_with_client_key_u128(1, 2, client_key, &rhs);
|
||||
ok = fhe_uint128_try_encrypt_with_client_key_u128(rhs_clear, client_key, &rhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint128_sub(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
uint64_t w0, w1;
|
||||
ok = fhe_uint128_decrypt(result, client_key, &w0, &w1);
|
||||
ok = fhe_uint128_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(w0 == 9);
|
||||
assert(w1 == 18);
|
||||
assert(result_clear.w0 == 9);
|
||||
assert(result_clear.w1 == 18);
|
||||
|
||||
fhe_uint128_destroy(lhs);
|
||||
fhe_uint128_destroy(rhs);
|
||||
@@ -37,22 +39,24 @@ int uint128_encrypt_trivial(const ClientKey *client_key) {
|
||||
FheUint128 *lhs = NULL;
|
||||
FheUint128 *rhs = NULL;
|
||||
FheUint128 *result = NULL;
|
||||
|
||||
ok = fhe_uint128_try_encrypt_trivial_u128(10, 20, &lhs);
|
||||
U128 lhs_clear = {10, 20};
|
||||
U128 rhs_clear = {1, 2};
|
||||
U128 result_clear = {0};
|
||||
|
||||
ok = fhe_uint128_try_encrypt_trivial_u128(lhs_clear, &lhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint128_try_encrypt_trivial_u128(1, 2, &rhs);
|
||||
ok = fhe_uint128_try_encrypt_trivial_u128(rhs_clear, &rhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint128_sub(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
uint64_t w0, w1;
|
||||
ok = fhe_uint128_decrypt(result, client_key, &w0, &w1);
|
||||
ok = fhe_uint128_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(w0 == 9);
|
||||
assert(w1 == 18);
|
||||
assert(result_clear.w0 == 9);
|
||||
assert(result_clear.w1 == 18);
|
||||
|
||||
fhe_uint128_destroy(lhs);
|
||||
fhe_uint128_destroy(rhs);
|
||||
@@ -65,22 +69,24 @@ int uint128_public_key(const ClientKey *client_key, const PublicKey *public_key)
|
||||
FheUint128 *lhs = NULL;
|
||||
FheUint128 *rhs = NULL;
|
||||
FheUint128 *result = NULL;
|
||||
U128 lhs_clear = {10, 20};
|
||||
U128 rhs_clear = {1, 2};
|
||||
U128 result_clear = {0};
|
||||
|
||||
ok = fhe_uint128_try_encrypt_with_public_key_u128(1, 2, public_key, &lhs);
|
||||
ok = fhe_uint128_try_encrypt_with_public_key_u128(lhs_clear, public_key, &lhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint128_try_encrypt_with_public_key_u128(10, 20, public_key, &rhs);
|
||||
ok = fhe_uint128_try_encrypt_with_public_key_u128(rhs_clear, public_key, &rhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint128_add(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
uint64_t w0, w1;
|
||||
ok = fhe_uint128_decrypt(result, client_key, &w0, &w1);
|
||||
ok = fhe_uint128_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(w0 == 11);
|
||||
assert(w1 == 22);
|
||||
assert(result_clear.w0 == 11);
|
||||
assert(result_clear.w1 == 22);
|
||||
|
||||
fhe_uint128_destroy(lhs);
|
||||
fhe_uint128_destroy(rhs);
|
||||
|
||||
@@ -10,14 +10,9 @@ int uint256_client_key(const ClientKey *client_key) {
|
||||
FheUint256 *rhs = NULL;
|
||||
FheUint256 *result = NULL;
|
||||
FheUint64 *cast_result = NULL;
|
||||
U256 *lhs_clear = NULL;
|
||||
U256 *rhs_clear = NULL;
|
||||
U256 *result_clear = NULL;
|
||||
|
||||
ok = u256_from_u64_words(1, 2, 3, 4, &lhs_clear);
|
||||
assert(ok == 0);
|
||||
ok = u256_from_u64_words(5, 6, 7, 8, &rhs_clear);
|
||||
assert(ok == 0);
|
||||
U256 lhs_clear = {1 , 2, 3, 4};
|
||||
U256 rhs_clear = {5, 6, 7, 8};
|
||||
U256 result_clear = { 0 };
|
||||
|
||||
ok = fhe_uint256_try_encrypt_with_client_key_u256(lhs_clear, client_key, &lhs);
|
||||
assert(ok == 0);
|
||||
@@ -31,14 +26,10 @@ int uint256_client_key(const ClientKey *client_key) {
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
uint64_t w0, w1, w2, w3;
|
||||
ok = u256_to_u64_words(result_clear, &w0, &w1, &w2, &w3);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(w0 == 6);
|
||||
assert(w1 == 8);
|
||||
assert(w2 == 10);
|
||||
assert(w3 == 12);
|
||||
assert(result_clear.w0 == 6);
|
||||
assert(result_clear.w1 == 8);
|
||||
assert(result_clear.w2 == 10);
|
||||
assert(result_clear.w3 == 12);
|
||||
|
||||
// try some casting
|
||||
ok = fhe_uint256_cast_into_fhe_uint64(result, &cast_result);
|
||||
@@ -48,9 +39,6 @@ int uint256_client_key(const ClientKey *client_key) {
|
||||
assert(ok == 0);
|
||||
assert(u64_clear == 6);
|
||||
|
||||
u256_destroy(lhs_clear);
|
||||
u256_destroy(rhs_clear);
|
||||
u256_destroy(result_clear);
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
@@ -63,14 +51,9 @@ int uint256_encrypt_trivial(const ClientKey *client_key) {
|
||||
FheUint256 *lhs = NULL;
|
||||
FheUint256 *rhs = NULL;
|
||||
FheUint256 *result = NULL;
|
||||
U256 *lhs_clear = NULL;
|
||||
U256 *rhs_clear = NULL;
|
||||
U256 *result_clear = NULL;
|
||||
|
||||
ok = u256_from_u64_words(1, 2, 3, 4, &lhs_clear);
|
||||
assert(ok == 0);
|
||||
ok = u256_from_u64_words(5, 6, 7, 8, &rhs_clear);
|
||||
assert(ok == 0);
|
||||
U256 lhs_clear = {1 , 2, 3, 4};
|
||||
U256 rhs_clear = {5, 6, 7, 8};
|
||||
U256 result_clear = { 0 };
|
||||
|
||||
ok = fhe_uint256_try_encrypt_trivial_u256(lhs_clear, &lhs);
|
||||
assert(ok == 0);
|
||||
@@ -84,18 +67,11 @@ int uint256_encrypt_trivial(const ClientKey *client_key) {
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
uint64_t w0, w1, w2, w3;
|
||||
ok = u256_to_u64_words(result_clear, &w0, &w1, &w2, &w3);
|
||||
assert(ok == 0);
|
||||
assert(result_clear.w0 == 6);
|
||||
assert(result_clear.w1 == 8);
|
||||
assert(result_clear.w2 == 10);
|
||||
assert(result_clear.w3 == 12);
|
||||
|
||||
assert(w0 == 6);
|
||||
assert(w1 == 8);
|
||||
assert(w2 == 10);
|
||||
assert(w3 == 12);
|
||||
|
||||
u256_destroy(lhs_clear);
|
||||
u256_destroy(rhs_clear);
|
||||
u256_destroy(result_clear);
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
@@ -107,14 +83,9 @@ int uint256_public_key(const ClientKey *client_key, const PublicKey *public_key)
|
||||
FheUint256 *lhs = NULL;
|
||||
FheUint256 *rhs = NULL;
|
||||
FheUint256 *result = NULL;
|
||||
U256 *lhs_clear = NULL;
|
||||
U256 *rhs_clear = NULL;
|
||||
U256 *result_clear = NULL;
|
||||
|
||||
ok = u256_from_u64_words(5, 6, 7, 8, &lhs_clear);
|
||||
assert(ok == 0);
|
||||
ok = u256_from_u64_words(1, 2, 3, 4, &rhs_clear);
|
||||
assert(ok == 0);
|
||||
U256 lhs_clear = {5, 6, 7, 8};
|
||||
U256 rhs_clear = {1 , 2, 3, 4};
|
||||
U256 result_clear = { 0 };
|
||||
|
||||
ok = fhe_uint256_try_encrypt_with_public_key_u256(lhs_clear, public_key, &lhs);
|
||||
assert(ok == 0);
|
||||
@@ -128,18 +99,11 @@ int uint256_public_key(const ClientKey *client_key, const PublicKey *public_key)
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
uint64_t w0, w1, w2, w3;
|
||||
ok = u256_to_u64_words(result_clear, &w0, &w1, &w2, &w3);
|
||||
assert(ok == 0);
|
||||
assert(result_clear.w0 == 4);
|
||||
assert(result_clear.w1 == 4);
|
||||
assert(result_clear.w2 == 4);
|
||||
assert(result_clear.w3 == 4);
|
||||
|
||||
assert(w0 == 4);
|
||||
assert(w1 == 4);
|
||||
assert(w2 == 4);
|
||||
assert(w3 == 4);
|
||||
|
||||
u256_destroy(lhs_clear);
|
||||
u256_destroy(rhs_clear);
|
||||
u256_destroy(result_clear);
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
|
||||
213
tfhe/c_api_tests/test_high_level_custom_integers.c
Normal file
213
tfhe/c_api_tests/test_high_level_custom_integers.c
Normal file
@@ -0,0 +1,213 @@
|
||||
#include <tfhe.h>
|
||||
|
||||
#include <assert.h>
|
||||
#include <inttypes.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int uint256_client_key(const ClientKey *client_key) {
|
||||
int ok;
|
||||
FheUint256 *lhs = NULL;
|
||||
FheUint256 *rhs = NULL;
|
||||
FheUint256 *result = NULL;
|
||||
FheUint64 *cast_result = NULL;
|
||||
U256 lhs_clear = {1, 2, 3, 4};
|
||||
U256 rhs_clear = {5, 6, 7, 8};
|
||||
U256 result_clear = {0};
|
||||
|
||||
ok = fhe_uint256_try_encrypt_with_client_key_u256(lhs_clear, client_key, &lhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_try_encrypt_with_client_key_u256(rhs_clear, client_key, &rhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_add(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(result_clear.w0 == 6);
|
||||
assert(result_clear.w1 == 8);
|
||||
assert(result_clear.w2 == 10);
|
||||
assert(result_clear.w3 == 12);
|
||||
|
||||
// try some casting
|
||||
ok = fhe_uint256_cast_into_fhe_uint64(result, &cast_result);
|
||||
assert(ok == 0);
|
||||
uint64_t u64_clear;
|
||||
ok = fhe_uint64_decrypt(cast_result, client_key, &u64_clear);
|
||||
assert(ok == 0);
|
||||
assert(u64_clear == 6);
|
||||
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
fhe_uint64_destroy(cast_result);
|
||||
return ok;
|
||||
}
|
||||
|
||||
int uint256_encrypt_trivial(const ClientKey *client_key) {
|
||||
int ok;
|
||||
FheUint256 *lhs = NULL;
|
||||
FheUint256 *rhs = NULL;
|
||||
FheUint256 *result = NULL;
|
||||
U256 lhs_clear = {1, 2, 3, 4};
|
||||
U256 rhs_clear = {5, 6, 7, 8};
|
||||
U256 result_clear = {0};
|
||||
|
||||
ok = fhe_uint256_try_encrypt_trivial_u256(lhs_clear, &lhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_try_encrypt_trivial_u256(rhs_clear, &rhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_add(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(result_clear.w0 == 6);
|
||||
assert(result_clear.w1 == 8);
|
||||
assert(result_clear.w2 == 10);
|
||||
assert(result_clear.w3 == 12);
|
||||
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
return ok;
|
||||
}
|
||||
|
||||
int uint256_public_key(const ClientKey *client_key,
|
||||
const CompressedCompactPublicKey *compressed_public_key) {
|
||||
int ok;
|
||||
CompactPublicKey *public_key = NULL;
|
||||
FheUint256 *lhs = NULL;
|
||||
FheUint256 *rhs = NULL;
|
||||
FheUint256 *result = NULL;
|
||||
CompactFheUint256List *list = NULL;
|
||||
|
||||
U256 result_clear = {0};
|
||||
U256 clears[2] = {{5, 6, 7, 8}, {1, 2, 3, 4}};
|
||||
|
||||
ok = compressed_compact_public_key_decompress(compressed_public_key, &public_key);
|
||||
assert(ok == 0);
|
||||
|
||||
// Compact list example
|
||||
{
|
||||
ok = compact_fhe_uint256_list_try_encrypt_with_compact_public_key_u256(
|
||||
&clears[0], 2, public_key, &list);
|
||||
assert(ok == 0);
|
||||
|
||||
size_t len = 0;
|
||||
ok = compact_fhe_uint256_list_len(list, &len);
|
||||
assert(ok == 0);
|
||||
assert(len == 2);
|
||||
|
||||
FheUint256 *expand_output[2] = {NULL};
|
||||
ok = compact_fhe_uint256_list_expand(list, &expand_output[0], 2);
|
||||
assert(ok == 0);
|
||||
|
||||
// transfer ownership
|
||||
lhs = expand_output[0];
|
||||
rhs = expand_output[1];
|
||||
|
||||
ok = fhe_uint256_sub(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(result_clear.w0 == 4);
|
||||
assert(result_clear.w1 == 4);
|
||||
assert(result_clear.w2 == 4);
|
||||
assert(result_clear.w3 == 4);
|
||||
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
}
|
||||
|
||||
{
|
||||
ok = fhe_uint256_try_encrypt_with_compact_public_key_u256(clears[0], public_key, &lhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_try_encrypt_with_compact_public_key_u256(clears[1], public_key, &rhs);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_sub(lhs, rhs, &result);
|
||||
assert(ok == 0);
|
||||
|
||||
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
|
||||
assert(ok == 0);
|
||||
|
||||
assert(result_clear.w0 == 4);
|
||||
assert(result_clear.w1 == 4);
|
||||
assert(result_clear.w2 == 4);
|
||||
assert(result_clear.w3 == 4);
|
||||
|
||||
fhe_uint256_destroy(lhs);
|
||||
fhe_uint256_destroy(rhs);
|
||||
fhe_uint256_destroy(result);
|
||||
}
|
||||
|
||||
compact_public_key_destroy(public_key);
|
||||
return ok;
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
int ok = 0;
|
||||
{
|
||||
ConfigBuilder *builder;
|
||||
Config *config;
|
||||
|
||||
config_builder_all_disabled(&builder);
|
||||
config_builder_enable_custom_integers(&builder, SHORTINT_PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
config_builder_build(builder, &config);
|
||||
|
||||
ClientKey *client_key = NULL;
|
||||
ServerKey *server_key = NULL;
|
||||
CompressedCompactPublicKey *compressed_public_key = NULL;
|
||||
|
||||
generate_keys(config, &client_key, &server_key);
|
||||
compressed_compact_public_key_new(client_key, &compressed_public_key);
|
||||
|
||||
set_server_key(server_key);
|
||||
|
||||
uint256_client_key(client_key);
|
||||
uint256_encrypt_trivial(client_key);
|
||||
uint256_public_key(client_key, compressed_public_key);
|
||||
|
||||
client_key_destroy(client_key);
|
||||
compressed_compact_public_key_destroy(compressed_public_key);
|
||||
server_key_destroy(server_key);
|
||||
}
|
||||
|
||||
{
|
||||
ConfigBuilder *builder;
|
||||
Config *config;
|
||||
|
||||
config_builder_all_disabled(&builder);
|
||||
config_builder_enable_custom_integers(&builder,
|
||||
SHORTINT_PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
config_builder_build(builder, &config);
|
||||
|
||||
ClientKey *client_key = NULL;
|
||||
ServerKey *server_key = NULL;
|
||||
CompressedCompactPublicKey *compressed_public_key = NULL;
|
||||
|
||||
generate_keys(config, &client_key, &server_key);
|
||||
compressed_compact_public_key_new(client_key, &compressed_public_key);
|
||||
|
||||
set_server_key(server_key);
|
||||
|
||||
uint256_client_key(client_key);
|
||||
uint256_encrypt_trivial(client_key);
|
||||
uint256_public_key(client_key, compressed_public_key);
|
||||
|
||||
client_key_destroy(client_key);
|
||||
compressed_compact_public_key_destroy(compressed_public_key);
|
||||
server_key_destroy(server_key);
|
||||
}
|
||||
return ok;
|
||||
}
|
||||
@@ -43,6 +43,8 @@ void micro_bench_and() {
|
||||
|
||||
destroy_boolean_client_key(cks);
|
||||
destroy_boolean_server_key(sks);
|
||||
destroy_boolean_ciphertext(ct_left);
|
||||
destroy_boolean_ciphertext(ct_right);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
|
||||
@@ -8,16 +8,13 @@
|
||||
void test_predefined_keygen_w_serde(void) {
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintServerKey *sks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
ShortintCiphertext *ct = NULL;
|
||||
Buffer ct_ser_buffer = {.pointer = NULL, .length = 0};
|
||||
ShortintCiphertext *deser_ct = NULL;
|
||||
ShortintCompressedCiphertext *cct = NULL;
|
||||
ShortintCompressedCiphertext *deser_cct = NULL;
|
||||
ShortintCiphertext *decompressed_ct = NULL;
|
||||
|
||||
int get_params_ok = shortint_get_parameters(2, 2, ¶ms);
|
||||
assert(get_params_ok == 0);
|
||||
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
|
||||
assert(gen_keys_ok == 0);
|
||||
@@ -67,7 +64,6 @@ void test_predefined_keygen_w_serde(void) {
|
||||
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_server_key(sks);
|
||||
destroy_shortint_parameters(params);
|
||||
destroy_shortint_ciphertext(ct);
|
||||
destroy_shortint_ciphertext(deser_ct);
|
||||
destroy_shortint_compressed_ciphertext(cct);
|
||||
@@ -79,11 +75,8 @@ void test_predefined_keygen_w_serde(void) {
|
||||
void test_server_key_trivial_encrypt(void) {
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintServerKey *sks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
ShortintCiphertext *ct = NULL;
|
||||
|
||||
int get_params_ok = shortint_get_parameters(2, 2, ¶ms);
|
||||
assert(get_params_ok == 0);
|
||||
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
|
||||
assert(gen_keys_ok == 0);
|
||||
@@ -99,25 +92,32 @@ void test_server_key_trivial_encrypt(void) {
|
||||
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_server_key(sks);
|
||||
destroy_shortint_parameters(params);
|
||||
destroy_shortint_ciphertext(ct);
|
||||
}
|
||||
|
||||
void test_custom_keygen(void) {
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintServerKey *sks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
|
||||
int params_ok =
|
||||
shortint_create_parameters(10, 1, 1024, 10e-100, 10e-100, 2, 3, 2, 3, 2, 3, 10e-100, 2, 3, 2,
|
||||
2, 64, ShortintEncryptionKeyChoiceBig, ¶ms);
|
||||
assert(params_ok == 0);
|
||||
ShortintPBSParameters params = {
|
||||
.lwe_dimension = 10,
|
||||
.glwe_dimension = 1,
|
||||
.polynomial_size = 1024,
|
||||
.lwe_modular_std_dev = 10e-100,
|
||||
.glwe_modular_std_dev = 10e-100,
|
||||
.pbs_base_log = 2,
|
||||
.pbs_level = 3,
|
||||
.ks_base_log = 2,
|
||||
.ks_level = 3,
|
||||
.message_modulus = 2,
|
||||
.carry_modulus = 2,
|
||||
.modulus_power_of_2_exponent = 64,
|
||||
.encryption_key_choice = ShortintEncryptionKeyChoiceBig,
|
||||
};
|
||||
|
||||
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
|
||||
|
||||
assert(gen_keys_ok == 0);
|
||||
|
||||
destroy_shortint_parameters(params);
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_server_key(sks);
|
||||
}
|
||||
@@ -126,12 +126,9 @@ void test_public_keygen(ShortintPublicKeyKind pk_kind) {
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintPublicKey *pks = NULL;
|
||||
ShortintPublicKey *pks_deser = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
ShortintCiphertext *ct = NULL;
|
||||
Buffer pks_ser_buff = {.pointer = NULL, .length = 0};
|
||||
|
||||
int get_params_ok = shortint_get_parameters(2, 2, ¶ms);
|
||||
assert(get_params_ok == 0);
|
||||
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
int gen_keys_ok = shortint_gen_client_key(params, &cks);
|
||||
assert(gen_keys_ok == 0);
|
||||
@@ -157,7 +154,6 @@ void test_public_keygen(ShortintPublicKeyKind pk_kind) {
|
||||
|
||||
assert(result == 2);
|
||||
|
||||
destroy_shortint_parameters(params);
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_public_key(pks);
|
||||
destroy_shortint_public_key(pks_deser);
|
||||
@@ -169,11 +165,8 @@ void test_compressed_public_keygen(ShortintPublicKeyKind pk_kind) {
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintCompressedPublicKey *cpks = NULL;
|
||||
ShortintPublicKey *pks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
ShortintCiphertext *ct = NULL;
|
||||
|
||||
int get_params_ok = shortint_get_parameters(2, 2, ¶ms);
|
||||
assert(get_params_ok == 0);
|
||||
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
int gen_keys_ok = shortint_gen_client_key(params, &cks);
|
||||
assert(gen_keys_ok == 0);
|
||||
@@ -204,7 +197,6 @@ void test_compressed_public_keygen(ShortintPublicKeyKind pk_kind) {
|
||||
|
||||
assert(result == 2);
|
||||
|
||||
destroy_shortint_parameters(params);
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_compressed_public_key(cpks);
|
||||
destroy_shortint_public_key(pks);
|
||||
|
||||
@@ -41,10 +41,7 @@ void test_shortint_pbs_2_bits_message(void) {
|
||||
ShortintPBSLookupTable *accumulator = NULL;
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintServerKey *sks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
|
||||
int get_params_ok = shortint_get_parameters(2, 2, ¶ms);
|
||||
assert(get_params_ok == 0);
|
||||
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
|
||||
assert(gen_keys_ok == 0);
|
||||
@@ -111,17 +108,13 @@ void test_shortint_pbs_2_bits_message(void) {
|
||||
destroy_shortint_pbs_accumulator(accumulator);
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_server_key(sks);
|
||||
destroy_shortint_parameters(params);
|
||||
}
|
||||
|
||||
void test_shortint_bivariate_pbs_2_bits_message(void) {
|
||||
ShortintBivariatePBSLookupTable *accumulator = NULL;
|
||||
ShortintClientKey *cks = NULL;
|
||||
ShortintServerKey *sks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
|
||||
int get_params_ok = shortint_get_parameters(2, 2, ¶ms);
|
||||
assert(get_params_ok == 0);
|
||||
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
|
||||
assert(gen_keys_ok == 0);
|
||||
@@ -187,7 +180,6 @@ void test_shortint_bivariate_pbs_2_bits_message(void) {
|
||||
destroy_shortint_bivariate_pbs_accumulator(accumulator);
|
||||
destroy_shortint_client_key(cks);
|
||||
destroy_shortint_server_key(sks);
|
||||
destroy_shortint_parameters(params);
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
|
||||
@@ -419,10 +419,10 @@ void test_server_key(void) {
|
||||
ShortintCompressedServerKey *deser_csks = NULL;
|
||||
Buffer sks_ser_buffer = {.pointer = NULL, .length = 0};
|
||||
ShortintServerKey *deser_sks = NULL;
|
||||
ShortintParameters *params = NULL;
|
||||
ShortintClientKey *cks_small = NULL;
|
||||
ShortintServerKey *sks_small = NULL;
|
||||
ShortintParameters *params_small = NULL;
|
||||
ShortintPBSParameters params = { 0 };
|
||||
ShortintPBSParameters params_small = { 0 };
|
||||
|
||||
const uint32_t message_bits = 2;
|
||||
const uint32_t carry_bits = 2;
|
||||
@@ -736,8 +736,6 @@ void test_server_key(void) {
|
||||
destroy_shortint_client_key(deser_cks);
|
||||
destroy_shortint_compressed_server_key(deser_csks);
|
||||
destroy_shortint_server_key(deser_sks);
|
||||
destroy_shortint_parameters(params);
|
||||
destroy_shortint_parameters(params_small);
|
||||
destroy_buffer(&cks_ser_buffer);
|
||||
destroy_buffer(&csks_ser_buffer);
|
||||
destroy_buffer(&sks_ser_buffer);
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
* [What is TFHE-rs?](README.md)
|
||||
|
||||
## Getting Started
|
||||
|
||||
* [Installation](getting\_started/installation.md)
|
||||
* [Quick Start](getting\_started/quick\_start.md)
|
||||
* [Supported Operations](getting\_started/operations.md)
|
||||
@@ -33,6 +32,12 @@
|
||||
* [Cryptographic Parameters](integer/parameters.md)
|
||||
* [Serialization/Deserialization](integer/serialization.md)
|
||||
|
||||
## Tutorials for real-life applications
|
||||
* [Dark Market](tutorial/dark_market.md)
|
||||
* [SHA256](tutorial/sha256_bool.md)
|
||||
* [Homomorphic Regular Expressions](tutorial/regex/tutorial.md)
|
||||
|
||||
|
||||
## C API
|
||||
* [High-Level API](c_api/high-level-api.md)
|
||||
* [Shortint API](c_api/shortint-api.md)
|
||||
|
||||
@@ -58,11 +58,11 @@ All timings are related to parallelized Radix-based integer operations, where ea
|
||||
To ensure predictable timings, the operation flavor is the `default` one: a carry propagation is computed after each operation. Operation cost could be reduced by using `unchecked`, `checked`, or `smart`.
|
||||
|
||||
| Plaintext size | add | mul | greater\_than (gt) | min |
|
||||
| -------------------| -------------- | ------------------- | --------- | ------- |
|
||||
| 8 bits | 129.0 ms | 227.2 ms | 111.9 ms | 186.8 ms |
|
||||
| 16 bits | 256.3 ms | 756.0 ms | 145.3 ms | 233.1 ms |
|
||||
| 32 bits | 469.4 ms | 2.10 s | 192.0 ms | 282.9 ms |
|
||||
| 40 bits | 608.0 ms | 3.37 s | 228.4 ms | 318.6 ms |
|
||||
| 64 bits | 959.9 ms | 5.53 s | 249.0 ms | 336.5 ms |
|
||||
| 128 bits | 1.88 s | 14.1 s | 294.7 ms | 398.6 ms |
|
||||
| 256 bits | 3.66 s | 29.2 s | 361.8 ms | 509.1 ms |
|
||||
| -------------------| ---------------| --------------------| ---------------------| -------------|
|
||||
| 8 bits | 129.0 ms | 227 ms | 111.9 ms | 186.8 ms |
|
||||
| 16 bits | 195 ms | 369 ms | 145.3 ms | 233.1 ms |
|
||||
| 32 bits | 238 ms | 519 ms | 192.0 ms | 282.9 ms |
|
||||
| 40 bits | 283 ms | 754 ms | 228.4 ms | 318.6 ms |
|
||||
| 64 bits | 297 ms | 1.18 s | 249.0 ms | 336.5 ms |
|
||||
| 128 bits | 424 ms | 3.13 s | 294.7 ms | 398.6 ms |
|
||||
| 256 bits | 500 ms | 11 s | 361.8 ms | 509.1 ms |
|
||||
|
||||
@@ -115,7 +115,7 @@ fn main() {
|
||||
let msg1 = 1;
|
||||
let msg2 = 0;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -143,7 +143,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We create keys for radix represention to create 16 bits integers
|
||||
// using 8 blocks of 2 bits
|
||||
let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, 8);
|
||||
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, 8);
|
||||
|
||||
let clear_a = 2382u16;
|
||||
let clear_b = 29374u16;
|
||||
|
||||
@@ -26,7 +26,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
}
|
||||
```
|
||||
|
||||
@@ -105,7 +105,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12u64;
|
||||
let msg2 = 11u64;
|
||||
@@ -113,7 +113,7 @@ fn main() {
|
||||
let scalar = 3u64;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -143,7 +143,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 2;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12u64;
|
||||
let msg2 = 11u64;
|
||||
@@ -151,7 +151,7 @@ fn main() {
|
||||
let scalar = 3u64;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -181,7 +181,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
@@ -189,7 +189,7 @@ fn main() {
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -220,7 +220,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
fn main() {
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 12;
|
||||
let msg2 = 11;
|
||||
@@ -228,7 +228,7 @@ fn main() {
|
||||
let scalar = 3;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
@@ -27,13 +27,13 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 201;
|
||||
let msg2 = 12;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
|
||||
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
let ct_2 = client_key.encrypt(msg2);
|
||||
|
||||
@@ -34,7 +34,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
}
|
||||
```
|
||||
|
||||
@@ -49,7 +49,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 128u64;
|
||||
let msg2 = 13u64;
|
||||
@@ -72,7 +72,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, _) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, _) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
//We generate the public key from the secret client key:
|
||||
let public_key = PublicKeyBig::new(&client_key);
|
||||
@@ -98,13 +98,13 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
fn main() {
|
||||
// We generate a set of client/server keys, using the default parameters:
|
||||
let num_block = 4;
|
||||
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
|
||||
|
||||
let msg1 = 128;
|
||||
let msg2 = 13;
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
|
||||
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
@@ -55,7 +55,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -88,7 +88,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -131,7 +131,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -165,7 +165,7 @@ fn main() {
|
||||
let msg2 = 3;
|
||||
let scalar = 4;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let mut ct_1 = client_key.encrypt(msg1);
|
||||
@@ -241,7 +241,7 @@ fn main() {
|
||||
let msg1 = 2;
|
||||
let msg2 = 1;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -272,7 +272,7 @@ fn main() {
|
||||
let msg1 = 2;
|
||||
let msg2 = 1;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -303,7 +303,7 @@ fn main() {
|
||||
let msg1 = 2;
|
||||
let msg2 = 1;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -331,7 +331,7 @@ fn main() {
|
||||
|
||||
let msg1 = 3;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
@@ -365,7 +365,7 @@ fn main() {
|
||||
let msg1 = 3;
|
||||
let msg2 = 2;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0 as u64;
|
||||
let modulus = client_key.parameters.message_modulus().0 as u64;
|
||||
|
||||
// We use the private client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
@@ -44,7 +44,7 @@ In the case of multiplication, two algorithms are implemented: the first one rel
|
||||
|
||||
## User-defined parameter sets
|
||||
|
||||
It is possible to define new parameter sets. To do so, it is sufficient to use the function `unsecure_parameters()` or to manually fill the `Parameter` structure fields.
|
||||
It is possible to define new parameter sets. To do so, it is sufficient to use the function `unsecure_parameters()` or to manually fill the `ClassicPBSParameters` structure fields.
|
||||
|
||||
For instance:
|
||||
|
||||
@@ -53,7 +53,7 @@ use tfhe::shortint::prelude::*;
|
||||
|
||||
fn main() {
|
||||
let param = unsafe {
|
||||
Parameters::new(
|
||||
ClassicPBSParameters::new(
|
||||
LweDimension(656),
|
||||
GlweDimension(2),
|
||||
PolynomialSize(512),
|
||||
@@ -63,11 +63,6 @@ fn main() {
|
||||
DecompositionLevelCount(2),
|
||||
DecompositionBaseLog(3),
|
||||
DecompositionLevelCount(4),
|
||||
StandardDev(0.00000000037411618952047216),
|
||||
DecompositionBaseLog(15),
|
||||
DecompositionLevelCount(1),
|
||||
DecompositionLevelCount(0),
|
||||
DecompositionBaseLog(0),
|
||||
MessageModulus(4),
|
||||
CarryModulus(1),
|
||||
CiphertextModulus::new_native(),
|
||||
|
||||
@@ -82,7 +82,7 @@ fn main() {
|
||||
let msg1 = 1;
|
||||
let msg2 = 0;
|
||||
|
||||
let modulus = client_key.parameters.message_modulus.0;
|
||||
let modulus = client_key.parameters.message_modulus().0;
|
||||
|
||||
// We use the client key to encrypt two messages:
|
||||
let ct_1 = client_key.encrypt(msg1);
|
||||
|
||||
479
tfhe/docs/tutorial/dark_market.md
Normal file
479
tfhe/docs/tutorial/dark_market.md
Normal file
@@ -0,0 +1,479 @@
|
||||
# Dark Market Tutorial
|
||||
|
||||
In this tutorial, we are going to build a dark market application using TFHE-rs. A dark market is a marketplace where
|
||||
buy and sell orders are not visible to the public before they are filled. Different algorithms aim to
|
||||
solve this problem, we are going to implement the algorithm defined [in this paper](https://eprint.iacr.org/2022/923.pdf) with TFHE-rs.
|
||||
|
||||
We will first implement the algorithm in plain Rust and then we will see how to use TFHE-rs to
|
||||
implement the same algorithm with FHE.
|
||||
|
||||
In addition, we will also implement a modified version of the algorithm that allows for more concurrent operations which
|
||||
improves the performance in hardware where there are multiple cores.
|
||||
|
||||
## Specifications
|
||||
|
||||
#### Inputs:
|
||||
|
||||
* A list of sell orders where each sell order is only defined in volume terms, it is assumed that the price is fetched
|
||||
from a different source.
|
||||
* A list of buy orders where each buy order is only defined in volume terms, it is assumed that the price is fetched
|
||||
from a different source.
|
||||
|
||||
#### Input constraints:
|
||||
|
||||
* The sell and buy orders are within the range [1,100].
|
||||
* The maximum number of sell and buy orders is 500, respectively.
|
||||
|
||||
#### Outputs:
|
||||
|
||||
There is no output returned at the end of the algorithm. Instead, the algorithm makes changes on the given input lists.
|
||||
The number of filled orders is written over the original order count in the respective lists. If it is not possible to
|
||||
fill the orders, the order count is set to zero.
|
||||
|
||||
#### Example input and output:
|
||||
|
||||
##### Example 1:
|
||||
|
||||
| | Sell | Buy |
|
||||
|--------|--------------------|-----------|
|
||||
| Input | [ 5, 12, 7, 4, 3 ] | [ 19, 2 ] |
|
||||
| Output | [ 5, 12, 4, 0, 0 ] | [ 19, 2 ] |
|
||||
|
||||
Last three indices of the filled sell orders are zero because there is no buy orders to match them.
|
||||
|
||||
##### Example 2:
|
||||
|
||||
| | Sell | Buy |
|
||||
|--------|-------------------|----------------------|
|
||||
| Input | [ 3, 1, 1, 4, 2 ] | [ 5, 3, 3, 2, 4, 1 ] |
|
||||
| Output | [ 3, 1, 1, 4, 2 ] | [ 5, 3, 3, 0, 0, 0 ] |
|
||||
|
||||
Last three indices of the filled buy orders are zero because there is no sell orders to match them.
|
||||
|
||||
## Plain Implementation
|
||||
|
||||
1. Calculate the total sell volume and the total buy volume.
|
||||
|
||||
```rust
|
||||
let total_sell_volume: u16 = sell_orders.iter().sum();
|
||||
let total_buy_volume: u16 = buy_orders.iter().sum();
|
||||
```
|
||||
|
||||
2. Find the total volume that will be transacted. In the paper, this amount is calculated with the formula:
|
||||
|
||||
```
|
||||
(total_sell_volume > total_buy_volume) * (total_buy_volume − total_sell_volume) + total_sell_volume
|
||||
```
|
||||
|
||||
When closely observed, we can see that this formula can be replaced with the `min` function. Therefore, we calculate this
|
||||
value by taking the minimum of the total sell volume and the total buy volume.
|
||||
|
||||
```rust
|
||||
let total_volume = std::cmp::min(total_buy_volume, total_sell_volume);
|
||||
```
|
||||
|
||||
3. Beginning with the first item, start filling the sell orders one by one. We apply the `min` function replacement also
|
||||
here.
|
||||
|
||||
```rust
|
||||
let mut volume_left_to_transact = total_volume;
|
||||
for sell_order in sell_orders.iter_mut() {
|
||||
let filled_amount = std::cmp::min(volume_left_to_transact, *sell_order);
|
||||
*sell_order = filled_amount;
|
||||
volume_left_to_transact -= filled_amount;
|
||||
}
|
||||
```
|
||||
|
||||
The number of orders that are filled is indicated by modifying the input list. For example, if the first sell order is
|
||||
1000 and the total volume is 500, then the first sell order will be modified to 500 and the second sell order will be
|
||||
modified to 0.
|
||||
|
||||
4. Do the fill operation also for the buy orders.
|
||||
|
||||
```rust
|
||||
let mut volume_left_to_transact = total_volume;
|
||||
for buy_order in buy_orders.iter_mut() {
|
||||
let filled_amount = std::cmp::min(volume_left_to_transact, *buy_order);
|
||||
*buy_order = filled_amount;
|
||||
volume_left_to_transact -= filled_amount;
|
||||
}
|
||||
```
|
||||
|
||||
#### The complete algorithm in plain Rust:
|
||||
|
||||
```rust
|
||||
fn volume_match_plain(sell_orders: &mut Vec<u16>, buy_orders: &mut Vec<u16>) {
|
||||
let total_sell_volume: u16 = sell_orders.iter().sum();
|
||||
let total_buy_volume: u16 = buy_orders.iter().sum();
|
||||
|
||||
let total_volume = std::cmp::min(total_buy_volume, total_sell_volume);
|
||||
|
||||
let mut volume_left_to_transact = total_volume;
|
||||
for sell_order in sell_orders.iter_mut() {
|
||||
let filled_amount = std::cmp::min(volume_left_to_transact, *sell_order);
|
||||
*sell_order = filled_amount;
|
||||
volume_left_to_transact -= filled_amount;
|
||||
}
|
||||
|
||||
let mut volume_left_to_transact = total_volume;
|
||||
for buy_order in buy_orders.iter_mut() {
|
||||
let filled_amount = std::cmp::min(volume_left_to_transact, *buy_order);
|
||||
*buy_order = filled_amount;
|
||||
volume_left_to_transact -= filled_amount;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## FHE Implementation
|
||||
|
||||
For the FHE implementation, we first start with finding the right bit size for our algorithm to work without
|
||||
overflows.
|
||||
|
||||
The variables that are declared in the algorithm and their maximum values are described in the table below:
|
||||
|
||||
| Variable | Maximum Value | Bit Size |
|
||||
|-------------------------|---------------|----------|
|
||||
| total_sell_volume | 50000 | 16 |
|
||||
| total_buy_volume | 50000 | 16 |
|
||||
| total_volume | 50000 | 16 |
|
||||
| volume_left_to_transact | 50000 | 16 |
|
||||
| sell_order | 100 | 7 |
|
||||
| buy_order | 100 | 7 |
|
||||
|
||||
As we can observe from the table, we need **16 bits of message space** to be able to run the algorithm without
|
||||
overflows. TFHE-rs provides different presets for the different bit sizes. Since we need 16 bits of message, we are
|
||||
going to use the `integer` module to implement the algorithm.
|
||||
|
||||
Here are the input types of our algorithm:
|
||||
|
||||
* `sell_orders` is of type `Vec<tfhe::integer::RadixCipherText>`
|
||||
* `buy_orders` is of type `Vec<tfhe::integer::RadixCipherText>`
|
||||
* `server_key` is of type `tfhe::integer::ServerKey`
|
||||
|
||||
Now, we can start implementing the algorithm with FHE:
|
||||
|
||||
1. Calculate the total sell volume and the total buy volume.
|
||||
|
||||
```rust
|
||||
let mut total_sell_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for sell_order in sell_orders.iter_mut() {
|
||||
server_key.smart_add_assign(&mut total_sell_volume, sell_order);
|
||||
}
|
||||
|
||||
let mut total_buy_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for buy_order in buy_orders.iter_mut() {
|
||||
server_key.smart_add_assign(&mut total_buy_volume, buy_order);
|
||||
}
|
||||
```
|
||||
|
||||
2. Find the total volume that will be transacted by taking the minimum of the total sell volume and the total buy
|
||||
volume.
|
||||
|
||||
```rust
|
||||
let total_volume = server_key.smart_min(&mut total_sell_volume, &mut total_buy_volume);
|
||||
```
|
||||
|
||||
3. Beginning with the first item, start filling the sell and buy orders one by one. We can create `fill_orders` closure to
|
||||
reduce code duplication since the code for filling buy orders and sell orders are the same.
|
||||
|
||||
```rust
|
||||
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
|
||||
let mut volume_left_to_transact = total_volume.clone();
|
||||
for mut order in orders.iter_mut() {
|
||||
let mut filled_amount = server_key.smart_min(&mut volume_left_to_transact, &mut order);
|
||||
server_key.smart_sub_assign(&mut volume_left_to_transact, &mut filled_amount);
|
||||
*order = filled_amount;
|
||||
}
|
||||
};
|
||||
|
||||
fill_orders(sell_orders);
|
||||
fill_orders(buy_orders);
|
||||
```
|
||||
|
||||
#### The complete algorithm in TFHE-rs:
|
||||
|
||||
```rust
|
||||
const NUMBER_OF_BLOCKS: usize = 8;
|
||||
|
||||
fn volume_match_fhe(
|
||||
sell_orders: &mut [RadixCiphertextBig],
|
||||
buy_orders: &mut [RadixCiphertextBig],
|
||||
server_key: &ServerKey,
|
||||
) {
|
||||
let mut total_sell_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for sell_order in sell_orders.iter_mut() {
|
||||
server_key.smart_add_assign(&mut total_sell_volume, sell_order);
|
||||
}
|
||||
|
||||
let mut total_buy_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for buy_order in buy_orders.iter_mut() {
|
||||
server_key.smart_add_assign(&mut total_buy_volume, buy_order);
|
||||
}
|
||||
|
||||
let total_volume = server_key.smart_min(&mut total_sell_volume, &mut total_buy_volume);
|
||||
|
||||
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
|
||||
let mut volume_left_to_transact = total_volume.clone();
|
||||
for mut order in orders.iter_mut() {
|
||||
let mut filled_amount = server_key.smart_min(&mut volume_left_to_transact, &mut order);
|
||||
server_key.smart_sub_assign(&mut volume_left_to_transact, &mut filled_amount);
|
||||
*order = filled_amount;
|
||||
}
|
||||
};
|
||||
|
||||
fill_orders(sell_orders);
|
||||
fill_orders(buy_orders);
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
### Optimizing the implementation
|
||||
|
||||
* TFHE-rs provides parallelized implementations of the operations. We can use these parallelized
|
||||
implementations to speed up the algorithm. For example, we can use `smart_add_assign_parallelized` instead of
|
||||
`smart_add_assign`.
|
||||
|
||||
* We can parallelize vector sum with Rayon and `reduce` operation.
|
||||
```rust
|
||||
let parallel_vector_sum = |vec: &mut [RadixCiphertextBig]| {
|
||||
vec.to_vec().into_par_iter().reduce(
|
||||
|| server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|
||||
|mut acc: RadixCiphertextBig, mut ele: RadixCiphertextBig| {
|
||||
server_key.smart_add_parallelized(&mut acc, &mut ele)
|
||||
},
|
||||
)
|
||||
};
|
||||
```
|
||||
|
||||
* We can run vector summation on `buy_orders` and `sell_orders` in parallel since these operations do not depend on each other.
|
||||
```rust
|
||||
let (mut total_sell_volume, mut total_buy_volume) =
|
||||
rayon::join(|| vector_sum(sell_orders), || vector_sum(buy_orders));
|
||||
```
|
||||
|
||||
* We can match sell and buy orders in parallel since the matching does not depend on each other.
|
||||
```rust
|
||||
rayon::join(|| fill_orders(sell_orders), || fill_orders(buy_orders));
|
||||
```
|
||||
|
||||
#### Optimized algorithm
|
||||
```rust
|
||||
fn volume_match_fhe_parallelized(
|
||||
sell_orders: &mut [RadixCiphertextBig],
|
||||
buy_orders: &mut [RadixCiphertextBig],
|
||||
server_key: &ServerKey,
|
||||
) {
|
||||
let parallel_vector_sum = |vec: &mut [RadixCiphertextBig]| {
|
||||
vec.to_vec().into_par_iter().reduce(
|
||||
|| server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|
||||
|mut acc: RadixCiphertextBig, mut ele: RadixCiphertextBig| {
|
||||
server_key.smart_add_parallelized(&mut acc, &mut ele)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
let (mut total_sell_volume, mut total_buy_volume) = rayon::join(
|
||||
|| parallel_vector_sum(sell_orders),
|
||||
|| parallel_vector_sum(buy_orders),
|
||||
);
|
||||
|
||||
let total_volume =
|
||||
server_key.smart_min_parallelized(&mut total_sell_volume, &mut total_buy_volume);
|
||||
|
||||
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
|
||||
let mut volume_left_to_transact = total_volume.clone();
|
||||
for mut order in orders.iter_mut() {
|
||||
let mut filled_amount =
|
||||
server_key.smart_min_parallelized(&mut volume_left_to_transact, &mut order);
|
||||
server_key
|
||||
.smart_sub_assign_parallelized(&mut volume_left_to_transact, &mut filled_amount);
|
||||
*order = filled_amount;
|
||||
}
|
||||
};
|
||||
|
||||
rayon::join(|| fill_orders(sell_orders), || fill_orders(buy_orders));
|
||||
}
|
||||
```
|
||||
|
||||
## Modified Algorithm
|
||||
|
||||
When observed closely, there is only a small amount of concurrency introduced in the `fill_orders` part of the algorithm.
|
||||
The reason is that the `volume_left_to_transact` is shared between all the orders and should be modified sequentially.
|
||||
This means that the orders cannot be filled in parallel. If we can somehow remove this dependency, we can fill the orders in parallel.
|
||||
|
||||
In order to do so, we closely observe the function of `volume_left_to_transact` variable in the algorithm. We can see that it is being used to check whether we can fill the current order or not.
|
||||
Instead of subtracting the current order value from `volume_left_to_transact` in each loop, we can add this value to the next order
|
||||
index and check the availability by comparing the current order value with the total volume. If the current order value
|
||||
(now representing the sum of values before this order plus this order) is smaller than the total number of matching orders,
|
||||
we can safely fill all the orders and continue the loop. If not, we should partially fill the orders with what is left from
|
||||
matching orders.
|
||||
|
||||
We will call the new list the "prefix sum" of the array.
|
||||
|
||||
The new version for the plain `fill_orders` is as follows:
|
||||
```rust
|
||||
let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64|{
|
||||
orders.iter().for_each(|order : &mut u64| {
|
||||
if (total_orders >= prefix_sum[i]) {
|
||||
continue;
|
||||
} else if total_orders >= prefix_sum.get(i-1).unwrap_or(0) {
|
||||
*order = total_orders - prefix_sum.get(i-1).unwrap_or(0);
|
||||
} else {
|
||||
*order = 0;
|
||||
}
|
||||
});
|
||||
};
|
||||
```
|
||||
|
||||
To write this new function we need transform the conditional code into a mathematical expression since FHE does not support conditional operations.
|
||||
```rust
|
||||
|
||||
let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64| {
|
||||
orders.iter().for_each(|order| : &mut){
|
||||
*order = *order + ((total_orders >= prefix_sum - std::cmp::min(total_orders, prefix_sum.get(i - 1).unwrap_or(&0).clone()) - *order);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
New `fill_order` function requires a prefix sum array. We are going to calculate this prefix sum array in parallel
|
||||
with the algorithm described [here](https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda).
|
||||
|
||||
The sample code in the paper is written in CUDA. When we try to implement the algorithm in Rust we see that the compiler does not allow us to do so.
|
||||
The reason for that is while the algorithm does not access the same array element in any of the threads(the index calculations using `d` and `k` values never overlap),
|
||||
Rust compiler cannot understand this and does not let us share the same array between threads.
|
||||
So we modify how the algorithm is implemented, but we don't change the algorithm itself.
|
||||
|
||||
Here is the modified version of the algorithm in TFHE-rs:
|
||||
```rust
|
||||
fn volume_match_fhe_modified(
|
||||
sell_orders: &mut [RadixCiphertextBig],
|
||||
buy_orders: &mut [RadixCiphertextBig],
|
||||
server_key: &ServerKey,
|
||||
) {
|
||||
let compute_prefix_sum = |arr: &[RadixCiphertextBig]| {
|
||||
if arr.is_empty() {
|
||||
return arr.to_vec();
|
||||
}
|
||||
let mut prefix_sum: Vec<RadixCiphertextBig> = (0..arr.len().next_power_of_two())
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
if i < arr.len() {
|
||||
arr[i].clone()
|
||||
} else {
|
||||
server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
// Up sweep
|
||||
for d in 0..(prefix_sum.len().ilog2() as u32) {
|
||||
prefix_sum
|
||||
.par_chunks_exact_mut(2_usize.pow(d + 1))
|
||||
.for_each(move |chunk| {
|
||||
let length = chunk.len();
|
||||
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
|
||||
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left)
|
||||
});
|
||||
}
|
||||
// Down sweep
|
||||
let last = prefix_sum.last().unwrap().clone();
|
||||
*prefix_sum.last_mut().unwrap() = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for d in (0..(prefix_sum.len().ilog2() as u32)).rev() {
|
||||
prefix_sum
|
||||
.par_chunks_exact_mut(2_usize.pow(d + 1))
|
||||
.for_each(move |chunk| {
|
||||
let length = chunk.len();
|
||||
let t = chunk.last().unwrap().clone();
|
||||
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
|
||||
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left);
|
||||
chunk[(length - 1) / 2] = t;
|
||||
});
|
||||
}
|
||||
prefix_sum.push(last);
|
||||
prefix_sum[1..=arr.len()].to_vec()
|
||||
};
|
||||
|
||||
println!("Creating prefix sum arrays...");
|
||||
let time = Instant::now();
|
||||
let (prefix_sum_sell_orders, prefix_sum_buy_orders) = rayon::join(
|
||||
|| compute_prefix_sum(sell_orders),
|
||||
|| compute_prefix_sum(buy_orders),
|
||||
);
|
||||
println!("Created prefix sum arrays in {:?}", time.elapsed());
|
||||
|
||||
let fill_orders = |total_orders: &RadixCiphertextBig,
|
||||
orders: &mut [RadixCiphertextBig],
|
||||
prefix_sum_arr: &[RadixCiphertextBig]| {
|
||||
orders
|
||||
.into_par_iter()
|
||||
.enumerate()
|
||||
.for_each(move |(i, order)| {
|
||||
server_key.smart_add_assign_parallelized(
|
||||
order,
|
||||
&mut server_key.smart_mul_parallelized(
|
||||
&mut server_key
|
||||
.smart_ge_parallelized(&mut order.clone(), &mut total_orders.clone()),
|
||||
&mut server_key.smart_sub_parallelized(
|
||||
&mut server_key.smart_sub_parallelized(
|
||||
&mut total_orders.clone(),
|
||||
&mut server_key.smart_min_parallelized(
|
||||
&mut total_orders.clone(),
|
||||
&mut prefix_sum_arr
|
||||
.get(i - 1)
|
||||
.unwrap_or(
|
||||
&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|
||||
)
|
||||
.clone(),
|
||||
),
|
||||
),
|
||||
&mut order.clone(),
|
||||
),
|
||||
),
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
let total_buy_orders = &mut prefix_sum_buy_orders
|
||||
.last()
|
||||
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
|
||||
.clone();
|
||||
|
||||
let total_sell_orders = &mut prefix_sum_sell_orders
|
||||
.last()
|
||||
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
|
||||
.clone();
|
||||
|
||||
println!("Matching orders...");
|
||||
let time = Instant::now();
|
||||
rayon::join(
|
||||
|| fill_orders(total_sell_orders, buy_orders, &prefix_sum_buy_orders),
|
||||
|| fill_orders(total_buy_orders, sell_orders, &prefix_sum_sell_orders),
|
||||
);
|
||||
println!("Matched orders in {:?}", time.elapsed());
|
||||
}
|
||||
```
|
||||
|
||||
## Running the tutorial
|
||||
|
||||
The plain, FHE and parallel FHE implementations can be run by providing respective arguments as described below.
|
||||
|
||||
```bash
|
||||
# Runs FHE implementation
|
||||
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- fhe
|
||||
|
||||
# Runs parallelized FHE implementation
|
||||
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- fhe-parallel
|
||||
|
||||
# Runs modified FHE implementation
|
||||
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- fhe-modified
|
||||
|
||||
# Runs plain implementation
|
||||
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- plain
|
||||
|
||||
# Multiple implementations can be run within same instance
|
||||
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- plain fhe-parallel
|
||||
```
|
||||
|
||||
## Conclusion
|
||||
|
||||
In this tutorial, we've learned how to implement the volume matching algorithm described [in this paper](https://eprint.iacr.org/2022/923.pdf) in plain Rust and in TFHE-rs.
|
||||
We've identified the right bit size for our problem at hand, used operations defined in `TFHE-rs`, and introduced concurrency to the algorithm to increase its performance.
|
||||
512
tfhe/docs/tutorial/regex/tutorial.md
Normal file
512
tfhe/docs/tutorial/regex/tutorial.md
Normal file
@@ -0,0 +1,512 @@
|
||||
# FHE Regex Pattern Matching Tutorial
|
||||
|
||||
This tutorial explains how to build a regex Pattern Matching Engine (PME) where ciphertext is the
|
||||
content that is evaluated.
|
||||
|
||||
A regex PME is an essential tool for programmers. It allows you to perform complex searches on content.
|
||||
A less powerful simple search on string can only find matches of the exact given sequence of
|
||||
characters (e.g., your browser's default search function). Regex PMEs
|
||||
are more powerful, allowing searches on certain structures of text, where a
|
||||
structure may take any form in multiple possible sequences of characters. The
|
||||
structure to be searched is defined with the regex, a very concise
|
||||
language.
|
||||
|
||||
Here are some example regexes to give you an idea of what is possible:
|
||||
|
||||
Regex | Semantics
|
||||
--- | ---
|
||||
/abc/ | Searches for the sequence `abc` (equivalent to a simple text search)
|
||||
/^abc/ | Searches for the sequence `abc` at the beginning of the content
|
||||
/a?bc/ | Searches for sequences `abc`, `bc`
|
||||
/ab\|c+d/ | Searches for sequences of `ab`, `c` repeated 1 or more times, followed by `d`
|
||||
|
||||
Regexes are powerful enough to be able to express structures like email address
|
||||
formats. This capability is what makes regexes useful for many programming
|
||||
solutions.
|
||||
|
||||
There are two main components identifiable in a PME:
|
||||
1. The pattern that is to be matched has to be parsed, translated from a
|
||||
textual representation into a recursively structured object (an Abstract
|
||||
Syntax Tree, or AST).
|
||||
2. This AST must then be applied to the text that it is to be matched against,
|
||||
resulting in a 'yes' or 'no' to whether the pattern has matched (in the case of
|
||||
our FHE implementation, this result is an encrypted 'yes' or an encrypted 'no').
|
||||
|
||||
Parsing is a well understood problem. There are a couple of different
|
||||
approaches possible here. Regardless of the approach chosen, it starts with
|
||||
figuring out what language we want to support. That is, what are
|
||||
the kinds of sentences we want our regex language to include? A few
|
||||
example sentences we definitely want to support are, for example: `/a/`,
|
||||
`/a?bc/`, `/^ab$/`, `/ab|cd/`, however example sentences don't suffice as
|
||||
a specification because they can never be exhaustive (they're endless). We need
|
||||
something to specify _exactly_ the full set of sentences our language supports.
|
||||
There exists a language that can help us describe our own language's structure exactly:
|
||||
Grammar.
|
||||
|
||||
## The Grammar and datastructure
|
||||
|
||||
It is useful to start with defining the Grammar before starting to write
|
||||
code for the parser because the code structure follows directly from the
|
||||
Grammar. A Grammar consists of a generally small set of rules. For example,
|
||||
a very basic Grammar could look like this:
|
||||
```
|
||||
Start := 'a'
|
||||
```
|
||||
This describes a language that only contains the sentence "a". Not a very interesting language.
|
||||
|
||||
We can make it more interesting though by introducing choice into the Grammar
|
||||
with \| (called a 'pipe') operators. If we want the above Grammar to accept
|
||||
either "a" or "b":
|
||||
```
|
||||
Start := 'a' | 'b'
|
||||
```
|
||||
|
||||
So far, only Grammars with a single rule have been shown. However, a Grammar can
|
||||
consist of multiple rules. Most languages require it. So let's consider a more meaningful language,
|
||||
one that accepts sentences consisting of one or more digits. We could describe such a language
|
||||
with the following Grammar:
|
||||
```
|
||||
Start := Digit+
|
||||
|
||||
Digit := '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9'
|
||||
```
|
||||
|
||||
The `+` after `Digit` is another Grammar operator. With it, we specify that
|
||||
Digit must be matched one or more times. Here are all the Grammar operators that
|
||||
are relevant for this tutorial:
|
||||
|
||||
Operator | Example | Semantics
|
||||
--- | --- | ---
|
||||
`\|` | a \| b | we first try matching on 'a' - if no match, we try to match on 'b'
|
||||
`+` | a+ | match 'a' one or more times
|
||||
`*` | a* | match 'a' any amount of times (including zero times)
|
||||
`?` | a? | optionally match 'a' (match zero or one time)
|
||||
`.` | . | match any character
|
||||
`..` | a .. b | match on a range of alphabetically ordered characters from 'a', up to and including 'b'
|
||||
` ` | a b | sequencing; match on 'a' and then on 'b'
|
||||
|
||||
In the case of the example PME, the Grammar is as follows (notice the unquoted ? and quoted ?, etc. The unquoted characters are Grammar operators, and the quoted are characters we are matching in the parsing).
|
||||
```
|
||||
Start := '/' '^'? Regex '$'? '/' Modifier?
|
||||
|
||||
Regex := Term '|' Term
|
||||
| Term
|
||||
|
||||
Term := Factor*
|
||||
|
||||
Factor := Atom '?'
|
||||
| Repeated
|
||||
| Atom
|
||||
|
||||
Repeated := Atom '*'
|
||||
| Atom '+'
|
||||
| Atom '{' Digit* ','? '}'
|
||||
| Atom '{' Digit+ ',' Digit* '}'
|
||||
|
||||
Atom := '.'
|
||||
| '\' .
|
||||
| Character
|
||||
| '[' Range ']'
|
||||
| '(' Regex ')'
|
||||
|
||||
Range := '^' Range
|
||||
| AlphaNum '-' AlphaNum
|
||||
| AlphaNum+
|
||||
|
||||
Digit := '0' .. '9'
|
||||
|
||||
Character := AlphaNum
|
||||
| '&' | ';' | ':' | ',' | '`' | '~' | '-' | '_' | '!' | '@' | '#' | '%' | '\'' | '\"'
|
||||
|
||||
AlphaNum := 'a' .. 'z'
|
||||
| 'A' .. 'Z'
|
||||
| '0' .. '9'
|
||||
|
||||
Modifier := 'i'
|
||||
```
|
||||
We will refer occasionally to specific parts in the Grammar listed above by \<rule name\>.\<variant index\> (where the first rule variant has index 1).
|
||||
|
||||
With the Grammar defined, we can start defining a type to parse into. In Rust, we
|
||||
have the `enum` kind of type that is perfect for this, as it allows you to define
|
||||
multiple variants that may recurse. I prefer to start by defining variants that
|
||||
do not recurse (i.e., that don't contain nested regex expressions):
|
||||
```rust
|
||||
enum RegExpr {
|
||||
Char { c: char }, // matching against a single character (Atom.2 and Atom.3)
|
||||
AnyChar, // matching _any_ character (Atom.1)
|
||||
SOF, // matching only at the beginning of the content ('^' in Start.1)
|
||||
EOF, // matching only at the end of the content (the '$' in Start.1)
|
||||
Range { cs: Vec<char> }, // matching on a list of characters (Range.3, eg '[acd]')
|
||||
Between { from: char, to: char }, // matching between 2 characters based on ascii ordering (Range.2, eg '[a-g]')
|
||||
}
|
||||
```
|
||||
|
||||
With this, we can translate the following basic regexes:
|
||||
|
||||
Pattern | RegExpr value
|
||||
--- | ---
|
||||
`/a/` | `RegExpr::Char { c: 'a' }`
|
||||
`/\\^/` | `RegExpr::Char { c: '^' }`
|
||||
`/./` | `RegExpr::AnyChar`
|
||||
`/^/` | `RegExpr::SOF`
|
||||
`/$/` | `RegExpr::EOF`
|
||||
`/[acd]/` | `RegExpr::Range { vec!['a', 'c', 'd'] }`
|
||||
`/[a-g]/` | `RegExpr::Between { from: 'a', to: 'g' }`
|
||||
|
||||
Notice we're not yet able to sequence multiple components together. Let's define
|
||||
the first variant that captures recursive RegExpr for this:
|
||||
```rust
|
||||
enum RegExpr {
|
||||
...
|
||||
Seq { re_xs: Vec<RegExpr> }, // matching sequences of RegExpr components (Term.1)
|
||||
}
|
||||
```
|
||||
With this Seq (short for sequence) variant, we allow translating patterns that
|
||||
contain multiple components:
|
||||
|
||||
Pattern | RegExpr value
|
||||
--- | ---
|
||||
`/ab/` | `RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Char { c: 'b' }] }`
|
||||
`/^a.$/` | `RegExpr::Seq { re_xs: vec![RegExpr::SOF, RexExpr::Char { 'a' }, RegExpr::AnyChar, RegExpr::EOF] }`
|
||||
`/a[f-l]/` | `RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Between { from: 'f', to: 'l' }] }`
|
||||
|
||||
Let's finish the RegExpr datastructure by adding variants for 'Optional' matching,
|
||||
'Not' logic in a range, and 'Either' left or right matching:
|
||||
```rust
|
||||
enum RegExpr {
|
||||
...
|
||||
Optional { opt_re: Box<RegExpr> }, // matching optionally (Factor.1)
|
||||
Not { not_re: Box<RegExpr> }, // matching inversely on a range (Range.1)
|
||||
Either { l_re: Box<RegExpr>, r_re: Box<RegExpr> }, // matching the left or right regex (Regex.1)
|
||||
}
|
||||
```
|
||||
|
||||
Some features may make the most sense being implemented during post-processing of
|
||||
the parsed datastructure. For example, the case insensitivity feature (the `i`
|
||||
Modifier) is implemented in the example implementation by taking the parsed
|
||||
RegExpr and mutating every character mentioned inside to cover both the lower
|
||||
case as well as the upper case variant (see function `case_insensitive` in
|
||||
`parser.rs` for the example implementation).
|
||||
|
||||
The modifier `i` in our Grammar (for enabling case insensitivity) was easiest
|
||||
to implement by applying a post-processing step to the parser.
|
||||
|
||||
We are now able to translate any complex regex into a RegExpr value. For example:
|
||||
|
||||
Pattern | RegExpr value
|
||||
--- | ---
|
||||
`/a?/` | `RegExpr::Optional { opt_re: Box::new(RegExpr::Char { c: 'a' }) }`
|
||||
`/[a-d]?/` | `RegExpr::Optional { opt_re: Box::new(RegExpr::Between { from: 'a', to: 'd' }) }`
|
||||
`/[^ab]/` | `RegExpr::Not { not_re: Box::new(RegExpr::Range { cs: vec!['a', 'b'] }) }`
|
||||
`/av\|d?/` | `RegExpr::Either { l_re: Box::new(RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Char { c: 'v' }] }), r_re: Box::new(RegExpr::Optional { opt_re: Box::new(RegExpr::Char { c: 'd' }) }) }`
|
||||
`/(av\|d)?/` | `RegExpr::Optional { opt_re: Box::new(RegExpr::Either { l_re: Box::new(RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Char { c: 'v' }] }), r_re: Box::new(RegExpr::Char { c: 'd' }) }) }`
|
||||
|
||||
With both the Grammar and the datastructure to parse into defined, we can now
|
||||
start implementing the actual parsing logic. There are multiple ways this can
|
||||
be done. For example, there exist tools that can automatically generate parser
|
||||
code by giving it the Grammar definition (these are called parser generators).
|
||||
However, you might prefer to write parsers with a parser combinator library.
|
||||
This may be the better option for you because the behavior in runtime is easier to understand
|
||||
for parsers constructed with a parser combinator library than of parsers that were
|
||||
generated with a parser generator tool.
|
||||
|
||||
Rust offers a number of popular parser combinator libraries. This tutorial used
|
||||
`combine`, but any other library would work just as well. Choose whichever appeals
|
||||
the most to you (including any parser generator tool). The implementation of
|
||||
our regex parser will differ significantly depending on the approach you choose,
|
||||
so we will not cover this in detail here. You may look at the parser code in the example
|
||||
implementation to get an idea of how this could be done. In general though, the Grammar and the
|
||||
datastructure are the important components, while the parser code follows directly from these.
|
||||
|
||||
## Matching the RegExpr to encrypted content
|
||||
|
||||
The next challenge is to build the execution engine, where we take a RegExpr
|
||||
value and recurse into it to apply the necessary actions on the encrypted
|
||||
content. We first have to define how we actually encode our content into an
|
||||
encrypted state. Once that is defined, we can start working on how we will
|
||||
execute our RegExpr onto the encrypted content.
|
||||
|
||||
### Encoding and encrypting the content.
|
||||
|
||||
It is not possible to encrypt the entire content into a single encrypted value.
|
||||
We can only encrypt numbers and preform operations on those encrypted numbers with
|
||||
FHE. Therefore, we have to find a scheme where we encode the content into a
|
||||
sequence of numbers that are then encrypted individually to form a sequence of
|
||||
encrypted numbers.
|
||||
|
||||
We recommend the following two strategies:
|
||||
1. to map each character of the content into the u8 ascii value, and then encrypt
|
||||
each bit of these u8 values individually.
|
||||
2. to, instead of encrypting each bit individually, encrypt each u8 ascii value in
|
||||
its entirety.
|
||||
|
||||
Strategy 1 requires more high-level TFHE-rs operations to check for
|
||||
a simple character match (we have to check each bit individually for
|
||||
equality as opposed to checking the entire byte in one, high-level TFHE-rs
|
||||
operation), though some experimentation did show that both options performed
|
||||
equally well on a regex like `/a/`. This is likely because bitwise FHE
|
||||
operations are relatively cheap compared to u8 FHE operations. However,
|
||||
option 1 falls apart as soon as you introduce '[a-z]' regex logic.
|
||||
With option 2, it is possible to complete this match with just three TFHE-rs
|
||||
operations: `ge`, `le`, and `bitand`.
|
||||
```rust
|
||||
// note: this is pseudocode
|
||||
c = <the encrypted character under inspection>;
|
||||
sk = <the server key, aka the public key>
|
||||
|
||||
ge_from = sk.ge(c, 'a');
|
||||
le_to = sk.le(c, 'z');
|
||||
result = sk.bitand(ge_from, le_to);
|
||||
```
|
||||
|
||||
If, on the other hand, we had encrypted the content with the first strategy,
|
||||
there would be no way to test for `greater/equal than from` and `less/equal
|
||||
than to`. We'd have to check for the potential equality of each character between
|
||||
`from` and `to`, and then join the results together with a sequence of
|
||||
`sk.bitor`; that would require far more cryptographic operations than in strategy 2.
|
||||
|
||||
Because FHE operations are computationally expensive, and strategy 1 requires
|
||||
significantly more FHE operations for matching on `[a-z]` regex logic, we
|
||||
should opt for strategy 2.
|
||||
|
||||
### Matching with the AST versus matching with a derived DFA.
|
||||
|
||||
There are a lot of regex PMEs. It's been built many times and it's been
|
||||
researched thoroughly. There are different strategies possible here.
|
||||
A straight forward strategy is to directly recurse into our RegExpr
|
||||
value and apply the necessary matching operations onto the content. In a way,
|
||||
this is nice because it allows us to link the RegExpr structure directly to
|
||||
the matching semantics, resulting in code that is easier to
|
||||
understand, maintain, etc.
|
||||
|
||||
Alternatively, there exists an algorithm that transforms the AST (i.e., the
|
||||
RegExpr, in our case) into a Deterministic Finite Automata (DFA). Normally, this
|
||||
is a favorable approach in terms of efficiency because the derived DFA can be
|
||||
walked over without needing to backtrack (whereas the former strategy cannot
|
||||
prevent backtracking). This means that the content can be walked over from
|
||||
character to character, and depending on what the character is at this
|
||||
cursor, the DFA is conjunctively traveled in a definite direction which
|
||||
ultimately leads us to the `yes, there is a match` or the `no, there is no
|
||||
match`. There is a small upfront cost of having to translate the AST into the
|
||||
DFA, but the lack of backtracking during matching generally makes up for
|
||||
this, especially if the content that it is matched against is significantly big.
|
||||
|
||||
In our case though, we are matching on encrypted content. We have no way to know
|
||||
what the character at our cursor is, and therefore no way to find this definite
|
||||
direction to go forward in the DFA. Therefore, translating the AST into the DFA does
|
||||
not help us as it does in normal regex PMEs. For this reason, consider opting for the
|
||||
former strategy because it allows for matching logic that is easier to understand.
|
||||
|
||||
### Matching.
|
||||
|
||||
In the previous section, we decided we'll match by traversing into the RegExpr
|
||||
value. This section will explain exactly how to do that. Similarly to defining
|
||||
the Grammar, it is often best to start with working out the non-recursive
|
||||
RegExpr variants.
|
||||
|
||||
We'll start by defining the function that will recursively traverse into the RegExpr value:
|
||||
```rust
|
||||
|
||||
type StringCiphertext = Vec<RadixCiphertextBig>;
|
||||
type ResultCiphertext = RadixCiphertextBig;
|
||||
|
||||
fn match(
|
||||
sk: &ServerKey,
|
||||
content: &StringCipherText,
|
||||
re: &RegExpr,
|
||||
content_pos: usize,
|
||||
) -> Vec<(ResultCiphertext, usize)> {
|
||||
let content_char = &content[c_pos];
|
||||
match re {
|
||||
...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`sk` is the server key (aka, public key),`content` is what we'll be matching
|
||||
against, `re` is the RegExpr value we built when parsing the regex, and `c_pos`
|
||||
is the cursor position (the index in content we are currently matching
|
||||
against).
|
||||
|
||||
The result is a vector of tuples, with the first value of the tuple being the computed
|
||||
ciphertext result, and the second value being the content position after the
|
||||
regex components were applied. It's a vector because certain RegExpr variants
|
||||
require the consideration of a list of possible execution paths. For example,
|
||||
RegExpr::Optional might succeed by applying _or_ and *not* applying the optional
|
||||
regex (notice that in the former case, `c_pos` moves forward whereas in the
|
||||
latter case it stays put).
|
||||
|
||||
On first call, a `match` of the entire regex pattern starts with `c_pos=0`.
|
||||
Then `match` is called again for the entire regex pattern with `c_pos=1`, etc. until
|
||||
`c_pos` exceeds the length of the content. Each of these alternative match results
|
||||
are then joined together with `sk.bitor` operations (this works because if one of them results
|
||||
in 'true' then, in general, our matching algorithm should return 'true').
|
||||
|
||||
The `...` within the match statement above is what we will be working out for
|
||||
some of the RegExpr variants now. Starting with `RegExpr::Char`:
|
||||
```rust
|
||||
case RegExpr::Char { c } => {
|
||||
vec![(sk.eq(content_char, c), c_pos + 1)]
|
||||
},
|
||||
```
|
||||
|
||||
Let's consider an example of the variant above. If we apply `/a/` to content
|
||||
`bac`, we'll have the following list of `match` calls `re` and `c_pos` values
|
||||
(for simplicity, `re` is denoted in regex pattern instead of in RegExpr value):
|
||||
|
||||
re | c\_pos | Ciphertext operation
|
||||
--- | --- | ---
|
||||
/a/ | 0 | sk.eq(content[0], a)
|
||||
/a/ | 1 | sk.eq(content[1], a)
|
||||
/a/ | 2 | sk.eq(content[2], a)
|
||||
|
||||
And we would arrive at the following sequence of ciphertext operations:
|
||||
```
|
||||
sk.bitor(sk.eq(content[0], a), sk.bitor(sk.eq(content[1], a), sk.eq(content[2], a)))
|
||||
```
|
||||
|
||||
AnyChar is a no operation:
|
||||
```rust
|
||||
case RegExpr::AnyChar => {
|
||||
// note: ct_true is just some constant representing True that is trivially encoded into ciphertext
|
||||
return vec![(ct_true, c_pos + 1)];
|
||||
}
|
||||
```
|
||||
|
||||
The sequence iterates over its `re_xs`, increasing the content position
|
||||
accordingly, and joins the results with `bitand` operations:
|
||||
```rust
|
||||
case RegExpr::Seq { re_xs } => {
|
||||
re_xs.iter().fold(|prev_results, re_x| {
|
||||
prev_results.iter().flat_map(|(prev_res, prev_c_pos)| {
|
||||
(x_res, new_c_pos) = match(sk, content, re_x, prev_c_pos);
|
||||
(sk.bitand(prev_res, x_res), new_c_pos)
|
||||
})
|
||||
}, (ct_true, c_pos))
|
||||
},
|
||||
```
|
||||
|
||||
Other variants are similar, as they recurse and manipulate `re` and `c_pos`
|
||||
accordingly. Hopefully, the general idea is already clear.
|
||||
|
||||
Ultimately the entire pattern-matching logic unfolds into a sequence of
|
||||
the following set of FHE operations:
|
||||
1. eq (tests for an exact character match)
|
||||
2. ge (tests for 'greater than' or 'equal to' a character)
|
||||
3. le (tests for 'less than' or 'equal to' a character)
|
||||
4. bitand (bitwise AND, used for sequencing multiple regex components)
|
||||
5. bitor (bitwise OR, used for folding multiple possible execution variants'
|
||||
results into a single result)
|
||||
6. bitxor (bitwise XOR, used for the 'not' logic in ranges)
|
||||
|
||||
### Optimizations.
|
||||
|
||||
Generally, the included example PME follows the approach outlined above. However, there were
|
||||
two additional optimizations applied. Both of these optimizations involved
|
||||
reducing the number of unnecessary FHE operations. Given how computationally expensive
|
||||
these operations are, it makes sense to optimize for this (and to ignore any suboptimal
|
||||
memory usage of our PME, etc.).
|
||||
|
||||
The first optimization involved delaying the execution of FHE operations to _after_
|
||||
the generation of all possible execution paths to be considered. This optimization
|
||||
allows us to prune execution paths during execution path construction that are provably
|
||||
going to result in an encrypted false value, without having already performed the FHE
|
||||
operations up to the point of pruning. Consider the regex `/^a+b$/`, and we are applying
|
||||
this to a content of size 4. If we are executing execution paths naively, we would go ahead
|
||||
and check for all possible amounts of `a` repetitions: `ab`, `aab`, `aaab`.
|
||||
However, while building the execution paths, we can use the fact that `a+` must
|
||||
begin at the beginning of the content, and that `b` must be the final character
|
||||
of the content. From this follows that we only have to check for the following
|
||||
sentence: `aaab`. Delaying execution of the FHE operations until after we've
|
||||
built the possible execution paths in this example reduced the number of FHE
|
||||
operations applied by approximately half.
|
||||
|
||||
The second optimization involved preventing the same FHE conditions to be
|
||||
re-evaluated. Consider the regex `/^a?ab/`. This would give us the following
|
||||
possible execution paths to consider:
|
||||
1. `content[0] == a && content[1] == a && content[2] == b` (we match the `a` in
|
||||
`a?`)
|
||||
2. `content[0] == a && content[1] == b` (we don't match the `a` in `a?`)
|
||||
|
||||
Notice that, for both execution paths, we are checking for `content[0] == a`.
|
||||
Even though we cannot see what the encrypted result is, we do know that it's
|
||||
either going to be an encrypted false for both cases or an encrypted true for
|
||||
both cases. Therefore, we can skip the re-evaluation of `content[0] == a` and
|
||||
simply copy the result from the first evaluation over. This optimization
|
||||
involves maintaining a cache of known expression evaluation results and
|
||||
reusing those where possible.
|
||||
|
||||
## Trying out the example implementation
|
||||
|
||||
The implementation that guided the writing of this tutorial can be found
|
||||
under `tfhe/examples/regex_engine`.
|
||||
|
||||
When compiling with `--example regex_engine`, a binary is produced that serves
|
||||
as a basic demo. Simply call it with the content string as a first argument and
|
||||
the pattern string as a second argument. For example,
|
||||
`cargo run --release --features=x86_64-unix,integer --example regex_engine -- 'this is the content' '/^pattern$/'`;
|
||||
note it's advised to compile the executable with `--release` flag as the key
|
||||
generation and homomorphic operations otherwise seem to experience a heavy
|
||||
performance penalty.
|
||||
|
||||
On execution, a private and public key pair are created. Then, the content is
|
||||
encrypted with the client key, and the regex pattern is applied onto the
|
||||
encrypted content string - with access given only to the server key. Finally, it
|
||||
decrypts the resulting encrypted result using the client key and prints the
|
||||
verdict to the console.
|
||||
|
||||
To get more information on exact computations and performance, set the `RUST_LOG`
|
||||
environment variable to `debug` or to `trace`.
|
||||
|
||||
|
||||
### Supported regex patterns
|
||||
|
||||
This section specifies the supported set of regex patterns in the regex engine.
|
||||
|
||||
#### Components
|
||||
|
||||
A regex is described by a sequence of components surrounded by `/`, the
|
||||
following components are supported:
|
||||
|
||||
Name | Notation | Examples
|
||||
--- | --- | ---
|
||||
Character | Simply the character itself | `/a/`, `/b/`, `/Z/`, `/5/`
|
||||
Character range | `[<character>-<character]` | `/[a-d]/`, `/[C-H]`/
|
||||
Any character | `.` | `/a.c/`
|
||||
Escaped symbol | `\<symbol>` | `/\^/`, `/\$/`
|
||||
Parenthesis | `(<regex>)` | `/(abc)*/`, `/d(ab)?/`
|
||||
Optional | `<regex>?` | `/a?/`, `/(az)?/`
|
||||
Zero or more | `<regex>*` | `/a*/`, `/ab*c/`
|
||||
One or more | `<regex>+` | `/a+/`, `/ab+c/`
|
||||
Exact repeat | `<regex{<number>}>` | `/ab{2}c/`
|
||||
At least repeat | `<regex{<number>,}>` | `/ab{2,}c/`
|
||||
At most repeat | `<regex{,<number>}>` | `/ab{,2}c/`
|
||||
Repeat between | `<regex{<number>,<number>}>` | `/ab{2,4}c/`
|
||||
Either | `<regex>\|<regex>` | `/a\|b/`, `/ab\|cd/`
|
||||
Start matching | `/^<regex>` | `/^abc/`
|
||||
End matching | `<regex>$/` | `/abc$/`
|
||||
|
||||
#### Modifiers
|
||||
|
||||
Modifiers are mode selectors that affect the entire regex behavior. One modifier is
|
||||
currently supported:
|
||||
|
||||
- Case insensitive matching, by appending an `i` after the regex pattern. For example: `/abc/i`
|
||||
|
||||
#### General examples
|
||||
|
||||
These components and modifiers can be combined to form any desired regex
|
||||
pattern. To give some idea of what is possible, here is a non-exhaustive list of
|
||||
supported regex patterns:
|
||||
|
||||
Pattern | Description
|
||||
--- | ---
|
||||
`/^abc$/` | Matches with content that equals exactly `abc` (case sensitive)
|
||||
`/^abc$/i` | Matches with content that equals `abc` (case insensitive)
|
||||
`/abc/` | Matches with content that contains somewhere `abc`
|
||||
`/ab?c/` | Matches with content that contains somewhere `abc` or somwhere `ab`
|
||||
`/^ab*c$/` | For example, matches with: `ac`, `abc`, `abbbbc`
|
||||
`/^[a-c]b\|cd$/` | Matches with: `ab`, `bb`, `cb`, `cd`
|
||||
`/^[a-c]b\|cd$/i` | Matches with: `ab`, `Ab`, `aB`, ..., `cD`, `CD`
|
||||
`/^d(abc)+d$/` | For example, matches with: `dabcd`, `dabcabcd`, `dabcabcabcd`
|
||||
`/^a.*d$/` | Matches with any content that starts with `a` and ends with `d`
|
||||
322
tfhe/docs/tutorial/sha256_bool.md
Normal file
322
tfhe/docs/tutorial/sha256_bool.md
Normal file
@@ -0,0 +1,322 @@
|
||||
# Tutorial
|
||||
|
||||
## Intro
|
||||
|
||||
In this tutorial we will go through the steps to turn a regular sha256 implementation into its homomorphic version. We explain the basics of the sha256 function first, and then how to implement it homomorphically with performance considerations.
|
||||
|
||||
## Sha256
|
||||
|
||||
The first step in this experiment is actually implementing the sha256 function. We can find the specification [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf), but let's summarize the three main sections of the document.
|
||||
|
||||
#### Padding
|
||||
|
||||
The sha256 function processes the input data in blocks or chunks of 512 bits. Before actually performing the hash computations we have to pad the input in the following way:
|
||||
* Append a single "1" bit
|
||||
* Append a number of "0" bits such that exactly 64 bits are left to make the message length a multiple of 512
|
||||
* Append the last 64 bits as a binary encoding of the original input length
|
||||
|
||||
Or visually:
|
||||
|
||||
```
|
||||
0 L L+1 L+1+k L+1+k+64
|
||||
|-----------------------------------|---|--------------------------------|----------------------|
|
||||
Original input (L bits) "1" bit "0" bits Encoding of the number L
|
||||
```
|
||||
Where the numbers on the top represent the length of the padded input at each position, and L+1+k+64 is a multiple of 512 (the length of the padded input).
|
||||
|
||||
#### Operations and functions
|
||||
|
||||
Let's take a look at the operations that we will use as building blocks for functions inside the sha256 computation. These are bitwise AND, XOR, NOT, addition modulo 2^32 and the Rotate Right (ROTR) and Shift Right (SHR) operations, all working with 32-bit words and producing a new word.
|
||||
|
||||
We combine these operations inside the sigma (with 4 variations), Ch and Maj functions. At the end of the day, when we change the sha256 to be computed homomorphically, we will mainly change the isolated code of each operation.
|
||||
|
||||
Here is the definition of each function:
|
||||
```
|
||||
Ch(x, y, z) = (x AND y) XOR ((NOT x) AND z)
|
||||
Maj(x, y, z) = (x AND y) XOR (x AND z) XOR (y AND z)
|
||||
|
||||
Σ0(x) = ROTR-2(x) XOR ROTR-13(x) XOR ROTR-22(x)
|
||||
Σ1(x) = ROTR-6(x) XOR ROTR-11(x) XOR ROTR-25(x)
|
||||
σ0(x) = ROTR-7(x) XOR ROTR-18(x) XOR SHR-3(x)
|
||||
σ1(x) = ROTR-17(x) XOR ROTR-19(x) XOR SHR-10(x)
|
||||
```
|
||||
There are some things to note about the functions. Firstly we see that Maj can be simplified by applying the boolean distributive law (x AND y) XOR (x AND z) = x AND (y XOR z). So the new Maj function looks like this:
|
||||
|
||||
```
|
||||
Maj(x, y, z) = (x AND (y XOR z)) XOR (y AND z)
|
||||
```
|
||||
Next we can also see that Ch can be simplified by using a single bitwise multiplexer. Let's take a look at the truth table of the Ch expression.
|
||||
| x | y | z | Result |
|
||||
| - | - | - | ------ |
|
||||
| 0 | 0 | 0 | 0 |
|
||||
| 0 | 0 | 1 | 1 |
|
||||
| 0 | 1 | 0 | 0 |
|
||||
| 0 | 1 | 1 | 1 |
|
||||
| 1 | 0 | 0 | 0 |
|
||||
| 1 | 0 | 1 | 0 |
|
||||
| 1 | 1 | 0 | 1 |
|
||||
| 1 | 1 | 1 | 1 |
|
||||
|
||||
When ```x = 0``` the result is identical to ```z```, but when ```x = 1``` the result is identical to ```y```. This is the same as saying ```if x {y} else {z}```. Hence we can replace the 4 bitwise operations of Ch by a single bitwise multiplexer.
|
||||
|
||||
Note that all these operations can be evaluated homomorphically. ROTR and SHR can be evaluated by changing the index of each individual bit of the word, even if each bit is encrypted, without using any homomorphic operation. Bitwise AND, XOR and multiplexer can be computed homomorphically and addition modulo 2^32 can be broken down into boolean homomorphic operations as well.
|
||||
|
||||
#### Sha256 computation
|
||||
|
||||
As we have mentioned, the sha256 function works with chunks of 512 bits. For each chunk, we will compute 64 32-bit words. 16 will come from the 512 bits and the rest will be computed using the previous functions. After computing the 64 words, and still within the same chunk iteration, a compression loop will compute a hash value (8 32-bit words), again using the previous functions and some constants to mix everything up. When we finish the last chunk iteration, the resulting hash values will be the output of the sha256 function.
|
||||
|
||||
Here is how this function looks like using arrays of 32 bools to represent words:
|
||||
|
||||
```rust
|
||||
fn sha256(padded_input: Vec<bool>) -> [bool; 256] {
|
||||
|
||||
// Initialize hash values with constant values
|
||||
let mut hash: [[bool; 32]; 8] = [
|
||||
hex_to_bools(0x6a09e667), hex_to_bools(0xbb67ae85),
|
||||
hex_to_bools(0x3c6ef372), hex_to_bools(0xa54ff53a),
|
||||
hex_to_bools(0x510e527f), hex_to_bools(0x9b05688c),
|
||||
hex_to_bools(0x1f83d9ab), hex_to_bools(0x5be0cd19),
|
||||
];
|
||||
|
||||
let chunks = padded_input.chunks(512);
|
||||
|
||||
for chunk in chunks {
|
||||
let mut w = [[false; 32]; 64];
|
||||
|
||||
// Copy first 16 words from current chunk
|
||||
for i in 0..16 {
|
||||
w[i].copy_from_slice(&chunk[i * 32..(i + 1) * 32]);
|
||||
}
|
||||
|
||||
// Compute the other 48 words
|
||||
for i in 16..64 {
|
||||
w[i] = add(add(add(sigma1(&w[i - 2]), w[i - 7]), sigma0(&w[i - 15])), w[i - 16]);
|
||||
}
|
||||
|
||||
let mut a = hash[0];
|
||||
let mut b = hash[1];
|
||||
let mut c = hash[2];
|
||||
let mut d = hash[3];
|
||||
let mut e = hash[4];
|
||||
let mut f = hash[5];
|
||||
let mut g = hash[6];
|
||||
let mut h = hash[7];
|
||||
|
||||
// Compression loop, each iteration uses a specific constant from K
|
||||
for i in 0..64 {
|
||||
let temp1 = add(add(add(add(h, ch(&e, &f, &g)), w[i]), hex_to_bools(K[i])), sigma_upper_case_1(&e));
|
||||
let temp2 = add(sigma_upper_case_0(&a), maj(&a, &b, &c));
|
||||
h = g;
|
||||
g = f;
|
||||
f = e;
|
||||
e = add(d, temp1);
|
||||
d = c;
|
||||
c = b;
|
||||
b = a;
|
||||
a = add(temp1, temp2);
|
||||
}
|
||||
|
||||
hash[0] = add(hash[0], a);
|
||||
hash[1] = add(hash[1], b);
|
||||
hash[2] = add(hash[2], c);
|
||||
hash[3] = add(hash[3], d);
|
||||
hash[4] = add(hash[4], e);
|
||||
hash[5] = add(hash[5], f);
|
||||
hash[6] = add(hash[6], g);
|
||||
hash[7] = add(hash[7], h);
|
||||
}
|
||||
|
||||
// Concatenate the final hash values to produce a 256-bit hash
|
||||
let mut output = [false; 256];
|
||||
for i in 0..8 {
|
||||
output[i * 32..(i + 1) * 32].copy_from_slice(&hash[i]);
|
||||
}
|
||||
output
|
||||
}
|
||||
```
|
||||
|
||||
## Making it homomorphic
|
||||
|
||||
The key idea is that we can replace each bit of ```padded_input``` with a Fully Homomorphic Encryption of the same bit value, and operate over the encrypted values using homomorphic operations. To achieve this we need to change the function signatures and deal with the borrowing rules of the Ciphertext type (which represents an encrypted bit) but the structure of the sha256 function remains the same. The part of the code that requires more consideration is the implementation of the sha256 operations, since they will use homomorphic boolean operations internally.
|
||||
|
||||
Homomorphic operations are really expensive, so we have to remove their unnecessary use and maximize parallelization in order to speed up the program. To simplify our code we use the Rayon crate which provides parallel iterators and efficiently manages threads. Let's now take a look at each sha256 operation!
|
||||
|
||||
#### Rotate Right and Shift Right
|
||||
|
||||
As we have highlighted, these two operations can be evaluated by changing the position of each encrypted bit in the word, thereby requiring 0 homomorphic operations. Here is our implementation:
|
||||
|
||||
```rust
|
||||
fn rotate_right(x: &[Ciphertext; 32], n: usize) -> [Ciphertext; 32] {
|
||||
let mut result = x.clone();
|
||||
result.rotate_right(n);
|
||||
result
|
||||
}
|
||||
|
||||
fn shift_right(x: &[Ciphertext; 32], n: usize, sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let mut result = x.clone();
|
||||
result.rotate_right(n);
|
||||
result[..n].fill_with(|| sk.trivial_encrypt(false));
|
||||
result
|
||||
}
|
||||
```
|
||||
|
||||
#### Bitwise XOR, AND, Multiplexer
|
||||
|
||||
To implement these operations we will use the ```xor```, ```and``` and ```mux``` methods provided by the tfhe library to evaluate each boolean operation homomorphically. It's important to note that, since we will operate bitwise, we can parallelize the homomorphic computations. In other words, we can homomorphically XOR the bits at index 0 of two words using a thread, while XORing the bits at index 1 using another thread, and so on. This means we could compute these bitwise operations using up to 32 concurrent threads (since we work with 32-bit words).
|
||||
|
||||
Here is our implementation of the bitwise homomorphic XOR operation. The ```par_iter``` and ```par_iter_mut``` methods create a parallel iterator that we use to compute each individual XOR efficiently. The other two bitwise operations are implemented in the same way.
|
||||
|
||||
```rust
|
||||
fn xor(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let mut result = a.clone();
|
||||
result.par_iter_mut()
|
||||
.zip(a.par_iter().zip(b.par_iter()))
|
||||
.for_each(|(dst, (lhs, rhs))| *dst = sk.xor(lhs, rhs));
|
||||
result
|
||||
}
|
||||
```
|
||||
|
||||
#### Addition modulo 2^32
|
||||
|
||||
This is perhaps the trickiest operation to efficiently implement in a homomorphic fashion. A naive implementation could use the Ripple Carry Adder algorithm, which is straightforward but cannot be parallelized because each step depends on the previous one.
|
||||
|
||||
A better choice would be the Carry Lookahead Adder, which allows us to use the parallelized AND and XOR bitwise operations. With this design, our adder is around 50% faster than the Ripple Carry Adder.
|
||||
|
||||
```rust
|
||||
pub fn add(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let propagate = xor(a, b, sk); // Parallelized bitwise XOR
|
||||
let generate = and(a, b, sk); // Parallelized bitwise AND
|
||||
|
||||
let carry = compute_carry(&propagate, &generate, sk);
|
||||
let sum = xor(&propagate, &carry, sk); // Parallelized bitwise XOR
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
fn compute_carry(propagate: &[Ciphertext; 32], generate: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let mut carry = trivial_bools(&[false; 32], sk);
|
||||
carry[31] = sk.trivial_encrypt(false);
|
||||
|
||||
for i in (0..31).rev() {
|
||||
carry[i] = sk.or(&generate[i + 1], &sk.and(&propagate[i + 1], &carry[i + 1]));
|
||||
}
|
||||
|
||||
carry
|
||||
}
|
||||
```
|
||||
|
||||
To even improve performance more, the function that computes the carry signals can also be parallelized using parallel prefix algorithms. These algorithms involve more boolean operations (so homomorphic operations for us) but may be faster because of their parallel nature. We have implemented the Brent-Kung and Ladner-Fischer algorithms, which entail different tradeoffs.
|
||||
|
||||
Brent-Kung has the least amount of boolean operations we could find (140 when using grey cells, for 32-bit numbers), which makes it suitable when we can't process many operations concurrently and fast. Our results confirm that it's indeed faster than both the sequential algorithm and Ladner-Fischer when run on regular computers.
|
||||
|
||||
On the other hand, Ladner-Fischer performs more boolean operations (209 using grey cells) than Brent-Kung, but they are performed in larger batches. Hence we can compute more operations in parallel and finish earlier, but we need more fast threads available or they will slow down the carry signals computation. Ladner-Fischer can be suitable when using cloud-based computing services, which offer many high-speed threads.
|
||||
|
||||
Our implementation uses Brent-Kung by default, but Ladner-Fischer can be enabled when needed by using the ```--ladner-fischer``` command line argument.
|
||||
|
||||
For more information about parallel prefix adders you can read [this paper](https://www.iosrjournals.org/iosr-jece/papers/Vol6-Issue1/A0610106.pdf) or [this other paper](https://www.ijert.org/research/design-and-implementation-of-parallel-prefix-adder-for-improving-the-performance-of-carry-lookahead-adder-IJERTV4IS120608.pdf).
|
||||
|
||||
Finally, with all these sha256 operations working homomorphically, our functions will be homomomorphic as well along with the whole sha256 function (after adapting the code to work with the Ciphertext type). Let's talk about other performance improvements we can make before we finish.
|
||||
|
||||
### More parallel processing
|
||||
|
||||
If we inspect the main ```sha256_fhe``` function, we will find operations that can be performed in parallel. For instance, within the compression loop, ```temp1``` and ```temp2``` can be computed concurrently. An efficient way to parallelize computations here is using the ```rayon::join()``` function, which uses parallel processing only when there are available CPUs. Recall that the two temporary values in the compression loop are the result of several additions, so we can use nested calls to ```rayon::join()``` to potentially parallelize more operations.
|
||||
|
||||
Another way to speed up consecutive additions would be using the Carry Save Adder, a very efficient adder that takes 3 numbers and returns a sum and carry sequence. If our inputs are A, B and C, we can construct a CSA with our previously implemented Maj function and the bitwise XOR operation as follows:
|
||||
|
||||
```
|
||||
Carry = Maj(A, B, C)
|
||||
Sum = A XOR B XOR C
|
||||
```
|
||||
|
||||
By chaining CSAs, we can input the sum and carry from a preceding stage along with another number into a new CSA. Finally, to get the result of the additions we add the sum and carry sequences using a conventional adder. At the end we are performing the same number of additions, but some of them are now CSAs, speeding up the process. Let's see all this together in the ```temp1``` and ```temp2``` computations.
|
||||
|
||||
```rust
|
||||
let (temp1, temp2) = rayon::join(
|
||||
|| {
|
||||
let ((sum, carry), s1) = rayon::join(
|
||||
|| {
|
||||
let ((sum, carry), ch) = rayon::join(
|
||||
|| csa(&h, &w[i], &trivial_bools(&hex_to_bools(K[i]), sk), sk),
|
||||
|| ch(&e, &f, &g, sk),
|
||||
);
|
||||
csa(&sum, &carry, &ch, sk)
|
||||
},
|
||||
|| sigma_upper_case_1(&e, sk)
|
||||
);
|
||||
|
||||
let (sum, carry) = csa(&sum, &carry, &s1, sk);
|
||||
add(&sum, &carry, sk)
|
||||
},
|
||||
|| {
|
||||
add(&sigma_upper_case_0(&a, sk), &maj(&a, &b, &c, sk), sk)
|
||||
},
|
||||
);
|
||||
```
|
||||
|
||||
The first closure of the outer call to join will return ```temp1``` and the second ```temp2```. Inside the first outer closure we call join recursively until we reach the addition of the value ```h```, the current word ```w[i]``` and the current constant ```K[i]``` by using the CSA, while potentially computing in parallel the ```ch``` function. Then we take the sum, carry and ch values and add them again using the CSA.
|
||||
|
||||
All this is done while potentially computing the ```sigma_upper_case_1``` function. Finally we input the previous sum, carry and sigma values to the CSA and perform the final addition with ```add```. Once again, this is done while potentially computing ```sigma_upper_case_0``` and ```maj``` and adding them to get ```temp2```, in the second outer closure.
|
||||
|
||||
With some changes of this type, we finally get a homomorphic sha256 function that doesn't leave unused computational resources.
|
||||
|
||||
## How to use sha256_bool
|
||||
|
||||
First of all, the most important thing when running the program is using the ```--release``` flag. The use of sha256_bool would look like this, given the implementation of ```encrypt_bools``` and ```decrypt_bools```:
|
||||
|
||||
```rust
|
||||
fn main() {
|
||||
let matches = Command::new("Homomorphic sha256")
|
||||
.arg(Arg::new("ladner_fischer")
|
||||
.long("ladner-fischer")
|
||||
.help("Use the Ladner Fischer parallel prefix algorithm for additions")
|
||||
.action(ArgAction::SetTrue))
|
||||
.get_matches();
|
||||
|
||||
// If set using the command line flag "--ladner-fischer" this algorithm will be used in additions
|
||||
let ladner_fischer: bool = matches.get_flag("ladner_fischer");
|
||||
|
||||
// INTRODUCE INPUT FROM STDIN
|
||||
|
||||
let mut input = String::new();
|
||||
println!("Write input to hash:");
|
||||
|
||||
io::stdin()
|
||||
.read_line(&mut input)
|
||||
.expect("Failed to read line");
|
||||
|
||||
input = input.trim_end_matches('\n').to_string();
|
||||
|
||||
println!("You entered: \"{}\"", input);
|
||||
|
||||
// CLIENT PADS DATA AND ENCRYPTS IT
|
||||
|
||||
let (ck, sk) = gen_keys();
|
||||
|
||||
let padded_input = pad_sha256_input(&input);
|
||||
let encrypted_input = encrypt_bools(&padded_input, &ck);
|
||||
|
||||
// SERVER COMPUTES OVER THE ENCRYPTED PADDED DATA
|
||||
|
||||
println!("Computing the hash");
|
||||
let encrypted_output = sha256_fhe(encrypted_input, ladner_fischer, &sk);
|
||||
|
||||
// CLIENT DECRYPTS THE OUTPUT
|
||||
|
||||
let output = decrypt_bools(&encrypted_output, &ck);
|
||||
let outhex = bools_to_hex(output);
|
||||
|
||||
println!("{}", outhex);
|
||||
}
|
||||
```
|
||||
|
||||
By using ```stdin``` we can supply the data to hash using a file instead of the command line. For example, if our file ```input.txt``` is in the same directory as the project, we can use the following shell command after building with ```cargo build --release```:
|
||||
|
||||
```sh
|
||||
./target/release/examples/sha256_bool < input.txt
|
||||
```
|
||||
|
||||
Our implementation also accepts hexadecimal inputs. To be considered as such, the input must start with "0x" and contain only valid hex digits (otherwise it's interpreted as text).
|
||||
|
||||
Finally see that padding is executed on the client side. This has the advantage of hiding the exact length of the input to the server, who already doesn't know anything about the contents of it but may extract information from the length.
|
||||
|
||||
Another option would be to perform padding on the server side. The padding function would receive the encrypted input and pad it with trivial bit encryptions. We could then integrate the padding function inside the ```sha256_fhe``` function computed by the server.
|
||||
418
tfhe/examples/dark_market.rs
Normal file
418
tfhe/examples/dark_market.rs
Normal file
@@ -0,0 +1,418 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
use tfhe::integer::ciphertext::RadixCiphertextBig;
|
||||
use tfhe::integer::keycache::IntegerKeyCache;
|
||||
use tfhe::integer::ServerKey;
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
/// The number of blocks to be used in the Radix.
|
||||
const NUMBER_OF_BLOCKS: usize = 8;
|
||||
|
||||
/// Plain implementation of the volume matching algorithm.
|
||||
///
|
||||
/// Matches the given [sell_orders] with [buy_orders].
|
||||
/// The amount of the orders that are successfully filled is written over the original order count.
|
||||
fn volume_match_plain(sell_orders: &mut [u16], buy_orders: &mut [u16]) {
|
||||
let total_sell_volume: u16 = sell_orders.iter().sum();
|
||||
let total_buy_volume: u16 = buy_orders.iter().sum();
|
||||
|
||||
let total_volume = std::cmp::min(total_buy_volume, total_sell_volume);
|
||||
|
||||
let mut volume_left_to_transact = total_volume;
|
||||
for sell_order in sell_orders.iter_mut() {
|
||||
let filled_amount = std::cmp::min(volume_left_to_transact, *sell_order);
|
||||
*sell_order = filled_amount;
|
||||
volume_left_to_transact -= filled_amount;
|
||||
}
|
||||
|
||||
let mut volume_left_to_transact = total_volume;
|
||||
for buy_order in buy_orders.iter_mut() {
|
||||
let filled_amount = std::cmp::min(volume_left_to_transact, *buy_order);
|
||||
*buy_order = filled_amount;
|
||||
volume_left_to_transact -= filled_amount;
|
||||
}
|
||||
}
|
||||
|
||||
/// FHE implementation of the volume matching algorithm.
|
||||
///
|
||||
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
|
||||
/// [server_key]. The amount of the orders that are successfully filled is written over the original
|
||||
/// order count.
|
||||
fn volume_match_fhe(
|
||||
sell_orders: &mut [RadixCiphertextBig],
|
||||
buy_orders: &mut [RadixCiphertextBig],
|
||||
server_key: &ServerKey,
|
||||
) {
|
||||
println!("Calculating total sell and buy volumes...");
|
||||
let time = Instant::now();
|
||||
let mut total_sell_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for sell_order in sell_orders.iter_mut() {
|
||||
server_key.smart_add_assign(&mut total_sell_volume, sell_order);
|
||||
}
|
||||
|
||||
let mut total_buy_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for buy_order in buy_orders.iter_mut() {
|
||||
server_key.smart_add_assign(&mut total_buy_volume, buy_order);
|
||||
}
|
||||
println!(
|
||||
"Total sell and buy volumes are calculated in {:?}",
|
||||
time.elapsed()
|
||||
);
|
||||
|
||||
println!("Calculating total volume to be matched...");
|
||||
let time = Instant::now();
|
||||
let total_volume = server_key.smart_min(&mut total_sell_volume, &mut total_buy_volume);
|
||||
println!(
|
||||
"Calculated total volume to be matched in {:?}",
|
||||
time.elapsed()
|
||||
);
|
||||
|
||||
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
|
||||
let mut volume_left_to_transact = total_volume.clone();
|
||||
for order in orders.iter_mut() {
|
||||
let mut filled_amount = server_key.smart_min(&mut volume_left_to_transact, order);
|
||||
server_key.smart_sub_assign(&mut volume_left_to_transact, &mut filled_amount);
|
||||
*order = filled_amount;
|
||||
}
|
||||
};
|
||||
|
||||
println!("Filling orders...");
|
||||
let time = Instant::now();
|
||||
fill_orders(sell_orders);
|
||||
fill_orders(buy_orders);
|
||||
println!("Filled orders in {:?}", time.elapsed());
|
||||
}
|
||||
|
||||
/// FHE implementation of the volume matching algorithm.
|
||||
///
|
||||
/// This version of the algorithm utilizes parallelization to speed up the computation.
|
||||
///
|
||||
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
|
||||
/// [server_key]. The amount of the orders that are successfully filled is written over the original
|
||||
/// order count.
|
||||
fn volume_match_fhe_parallelized(
|
||||
sell_orders: &mut [RadixCiphertextBig],
|
||||
buy_orders: &mut [RadixCiphertextBig],
|
||||
server_key: &ServerKey,
|
||||
) {
|
||||
// Calculate the element sum of the given vector in parallel
|
||||
let parallel_vector_sum = |vec: &mut [RadixCiphertextBig]| {
|
||||
vec.to_vec().into_par_iter().reduce(
|
||||
|| server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|
||||
|mut acc: RadixCiphertextBig, mut ele: RadixCiphertextBig| {
|
||||
server_key.smart_add_parallelized(&mut acc, &mut ele)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
println!("Calculating total sell and buy volumes...");
|
||||
let time = Instant::now();
|
||||
// Total sell and buy volumes can be calculated in parallel because they have no dependency on
|
||||
// each other.
|
||||
let (mut total_sell_volume, mut total_buy_volume) = rayon::join(
|
||||
|| parallel_vector_sum(sell_orders),
|
||||
|| parallel_vector_sum(buy_orders),
|
||||
);
|
||||
println!(
|
||||
"Total sell and buy volumes are calculated in {:?}",
|
||||
time.elapsed()
|
||||
);
|
||||
|
||||
println!("Calculating total volume to be matched...");
|
||||
let time = Instant::now();
|
||||
let total_volume =
|
||||
server_key.smart_min_parallelized(&mut total_sell_volume, &mut total_buy_volume);
|
||||
println!(
|
||||
"Calculated total volume to be matched in {:?}",
|
||||
time.elapsed()
|
||||
);
|
||||
|
||||
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
|
||||
let mut volume_left_to_transact = total_volume.clone();
|
||||
for order in orders.iter_mut() {
|
||||
let mut filled_amount =
|
||||
server_key.smart_min_parallelized(&mut volume_left_to_transact, order);
|
||||
server_key
|
||||
.smart_sub_assign_parallelized(&mut volume_left_to_transact, &mut filled_amount);
|
||||
*order = filled_amount;
|
||||
}
|
||||
};
|
||||
println!("Filling orders...");
|
||||
let time = Instant::now();
|
||||
rayon::join(|| fill_orders(sell_orders), || fill_orders(buy_orders));
|
||||
println!("Filled orders in {:?}", time.elapsed());
|
||||
}
|
||||
|
||||
/// FHE implementation of the volume matching algorithm.
|
||||
///
|
||||
/// In this function, the implemented algorithm is modified to utilize more concurrency.
|
||||
///
|
||||
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
|
||||
/// [server_key]. The amount of the orders that are successfully filled is written over the original
|
||||
/// order count.
|
||||
fn volume_match_fhe_modified(
|
||||
sell_orders: &mut [RadixCiphertextBig],
|
||||
buy_orders: &mut [RadixCiphertextBig],
|
||||
server_key: &ServerKey,
|
||||
) {
|
||||
let compute_prefix_sum = |arr: &[RadixCiphertextBig]| {
|
||||
if arr.is_empty() {
|
||||
return arr.to_vec();
|
||||
}
|
||||
let mut prefix_sum: Vec<RadixCiphertextBig> = (0..arr.len().next_power_of_two())
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
if i < arr.len() {
|
||||
arr[i].clone()
|
||||
} else {
|
||||
server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
for d in 0..prefix_sum.len().ilog2() {
|
||||
prefix_sum
|
||||
.par_chunks_exact_mut(2_usize.pow(d + 1))
|
||||
.for_each(move |chunk| {
|
||||
let length = chunk.len();
|
||||
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
|
||||
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left)
|
||||
});
|
||||
}
|
||||
let last = prefix_sum.last().unwrap().clone();
|
||||
*prefix_sum.last_mut().unwrap() = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
|
||||
for d in (0..prefix_sum.len().ilog2()).rev() {
|
||||
prefix_sum
|
||||
.par_chunks_exact_mut(2_usize.pow(d + 1))
|
||||
.for_each(move |chunk| {
|
||||
let length = chunk.len();
|
||||
let temp = chunk.last().unwrap().clone();
|
||||
let mut mid = chunk.get((length - 1) / 2).unwrap().clone();
|
||||
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut mid);
|
||||
chunk[(length - 1) / 2] = temp;
|
||||
});
|
||||
}
|
||||
prefix_sum.push(last);
|
||||
prefix_sum[1..=arr.len()].to_vec()
|
||||
};
|
||||
|
||||
println!("Creating prefix sum arrays...");
|
||||
let time = Instant::now();
|
||||
let (prefix_sum_sell_orders, prefix_sum_buy_orders) = rayon::join(
|
||||
|| compute_prefix_sum(sell_orders),
|
||||
|| compute_prefix_sum(buy_orders),
|
||||
);
|
||||
println!("Created prefix sum arrays in {:?}", time.elapsed());
|
||||
|
||||
let fill_orders = |total_orders: &RadixCiphertextBig,
|
||||
orders: &mut [RadixCiphertextBig],
|
||||
prefix_sum_arr: &[RadixCiphertextBig]| {
|
||||
orders
|
||||
.into_par_iter()
|
||||
.enumerate()
|
||||
.for_each(move |(i, order)| {
|
||||
server_key.smart_add_assign_parallelized(
|
||||
order,
|
||||
&mut server_key.smart_mul_parallelized(
|
||||
&mut server_key
|
||||
.smart_ge_parallelized(&mut order.clone(), &mut total_orders.clone()),
|
||||
&mut server_key.smart_sub_parallelized(
|
||||
&mut server_key.smart_sub_parallelized(
|
||||
&mut total_orders.clone(),
|
||||
&mut server_key.smart_min_parallelized(
|
||||
&mut total_orders.clone(),
|
||||
&mut prefix_sum_arr
|
||||
.get(i - 1)
|
||||
.unwrap_or(
|
||||
&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|
||||
)
|
||||
.clone(),
|
||||
),
|
||||
),
|
||||
&mut order.clone(),
|
||||
),
|
||||
),
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
let total_buy_orders = &mut prefix_sum_buy_orders
|
||||
.last()
|
||||
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
|
||||
.clone();
|
||||
|
||||
let total_sell_orders = &mut prefix_sum_sell_orders
|
||||
.last()
|
||||
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
|
||||
.clone();
|
||||
|
||||
println!("Matching orders...");
|
||||
let time = Instant::now();
|
||||
rayon::join(
|
||||
|| fill_orders(total_sell_orders, buy_orders, &prefix_sum_buy_orders),
|
||||
|| fill_orders(total_buy_orders, sell_orders, &prefix_sum_sell_orders),
|
||||
);
|
||||
println!("Matched orders in {:?}", time.elapsed());
|
||||
}
|
||||
|
||||
/// Runs the given [tester] function with the test cases for volume matching algorithm.
|
||||
fn run_test_cases<F: Fn(&[u16], &[u16], &[u16], &[u16])>(tester: F) {
|
||||
println!("Testing empty sell orders...");
|
||||
tester(
|
||||
&[],
|
||||
&(1..11).collect::<Vec<_>>(),
|
||||
&[],
|
||||
&(1..11).map(|_| 0).collect::<Vec<_>>(),
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Testing empty buy orders...");
|
||||
tester(
|
||||
&(1..11).collect::<Vec<_>>(),
|
||||
&[],
|
||||
&(1..11).map(|_| 0).collect::<Vec<_>>(),
|
||||
&[],
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Testing exact matching of sell and buy orders...");
|
||||
tester(
|
||||
&(1..11).collect::<Vec<_>>(),
|
||||
&(1..11).collect::<Vec<_>>(),
|
||||
&(1..11).collect::<Vec<_>>(),
|
||||
&(1..11).collect::<Vec<_>>(),
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Testing the case where there are more buy orders than sell orders...");
|
||||
tester(
|
||||
&(1..11).map(|_| 10).collect::<Vec<_>>(),
|
||||
&[200],
|
||||
&(1..11).map(|_| 10).collect::<Vec<_>>(),
|
||||
&[100],
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Testing the case where there are more sell orders than buy orders...");
|
||||
tester(
|
||||
&[200],
|
||||
&(1..11).map(|_| 10).collect::<Vec<_>>(),
|
||||
&[100],
|
||||
&(1..11).map(|_| 10).collect::<Vec<_>>(),
|
||||
);
|
||||
println!();
|
||||
|
||||
println!("Testing maximum input size for sell and buy orders...");
|
||||
tester(
|
||||
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
|
||||
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
|
||||
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
|
||||
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
/// Runs the test cases for the plain implementation of the volume matching algorithm.
|
||||
fn test_volume_match_plain() {
|
||||
let tester = |input_sell_orders: &[u16],
|
||||
input_buy_orders: &[u16],
|
||||
expected_filled_sells: &[u16],
|
||||
expected_filled_buys: &[u16]| {
|
||||
let mut sell_orders = input_sell_orders.to_vec();
|
||||
let mut buy_orders = input_buy_orders.to_vec();
|
||||
|
||||
println!("Running plain implementation...");
|
||||
let time = Instant::now();
|
||||
volume_match_plain(&mut sell_orders, &mut buy_orders);
|
||||
println!("Ran plain implementation in {:?}", time.elapsed());
|
||||
|
||||
assert_eq!(sell_orders, expected_filled_sells);
|
||||
assert_eq!(buy_orders, expected_filled_buys);
|
||||
};
|
||||
|
||||
println!("Running test cases for the plain implementation");
|
||||
run_test_cases(tester);
|
||||
}
|
||||
|
||||
/// Runs the test cases for the fhe implementation of the volume matching algorithm.
|
||||
///
|
||||
/// [parallelized] indicates whether the fhe implementation should be run in parallel.
|
||||
fn test_volume_match_fhe(
|
||||
fhe_function: fn(&mut [RadixCiphertextBig], &mut [RadixCiphertextBig], &ServerKey),
|
||||
) {
|
||||
let working_dir = std::env::current_dir().unwrap();
|
||||
if working_dir.file_name().unwrap() != std::path::Path::new("tfhe") {
|
||||
std::env::set_current_dir(working_dir.join("tfhe")).unwrap();
|
||||
}
|
||||
|
||||
println!("Generating keys...");
|
||||
let time = Instant::now();
|
||||
let (client_key, server_key) = IntegerKeyCache.get_from_params(PARAM_MESSAGE_2_CARRY_2);
|
||||
println!("Keys generated in {:?}", time.elapsed());
|
||||
|
||||
let tester = |input_sell_orders: &[u16],
|
||||
input_buy_orders: &[u16],
|
||||
expected_filled_sells: &[u16],
|
||||
expected_filled_buys: &[u16]| {
|
||||
let mut encrypted_sell_orders = input_sell_orders
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|pt| client_key.encrypt_radix(pt as u64, NUMBER_OF_BLOCKS))
|
||||
.collect::<Vec<RadixCiphertextBig>>();
|
||||
let mut encrypted_buy_orders = input_buy_orders
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|pt| client_key.encrypt_radix(pt as u64, NUMBER_OF_BLOCKS))
|
||||
.collect::<Vec<RadixCiphertextBig>>();
|
||||
|
||||
println!("Running FHE implementation...");
|
||||
let time = Instant::now();
|
||||
fhe_function(
|
||||
&mut encrypted_sell_orders,
|
||||
&mut encrypted_buy_orders,
|
||||
&server_key,
|
||||
);
|
||||
println!("Ran FHE implementation in {:?}", time.elapsed());
|
||||
|
||||
let decrypted_filled_sells = encrypted_sell_orders
|
||||
.iter()
|
||||
.map(|ct| client_key.decrypt_radix::<u64, _>(ct) as u16)
|
||||
.collect::<Vec<u16>>();
|
||||
let decrypted_filled_buys = encrypted_buy_orders
|
||||
.iter()
|
||||
.map(|ct| client_key.decrypt_radix::<u64, _>(ct) as u16)
|
||||
.collect::<Vec<u16>>();
|
||||
|
||||
assert_eq!(decrypted_filled_sells, expected_filled_sells);
|
||||
assert_eq!(decrypted_filled_buys, expected_filled_buys);
|
||||
};
|
||||
|
||||
println!("Running test cases for the FHE implementation");
|
||||
run_test_cases(tester);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
for argument in std::env::args() {
|
||||
if argument == "fhe-modified" {
|
||||
println!("Running modified fhe version");
|
||||
test_volume_match_fhe(volume_match_fhe_modified);
|
||||
println!();
|
||||
}
|
||||
if argument == "fhe-parallel" {
|
||||
println!("Running parallelized fhe version");
|
||||
test_volume_match_fhe(volume_match_fhe_parallelized);
|
||||
println!();
|
||||
}
|
||||
if argument == "plain" {
|
||||
println!("Running plain version");
|
||||
test_volume_match_plain();
|
||||
println!();
|
||||
}
|
||||
if argument == "fhe" {
|
||||
println!("Running fhe version");
|
||||
test_volume_match_fhe(volume_match_fhe);
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,62 +1,100 @@
|
||||
use clap::{Arg, ArgAction, Command};
|
||||
use tfhe::shortint::keycache::{NamedParam, KEY_CACHE, KEY_CACHE_WOPBS};
|
||||
use tfhe::shortint::parameters::parameters_wopbs_message_carry::{
|
||||
WOPBS_PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_3_CARRY_3,
|
||||
WOPBS_PARAM_MESSAGE_4_CARRY_4,
|
||||
};
|
||||
use tfhe::shortint::parameters::{
|
||||
Parameters, ALL_PARAMETER_VEC, PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2,
|
||||
PARAM_MESSAGE_3_CARRY_3, PARAM_MESSAGE_4_CARRY_4,
|
||||
ClassicPBSParameters, WopbsParameters, ALL_MULTI_BIT_PARAMETER_VEC, ALL_PARAMETER_VEC,
|
||||
PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_3_CARRY_3,
|
||||
PARAM_MESSAGE_4_CARRY_4,
|
||||
};
|
||||
|
||||
fn client_server_keys() {
|
||||
println!("Generating shortint (ClientKey, ServerKey)");
|
||||
for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() {
|
||||
println!(
|
||||
"Generating [{} / {}] : {}",
|
||||
i + 1,
|
||||
ALL_PARAMETER_VEC.len(),
|
||||
params.name()
|
||||
);
|
||||
let matches = Command::new("test key gen")
|
||||
.arg(
|
||||
Arg::new("multi_bit_only")
|
||||
.long("multi-bit-only")
|
||||
.help("Set to generate only multi bit keys, otherwise only PBS and WoPBS keys are generated")
|
||||
.action(ArgAction::SetTrue),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
// If set using the command line flag "--ladner-fischer" this algorithm will be used in
|
||||
// additions
|
||||
let multi_bit_only: bool = matches.get_flag("multi_bit_only");
|
||||
|
||||
let _ = KEY_CACHE.get_from_param(params);
|
||||
if multi_bit_only {
|
||||
println!("Generating shortint multibit (ClientKey, ServerKey)");
|
||||
for (i, params) in ALL_MULTI_BIT_PARAMETER_VEC.iter().copied().enumerate() {
|
||||
println!(
|
||||
"Generating [{} / {}] : {}",
|
||||
i + 1,
|
||||
ALL_MULTI_BIT_PARAMETER_VEC.len(),
|
||||
params.name()
|
||||
);
|
||||
|
||||
let stop = start.elapsed().as_secs();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
println!("Generation took {stop} seconds");
|
||||
let _ = KEY_CACHE.get_from_param(params);
|
||||
|
||||
// Clear keys as we go to avoid filling the RAM
|
||||
KEY_CACHE.clear_in_memory_cache()
|
||||
}
|
||||
let stop = start.elapsed().as_secs();
|
||||
|
||||
const WOPBS_PARAMS: [(Parameters, Parameters); 4] = [
|
||||
(PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1),
|
||||
(PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2),
|
||||
(PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3),
|
||||
(PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4),
|
||||
];
|
||||
println!("Generation took {stop} seconds");
|
||||
|
||||
println!("Generating woPBS keys");
|
||||
for (i, (params_shortint, params_wopbs)) in WOPBS_PARAMS.iter().copied().enumerate() {
|
||||
println!(
|
||||
"Generating [{} / {}] : {}, {}",
|
||||
i + 1,
|
||||
WOPBS_PARAMS.len(),
|
||||
params_shortint.name(),
|
||||
params_wopbs.name(),
|
||||
);
|
||||
// Clear keys as we go to avoid filling the RAM
|
||||
KEY_CACHE.clear_in_memory_cache()
|
||||
}
|
||||
} else {
|
||||
println!("Generating shortint (ClientKey, ServerKey)");
|
||||
for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() {
|
||||
println!(
|
||||
"Generating [{} / {}] : {}",
|
||||
i + 1,
|
||||
ALL_PARAMETER_VEC.len(),
|
||||
params.name()
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let _ = KEY_CACHE_WOPBS.get_from_param((params_shortint, params_wopbs));
|
||||
let _ = KEY_CACHE.get_from_param(params);
|
||||
|
||||
let stop = start.elapsed().as_secs();
|
||||
let stop = start.elapsed().as_secs();
|
||||
|
||||
println!("Generation took {stop} seconds");
|
||||
println!("Generation took {stop} seconds");
|
||||
|
||||
// Clear keys as we go to avoid filling the RAM
|
||||
KEY_CACHE_WOPBS.clear_in_memory_cache()
|
||||
// Clear keys as we go to avoid filling the RAM
|
||||
KEY_CACHE.clear_in_memory_cache()
|
||||
}
|
||||
|
||||
const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 4] = [
|
||||
(PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1),
|
||||
(PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2),
|
||||
(PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3),
|
||||
(PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4),
|
||||
];
|
||||
|
||||
println!("Generating woPBS keys");
|
||||
for (i, (params_shortint, params_wopbs)) in WOPBS_PARAMS.iter().copied().enumerate() {
|
||||
println!(
|
||||
"Generating [{} / {}] : {}, {}",
|
||||
i + 1,
|
||||
WOPBS_PARAMS.len(),
|
||||
params_shortint.name(),
|
||||
params_wopbs.name(),
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let _ = KEY_CACHE_WOPBS.get_from_param((params_shortint, params_wopbs));
|
||||
|
||||
let stop = start.elapsed().as_secs();
|
||||
|
||||
println!("Generation took {stop} seconds");
|
||||
|
||||
// Clear keys as we go to avoid filling the RAM
|
||||
KEY_CACHE_WOPBS.clear_in_memory_cache()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
92
tfhe/examples/integer_compact_pk_ct_sizes.rs
Normal file
92
tfhe/examples/integer_compact_pk_ct_sizes.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use rand::Rng;
|
||||
|
||||
use tfhe::core_crypto::commons::numeric::Numeric;
|
||||
use tfhe::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
|
||||
use tfhe::integer::public_key::{CompactPublicKeyBig, CompactPublicKeySmall};
|
||||
use tfhe::integer::{gen_keys, U256};
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::parameters_compact_pk::*;
|
||||
|
||||
pub fn main() {
|
||||
fn size_func<Scalar: Numeric + DecomposableInto<u64> + RecomposableFrom<u64> + From<u32>>() {
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_bits = Scalar::BITS;
|
||||
|
||||
let params = PARAM_MESSAGE_2_CARRY_2_COMPACT_PK;
|
||||
{
|
||||
println!("Sizes for: {} and {num_bits} bits", params.name());
|
||||
let (cks, _) = gen_keys(params);
|
||||
let pk = CompactPublicKeyBig::new(&cks);
|
||||
|
||||
println!("PK size: {} bytes", bincode::serialize(&pk).unwrap().len());
|
||||
|
||||
let num_block =
|
||||
(num_bits as f64 / (params.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
const MAX_CT: usize = 20;
|
||||
|
||||
let mut clear_vec = Vec::with_capacity(MAX_CT);
|
||||
// 5 inputs to a smart contract
|
||||
let num_ct_for_this_iter = 5;
|
||||
clear_vec.truncate(0);
|
||||
for _ in 0..num_ct_for_this_iter {
|
||||
let clear = rng.gen::<u32>();
|
||||
clear_vec.push(Scalar::from(clear));
|
||||
}
|
||||
|
||||
let compact_encrypted_list = pk.encrypt_slice_radix_compact(&clear_vec, num_block);
|
||||
|
||||
println!(
|
||||
"Compact CT list for {num_ct_for_this_iter} CTs: {} bytes",
|
||||
bincode::serialize(&compact_encrypted_list).unwrap().len()
|
||||
);
|
||||
|
||||
let ciphertext_vec = compact_encrypted_list.expand();
|
||||
|
||||
for (ciphertext, clear) in ciphertext_vec.iter().zip(clear_vec.iter().copied()) {
|
||||
let decrypted: Scalar = cks.decrypt_radix(ciphertext);
|
||||
assert_eq!(decrypted, clear);
|
||||
}
|
||||
}
|
||||
|
||||
let params = PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_SMALL;
|
||||
{
|
||||
println!("Sizes for: {} and {num_bits} bits", params.name());
|
||||
let (cks, _) = gen_keys(params);
|
||||
let pk = CompactPublicKeySmall::new(&cks);
|
||||
|
||||
println!("PK size: {} bytes", bincode::serialize(&pk).unwrap().len());
|
||||
|
||||
let num_block =
|
||||
(num_bits as f64 / (params.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
const MAX_CT: usize = 20;
|
||||
|
||||
let mut clear_vec = Vec::with_capacity(MAX_CT);
|
||||
// 5 inputs to a smart contract
|
||||
let num_ct_for_this_iter = 5;
|
||||
clear_vec.truncate(0);
|
||||
for _ in 0..num_ct_for_this_iter {
|
||||
let clear = rng.gen::<u32>();
|
||||
clear_vec.push(Scalar::from(clear));
|
||||
}
|
||||
|
||||
let compact_encrypted_list = pk.encrypt_slice_radix_compact(&clear_vec, num_block);
|
||||
|
||||
println!(
|
||||
"Compact CT list for {num_ct_for_this_iter} CTs: {} bytes",
|
||||
bincode::serialize(&compact_encrypted_list).unwrap().len()
|
||||
);
|
||||
|
||||
let ciphertext_vec = compact_encrypted_list.expand();
|
||||
|
||||
for (ciphertext, clear) in ciphertext_vec.iter().zip(clear_vec.iter().copied()) {
|
||||
let decrypted: Scalar = cks.decrypt_radix(ciphertext);
|
||||
assert_eq!(decrypted, clear);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_func::<u32>();
|
||||
size_func::<U256>();
|
||||
}
|
||||
22
tfhe/examples/regex_engine/ciphertext.rs
Normal file
22
tfhe/examples/regex_engine/ciphertext.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use tfhe::integer::{gen_keys_radix, RadixCiphertextBig, RadixClientKey, ServerKey};
|
||||
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
|
||||
pub type StringCiphertext = Vec<RadixCiphertextBig>;
|
||||
|
||||
pub fn encrypt_str(
|
||||
client_key: &RadixClientKey,
|
||||
s: &str,
|
||||
) -> Result<StringCiphertext, Box<dyn std::error::Error>> {
|
||||
if !s.is_ascii() {
|
||||
return Err("content contains non-ascii characters".into());
|
||||
}
|
||||
Ok(s.as_bytes()
|
||||
.iter()
|
||||
.map(|byte| client_key.encrypt(*byte as u64))
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub fn gen_keys() -> (RadixClientKey, ServerKey) {
|
||||
let num_block = 4;
|
||||
gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block)
|
||||
}
|
||||
263
tfhe/examples/regex_engine/engine.rs
Normal file
263
tfhe/examples/regex_engine/engine.rs
Normal file
@@ -0,0 +1,263 @@
|
||||
use crate::execution::{Executed, Execution, LazyExecution};
|
||||
use crate::parser::{parse, RegExpr};
|
||||
use std::rc::Rc;
|
||||
use tfhe::integer::{RadixCiphertextBig, ServerKey};
|
||||
|
||||
pub fn has_match(
|
||||
sk: &ServerKey,
|
||||
content: &[RadixCiphertextBig],
|
||||
pattern: &str,
|
||||
) -> Result<RadixCiphertextBig, Box<dyn std::error::Error>> {
|
||||
let re = parse(pattern)?;
|
||||
|
||||
let branches: Vec<LazyExecution> = (0..content.len())
|
||||
.flat_map(|i| build_branches(content, &re, i))
|
||||
.map(|(lazy_branch_res, _)| lazy_branch_res)
|
||||
.collect();
|
||||
|
||||
let mut exec = Execution::new(sk.clone());
|
||||
|
||||
let res = if branches.len() <= 1 {
|
||||
branches
|
||||
.get(0)
|
||||
.map_or(exec.ct_false(), |branch| branch(&mut exec))
|
||||
.0
|
||||
} else {
|
||||
branches[1..]
|
||||
.iter()
|
||||
.fold(branches[0](&mut exec), |res, branch| {
|
||||
let branch_res = branch(&mut exec);
|
||||
exec.ct_or(res, branch_res)
|
||||
})
|
||||
.0
|
||||
};
|
||||
info!(
|
||||
"{} ciphertext operations, {} cache hits",
|
||||
exec.ct_operations_count(),
|
||||
exec.cache_hits(),
|
||||
);
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn build_branches(
|
||||
content: &[RadixCiphertextBig],
|
||||
re: &RegExpr,
|
||||
c_pos: usize,
|
||||
) -> Vec<(LazyExecution, usize)> {
|
||||
trace!("program pointer: regex={:?}, content pos={}", re, c_pos);
|
||||
match re {
|
||||
RegExpr::Sof => {
|
||||
if c_pos == 0 {
|
||||
return vec![(Rc::new(|exec| exec.ct_true()), c_pos)];
|
||||
} else {
|
||||
return vec![];
|
||||
}
|
||||
}
|
||||
RegExpr::Eof => {
|
||||
if c_pos == content.len() {
|
||||
return vec![(Rc::new(|exec| exec.ct_true()), c_pos)];
|
||||
} else {
|
||||
return vec![];
|
||||
}
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
if c_pos >= content.len() {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
match re.clone() {
|
||||
RegExpr::Char { c } => {
|
||||
let c_char = (content[c_pos].clone(), Executed::ct_pos(c_pos));
|
||||
vec![(
|
||||
Rc::new(move |exec| exec.ct_eq(c_char.clone(), exec.ct_constant(c))),
|
||||
c_pos + 1,
|
||||
)]
|
||||
}
|
||||
RegExpr::AnyChar => vec![(Rc::new(|exec| exec.ct_true()), c_pos + 1)],
|
||||
RegExpr::Not { not_re } => build_branches(content, ¬_re, c_pos)
|
||||
.into_iter()
|
||||
.map(|(branch, c_pos)| {
|
||||
(
|
||||
Rc::new(move |exec: &mut Execution| {
|
||||
let branch_res = branch(exec);
|
||||
exec.ct_not(branch_res)
|
||||
}) as LazyExecution,
|
||||
c_pos,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
RegExpr::Either { l_re, r_re } => {
|
||||
let mut res = build_branches(content, &l_re, c_pos);
|
||||
res.append(&mut build_branches(content, &r_re, c_pos));
|
||||
res
|
||||
}
|
||||
RegExpr::Between { from, to } => {
|
||||
let c_char = (content[c_pos].clone(), Executed::ct_pos(c_pos));
|
||||
vec![(
|
||||
Rc::new(move |exec| {
|
||||
let ct_from = exec.ct_constant(from);
|
||||
let ct_to = exec.ct_constant(to);
|
||||
let ge_from = exec.ct_ge(c_char.clone(), ct_from);
|
||||
let le_to = exec.ct_le(c_char.clone(), ct_to);
|
||||
exec.ct_and(ge_from, le_to)
|
||||
}),
|
||||
c_pos + 1,
|
||||
)]
|
||||
}
|
||||
RegExpr::Range { cs } => {
|
||||
let c_char = (content[c_pos].clone(), Executed::ct_pos(c_pos));
|
||||
vec![(
|
||||
Rc::new(move |exec| {
|
||||
cs[1..].iter().fold(
|
||||
exec.ct_eq(c_char.clone(), exec.ct_constant(cs[0])),
|
||||
|res, c| {
|
||||
let ct_c_char_eq = exec.ct_eq(c_char.clone(), exec.ct_constant(*c));
|
||||
exec.ct_or(res, ct_c_char_eq)
|
||||
},
|
||||
)
|
||||
}),
|
||||
c_pos + 1,
|
||||
)]
|
||||
}
|
||||
RegExpr::Repeated {
|
||||
repeat_re,
|
||||
at_least,
|
||||
at_most,
|
||||
} => {
|
||||
let at_least = at_least.unwrap_or(0);
|
||||
let at_most = at_most.unwrap_or(content.len() - c_pos);
|
||||
|
||||
if at_least > at_most {
|
||||
return vec![];
|
||||
}
|
||||
|
||||
let mut res = vec![
|
||||
if at_least == 0 {
|
||||
vec![(
|
||||
Rc::new(|exec: &mut Execution| exec.ct_true()) as LazyExecution,
|
||||
c_pos,
|
||||
)]
|
||||
} else {
|
||||
vec![]
|
||||
},
|
||||
build_branches(
|
||||
content,
|
||||
&(RegExpr::Seq {
|
||||
re_xs: std::iter::repeat(*repeat_re.clone())
|
||||
.take(std::cmp::max(1, at_least))
|
||||
.collect(),
|
||||
}),
|
||||
c_pos,
|
||||
),
|
||||
];
|
||||
|
||||
for _ in (at_least + 1)..(at_most + 1) {
|
||||
res.push(
|
||||
res.last()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.flat_map(|(branch_prev, branch_c_pos)| {
|
||||
build_branches(content, &repeat_re, *branch_c_pos)
|
||||
.into_iter()
|
||||
.map(move |(branch_x, branch_x_c_pos)| {
|
||||
let branch_prev = branch_prev.clone();
|
||||
(
|
||||
Rc::new(move |exec: &mut Execution| {
|
||||
let res_prev = branch_prev(exec);
|
||||
let res_x = branch_x(exec);
|
||||
exec.ct_and(res_prev, res_x)
|
||||
}) as LazyExecution,
|
||||
branch_x_c_pos,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
res.into_iter().flatten().collect()
|
||||
}
|
||||
RegExpr::Optional { opt_re } => {
|
||||
let mut res = build_branches(content, &opt_re, c_pos);
|
||||
res.push((Rc::new(|exec| exec.ct_true()), c_pos));
|
||||
res
|
||||
}
|
||||
RegExpr::Seq { re_xs } => re_xs[1..].iter().fold(
|
||||
build_branches(content, &re_xs[0], c_pos),
|
||||
|continuations, re_x| {
|
||||
continuations
|
||||
.into_iter()
|
||||
.flat_map(|(branch_prev, branch_prev_c_pos)| {
|
||||
build_branches(content, re_x, branch_prev_c_pos)
|
||||
.into_iter()
|
||||
.map(move |(branch_x, branch_x_c_pos)| {
|
||||
let branch_prev = branch_prev.clone();
|
||||
(
|
||||
Rc::new(move |exec: &mut Execution| {
|
||||
let res_prev = branch_prev(exec);
|
||||
let res_x = branch_x(exec);
|
||||
exec.ct_and(res_prev, res_x)
|
||||
}) as LazyExecution,
|
||||
branch_x_c_pos,
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
},
|
||||
),
|
||||
_ => panic!("unmatched regex variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::engine::has_match;
|
||||
use test_case::test_case;
|
||||
|
||||
use crate::ciphertext::{encrypt_str, gen_keys, StringCiphertext};
|
||||
use lazy_static::lazy_static;
|
||||
use tfhe::integer::{RadixClientKey, ServerKey};
|
||||
|
||||
lazy_static! {
|
||||
pub static ref KEYS: (RadixClientKey, ServerKey) = gen_keys();
|
||||
}
|
||||
|
||||
#[test_case("ab", "/ab/", 1)]
|
||||
#[test_case("b", "/ab/", 0)]
|
||||
#[test_case("ab", "/a?b/", 1)]
|
||||
#[test_case("b", "/a?b/", 1)]
|
||||
#[test_case("ab", "/^ab|cd$/", 1)]
|
||||
#[test_case(" ab", "/^ab|cd$/", 0)]
|
||||
#[test_case(" cd", "/^ab|cd$/", 0)]
|
||||
#[test_case("cd", "/^ab|cd$/", 1)]
|
||||
#[test_case("abcd", "/^ab|cd$/", 0)]
|
||||
#[test_case("abcd", "/ab|cd$/", 1)]
|
||||
#[test_case("abc", "/abc/", 1)]
|
||||
#[test_case("123abc", "/abc/", 1)]
|
||||
#[test_case("123abc456", "/abc/", 1)]
|
||||
#[test_case("123abdc456", "/abc/", 0)]
|
||||
#[test_case("abc456", "/abc/", 1)]
|
||||
#[test_case("bc", "/a*bc/", 1)]
|
||||
#[test_case("cdaabc", "/a*bc/", 1)]
|
||||
#[test_case("cdbc", "/a+bc/", 0)]
|
||||
#[test_case("bc", "/a+bc/", 0)]
|
||||
#[test_case("Ab", "/ab/i", 1 ; "ab case insensitive")]
|
||||
#[test_case("Ab", "/ab/", 0 ; "ab case sensitive")]
|
||||
#[test_case("cD", "/ab|cd/i", 1)]
|
||||
#[test_case("cD", "/cD/", 1)]
|
||||
#[test_case("test a num 8", "/8/", 1)]
|
||||
#[test_case("test a num 8", "/^8/", 0)]
|
||||
#[test_case("4453", "/^[0-9]*$/", 1)]
|
||||
#[test_case("4453", "/^[09]*$/", 0)]
|
||||
#[test_case("09009", "/^[09]*$/", 1)]
|
||||
#[test_case("de", "/^ab|cd|de$/", 1 ; "multiple or")]
|
||||
#[test_case(" de", "/^ab|cd|de$/", 0 ; "multiple or nests below ^")]
|
||||
fn test_has_match(content: &str, pattern: &str, exp: u64) {
|
||||
let ct_content: StringCiphertext = encrypt_str(&KEYS.0, content).unwrap();
|
||||
let ct_res = has_match(&KEYS.1, &ct_content, pattern).unwrap();
|
||||
|
||||
let got = KEYS.0.decrypt(&ct_res);
|
||||
assert_eq!(exp, got);
|
||||
}
|
||||
}
|
||||
272
tfhe/examples/regex_engine/execution.rs
Normal file
272
tfhe/examples/regex_engine/execution.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use tfhe::integer::{RadixCiphertextBig, ServerKey};
|
||||
|
||||
use crate::parser::u8_to_char;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash)]
|
||||
pub(crate) enum Executed {
|
||||
Constant { c: u8 },
|
||||
CtPos { at: usize },
|
||||
And { a: Box<Executed>, b: Box<Executed> },
|
||||
Or { a: Box<Executed>, b: Box<Executed> },
|
||||
Equal { a: Box<Executed>, b: Box<Executed> },
|
||||
GreaterOrEqual { a: Box<Executed>, b: Box<Executed> },
|
||||
LessOrEqual { a: Box<Executed>, b: Box<Executed> },
|
||||
Not { a: Box<Executed> },
|
||||
}
|
||||
type ExecutedResult = (RadixCiphertextBig, Executed);
|
||||
|
||||
impl Executed {
|
||||
pub(crate) fn ct_pos(at: usize) -> Self {
|
||||
Executed::CtPos { at }
|
||||
}
|
||||
|
||||
fn get_trivial_constant(&self) -> Option<u8> {
|
||||
match self {
|
||||
Self::Constant { c } => Some(*c),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const CT_FALSE: u8 = 0;
|
||||
const CT_TRUE: u8 = 1;
|
||||
|
||||
pub(crate) struct Execution {
|
||||
sk: ServerKey,
|
||||
cache: HashMap<Executed, RadixCiphertextBig>,
|
||||
|
||||
ct_ops: usize,
|
||||
cache_hits: usize,
|
||||
}
|
||||
pub(crate) type LazyExecution = Rc<dyn Fn(&mut Execution) -> ExecutedResult>;
|
||||
|
||||
impl Execution {
|
||||
pub(crate) fn new(sk: ServerKey) -> Self {
|
||||
Self {
|
||||
sk,
|
||||
cache: HashMap::new(),
|
||||
ct_ops: 0,
|
||||
cache_hits: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn ct_operations_count(&self) -> usize {
|
||||
self.ct_ops
|
||||
}
|
||||
|
||||
pub(crate) fn cache_hits(&self) -> usize {
|
||||
self.cache_hits
|
||||
}
|
||||
|
||||
pub(crate) fn ct_eq(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
||||
let ctx = Executed::Equal {
|
||||
a: Box::new(a.1.clone()),
|
||||
b: Box::new(b.1.clone()),
|
||||
};
|
||||
self.with_cache(
|
||||
ctx.clone(),
|
||||
Rc::new(move |exec: &mut Execution| {
|
||||
exec.ct_ops += 1;
|
||||
|
||||
let mut ct_a = a.0.clone();
|
||||
let mut ct_b = b.0.clone();
|
||||
(exec.sk.smart_eq(&mut ct_a, &mut ct_b), ctx.clone())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_ge(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
||||
let ctx = Executed::GreaterOrEqual {
|
||||
a: Box::new(a.1.clone()),
|
||||
b: Box::new(b.1.clone()),
|
||||
};
|
||||
self.with_cache(
|
||||
ctx.clone(),
|
||||
Rc::new(move |exec| {
|
||||
exec.ct_ops += 1;
|
||||
|
||||
let mut ct_a = a.0.clone();
|
||||
let mut ct_b = b.0.clone();
|
||||
(exec.sk.smart_gt(&mut ct_a, &mut ct_b), ctx.clone())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_le(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
||||
let ctx = Executed::LessOrEqual {
|
||||
a: Box::new(a.1.clone()),
|
||||
b: Box::new(b.1.clone()),
|
||||
};
|
||||
self.with_cache(
|
||||
ctx.clone(),
|
||||
Rc::new(move |exec| {
|
||||
exec.ct_ops += 1;
|
||||
|
||||
let mut ct_a = a.0.clone();
|
||||
let mut ct_b = b.0.clone();
|
||||
(exec.sk.smart_le(&mut ct_a, &mut ct_b), ctx.clone())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_and(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
||||
let ctx = Executed::And {
|
||||
a: Box::new(a.1.clone()),
|
||||
b: Box::new(b.1.clone()),
|
||||
};
|
||||
|
||||
let c_a = a.1.get_trivial_constant();
|
||||
let c_b = b.1.get_trivial_constant();
|
||||
if c_a == Some(CT_TRUE) {
|
||||
return (b.0, ctx);
|
||||
}
|
||||
if c_a == Some(CT_FALSE) {
|
||||
return (a.0, ctx);
|
||||
}
|
||||
if c_b == Some(CT_TRUE) {
|
||||
return (a.0, ctx);
|
||||
}
|
||||
if c_b == Some(CT_FALSE) {
|
||||
return (b.0, ctx);
|
||||
}
|
||||
|
||||
self.with_cache(
|
||||
ctx.clone(),
|
||||
Rc::new(move |exec| {
|
||||
exec.ct_ops += 1;
|
||||
|
||||
let mut ct_a = a.0.clone();
|
||||
let mut ct_b = b.0.clone();
|
||||
(exec.sk.smart_bitand(&mut ct_a, &mut ct_b), ctx.clone())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_or(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
|
||||
let ctx = Executed::Or {
|
||||
a: Box::new(a.1.clone()),
|
||||
b: Box::new(b.1.clone()),
|
||||
};
|
||||
|
||||
let c_a = a.1.get_trivial_constant();
|
||||
let c_b = b.1.get_trivial_constant();
|
||||
if c_a == Some(CT_TRUE) {
|
||||
return (a.0, ctx);
|
||||
}
|
||||
if c_b == Some(CT_TRUE) {
|
||||
return (b.0, ctx);
|
||||
}
|
||||
if c_a == Some(CT_FALSE) && c_b == Some(CT_FALSE) {
|
||||
return (a.0, ctx);
|
||||
}
|
||||
|
||||
self.with_cache(
|
||||
ctx.clone(),
|
||||
Rc::new(move |exec| {
|
||||
exec.ct_ops += 1;
|
||||
|
||||
let mut ct_a = a.0.clone();
|
||||
let mut ct_b = b.0.clone();
|
||||
(exec.sk.smart_bitor(&mut ct_a, &mut ct_b), ctx.clone())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_not(&mut self, a: ExecutedResult) -> ExecutedResult {
|
||||
let ctx = Executed::Not {
|
||||
a: Box::new(a.1.clone()),
|
||||
};
|
||||
self.with_cache(
|
||||
ctx.clone(),
|
||||
Rc::new(move |exec| {
|
||||
exec.ct_ops += 1;
|
||||
|
||||
let mut ct_a = a.0.clone();
|
||||
let mut ct_b = exec.ct_constant(1).0;
|
||||
(exec.sk.smart_bitxor(&mut ct_a, &mut ct_b), ctx.clone())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_false(&self) -> ExecutedResult {
|
||||
self.ct_constant(CT_FALSE)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_true(&self) -> ExecutedResult {
|
||||
self.ct_constant(CT_TRUE)
|
||||
}
|
||||
|
||||
pub(crate) fn ct_constant(&self, c: u8) -> ExecutedResult {
|
||||
(
|
||||
self.sk.create_trivial_radix(c as u64, 4),
|
||||
Executed::Constant { c },
|
||||
)
|
||||
}
|
||||
|
||||
fn with_cache(&mut self, ctx: Executed, f: LazyExecution) -> ExecutedResult {
|
||||
if let Some(res) = self.cache.get(&ctx) {
|
||||
trace!("cache hit: {:?}", &ctx);
|
||||
self.cache_hits += 1;
|
||||
return (res.clone(), ctx);
|
||||
}
|
||||
debug!("evaluation for: {:?}", &ctx);
|
||||
let res = f(self);
|
||||
self.cache.insert(ctx, res.0.clone());
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Executed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Constant { c } => match c {
|
||||
0 => write!(f, "f"),
|
||||
1 => write!(f, "t"),
|
||||
_ => write!(f, "{}", u8_to_char(*c)),
|
||||
},
|
||||
Self::CtPos { at } => write!(f, "ct_{}", at),
|
||||
Self::And { a, b } => {
|
||||
write!(f, "(")?;
|
||||
a.fmt(f)?;
|
||||
write!(f, "/\\")?;
|
||||
b.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Self::Or { a, b } => {
|
||||
write!(f, "(")?;
|
||||
a.fmt(f)?;
|
||||
write!(f, "\\/")?;
|
||||
b.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Self::Equal { a, b } => {
|
||||
write!(f, "(")?;
|
||||
a.fmt(f)?;
|
||||
write!(f, "==")?;
|
||||
b.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Self::GreaterOrEqual { a, b } => {
|
||||
write!(f, "(")?;
|
||||
a.fmt(f)?;
|
||||
write!(f, ">=")?;
|
||||
b.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Self::LessOrEqual { a, b } => {
|
||||
write!(f, "(")?;
|
||||
a.fmt(f)?;
|
||||
write!(f, "<=")?;
|
||||
b.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Self::Not { a } => {
|
||||
write!(f, "(!")?;
|
||||
a.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
30
tfhe/examples/regex_engine/main.rs
Normal file
30
tfhe/examples/regex_engine/main.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
mod ciphertext;
|
||||
mod engine;
|
||||
mod execution;
|
||||
mod parser;
|
||||
|
||||
use env_logger::Env;
|
||||
use std::env;
|
||||
|
||||
fn main() {
|
||||
let env = Env::default().filter_or("RUST_LOG", "info");
|
||||
env_logger::init_from_env(env);
|
||||
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let content = &args[1];
|
||||
let pattern = &args[2];
|
||||
|
||||
let (client_key, server_key) = ciphertext::gen_keys();
|
||||
let ct_content = ciphertext::encrypt_str(&client_key, content).unwrap();
|
||||
|
||||
let ct_res = engine::has_match(&server_key, &ct_content, pattern).unwrap();
|
||||
let res: u64 = client_key.decrypt(&ct_res);
|
||||
if res == 0 {
|
||||
println!("no match");
|
||||
} else {
|
||||
println!("match");
|
||||
}
|
||||
}
|
||||
701
tfhe/examples/regex_engine/parser.rs
Normal file
701
tfhe/examples/regex_engine/parser.rs
Normal file
@@ -0,0 +1,701 @@
|
||||
use combine::parser::byte;
|
||||
use combine::parser::byte::byte;
|
||||
use combine::*;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash)]
|
||||
pub(crate) enum RegExpr {
|
||||
Sof,
|
||||
Eof,
|
||||
Char {
|
||||
c: u8,
|
||||
},
|
||||
AnyChar,
|
||||
Between {
|
||||
from: u8,
|
||||
to: u8,
|
||||
},
|
||||
Range {
|
||||
cs: Vec<u8>,
|
||||
},
|
||||
Not {
|
||||
not_re: Box<RegExpr>,
|
||||
},
|
||||
Either {
|
||||
l_re: Box<RegExpr>,
|
||||
r_re: Box<RegExpr>,
|
||||
},
|
||||
Optional {
|
||||
opt_re: Box<RegExpr>,
|
||||
},
|
||||
Repeated {
|
||||
repeat_re: Box<RegExpr>,
|
||||
at_least: Option<usize>, // if None: no least limit, aka 0 times
|
||||
at_most: Option<usize>, // if None: no most limit
|
||||
},
|
||||
Seq {
|
||||
re_xs: Vec<RegExpr>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RegExpr {
|
||||
fn case_insensitive(self) -> Self {
|
||||
match self {
|
||||
Self::Char { c } => Self::Range {
|
||||
cs: case_insensitive(c),
|
||||
},
|
||||
Self::Not { not_re } => Self::Not {
|
||||
not_re: Box::new(not_re.case_insensitive()),
|
||||
},
|
||||
Self::Either { l_re, r_re } => Self::Either {
|
||||
l_re: Box::new(l_re.case_insensitive()),
|
||||
r_re: Box::new(r_re.case_insensitive()),
|
||||
},
|
||||
Self::Optional { opt_re } => Self::Optional {
|
||||
opt_re: Box::new(opt_re.case_insensitive()),
|
||||
},
|
||||
Self::Repeated {
|
||||
repeat_re,
|
||||
at_least,
|
||||
at_most,
|
||||
} => Self::Repeated {
|
||||
repeat_re: Box::new(repeat_re.case_insensitive()),
|
||||
at_least,
|
||||
at_most,
|
||||
},
|
||||
Self::Seq { re_xs } => Self::Seq {
|
||||
re_xs: re_xs.into_iter().map(|re| re.case_insensitive()).collect(),
|
||||
},
|
||||
_ => self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn case_insensitive(x: u8) -> Vec<u8> {
|
||||
let c = u8_to_char(x);
|
||||
if c.is_ascii_lowercase() {
|
||||
return vec![x, c.to_ascii_uppercase() as u8];
|
||||
}
|
||||
if c.is_ascii_uppercase() {
|
||||
return vec![x, c.to_ascii_lowercase() as u8];
|
||||
}
|
||||
vec![x]
|
||||
}
|
||||
|
||||
pub(crate) fn u8_to_char(c: u8) -> char {
|
||||
char::from_u32(c as u32).unwrap()
|
||||
}
|
||||
|
||||
impl fmt::Debug for RegExpr {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::Sof => write!(f, "^"),
|
||||
Self::Eof => write!(f, "$"),
|
||||
Self::Char { c } => write!(f, "{}", u8_to_char(*c)),
|
||||
Self::AnyChar => write!(f, "."),
|
||||
Self::Not { not_re } => {
|
||||
write!(f, "[^")?;
|
||||
not_re.fmt(f)?;
|
||||
write!(f, "]")
|
||||
}
|
||||
Self::Between { from, to } => {
|
||||
write!(f, "[{}->{}]", u8_to_char(*from), u8_to_char(*to),)
|
||||
}
|
||||
Self::Range { cs } => write!(
|
||||
f,
|
||||
"[{}]",
|
||||
cs.iter().map(|c| u8_to_char(*c)).collect::<String>(),
|
||||
),
|
||||
Self::Either { l_re, r_re } => {
|
||||
write!(f, "(")?;
|
||||
l_re.fmt(f)?;
|
||||
write!(f, "|")?;
|
||||
r_re.fmt(f)?;
|
||||
write!(f, ")")
|
||||
}
|
||||
Self::Repeated {
|
||||
repeat_re,
|
||||
at_least,
|
||||
at_most,
|
||||
} => {
|
||||
let stringify_opt_n = |opt_n: &Option<usize>| -> String {
|
||||
opt_n.map_or("*".to_string(), |n| format!("{:?}", n))
|
||||
};
|
||||
repeat_re.fmt(f)?;
|
||||
write!(
|
||||
f,
|
||||
"{{{},{}}}",
|
||||
stringify_opt_n(at_least),
|
||||
stringify_opt_n(at_most)
|
||||
)
|
||||
}
|
||||
Self::Optional { opt_re } => {
|
||||
opt_re.fmt(f)?;
|
||||
write!(f, "?")
|
||||
}
|
||||
Self::Seq { re_xs } => {
|
||||
write!(f, "<")?;
|
||||
for re_x in re_xs {
|
||||
re_x.fmt(f)?;
|
||||
}
|
||||
write!(f, ">")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse(pattern: &str) -> Result<RegExpr, Box<dyn std::error::Error>> {
|
||||
let (parsed, unparsed) = (
|
||||
between(
|
||||
byte(b'/'),
|
||||
byte(b'/'),
|
||||
(optional(byte(b'^')), regex(), optional(byte(b'$'))),
|
||||
)
|
||||
.map(|(sof, re, eof)| {
|
||||
if sof.is_none() && eof.is_none() {
|
||||
return re;
|
||||
}
|
||||
let mut re_xs = vec![];
|
||||
if sof.is_some() {
|
||||
re_xs.push(RegExpr::Sof);
|
||||
}
|
||||
re_xs.push(re);
|
||||
if eof.is_some() {
|
||||
re_xs.push(RegExpr::Eof);
|
||||
}
|
||||
RegExpr::Seq { re_xs }
|
||||
}),
|
||||
optional(byte(b'i')),
|
||||
)
|
||||
.map(|(re, case_insensitive)| {
|
||||
if case_insensitive.is_some() {
|
||||
re.case_insensitive()
|
||||
} else {
|
||||
re
|
||||
}
|
||||
})
|
||||
.parse(pattern.as_bytes())?;
|
||||
if !unparsed.is_empty() {
|
||||
return Err(format!(
|
||||
"failed to parse regular expression, unexpected token at start of: {}",
|
||||
std::str::from_utf8(unparsed).unwrap()
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
|
||||
// based on grammar from: https://matt.might.net/articles/parsing-regex-with-recursive-descent/
|
||||
//
|
||||
// <regex> ::= <term> '|' <regex>
|
||||
// | <term>
|
||||
//
|
||||
// <term> ::= { <factor> }
|
||||
//
|
||||
// <factor> ::= <base> { '*' }
|
||||
//
|
||||
// <base> ::= <char>
|
||||
// | '\' <char>
|
||||
// | '(' <regex> ')'
|
||||
|
||||
parser! {
|
||||
fn regex[Input]()(Input) -> RegExpr
|
||||
where [Input: Stream<Token = u8>]
|
||||
{
|
||||
regex_()
|
||||
}
|
||||
}
|
||||
|
||||
fn regex_<Input>() -> impl Parser<Input, Output = RegExpr>
|
||||
where
|
||||
Input: Stream<Token = u8>,
|
||||
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
|
||||
{
|
||||
choice((
|
||||
attempt(
|
||||
(term(), byte(b'|'), regex()).map(|(l_re, _, r_re)| RegExpr::Either {
|
||||
l_re: Box::new(l_re),
|
||||
r_re: Box::new(r_re),
|
||||
}),
|
||||
),
|
||||
term(),
|
||||
))
|
||||
}
|
||||
|
||||
fn term<Input>() -> impl Parser<Input, Output = RegExpr>
|
||||
where
|
||||
Input: Stream<Token = u8>,
|
||||
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
|
||||
{
|
||||
many(factor()).map(|re_xs: Vec<RegExpr>| {
|
||||
if re_xs.len() == 1 {
|
||||
re_xs[0].clone()
|
||||
} else {
|
||||
RegExpr::Seq { re_xs }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn factor<Input>() -> impl Parser<Input, Output = RegExpr>
|
||||
where
|
||||
Input: Stream<Token = u8>,
|
||||
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
|
||||
{
|
||||
choice((
|
||||
attempt((atom(), byte(b'?'))).map(|(re, _)| RegExpr::Optional {
|
||||
opt_re: Box::new(re),
|
||||
}),
|
||||
attempt(repeated()),
|
||||
atom(),
|
||||
))
|
||||
}
|
||||
|
||||
const NON_ESCAPABLE_SYMBOLS: [u8; 14] = [
|
||||
b'&', b';', b':', b',', b'`', b'~', b'-', b'_', b'!', b'@', b'#', b'%', b'\'', b'\"',
|
||||
];
|
||||
|
||||
fn atom<Input>() -> impl Parser<Input, Output = RegExpr>
|
||||
where
|
||||
Input: Stream<Token = u8>,
|
||||
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
|
||||
{
|
||||
choice((
|
||||
byte(b'.').map(|_| RegExpr::AnyChar),
|
||||
attempt(byte(b'\\').with(parser::token::any())).map(|c| RegExpr::Char { c }),
|
||||
choice((
|
||||
byte::alpha_num(),
|
||||
parser::token::one_of(NON_ESCAPABLE_SYMBOLS),
|
||||
))
|
||||
.map(|c| RegExpr::Char { c }),
|
||||
between(byte(b'['), byte(b']'), range()),
|
||||
between(byte(b'('), byte(b')'), regex()),
|
||||
))
|
||||
}
|
||||
|
||||
parser! {
|
||||
fn range[Input]()(Input) -> RegExpr
|
||||
where [Input: Stream<Token = u8>]
|
||||
{
|
||||
range_()
|
||||
}
|
||||
}
|
||||
|
||||
fn range_<Input>() -> impl Parser<Input, Output = RegExpr>
|
||||
where
|
||||
Input: Stream<Token = u8>,
|
||||
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
|
||||
{
|
||||
choice((
|
||||
byte(b'^').with(range()).map(|re| RegExpr::Not {
|
||||
not_re: Box::new(re),
|
||||
}),
|
||||
attempt(
|
||||
(byte::alpha_num(), byte(b'-'), byte::alpha_num())
|
||||
.map(|(from, _, to)| RegExpr::Between { from, to }),
|
||||
),
|
||||
many1(byte::alpha_num()).map(|cs| RegExpr::Range { cs }),
|
||||
))
|
||||
}
|
||||
|
||||
fn repeated<Input>() -> impl Parser<Input, Output = RegExpr>
|
||||
where
|
||||
Input: Stream<Token = u8>,
|
||||
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
|
||||
{
|
||||
choice((
|
||||
attempt((atom(), choice((byte(b'*'), byte(b'+'))))).map(|(re, c)| RegExpr::Repeated {
|
||||
repeat_re: Box::new(re),
|
||||
at_least: if c == b'*' { None } else { Some(1) },
|
||||
at_most: None,
|
||||
}),
|
||||
attempt((
|
||||
atom(),
|
||||
between(byte(b'{'), byte(b'}'), many::<Vec<u8>, _, _>(byte::digit())),
|
||||
))
|
||||
.map(|(re, repeat_digits)| {
|
||||
let repeat = parse_digits(&repeat_digits);
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(re),
|
||||
at_least: Some(repeat),
|
||||
at_most: Some(repeat),
|
||||
}
|
||||
}),
|
||||
(
|
||||
atom(),
|
||||
between(
|
||||
byte(b'{'),
|
||||
byte(b'}'),
|
||||
(
|
||||
many::<Vec<u8>, _, _>(byte::digit()),
|
||||
byte(b','),
|
||||
many::<Vec<u8>, _, _>(byte::digit()),
|
||||
),
|
||||
),
|
||||
)
|
||||
.map(
|
||||
|(re, (at_least_digits, _, at_most_digits))| RegExpr::Repeated {
|
||||
repeat_re: Box::new(re),
|
||||
at_least: if at_least_digits.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(parse_digits(&at_least_digits))
|
||||
},
|
||||
at_most: if at_most_digits.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(parse_digits(&at_most_digits))
|
||||
},
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_digits(digits: &[u8]) -> usize {
|
||||
std::str::from_utf8(digits).unwrap().parse().unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::parser::{parse, RegExpr};
|
||||
use test_case::test_case;
|
||||
|
||||
#[test_case("/h/", RegExpr::Char { c: b'h' }; "char")]
|
||||
#[test_case("/&/", RegExpr::Char { c: b'&' }; "not necessary to escape ampersand")]
|
||||
#[test_case("/;/", RegExpr::Char { c: b';' }; "not necessary to escape semicolon")]
|
||||
#[test_case("/:/", RegExpr::Char { c: b':' }; "not necessary to escape colon")]
|
||||
#[test_case("/,/", RegExpr::Char { c: b',' }; "not necessary to escape comma")]
|
||||
#[test_case("/`/", RegExpr::Char { c: b'`' }; "not necessary to escape backtick")]
|
||||
#[test_case("/~/", RegExpr::Char { c: b'~' }; "not necessary to escape tilde")]
|
||||
#[test_case("/-/", RegExpr::Char { c: b'-' }; "not necessary to escape minus")]
|
||||
#[test_case("/_/", RegExpr::Char { c: b'_' }; "not necessary to escape underscore")]
|
||||
#[test_case("/%/", RegExpr::Char { c: b'%' }; "not necessary to escape percentage")]
|
||||
#[test_case("/#/", RegExpr::Char { c: b'#' }; "not necessary to escape hashtag")]
|
||||
#[test_case("/@/", RegExpr::Char { c: b'@' }; "not necessary to escape at")]
|
||||
#[test_case("/!/", RegExpr::Char { c: b'!' }; "not necessary to escape exclamation")]
|
||||
#[test_case("/'/", RegExpr::Char { c: b'\'' }; "not necessary to escape single quote")]
|
||||
#[test_case("/\"/", RegExpr::Char { c: b'\"' }; "not necessary to escape double quote")]
|
||||
#[test_case("/\\h/", RegExpr::Char { c: b'h' }; "anything can be escaped")]
|
||||
#[test_case("/./", RegExpr::AnyChar; "any")]
|
||||
#[test_case("/abc/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
RegExpr::Char { c: b'c' },
|
||||
]};
|
||||
"abc")]
|
||||
#[test_case("/^abc/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
]};
|
||||
"<sof>abc")]
|
||||
#[test_case("/abc$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"abc<eof>")]
|
||||
#[test_case("/^abc$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>abc<eof>")]
|
||||
#[test_case("/^ab?c$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Optional { opt_re: Box::new(RegExpr::Char { c: b'b' }) },
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>ab<question>c<eof>")]
|
||||
#[test_case("/^ab*c$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
at_least: None,
|
||||
at_most: None,
|
||||
},
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>ab<star>c<eof>")]
|
||||
#[test_case("/^ab+c$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
at_least: Some(1),
|
||||
at_most: None,
|
||||
},
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>ab<plus>c<eof>")]
|
||||
#[test_case("/^ab{2}c$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
at_least: Some(2),
|
||||
at_most: Some(2),
|
||||
},
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>ab<twice>c<eof>")]
|
||||
#[test_case("/^ab{3,}c$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
at_least: Some(3),
|
||||
at_most: None,
|
||||
},
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>ab<atleast 3>c<eof>")]
|
||||
#[test_case("/^ab{2,4}c$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
at_least: Some(2),
|
||||
at_most: Some(4),
|
||||
},
|
||||
RegExpr::Char { c: b'c' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>ab<between 2 and 4>c<eof>")]
|
||||
#[test_case("/^.$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::AnyChar,
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof><any><eof>")]
|
||||
#[test_case("/^[abc]$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Range { cs: vec![b'a', b'b', b'c'] },
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof><a or b or c><eof>")]
|
||||
#[test_case("/^[a-d]$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Between { from: b'a', to: b'd' },
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof><between a and d><eof>")]
|
||||
#[test_case("/^[^abc]$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Not { not_re: Box::new(RegExpr::Range { cs: vec![b'a', b'b', b'c'] })},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof><not <a or b or c>><eof>")]
|
||||
#[test_case("/^[^a-d]$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Not { not_re: Box::new(RegExpr::Between { from: b'a', to: b'd' }) },
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof><not <between a and d>><eof>")]
|
||||
#[test_case("/^abc$/i",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Range { cs: vec![b'a', b'A'] },
|
||||
RegExpr::Range { cs: vec![b'b', b'B'] },
|
||||
RegExpr::Range { cs: vec![b'c', b'C'] },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"<sof>abc<eof> (case insensitive)")]
|
||||
#[test_case("/^/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq { re_xs: vec![] }
|
||||
]};
|
||||
"sof")]
|
||||
#[test_case("/$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Seq { re_xs: vec![] },
|
||||
RegExpr::Eof
|
||||
]};
|
||||
"eof")]
|
||||
#[test_case("/a*/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
at_least: None,
|
||||
at_most: None,
|
||||
};
|
||||
"repeat unbounded (w/ *)")]
|
||||
#[test_case("/a+/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
at_least: Some(1),
|
||||
at_most: None,
|
||||
};
|
||||
"repeat bounded at least (w/ +)")]
|
||||
#[test_case("/a{104,}/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
at_least: Some(104),
|
||||
at_most: None,
|
||||
};
|
||||
"repeat bounded at least (w/ {x,}")]
|
||||
#[test_case("/a{,15}/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
at_least: None,
|
||||
at_most: Some(15),
|
||||
};
|
||||
"repeat bounded at most (w/ {,x}")]
|
||||
#[test_case("/a{12,15}/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
at_least: Some(12),
|
||||
at_most: Some(15),
|
||||
};
|
||||
"repeat bounded at least and at most (w/ {x,y}")]
|
||||
#[test_case("/(a|b)*/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Either {
|
||||
l_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
r_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
}),
|
||||
at_least: None,
|
||||
at_most: None,
|
||||
};
|
||||
"repeat complex unbounded")]
|
||||
#[test_case("/(a|b){3,7}/",
|
||||
RegExpr::Repeated {
|
||||
repeat_re: Box::new(RegExpr::Either {
|
||||
l_re: Box::new(RegExpr::Char { c: b'a' }),
|
||||
r_re: Box::new(RegExpr::Char { c: b'b' }),
|
||||
}),
|
||||
at_least: Some(3),
|
||||
at_most: Some(7),
|
||||
};
|
||||
"repeat complex bounded")]
|
||||
#[test_case("/^ab|cd/",
|
||||
RegExpr::Seq { re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Either {
|
||||
l_re: Box::new(RegExpr::Seq { re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
] }),
|
||||
r_re: Box::new(RegExpr::Seq { re_xs: vec![
|
||||
RegExpr::Char { c: b'c' },
|
||||
RegExpr::Char { c: b'd' },
|
||||
]}),
|
||||
},
|
||||
]};
|
||||
"Sof encapsulates full RHS")]
|
||||
#[test_case("/ab|cd$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Either {
|
||||
l_re: Box::new(RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
]}),
|
||||
r_re: Box::new(RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'c' },
|
||||
RegExpr::Char { c: b'd' },
|
||||
]}),
|
||||
},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"Eof encapsulates full RHS" )]
|
||||
#[test_case("/^ab|cd$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Either {
|
||||
l_re: Box::new(RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
]}),
|
||||
r_re: Box::new(RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'c' },
|
||||
RegExpr::Char { c: b'd' },
|
||||
]}),
|
||||
},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"Sof + Eof both encapsulate full center")]
|
||||
#[test_case("/\\^/",
|
||||
RegExpr::Char { c: b'^' };
|
||||
"escaping sof symbol")]
|
||||
#[test_case("/\\./",
|
||||
RegExpr::Char { c: b'.' };
|
||||
"escaping period symbol")]
|
||||
#[test_case("/\\*/",
|
||||
RegExpr::Char { c: b'*' };
|
||||
"escaping star symbol")]
|
||||
#[test_case("/^ca\\^b$/",
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Sof,
|
||||
RegExpr::Seq {re_xs: vec![
|
||||
RegExpr::Char { c: b'c' },
|
||||
RegExpr::Char { c: b'a' },
|
||||
RegExpr::Char { c: b'^' },
|
||||
RegExpr::Char { c: b'b' },
|
||||
]},
|
||||
RegExpr::Eof,
|
||||
]};
|
||||
"escaping, more realistic")]
|
||||
#[test_case("/8/",
|
||||
RegExpr::Char { c: b'8' };
|
||||
"able to match numbers")]
|
||||
#[test_case("/[7-9]/",
|
||||
RegExpr::Between { from: b'7', to: b'9' };
|
||||
"able to match a number range")]
|
||||
#[test_case("/[79]/",
|
||||
RegExpr::Range { cs: vec![b'7', b'9'] };
|
||||
"able to match a number range (part 2)")]
|
||||
fn test_parser(pattern: &str, exp: RegExpr) {
|
||||
match parse(pattern) {
|
||||
Ok(got) => assert_eq!(exp, got),
|
||||
Err(e) => panic!("got err: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
469
tfhe/examples/sha256_bool/boolean_ops.rs
Normal file
469
tfhe/examples/sha256_bool/boolean_ops.rs
Normal file
@@ -0,0 +1,469 @@
|
||||
// This module contains all the operations and functions used in the sha256 function, implemented
|
||||
// with homomorphic boolean operations. Both the bitwise operations, which serve as the building
|
||||
// blocks for other functions, and the adders employ parallel processing techniques.
|
||||
|
||||
use rayon::prelude::*;
|
||||
use std::array;
|
||||
use tfhe::boolean::prelude::{BinaryBooleanGates, Ciphertext, ServerKey};
|
||||
|
||||
// Implementation of a Carry Save Adder, which computes sum and carry sequences very efficiently. We
|
||||
// then add the final sum and carry values to obtain the result. CSAs are useful to speed up
|
||||
// sequential additions
|
||||
pub fn csa(
|
||||
a: &[Ciphertext; 32],
|
||||
b: &[Ciphertext; 32],
|
||||
c: &[Ciphertext; 32],
|
||||
sk: &ServerKey,
|
||||
) -> ([Ciphertext; 32], [Ciphertext; 32]) {
|
||||
let (carry, sum) = rayon::join(|| maj(a, b, c, sk), || xor(a, &xor(b, c, sk), sk));
|
||||
|
||||
// perform a left shift by one to discard the carry-out and set the carry-in to 0
|
||||
let mut shifted_carry = trivial_bools(&[false; 32], sk);
|
||||
for (i, elem) in carry.into_iter().enumerate() {
|
||||
if i == 0 {
|
||||
continue;
|
||||
} else {
|
||||
shifted_carry[i - 1] = elem;
|
||||
}
|
||||
}
|
||||
|
||||
(sum, shifted_carry)
|
||||
}
|
||||
|
||||
pub fn add(
|
||||
a: &[Ciphertext; 32],
|
||||
b: &[Ciphertext; 32],
|
||||
ladner_fischer_opt: bool,
|
||||
sk: &ServerKey,
|
||||
) -> [Ciphertext; 32] {
|
||||
let (propagate, generate) = rayon::join(|| xor(a, b, sk), || and(a, b, sk));
|
||||
|
||||
let carry = if ladner_fischer_opt {
|
||||
ladner_fischer(&propagate, &generate, sk)
|
||||
} else {
|
||||
brent_kung(&propagate, &generate, sk)
|
||||
};
|
||||
|
||||
xor(&propagate, &carry, sk)
|
||||
}
|
||||
|
||||
// Implementation of the Brent Kung parallel prefix algorithm
|
||||
// This function computes the carry signals in parallel while minimizing the number of homomorphic
|
||||
// operations
|
||||
fn brent_kung(
|
||||
propagate: &[Ciphertext; 32],
|
||||
generate: &[Ciphertext; 32],
|
||||
sk: &ServerKey,
|
||||
) -> [Ciphertext; 32] {
|
||||
let mut propagate = propagate.clone();
|
||||
let mut generate = generate.clone();
|
||||
|
||||
for d in 0..5 {
|
||||
// first 5 stages
|
||||
let stride = 1 << d;
|
||||
|
||||
let indices: Vec<(usize, usize)> = (0..32 - stride)
|
||||
.rev()
|
||||
.step_by(2 * stride)
|
||||
.map(|i| i + 1 - stride)
|
||||
.enumerate()
|
||||
.collect();
|
||||
|
||||
let updates: Vec<(usize, Ciphertext, Ciphertext)> = indices
|
||||
.into_par_iter()
|
||||
.map(|(n, index)| {
|
||||
let new_p;
|
||||
let new_g;
|
||||
|
||||
if n == 0 {
|
||||
// grey cell
|
||||
new_p = propagate[index].clone();
|
||||
new_g = sk.or(
|
||||
&generate[index],
|
||||
&sk.and(&generate[index + stride], &propagate[index]),
|
||||
);
|
||||
} else {
|
||||
// black cell
|
||||
new_p = sk.and(&propagate[index], &propagate[index + stride]);
|
||||
new_g = sk.or(
|
||||
&generate[index],
|
||||
&sk.and(&generate[index + stride], &propagate[index]),
|
||||
);
|
||||
}
|
||||
|
||||
(index, new_p, new_g)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (index, new_p, new_g) in updates {
|
||||
propagate[index] = new_p;
|
||||
generate[index] = new_g;
|
||||
}
|
||||
|
||||
if d == 4 {
|
||||
let mut cells = 0;
|
||||
for d_2 in 0..4 {
|
||||
// last 4 stages
|
||||
let stride = 1 << (4 - d_2 - 1);
|
||||
cells += 1 << d_2;
|
||||
|
||||
let indices: Vec<(usize, usize)> = (0..cells)
|
||||
.map(|cell| (cell, stride + 2 * stride * cell))
|
||||
.collect();
|
||||
|
||||
let updates: Vec<(usize, Ciphertext)> = indices
|
||||
.into_par_iter()
|
||||
.map(|(_, index)| {
|
||||
let new_g = sk.or(
|
||||
&generate[index],
|
||||
&sk.and(&generate[index + stride], &propagate[index]),
|
||||
);
|
||||
|
||||
(index, new_g)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (index, new_g) in updates {
|
||||
generate[index] = new_g;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut carry = trivial_bools(&[false; 32], sk);
|
||||
carry[..31].clone_from_slice(&generate[1..(31 + 1)]);
|
||||
|
||||
carry
|
||||
}
|
||||
|
||||
// Implementation of the Ladner Fischer parallel prefix algorithm
|
||||
// This function may perform better than the previous one when many threads are available as it has
|
||||
// less stages
|
||||
fn ladner_fischer(
|
||||
propagate: &[Ciphertext; 32],
|
||||
generate: &[Ciphertext; 32],
|
||||
sk: &ServerKey,
|
||||
) -> [Ciphertext; 32] {
|
||||
let mut propagate = propagate.clone();
|
||||
let mut generate = generate.clone();
|
||||
|
||||
for d in 0..5 {
|
||||
let stride = 1 << d;
|
||||
|
||||
let indices: Vec<(usize, usize)> = (0..32 - stride)
|
||||
.rev()
|
||||
.step_by(2 * stride)
|
||||
.flat_map(|i| (0..stride).map(move |count| (i, count)))
|
||||
.collect();
|
||||
|
||||
let updates: Vec<(usize, Ciphertext, Ciphertext)> = indices
|
||||
.into_par_iter()
|
||||
.map(|(i, count)| {
|
||||
let index = i - count; // current column
|
||||
|
||||
let p = propagate[i + 1].clone(); // propagate from a previous column
|
||||
let g = generate[i + 1].clone(); // generate from a previous column
|
||||
let new_p;
|
||||
let new_g;
|
||||
|
||||
if index < 32 - (2 * stride) {
|
||||
// black cell
|
||||
new_p = sk.and(&propagate[index], &p);
|
||||
new_g = sk.or(&generate[index], &sk.and(&g, &propagate[index]));
|
||||
} else {
|
||||
// grey cell
|
||||
new_p = propagate[index].clone();
|
||||
new_g = sk.or(&generate[index], &sk.and(&g, &propagate[index]));
|
||||
}
|
||||
(index, new_p, new_g)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for (index, new_p, new_g) in updates {
|
||||
propagate[index] = new_p;
|
||||
generate[index] = new_g;
|
||||
}
|
||||
}
|
||||
|
||||
let mut carry = trivial_bools(&[false; 32], sk);
|
||||
carry[..31].clone_from_slice(&generate[1..(31 + 1)]);
|
||||
|
||||
carry
|
||||
}
|
||||
|
||||
// 2 (homomorphic) bitwise ops
|
||||
pub fn sigma0(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let a = rotate_right(x, 7);
|
||||
let b = rotate_right(x, 18);
|
||||
let c = shift_right(x, 3, sk);
|
||||
xor(&xor(&a, &b, sk), &c, sk)
|
||||
}
|
||||
|
||||
pub fn sigma1(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let a = rotate_right(x, 17);
|
||||
let b = rotate_right(x, 19);
|
||||
let c = shift_right(x, 10, sk);
|
||||
xor(&xor(&a, &b, sk), &c, sk)
|
||||
}
|
||||
|
||||
pub fn sigma_upper_case_0(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let a = rotate_right(x, 2);
|
||||
let b = rotate_right(x, 13);
|
||||
let c = rotate_right(x, 22);
|
||||
xor(&xor(&a, &b, sk), &c, sk)
|
||||
}
|
||||
|
||||
pub fn sigma_upper_case_1(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let a = rotate_right(x, 6);
|
||||
let b = rotate_right(x, 11);
|
||||
let c = rotate_right(x, 25);
|
||||
xor(&xor(&a, &b, sk), &c, sk)
|
||||
}
|
||||
|
||||
// 0 bitwise ops
|
||||
fn rotate_right(x: &[Ciphertext; 32], n: usize) -> [Ciphertext; 32] {
|
||||
let mut result = x.clone();
|
||||
result.rotate_right(n);
|
||||
result
|
||||
}
|
||||
|
||||
fn shift_right(x: &[Ciphertext; 32], n: usize, sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let mut result = x.clone();
|
||||
result.rotate_right(n);
|
||||
result[..n].fill_with(|| sk.trivial_encrypt(false));
|
||||
result
|
||||
}
|
||||
|
||||
// 1 bitwise op
|
||||
pub fn ch(
|
||||
x: &[Ciphertext; 32],
|
||||
y: &[Ciphertext; 32],
|
||||
z: &[Ciphertext; 32],
|
||||
sk: &ServerKey,
|
||||
) -> [Ciphertext; 32] {
|
||||
mux(x, y, z, sk)
|
||||
}
|
||||
|
||||
// 4 bitwise ops
|
||||
pub fn maj(
|
||||
x: &[Ciphertext; 32],
|
||||
y: &[Ciphertext; 32],
|
||||
z: &[Ciphertext; 32],
|
||||
sk: &ServerKey,
|
||||
) -> [Ciphertext; 32] {
|
||||
let (lhs, rhs) = rayon::join(|| and(x, &xor(y, z, sk), sk), || and(y, z, sk));
|
||||
xor(&lhs, &rhs, sk)
|
||||
}
|
||||
|
||||
// Parallelized homomorphic bitwise ops
|
||||
// Building block for most of the previous functions
|
||||
fn xor(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let mut result = a.clone();
|
||||
result
|
||||
.par_iter_mut()
|
||||
.zip(a.par_iter().zip(b.par_iter()))
|
||||
.for_each(|(dst, (lhs, rhs))| *dst = sk.xor(lhs, rhs));
|
||||
result
|
||||
}
|
||||
|
||||
fn and(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
let mut result = a.clone();
|
||||
result
|
||||
.par_iter_mut()
|
||||
.zip(a.par_iter().zip(b.par_iter()))
|
||||
.for_each(|(dst, (lhs, rhs))| *dst = sk.and(lhs, rhs));
|
||||
result
|
||||
}
|
||||
|
||||
fn mux(
|
||||
condition: &[Ciphertext; 32],
|
||||
then: &[Ciphertext; 32],
|
||||
otherwise: &[Ciphertext; 32],
|
||||
sk: &ServerKey,
|
||||
) -> [Ciphertext; 32] {
|
||||
let mut result = condition.clone();
|
||||
result
|
||||
.par_iter_mut()
|
||||
.zip(
|
||||
condition
|
||||
.par_iter()
|
||||
.zip(then.par_iter().zip(otherwise.par_iter())),
|
||||
)
|
||||
.for_each(|(dst, (condition, (then, other)))| *dst = sk.mux(condition, then, other));
|
||||
result
|
||||
}
|
||||
|
||||
// Trivial encryption of 32 bools
|
||||
pub fn trivial_bools(bools: &[bool; 32], sk: &ServerKey) -> [Ciphertext; 32] {
|
||||
array::from_fn(|i| sk.trivial_encrypt(bools[i]))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tfhe::boolean::prelude::*;
|
||||
|
||||
fn to_bool_array(arr: [i32; 32]) -> [bool; 32] {
|
||||
let mut bool_arr = [false; 32];
|
||||
for i in 0..32 {
|
||||
if arr[i] == 1 {
|
||||
bool_arr[i] = true;
|
||||
}
|
||||
}
|
||||
bool_arr
|
||||
}
|
||||
fn encrypt(bools: &[bool; 32], ck: &ClientKey) -> [Ciphertext; 32] {
|
||||
array::from_fn(|i| ck.encrypt(bools[i]))
|
||||
}
|
||||
|
||||
fn decrypt(bools: &[Ciphertext; 32], ck: &ClientKey) -> [bool; 32] {
|
||||
array::from_fn(|i| ck.decrypt(&bools[i]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_modulo_2_32() {
|
||||
let (ck, sk) = gen_keys();
|
||||
|
||||
let a = encrypt(
|
||||
&to_bool_array([
|
||||
0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1,
|
||||
1, 0, 0, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let b = encrypt(
|
||||
&to_bool_array([
|
||||
0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,
|
||||
1, 0, 1, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let c = encrypt(
|
||||
&to_bool_array([
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0,
|
||||
1, 1, 0, 0,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let d = encrypt(
|
||||
&to_bool_array([
|
||||
0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1,
|
||||
1, 0, 0, 0,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let e = encrypt(
|
||||
&to_bool_array([
|
||||
0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
|
||||
1, 1, 0, 0,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
|
||||
let (sum, carry) = csa(&c, &d, &e, &sk);
|
||||
let (sum, carry) = csa(&b, &sum, &carry, &sk);
|
||||
let (sum, carry) = csa(&a, &sum, &carry, &sk);
|
||||
let output = add(&sum, &carry, false, &sk);
|
||||
|
||||
let result = decrypt(&output, &ck);
|
||||
let expected = to_bool_array([
|
||||
0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0,
|
||||
1, 0, 0,
|
||||
]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sigma0() {
|
||||
let (ck, sk) = gen_keys();
|
||||
|
||||
let input = encrypt(
|
||||
&to_bool_array([
|
||||
0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0,
|
||||
1, 1, 1, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let output = sigma0(&input, &sk);
|
||||
let result = decrypt(&output, &ck);
|
||||
let expected = to_bool_array([
|
||||
1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1,
|
||||
0, 1, 1,
|
||||
]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
} //the other sigmas are implemented in the same way
|
||||
|
||||
#[test]
|
||||
fn test_ch() {
|
||||
let (ck, sk) = gen_keys();
|
||||
|
||||
let e = encrypt(
|
||||
&to_bool_array([
|
||||
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
|
||||
1, 1, 1, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let f = encrypt(
|
||||
&to_bool_array([
|
||||
1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
|
||||
1, 1, 0, 0,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let g = encrypt(
|
||||
&to_bool_array([
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0,
|
||||
1, 0, 1, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
|
||||
let output = ch(&e, &f, &g, &sk);
|
||||
let result = decrypt(&output, &ck);
|
||||
let expected = to_bool_array([
|
||||
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1,
|
||||
1, 0, 0,
|
||||
]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_maj() {
|
||||
let (ck, sk) = gen_keys();
|
||||
|
||||
let a = encrypt(
|
||||
&to_bool_array([
|
||||
0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
|
||||
0, 1, 1, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let b = encrypt(
|
||||
&to_bool_array([
|
||||
1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0,
|
||||
0, 1, 0, 1,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
let c = encrypt(
|
||||
&to_bool_array([
|
||||
0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,
|
||||
0, 0, 1, 0,
|
||||
]),
|
||||
&ck,
|
||||
);
|
||||
|
||||
let output = maj(&a, &b, &c, &sk);
|
||||
let result = decrypt(&output, &ck);
|
||||
let expected = to_bool_array([
|
||||
0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
|
||||
1, 1, 1,
|
||||
]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
74
tfhe/examples/sha256_bool/main.rs
Normal file
74
tfhe/examples/sha256_bool/main.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
mod boolean_ops;
|
||||
mod padding;
|
||||
mod sha256_function;
|
||||
|
||||
use clap::{Arg, ArgAction, Command};
|
||||
use padding::pad_sha256_input;
|
||||
use sha256_function::{bools_to_hex, sha256_fhe};
|
||||
use std::io;
|
||||
use tfhe::boolean::prelude::*;
|
||||
|
||||
fn main() {
|
||||
let matches = Command::new("Homomorphic sha256")
|
||||
.arg(
|
||||
Arg::new("ladner_fischer")
|
||||
.long("ladner-fischer")
|
||||
.help("Use the Ladner Fischer parallel prefix algorithm for additions")
|
||||
.action(ArgAction::SetTrue),
|
||||
)
|
||||
.get_matches();
|
||||
|
||||
// If set using the command line flag "--ladner-fischer" this algorithm will be used in
|
||||
// additions
|
||||
let ladner_fischer: bool = matches.get_flag("ladner_fischer");
|
||||
|
||||
// INTRODUCE INPUT FROM STDIN
|
||||
|
||||
let mut input = String::new();
|
||||
println!("Write input to hash:");
|
||||
|
||||
io::stdin()
|
||||
.read_line(&mut input)
|
||||
.expect("Failed to read line");
|
||||
|
||||
input = input.trim_end_matches('\n').to_string();
|
||||
|
||||
println!("You entered: \"{}\"", input);
|
||||
|
||||
// CLIENT PADS DATA AND ENCRYPTS IT
|
||||
|
||||
let (ck, sk) = gen_keys();
|
||||
|
||||
let padded_input = pad_sha256_input(&input);
|
||||
let encrypted_input = encrypt_bools(&padded_input, &ck);
|
||||
|
||||
// SERVER COMPUTES OVER THE ENCRYPTED PADDED DATA
|
||||
|
||||
println!("Computing the hash");
|
||||
let encrypted_output = sha256_fhe(encrypted_input, ladner_fischer, &sk);
|
||||
|
||||
// CLIENT DECRYPTS THE OUTPUT
|
||||
|
||||
let output = decrypt_bools(&encrypted_output, &ck);
|
||||
let outhex = bools_to_hex(output);
|
||||
|
||||
println!("{}", outhex);
|
||||
}
|
||||
|
||||
fn encrypt_bools(bools: &Vec<bool>, ck: &ClientKey) -> Vec<Ciphertext> {
|
||||
let mut ciphertext = vec![];
|
||||
|
||||
for bool in bools {
|
||||
ciphertext.push(ck.encrypt(*bool));
|
||||
}
|
||||
ciphertext
|
||||
}
|
||||
|
||||
fn decrypt_bools(ciphertext: &Vec<Ciphertext>, ck: &ClientKey) -> Vec<bool> {
|
||||
let mut bools = vec![];
|
||||
|
||||
for cipher in ciphertext {
|
||||
bools.push(ck.decrypt(cipher));
|
||||
}
|
||||
bools
|
||||
}
|
||||
70
tfhe/examples/sha256_bool/padding.rs
Normal file
70
tfhe/examples/sha256_bool/padding.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
// This module contains the padding function, which is computed by the client over the plain text.
|
||||
// The function returns the padded data as a vector of bools, for later encryption. Note that
|
||||
// padding could also be performed by the server, by appending trivially encrypted bools. However,
|
||||
// in our implementation, the exact length of the pre-image (hashed message) is not revealed.
|
||||
|
||||
// If input starts with "0x" and following characters are valid hexadecimal values, it's interpreted
|
||||
// as hex, otherwise input is interpreted as text
|
||||
pub fn pad_sha256_input(input: &str) -> Vec<bool> {
|
||||
let bytes = if input.starts_with("0x") && is_valid_hex(&input[2..]) {
|
||||
let no_prefix = &input[2..];
|
||||
let hex_input = if no_prefix.len() % 2 == 0 {
|
||||
// hex value can be converted to bytes
|
||||
no_prefix.to_string()
|
||||
} else {
|
||||
format!("0{}", no_prefix) // pad hex value to ensure a correct conversion to bytes
|
||||
};
|
||||
hex_input
|
||||
.as_bytes()
|
||||
.chunks(2)
|
||||
.map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 16).unwrap())
|
||||
.collect::<Vec<u8>>()
|
||||
} else {
|
||||
input.as_bytes().to_vec()
|
||||
};
|
||||
|
||||
pad_sha256_data(&bytes)
|
||||
}
|
||||
|
||||
fn is_valid_hex(hex: &str) -> bool {
|
||||
hex.chars().all(|c| c.is_ascii_hexdigit())
|
||||
}
|
||||
|
||||
fn pad_sha256_data(data: &[u8]) -> Vec<bool> {
|
||||
let mut bits: Vec<bool> = data
|
||||
.iter()
|
||||
.flat_map(|byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1))
|
||||
.collect();
|
||||
|
||||
// Append a single '1' bit
|
||||
bits.push(true);
|
||||
|
||||
// Calculate the number of padding zeros required
|
||||
let padding_zeros = (512 - ((bits.len() + 64) % 512)) % 512;
|
||||
bits.extend(std::iter::repeat(false).take(padding_zeros));
|
||||
|
||||
// Append a 64-bit big-endian representation of the original message length
|
||||
let data_len_bits = (data.len() as u64) * 8;
|
||||
bits.extend((0..64).rev().map(|i| (data_len_bits >> i) & 1 == 1));
|
||||
|
||||
bits
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sha256_function::bools_to_hex;
|
||||
|
||||
#[test]
|
||||
fn test_pad_sha256_input() {
|
||||
let input = "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq";
|
||||
let expected_output = "6162636462636465636465666465666765666768666768696768696a68696a6\
|
||||
b696a6b6c6a6b6c6d6b6c6d6e6c6d6e6f6d6e6f706e6f70718000000000000000000000000000000000000000000\
|
||||
000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001c0";
|
||||
|
||||
let result = pad_sha256_input(input);
|
||||
let hex_result = bools_to_hex(result);
|
||||
|
||||
assert_eq!(hex_result, expected_output);
|
||||
}
|
||||
}
|
||||
236
tfhe/examples/sha256_bool/sha256_function.rs
Normal file
236
tfhe/examples/sha256_bool/sha256_function.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
// This module implements the main sha256 homomorphic function using parallel processing when
|
||||
// possible and some helper functions
|
||||
|
||||
use crate::boolean_ops::{
|
||||
add, ch, csa, maj, sigma0, sigma1, sigma_upper_case_0, sigma_upper_case_1, trivial_bools,
|
||||
};
|
||||
use std::array;
|
||||
use tfhe::boolean::prelude::*;
|
||||
|
||||
pub fn sha256_fhe(
|
||||
padded_input: Vec<Ciphertext>,
|
||||
ladner_fischer: bool,
|
||||
sk: &ServerKey,
|
||||
) -> Vec<Ciphertext> {
|
||||
assert_eq!(
|
||||
padded_input.len() % 512,
|
||||
0,
|
||||
"padded input length is not a multiple of 512"
|
||||
);
|
||||
|
||||
// Initialize hash values
|
||||
let mut hash: [[Ciphertext; 32]; 8] = [
|
||||
trivial_bools(&hex_to_bools(0x6a09e667), sk),
|
||||
trivial_bools(&hex_to_bools(0xbb67ae85), sk),
|
||||
trivial_bools(&hex_to_bools(0x3c6ef372), sk),
|
||||
trivial_bools(&hex_to_bools(0xa54ff53a), sk),
|
||||
trivial_bools(&hex_to_bools(0x510e527f), sk),
|
||||
trivial_bools(&hex_to_bools(0x9b05688c), sk),
|
||||
trivial_bools(&hex_to_bools(0x1f83d9ab), sk),
|
||||
trivial_bools(&hex_to_bools(0x5be0cd19), sk),
|
||||
];
|
||||
|
||||
let chunks = padded_input.chunks_exact(512);
|
||||
|
||||
for chunk in chunks {
|
||||
// Compute the 64 words
|
||||
let mut w = initialize_w(sk);
|
||||
|
||||
for i in 0..16 {
|
||||
w[i].clone_from_slice(&chunk[i * 32..(i + 1) * 32]);
|
||||
}
|
||||
|
||||
for i in (16..64).step_by(2) {
|
||||
let u = i + 1;
|
||||
|
||||
let (word_i, word_u) = rayon::join(
|
||||
|| {
|
||||
let (s0, s1) = rayon::join(|| sigma0(&w[i - 15], sk), || sigma1(&w[i - 2], sk));
|
||||
|
||||
let (sum, carry) = csa(&s0, &w[i - 7], &w[i - 16], sk);
|
||||
let (sum, carry) = csa(&s1, &sum, &carry, sk);
|
||||
add(&sum, &carry, ladner_fischer, sk)
|
||||
},
|
||||
|| {
|
||||
let (s0, s1) = rayon::join(|| sigma0(&w[u - 15], sk), || sigma1(&w[u - 2], sk));
|
||||
|
||||
let (sum, carry) = csa(&s0, &w[u - 7], &w[u - 16], sk);
|
||||
let (sum, carry) = csa(&s1, &sum, &carry, sk);
|
||||
add(&sum, &carry, ladner_fischer, sk)
|
||||
},
|
||||
);
|
||||
|
||||
w[i] = word_i;
|
||||
w[u] = word_u;
|
||||
}
|
||||
|
||||
let mut a = hash[0].clone();
|
||||
let mut b = hash[1].clone();
|
||||
let mut c = hash[2].clone();
|
||||
let mut d = hash[3].clone();
|
||||
let mut e = hash[4].clone();
|
||||
let mut f = hash[5].clone();
|
||||
let mut g = hash[6].clone();
|
||||
let mut h = hash[7].clone();
|
||||
|
||||
// Compression loop
|
||||
for i in 0..64 {
|
||||
let (temp1, temp2) = rayon::join(
|
||||
|| {
|
||||
let ((sum, carry), s1) = rayon::join(
|
||||
|| {
|
||||
let ((sum, carry), ch) = rayon::join(
|
||||
|| csa(&h, &w[i], &trivial_bools(&hex_to_bools(K[i]), sk), sk),
|
||||
|| ch(&e, &f, &g, sk),
|
||||
);
|
||||
csa(&sum, &carry, &ch, sk)
|
||||
},
|
||||
|| sigma_upper_case_1(&e, sk),
|
||||
);
|
||||
|
||||
let (sum, carry) = csa(&sum, &carry, &s1, sk);
|
||||
add(&sum, &carry, ladner_fischer, sk)
|
||||
},
|
||||
|| {
|
||||
add(
|
||||
&sigma_upper_case_0(&a, sk),
|
||||
&maj(&a, &b, &c, sk),
|
||||
ladner_fischer,
|
||||
sk,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
let (temp_e, temp_a) = rayon::join(
|
||||
|| add(&d, &temp1, ladner_fischer, sk),
|
||||
|| add(&temp1, &temp2, ladner_fischer, sk),
|
||||
);
|
||||
|
||||
h = g;
|
||||
g = f;
|
||||
f = e;
|
||||
e = temp_e;
|
||||
d = c;
|
||||
c = b;
|
||||
b = a;
|
||||
a = temp_a;
|
||||
}
|
||||
|
||||
hash[0] = add(&hash[0], &a, ladner_fischer, sk);
|
||||
hash[1] = add(&hash[1], &b, ladner_fischer, sk);
|
||||
hash[2] = add(&hash[2], &c, ladner_fischer, sk);
|
||||
hash[3] = add(&hash[3], &d, ladner_fischer, sk);
|
||||
hash[4] = add(&hash[4], &e, ladner_fischer, sk);
|
||||
hash[5] = add(&hash[5], &f, ladner_fischer, sk);
|
||||
hash[6] = add(&hash[6], &g, ladner_fischer, sk);
|
||||
hash[7] = add(&hash[7], &h, ladner_fischer, sk);
|
||||
}
|
||||
|
||||
// Concatenate the final hash values to produce a 256-bit hash
|
||||
let mut output = vec![];
|
||||
|
||||
for item in &hash {
|
||||
for j in item.iter().take(32) {
|
||||
output.push(j.clone());
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
// Initialize the 64 words with trivial encryption
|
||||
fn initialize_w(sk: &ServerKey) -> [[Ciphertext; 32]; 64] {
|
||||
array::from_fn(|_| trivial_bools(&[false; 32], sk))
|
||||
}
|
||||
|
||||
// To represent decrypted digest bools as hexadecimal String
|
||||
pub fn bools_to_hex(bools: Vec<bool>) -> String {
|
||||
let mut hex_string = String::new();
|
||||
let mut byte = 0u8;
|
||||
let mut counter = 0;
|
||||
|
||||
for bit in bools {
|
||||
byte <<= 1;
|
||||
if bit {
|
||||
byte |= 1;
|
||||
}
|
||||
|
||||
counter += 1;
|
||||
|
||||
if counter == 8 {
|
||||
hex_string.push_str(&format!("{:02x}", byte));
|
||||
byte = 0;
|
||||
counter = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle any remaining bits in case the bools vector length is not a multiple of 8
|
||||
if counter > 0 {
|
||||
byte <<= 8 - counter;
|
||||
hex_string.push_str(&format!("{:02x}", byte));
|
||||
}
|
||||
|
||||
hex_string
|
||||
}
|
||||
|
||||
// To represent constant values as bool arrays
|
||||
fn hex_to_bools(hex_value: u32) -> [bool; 32] {
|
||||
let mut bool_array = [false; 32];
|
||||
let mut mask = 0x8000_0000;
|
||||
|
||||
for item in &mut bool_array {
|
||||
*item = (hex_value & mask) != 0;
|
||||
mask >>= 1;
|
||||
}
|
||||
|
||||
bool_array
|
||||
}
|
||||
|
||||
const K: [u32; 64] = [
|
||||
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
|
||||
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
|
||||
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
|
||||
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
|
||||
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
|
||||
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
|
||||
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
|
||||
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn to_bool_array(arr: [i32; 32]) -> [bool; 32] {
|
||||
let mut bool_arr = [false; 32];
|
||||
for i in 0..32 {
|
||||
if arr[i] == 1 {
|
||||
bool_arr[i] = true;
|
||||
}
|
||||
}
|
||||
bool_arr
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bools_to_hex() {
|
||||
let bools = to_bool_array([
|
||||
1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
0, 1, 0,
|
||||
]);
|
||||
let hex_bools = bools_to_hex(bools.to_vec());
|
||||
|
||||
assert_eq!(hex_bools, "90befffa");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hex_to_bools() {
|
||||
let hex = 0x428a2f98;
|
||||
let result = hex_to_bools(hex);
|
||||
let expected = to_bool_array([
|
||||
0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1,
|
||||
0, 0, 0,
|
||||
]);
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
555
tfhe/js_on_wasm_tests/test-hlapi.js
Normal file
555
tfhe/js_on_wasm_tests/test-hlapi.js
Normal file
@@ -0,0 +1,555 @@
|
||||
const test = require('node:test');
|
||||
const assert = require('node:assert').strict;
|
||||
const { performance } = require('perf_hooks');
|
||||
const {
|
||||
ShortintParametersName,
|
||||
ShortintParameters,
|
||||
TfheClientKey,
|
||||
TfhePublicKey,
|
||||
TfheCompressedPublicKey,
|
||||
TfheCompactPublicKey,
|
||||
TfheCompressedServerKey,
|
||||
TfheConfigBuilder,
|
||||
CompressedFheUint8,
|
||||
FheUint8,
|
||||
FheUint32,
|
||||
CompactFheUint32,
|
||||
CompactFheUint32List,
|
||||
CompressedFheUint128,
|
||||
FheUint128,
|
||||
CompressedFheUint256,
|
||||
CompactFheUint256,
|
||||
CompactFheUint256List,
|
||||
FheUint256
|
||||
} = require("../pkg/tfhe.js");
|
||||
|
||||
|
||||
const U256_MAX = BigInt("115792089237316195423570985008687907853269984665640564039457584007913129639935");
|
||||
const U128_MAX = BigInt("340282366920938463463374607431768211455");
|
||||
const U32_MAX = 4294967295;
|
||||
|
||||
// Here integers are not enabled
|
||||
// but we try to use them, so an error should be returned
|
||||
// as the underlying panic should have been trapped
|
||||
test('hlapi_panic', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.build();
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let clear = 73;
|
||||
try {
|
||||
let _ = FheUint8.encrypt_with_client_key(clear, clientKey);
|
||||
assert(false);
|
||||
} catch (e) {
|
||||
assert(true);
|
||||
}
|
||||
});
|
||||
|
||||
test('hlapi_key_gen_big', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let compressedServerKey = TfheCompressedServerKey.new(clientKey);
|
||||
try {
|
||||
let publicKey = TfhePublicKey.new(clientKey);
|
||||
assert(false);
|
||||
} catch (e) {
|
||||
assert(true)
|
||||
}
|
||||
|
||||
let serializedClientKey = clientKey.serialize();
|
||||
let serializedCompressedServerKey = compressedServerKey.serialize();
|
||||
});
|
||||
|
||||
test('hlapi_key_gen_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let compressedServerKey = TfheCompressedServerKey.new(clientKey);
|
||||
let publicKey = TfhePublicKey.new(clientKey);
|
||||
|
||||
let serializedClientKey = clientKey.serialize();
|
||||
let serializedCompressedServerKey = compressedServerKey.serialize();
|
||||
let serializedPublicKey = publicKey.serialize();
|
||||
});
|
||||
|
||||
test('hlapi_client_key_encrypt_decrypt_uint8_big', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let clear = 73;
|
||||
let encrypted = FheUint8.encrypt_with_client_key(clear, clientKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, clear);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint8.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, clear);
|
||||
});
|
||||
|
||||
test('hlapi_public_key_encrypt_decrypt_uint32_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfhePublicKey.new(clientKey);
|
||||
|
||||
let encrypted = FheUint32.encrypt_with_public_key(U32_MAX, publicKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U32_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint32.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U32_MAX);
|
||||
});
|
||||
|
||||
test('hlapi_decompress_public_key_then_encrypt_decrypt_uint32_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
var startTime = performance.now()
|
||||
let compressedPublicKey = TfheCompressedPublicKey.new(clientKey);
|
||||
var endTime = performance.now()
|
||||
|
||||
let data = compressedPublicKey.serialize()
|
||||
|
||||
let publicKey = compressedPublicKey.decompress();
|
||||
|
||||
|
||||
var startTime = performance.now()
|
||||
let encrypted = FheUint8.encrypt_with_public_key(255, publicKey);
|
||||
var endTime = performance.now()
|
||||
|
||||
let ser = encrypted.serialize();
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, 255);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint32.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, 255);
|
||||
});
|
||||
|
||||
test('hlapi_compressed_public_client_uint8_big', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers()
|
||||
.build();
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let clear = 73;
|
||||
let compressed_encrypted = CompressedFheUint8.encrypt_with_client_key(clear, clientKey);
|
||||
let compressed_serialized = compressed_encrypted.serialize();
|
||||
let compressed_deserialized = CompressedFheUint8.deserialize(compressed_serialized);
|
||||
let decompressed = compressed_deserialized.decompress()
|
||||
|
||||
let decrypted = decompressed.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, clear);
|
||||
});
|
||||
|
||||
test('hlapi_client_key_encrypt_decrypt_uint128_big', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let encrypted = FheUint128.encrypt_with_client_key(U128_MAX, clientKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U128_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint128.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U128_MAX);
|
||||
|
||||
// Compressed
|
||||
let compressed_encrypted = CompressedFheUint128.encrypt_with_client_key(U128_MAX, clientKey);
|
||||
let compressed_serialized = compressed_encrypted.serialize();
|
||||
let compressed_deserialized = CompressedFheUint128.deserialize(compressed_serialized);
|
||||
let decompressed = compressed_deserialized.decompress()
|
||||
|
||||
decrypted = decompressed.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U128_MAX);
|
||||
});
|
||||
|
||||
test('hlapi_client_key_encrypt_decrypt_uint128_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let encrypted = FheUint128.encrypt_with_client_key(U128_MAX, clientKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U128_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint128.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U128_MAX);
|
||||
|
||||
// Compressed
|
||||
let compressed_encrypted = CompressedFheUint128.encrypt_with_client_key(U128_MAX, clientKey);
|
||||
let compressed_serialized = compressed_encrypted.serialize();
|
||||
let compressed_deserialized = CompressedFheUint128.deserialize(compressed_serialized);
|
||||
let decompressed = compressed_deserialized.decompress()
|
||||
|
||||
decrypted = decompressed.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U128_MAX);
|
||||
});
|
||||
|
||||
test('hlapi_client_key_encrypt_decrypt_uint256_big', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let encrypted = FheUint256.encrypt_with_client_key(U256_MAX, clientKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint256.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
|
||||
|
||||
// Compressed
|
||||
let compressed_encrypted = CompressedFheUint256.encrypt_with_client_key(U256_MAX, clientKey);
|
||||
let compressed_serialized = compressed_encrypted.serialize();
|
||||
let compressed_deserialized = CompressedFheUint256.deserialize(compressed_serialized);
|
||||
let decompressed = compressed_deserialized.decompress()
|
||||
|
||||
decrypted = decompressed.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
});
|
||||
|
||||
test('hlapi_client_key_encrypt_decrypt_uint256_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
|
||||
let encrypted = FheUint256.encrypt_with_client_key(U256_MAX, clientKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint256.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
|
||||
|
||||
// Compressed
|
||||
let compressed_encrypted = CompressedFheUint256.encrypt_with_client_key(U256_MAX, clientKey);
|
||||
let compressed_serialized = compressed_encrypted.serialize();
|
||||
let compressed_deserialized = CompressedFheUint256.deserialize(compressed_serialized);
|
||||
let decompressed = compressed_deserialized.decompress()
|
||||
|
||||
decrypted = decompressed.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
});
|
||||
|
||||
test('hlapi_decompress_public_key_then_encrypt_decrypt_uint256_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let compressedPublicKey = TfheCompressedPublicKey.new(clientKey);
|
||||
let publicKey = compressedPublicKey.decompress();
|
||||
|
||||
|
||||
let encrypted = FheUint256.encrypt_with_public_key(U256_MAX, publicKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint256.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
|
||||
});
|
||||
|
||||
test('hlapi_public_key_encrypt_decrypt_uint256_small', (t) => {
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_default_integers_small()
|
||||
.build();
|
||||
|
||||
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfhePublicKey.new(clientKey);
|
||||
|
||||
let encrypted = FheUint256.encrypt_with_public_key(U256_MAX, publicKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint256.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
|
||||
});
|
||||
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// 32 bits compact
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
function hlapi_compact_public_key_encrypt_decrypt_uint32_single(config) {
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfheCompactPublicKey.new(clientKey);
|
||||
|
||||
let encrypted = FheUint32.encrypt_with_compact_public_key(U32_MAX, publicKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U32_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint32.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U32_MAX);
|
||||
}
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint32_big_single', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint32_single(config);
|
||||
});
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint32_small_single', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint32_single(config);
|
||||
});
|
||||
|
||||
function hlapi_compact_public_key_encrypt_decrypt_uint32_single_compact(config) {
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfheCompactPublicKey.new(clientKey);
|
||||
|
||||
let compact_encrypted = CompactFheUint32.encrypt_with_compact_public_key(U32_MAX, publicKey);
|
||||
let encrypted = compact_encrypted.expand();
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U32_MAX);
|
||||
|
||||
let serialized = compact_encrypted.serialize();
|
||||
let deserialized = CompactFheUint32.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.expand().decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U32_MAX);
|
||||
}
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint32_small_single_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint32_single_compact(config);
|
||||
});
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint32_big_single_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint32_single_compact(config);
|
||||
});
|
||||
|
||||
function hlapi_compact_public_key_encrypt_decrypt_uint32_list_compact(config) {
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfheCompactPublicKey.new(clientKey);
|
||||
|
||||
let values = [0, 1, 2394, U32_MAX];
|
||||
|
||||
let compact_list = CompactFheUint32List.encrypt_with_compact_public_key(values, publicKey);
|
||||
|
||||
{
|
||||
let encrypted_list = compact_list.expand();
|
||||
|
||||
assert.deepStrictEqual(encrypted_list.length, values.length);
|
||||
|
||||
for (let i = 0; i < values.length; i++)
|
||||
{
|
||||
let decrypted = encrypted_list[i].decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
let serialized_list = compact_list.serialize();
|
||||
let deserialized_list = CompactFheUint32List.deserialize(serialized_list);
|
||||
let encrypted_list = deserialized_list.expand();
|
||||
assert.deepStrictEqual(encrypted_list.length, values.length);
|
||||
|
||||
for (let i = 0; i < values.length; i++)
|
||||
{
|
||||
let decrypted = encrypted_list[i].decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint32_small_list_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint32_list_compact(config);
|
||||
});
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint32_big_list_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint32_list_compact(config);
|
||||
});
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
/// 256 bits compact
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
function hlapi_compact_public_key_encrypt_decrypt_uint256_single(config) {
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfheCompactPublicKey.new(clientKey);
|
||||
|
||||
let encrypted = FheUint256.encrypt_with_compact_public_key(U256_MAX, publicKey);
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
|
||||
let serialized = encrypted.serialize();
|
||||
let deserialized = FheUint256.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
|
||||
}
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_single', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint256_single(config);
|
||||
});
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint256_small_single', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint256_single(config);
|
||||
});
|
||||
|
||||
function hlapi_compact_public_key_encrypt_decrypt_uint256_single_compact(config) {
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfheCompactPublicKey.new(clientKey);
|
||||
|
||||
let compact_encrypted = CompactFheUint256.encrypt_with_compact_public_key(U256_MAX, publicKey);
|
||||
let encrypted = compact_encrypted.expand();
|
||||
let decrypted = encrypted.decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, U256_MAX);
|
||||
|
||||
let serialized = compact_encrypted.serialize();
|
||||
let deserialized = CompactFheUint256.deserialize(serialized);
|
||||
let deserialized_decrypted = deserialized.expand().decrypt(clientKey);
|
||||
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
|
||||
}
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint256_small_single_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint256_single_compact(config);
|
||||
});
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_single_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint256_single_compact(config);
|
||||
});
|
||||
|
||||
function hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config) {
|
||||
let clientKey = TfheClientKey.generate(config);
|
||||
let publicKey = TfheCompactPublicKey.new(clientKey);
|
||||
|
||||
let values = [BigInt(0), BigInt(1), BigInt(2394), BigInt(2309840239), BigInt(U32_MAX), U256_MAX, U128_MAX];
|
||||
|
||||
let compact_list = CompactFheUint256List.encrypt_with_compact_public_key(values, publicKey);
|
||||
|
||||
{
|
||||
let encrypted_list = compact_list.expand();
|
||||
|
||||
assert.deepStrictEqual(encrypted_list.length, values.length);
|
||||
|
||||
for (let i = 0; i < values.length; i++)
|
||||
{
|
||||
let decrypted = encrypted_list[i].decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
let serialized_list = compact_list.serialize();
|
||||
let deserialized_list = CompactFheUint256List.deserialize(serialized_list);
|
||||
let encrypted_list = deserialized_list.expand();
|
||||
assert.deepStrictEqual(encrypted_list.length, values.length);
|
||||
|
||||
for (let i = 0; i < values.length; i++)
|
||||
{
|
||||
let decrypted = encrypted_list[i].decrypt(clientKey);
|
||||
assert.deepStrictEqual(decrypted, values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint256_small_list_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config);
|
||||
});
|
||||
|
||||
test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_list_compact', (t) => {
|
||||
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
|
||||
let config = TfheConfigBuilder.all_disabled()
|
||||
.enable_custom_integers(block_params)
|
||||
.build();
|
||||
|
||||
hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config);
|
||||
});
|
||||
@@ -94,7 +94,7 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
pub fn bootstrapping_key_size_bytes(&self) -> usize {
|
||||
self.bootstrapping_key_size_elements() * std::mem::size_of::<concrete_fft::c64>()
|
||||
std::mem::size_of_val(self.bootstrapping_key.as_view().data())
|
||||
}
|
||||
|
||||
pub fn key_switching_key_size_elements(&self) -> usize {
|
||||
@@ -102,7 +102,7 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
pub fn key_switching_key_size_bytes(&self) -> usize {
|
||||
self.key_switching_key_size_elements() * std::mem::size_of::<u64>()
|
||||
std::mem::size_of_val(self.key_switching_key.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
use crate::boolean::ciphertext::{Ciphertext, CompressedCiphertext};
|
||||
use crate::boolean::parameters::BooleanParameters;
|
||||
use crate::boolean::{ClientKey, PublicKey, PLAINTEXT_FALSE, PLAINTEXT_TRUE};
|
||||
use crate::boolean::{ClientKey, CompressedPublicKey, PublicKey, PLAINTEXT_FALSE, PLAINTEXT_TRUE};
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use std::cell::RefCell;
|
||||
@@ -144,6 +144,38 @@ impl BooleanEngine {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_compressed_public_key(&mut self, client_key: &ClientKey) -> CompressedPublicKey {
|
||||
let client_parameters = client_key.parameters;
|
||||
|
||||
// Formula is (n + 1) * log2(q) + 128
|
||||
let zero_encryption_count = LwePublicKeyZeroEncryptionCount(
|
||||
client_parameters.lwe_dimension.to_lwe_size().0 * LOG2_Q_32 + 128,
|
||||
);
|
||||
|
||||
#[cfg(not(feature = "__wasm_api"))]
|
||||
let compressed_lwe_public_key = par_allocate_and_generate_new_seeded_lwe_public_key(
|
||||
&client_key.lwe_secret_key,
|
||||
zero_encryption_count,
|
||||
client_key.parameters.lwe_modular_std_dev,
|
||||
CiphertextModulus::new_native(),
|
||||
&mut self.bootstrapper.seeder,
|
||||
);
|
||||
|
||||
#[cfg(feature = "__wasm_api")]
|
||||
let compressed_lwe_public_key = allocate_and_generate_new_seeded_lwe_public_key(
|
||||
&client_key.lwe_secret_key,
|
||||
zero_encryption_count,
|
||||
client_key.parameters.lwe_modular_std_dev,
|
||||
CiphertextModulus::new_native(),
|
||||
&mut self.bootstrapper.seeder,
|
||||
);
|
||||
|
||||
CompressedPublicKey {
|
||||
compressed_lwe_public_key,
|
||||
parameters: client_parameters,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn trivial_encrypt(&mut self, message: bool) -> Ciphertext {
|
||||
Ciphertext::Trivial(message)
|
||||
}
|
||||
@@ -212,6 +244,32 @@ impl BooleanEngine {
|
||||
|
||||
Ciphertext::Encrypted(output)
|
||||
}
|
||||
pub fn encrypt_with_compressed_public_key(
|
||||
&mut self,
|
||||
message: bool,
|
||||
compressed_pk: &CompressedPublicKey,
|
||||
) -> Ciphertext {
|
||||
let plain: Plaintext<u32> = if message {
|
||||
Plaintext(PLAINTEXT_TRUE)
|
||||
} else {
|
||||
Plaintext(PLAINTEXT_FALSE)
|
||||
};
|
||||
|
||||
let mut output = LweCiphertext::new(
|
||||
0u32,
|
||||
compressed_pk.parameters.lwe_dimension.to_lwe_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
|
||||
encrypt_lwe_ciphertext_with_seeded_public_key(
|
||||
&compressed_pk.compressed_lwe_public_key,
|
||||
&mut output,
|
||||
plain,
|
||||
&mut self.secret_generator,
|
||||
);
|
||||
|
||||
Ciphertext::Encrypted(output)
|
||||
}
|
||||
|
||||
pub fn decrypt(&mut self, ct: &Ciphertext, cks: &ClientKey) -> bool {
|
||||
match ct {
|
||||
|
||||
@@ -54,7 +54,7 @@
|
||||
|
||||
use crate::boolean::client_key::ClientKey;
|
||||
use crate::boolean::parameters::DEFAULT_PARAMETERS;
|
||||
use crate::boolean::public_key::PublicKey;
|
||||
use crate::boolean::public_key::{CompressedPublicKey, PublicKey};
|
||||
use crate::boolean::server_key::ServerKey;
|
||||
#[cfg(test)]
|
||||
use rand::Rng;
|
||||
|
||||
@@ -7,5 +7,5 @@ pub use super::ciphertext::{Ciphertext, CompressedCiphertext};
|
||||
pub use super::client_key::ClientKey;
|
||||
pub use super::gen_keys;
|
||||
pub use super::parameters::*;
|
||||
pub use super::public_key::PublicKey;
|
||||
pub use super::public_key::{CompressedPublicKey, PublicKey};
|
||||
pub use super::server_key::{BinaryBooleanGates, ServerKey};
|
||||
|
||||
132
tfhe/src/boolean/public_key/compressed.rs
Normal file
132
tfhe/src/boolean/public_key/compressed.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::boolean::engine::{BooleanEngine, WithThreadLocalEngine};
|
||||
use crate::boolean::prelude::{BooleanParameters, Ciphertext, ClientKey};
|
||||
use crate::core_crypto::prelude::SeededLwePublicKeyOwned;
|
||||
|
||||
/// A structure containing a compressed public key.
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
pub struct CompressedPublicKey {
|
||||
pub(crate) compressed_lwe_public_key: SeededLwePublicKeyOwned<u32>,
|
||||
pub parameters: BooleanParameters,
|
||||
}
|
||||
|
||||
impl CompressedPublicKey {
|
||||
/// Generates a new public key that is compressed
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let cpks = CompressedPublicKey::new(&cks);
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Decompressing the key
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let cpks = CompressedPublicKey::new(&cks);
|
||||
/// let pks = PublicKey::from(cpks);
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(client_key: &ClientKey) -> Self {
|
||||
BooleanEngine::with_thread_local_mut(|engine| {
|
||||
engine.create_compressed_public_key(client_key)
|
||||
})
|
||||
}
|
||||
|
||||
/// Encrypt a Boolean message using the compressed public key.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// It is recommended to use the compressed
|
||||
/// public key to save on storage / tranfert
|
||||
/// and decompress it in the program before doing encryptions.
|
||||
///
|
||||
/// This is because encrypting using the compressed public key
|
||||
/// will require to lazyly decompress parts of the key
|
||||
/// for each encryption.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let cpks = CompressedPublicKey::new(&cks);
|
||||
///
|
||||
/// // Encryption of one message:
|
||||
/// let ct1 = cpks.encrypt(true);
|
||||
/// let ct2 = cpks.encrypt(false);
|
||||
/// let ct_res = sks.and(&ct1, &ct2);
|
||||
///
|
||||
/// // Decryption:
|
||||
/// let dec = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(false, dec);
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn encrypt(&self, message: bool) -> Ciphertext {
|
||||
BooleanEngine::with_thread_local_mut(|engine| {
|
||||
engine.encrypt_with_compressed_public_key(message, self)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::boolean::prelude::{
|
||||
BinaryBooleanGates, BooleanParameters, ClientKey, CompressedPublicKey, ServerKey,
|
||||
DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS,
|
||||
};
|
||||
use crate::boolean::random_boolean;
|
||||
const NB_TEST: usize = 32;
|
||||
|
||||
#[test]
|
||||
fn test_compressed_public_key_default_parameters() {
|
||||
test_compressed_public_key(DEFAULT_PARAMETERS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compressed_public_key_tfhe_lib_parameters() {
|
||||
test_compressed_public_key(TFHE_LIB_PARAMETERS);
|
||||
}
|
||||
|
||||
fn test_compressed_public_key(parameters: BooleanParameters) {
|
||||
let cks = ClientKey::new(¶meters);
|
||||
let sks = ServerKey::new(&cks);
|
||||
let cpks = CompressedPublicKey::new(&cks);
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
let expected_result = !(b1 && b2);
|
||||
|
||||
let ct1 = cpks.encrypt(b1);
|
||||
let ct2 = cpks.encrypt(b2);
|
||||
|
||||
let ct_res = sks.nand(&ct1, &ct2);
|
||||
|
||||
let dec_ct1 = cks.decrypt(&ct1);
|
||||
let dec_ct2 = cks.decrypt(&ct2);
|
||||
let dec_nand = cks.decrypt(&ct_res);
|
||||
|
||||
assert_eq!(dec_ct1, b1);
|
||||
assert_eq!(dec_ct2, b2);
|
||||
assert_eq!(dec_nand, expected_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,62 +1,5 @@
|
||||
//! Module with the definition of the encryption PublicKey.
|
||||
mod compressed;
|
||||
mod standard;
|
||||
|
||||
use crate::boolean::ciphertext::Ciphertext;
|
||||
use crate::boolean::client_key::ClientKey;
|
||||
use crate::boolean::engine::{BooleanEngine, WithThreadLocalEngine};
|
||||
use crate::boolean::parameters::BooleanParameters;
|
||||
use crate::core_crypto::entities::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A structure containing a public key.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PublicKey {
|
||||
pub(crate) lwe_public_key: LwePublicKeyOwned<u32>,
|
||||
pub(crate) parameters: BooleanParameters,
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Encrypt a Boolean message using the client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let pks = PublicKey::new(&cks);
|
||||
///
|
||||
/// // Encryption of one message:
|
||||
/// let ct1 = pks.encrypt(true);
|
||||
/// let ct2 = pks.encrypt(false);
|
||||
/// let ct_res = sks.and(&ct1, &ct2);
|
||||
///
|
||||
/// // Decryption:
|
||||
/// let dec = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(false, dec);
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn encrypt(&self, message: bool) -> Ciphertext {
|
||||
BooleanEngine::with_thread_local_mut(|engine| engine.encrypt_with_public_key(message, self))
|
||||
}
|
||||
|
||||
/// Allocate and generate a client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let pks = PublicKey::new(&cks);
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(client_key: &ClientKey) -> PublicKey {
|
||||
BooleanEngine::with_thread_local_mut(|engine| engine.create_public_key(client_key))
|
||||
}
|
||||
}
|
||||
pub use compressed::CompressedPublicKey;
|
||||
pub use standard::PublicKey;
|
||||
|
||||
162
tfhe/src/boolean/public_key/standard.rs
Normal file
162
tfhe/src/boolean/public_key/standard.rs
Normal file
@@ -0,0 +1,162 @@
|
||||
//! Module with the definition of the encryption PublicKey.
|
||||
|
||||
use crate::boolean::ciphertext::Ciphertext;
|
||||
use crate::boolean::client_key::ClientKey;
|
||||
use crate::boolean::engine::{BooleanEngine, WithThreadLocalEngine};
|
||||
use crate::boolean::parameters::BooleanParameters;
|
||||
use crate::core_crypto::entities::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::compressed::CompressedPublicKey;
|
||||
|
||||
/// A structure containing a public key.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PublicKey {
|
||||
pub(crate) lwe_public_key: LwePublicKeyOwned<u32>,
|
||||
pub(crate) parameters: BooleanParameters,
|
||||
}
|
||||
|
||||
impl PublicKey {
|
||||
/// Encrypt a Boolean message using the public key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let pks = PublicKey::new(&cks);
|
||||
///
|
||||
/// // Encryption of one message:
|
||||
/// let ct1 = pks.encrypt(true);
|
||||
/// let ct2 = pks.encrypt(false);
|
||||
/// let ct_res = sks.and(&ct1, &ct2);
|
||||
///
|
||||
/// // Decryption:
|
||||
/// let dec = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(false, dec);
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn encrypt(&self, message: bool) -> Ciphertext {
|
||||
BooleanEngine::with_thread_local_mut(|engine| engine.encrypt_with_public_key(message, self))
|
||||
}
|
||||
|
||||
/// Allocate and generate a client key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # fn main() {
|
||||
/// use tfhe::boolean::prelude::*;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys();
|
||||
///
|
||||
/// let pks = PublicKey::new(&cks);
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(client_key: &ClientKey) -> PublicKey {
|
||||
BooleanEngine::with_thread_local_mut(|engine| engine.create_public_key(client_key))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CompressedPublicKey> for PublicKey {
|
||||
fn from(compressed_public_key: CompressedPublicKey) -> Self {
|
||||
let parameters = compressed_public_key.parameters;
|
||||
|
||||
let decompressed_public_key = compressed_public_key
|
||||
.compressed_lwe_public_key
|
||||
.decompress_into_lwe_public_key();
|
||||
|
||||
Self {
|
||||
lwe_public_key: decompressed_public_key,
|
||||
parameters,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::boolean::prelude::{
|
||||
BinaryBooleanGates, BooleanParameters, ClientKey, CompressedPublicKey, ServerKey,
|
||||
DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS,
|
||||
};
|
||||
use crate::boolean::random_boolean;
|
||||
|
||||
use super::PublicKey;
|
||||
const NB_TEST: usize = 32;
|
||||
|
||||
#[test]
|
||||
fn test_public_key_default_parameters() {
|
||||
test_public_key(DEFAULT_PARAMETERS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_public_key_tfhe_lib_parameters() {
|
||||
test_public_key(TFHE_LIB_PARAMETERS);
|
||||
}
|
||||
|
||||
fn test_public_key(parameters: BooleanParameters) {
|
||||
let cks = ClientKey::new(¶meters);
|
||||
let sks = ServerKey::new(&cks);
|
||||
let pks = PublicKey::new(&cks);
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
let expected_result = !(b1 && b2);
|
||||
|
||||
let ct1 = pks.encrypt(b1);
|
||||
let ct2 = pks.encrypt(b2);
|
||||
|
||||
let ct_res = sks.nand(&ct1, &ct2);
|
||||
|
||||
let dec_ct1 = cks.decrypt(&ct1);
|
||||
let dec_ct2 = cks.decrypt(&ct2);
|
||||
let dec_nand = cks.decrypt(&ct_res);
|
||||
|
||||
assert_eq!(dec_ct1, b1);
|
||||
assert_eq!(dec_ct2, b2);
|
||||
assert_eq!(dec_nand, expected_result);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompressing_public_key_default_parameters() {
|
||||
test_public_key(DEFAULT_PARAMETERS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decompressing_public_key_tfhe_lib_parameters() {
|
||||
test_decompressing_public_key(TFHE_LIB_PARAMETERS);
|
||||
}
|
||||
|
||||
fn test_decompressing_public_key(parameters: BooleanParameters) {
|
||||
let cks = ClientKey::new(¶meters);
|
||||
let sks = ServerKey::new(&cks);
|
||||
let cpks = CompressedPublicKey::new(&cks);
|
||||
let pks = PublicKey::from(cpks);
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
let b1 = random_boolean();
|
||||
let b2 = random_boolean();
|
||||
let expected_result = !(b1 && b2);
|
||||
|
||||
let ct1 = pks.encrypt(b1);
|
||||
let ct2 = pks.encrypt(b2);
|
||||
|
||||
let ct_res = sks.nand(&ct1, &ct2);
|
||||
|
||||
let dec_ct1 = cks.decrypt(&ct1);
|
||||
let dec_ct2 = cks.decrypt(&ct2);
|
||||
let dec_nand = cks.decrypt(&ct_res);
|
||||
|
||||
assert_eq!(dec_ct1, b1);
|
||||
assert_eq!(dec_ct2, b2);
|
||||
assert_eq!(dec_nand, expected_result);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -72,6 +72,20 @@ define_enable_default_fn!(integers);
|
||||
#[cfg(feature = "integer")]
|
||||
define_enable_default_fn!(integers @small);
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn config_builder_enable_custom_integers(
|
||||
builder: *mut *mut ConfigBuilder,
|
||||
shortint_block_parameters: crate::c_api::shortint::parameters::ShortintPBSParameters,
|
||||
) -> ::std::os::raw::c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(builder).unwrap();
|
||||
|
||||
let inner = Box::from_raw(*builder)
|
||||
.0
|
||||
.enable_custom_integers(shortint_block_parameters.into(), None);
|
||||
*builder = Box::into_raw(Box::new(ConfigBuilder(inner)));
|
||||
})
|
||||
}
|
||||
/// Takes ownership of the builder
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn config_builder_build(
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use crate::c_api::high_level_api::keys::{ClientKey, PublicKey};
|
||||
use crate::c_api::high_level_api::keys::{ClientKey, CompactPublicKey, PublicKey};
|
||||
use crate::high_level_api::prelude::*;
|
||||
use std::ops::{
|
||||
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, MulAssign,
|
||||
Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
|
||||
};
|
||||
|
||||
use crate::c_api::high_level_api::u128::U128;
|
||||
use crate::c_api::high_level_api::u256::U256;
|
||||
use crate::c_api::utils::*;
|
||||
use std::os::raw::c_int;
|
||||
@@ -67,6 +68,49 @@ macro_rules! create_integer_wrapper_type {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// The compact list version of the ciphertext type
|
||||
::paste::paste! {
|
||||
pub struct [<Compact $name List>]($crate::high_level_api::[<Compact $name List>]);
|
||||
|
||||
impl_destroy_on_type!([<Compact $name List>]);
|
||||
|
||||
impl_clone_on_type!([<Compact $name List>]);
|
||||
|
||||
impl_serialize_deserialize_on_type!([<Compact $name List>]);
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn [<compact_ $name:snake _list_len>](
|
||||
sself: *const [<Compact $name List>],
|
||||
result: *mut usize,
|
||||
) -> ::std::os::raw::c_int {
|
||||
$crate::c_api::utils::catch_panic(|| {
|
||||
let list = $crate::c_api::utils::get_ref_checked(sself).unwrap();
|
||||
|
||||
*result = list.0.len();
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn [<compact_ $name:snake _list_expand>](
|
||||
sself: *const [<Compact $name List>],
|
||||
output: *mut *mut $name,
|
||||
output_len: usize
|
||||
) -> ::std::os::raw::c_int {
|
||||
$crate::c_api::utils::catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(output).unwrap();
|
||||
let list = $crate::c_api::utils::get_ref_checked(sself).unwrap();
|
||||
let expanded = list.0.expand();
|
||||
|
||||
let num_to_take = output_len.max(list.0.len());
|
||||
let iter = expanded.into_iter().take(num_to_take).enumerate();
|
||||
for (i, fhe_uint) in iter {
|
||||
let ptr = output.wrapping_add(i);
|
||||
*ptr = Box::into_raw(Box::new($name(fhe_uint)));
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -84,52 +128,65 @@ impl_decrypt_on_type!(FheUint8, u8);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint8{crate::high_level_api::CompressedFheUint8}, u8);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint8List{crate::high_level_api::CompactFheUint8List}, u8);
|
||||
|
||||
impl_decrypt_on_type!(FheUint10, u16);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint10{crate::high_level_api::CompressedFheUint10}, u16);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint10List{crate::high_level_api::CompactFheUint10List}, u16);
|
||||
|
||||
impl_decrypt_on_type!(FheUint12, u16);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint12{crate::high_level_api::CompressedFheUint12}, u16);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint12List{crate::high_level_api::CompactFheUint12List}, u16);
|
||||
|
||||
impl_decrypt_on_type!(FheUint14, u16);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint14{crate::high_level_api::CompressedFheUint14}, u16);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint14List{crate::high_level_api::CompactFheUint14List}, u16);
|
||||
|
||||
impl_decrypt_on_type!(FheUint16, u16);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint16{crate::high_level_api::CompressedFheUint16}, u16);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint16List{crate::high_level_api::CompactFheUint16List}, u16);
|
||||
|
||||
impl_decrypt_on_type!(FheUint32, u32);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint32{crate::high_level_api::CompressedFheUint32}, u32);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint32List{crate::high_level_api::CompactFheUint32List}, u32);
|
||||
|
||||
impl_decrypt_on_type!(FheUint64, u64);
|
||||
impl_try_encrypt_trivial_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
|
||||
impl_try_encrypt_with_client_key_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
|
||||
impl_try_encrypt_with_public_key_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
|
||||
impl_try_encrypt_with_compact_public_key_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
|
||||
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint64{crate::high_level_api::CompressedFheUint64}, u64);
|
||||
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint64List{crate::high_level_api::CompactFheUint64List}, u64);
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint128_try_encrypt_trivial_u128(
|
||||
low_word: u64,
|
||||
high_word: u64,
|
||||
value: U128,
|
||||
result: *mut *mut FheUint128,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let value = ((high_word as u128) << 64u128) | low_word as u128;
|
||||
let value = u128::from(value);
|
||||
|
||||
let inner = <crate::high_level_api::FheUint128>::try_encrypt_trivial(value).unwrap();
|
||||
|
||||
@@ -139,15 +196,14 @@ pub unsafe extern "C" fn fhe_uint128_try_encrypt_trivial_u128(
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_client_key_u128(
|
||||
low_word: u64,
|
||||
high_word: u64,
|
||||
value: U128,
|
||||
client_key: *const ClientKey,
|
||||
result: *mut *mut FheUint128,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
|
||||
let value = ((high_word as u128) << 64u128) | low_word as u128;
|
||||
let value = u128::from(value);
|
||||
|
||||
let inner = <crate::high_level_api::FheUint128>::try_encrypt(value, &client_key.0).unwrap();
|
||||
|
||||
@@ -157,15 +213,14 @@ pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_client_key_u128(
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compressed_fhe_uint128_try_encrypt_with_client_key_u128(
|
||||
low_word: u64,
|
||||
high_word: u64,
|
||||
value: U128,
|
||||
client_key: *const ClientKey,
|
||||
result: *mut *mut CompressedFheUint128,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
|
||||
let value = ((high_word as u128) << 64u128) | low_word as u128;
|
||||
let value = u128::from(value);
|
||||
|
||||
let inner =
|
||||
<crate::high_level_api::CompressedFheUint128>::try_encrypt(value, &client_key.0)
|
||||
@@ -177,15 +232,14 @@ pub unsafe extern "C" fn compressed_fhe_uint128_try_encrypt_with_client_key_u128
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_public_key_u128(
|
||||
low_word: u64,
|
||||
high_word: u64,
|
||||
value: U128,
|
||||
public_key: *const PublicKey,
|
||||
result: *mut *mut FheUint128,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
let value = ((high_word as u128) << 64u128) | low_word as u128;
|
||||
let value = u128::from(value);
|
||||
|
||||
let inner = <crate::high_level_api::FheUint128>::try_encrypt(value, &public_key.0).unwrap();
|
||||
|
||||
@@ -193,12 +247,48 @@ pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_public_key_u128(
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_compact_public_key_u128(
|
||||
value: U128,
|
||||
public_key: *const CompactPublicKey,
|
||||
result: *mut *mut FheUint128,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
let value = u128::from(value);
|
||||
|
||||
let inner = <crate::high_level_api::FheUint128>::try_encrypt(value, &public_key.0).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(FheUint128(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compact_fhe_uint256_list_try_encrypt_with_compact_public_key_u128(
|
||||
input: *const U128,
|
||||
input_len: usize,
|
||||
public_key: *const CompactPublicKey,
|
||||
result: *mut *mut CompactFheUint256List,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
let slc = ::std::slice::from_raw_parts(input, input_len);
|
||||
let values = slc.iter().copied().map(u128::from).collect::<Vec<_>>();
|
||||
let inner =
|
||||
<crate::high_level_api::CompactFheUint256List>::try_encrypt(&values, &public_key.0)
|
||||
.unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(CompactFheUint256List(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint128_decrypt(
|
||||
encrypted_value: *const FheUint128,
|
||||
client_key: *const ClientKey,
|
||||
low_word: *mut u64,
|
||||
high_word: *mut u64,
|
||||
result: *mut U128,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
@@ -206,18 +296,18 @@ pub unsafe extern "C" fn fhe_uint128_decrypt(
|
||||
|
||||
let inner: u128 = encrypted_value.0.decrypt(&client_key.0);
|
||||
|
||||
*low_word = (inner & (u64::MAX as u128)) as u64;
|
||||
*high_word = (inner >> 64) as u64;
|
||||
*result = U128::from(inner);
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint256_try_encrypt_trivial_u256(
|
||||
value: *const U256,
|
||||
value: U256,
|
||||
result: *mut *mut FheUint256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let inner = <crate::high_level_api::FheUint256>::try_encrypt_trivial((*value).0).unwrap();
|
||||
let value = crate::integer::U256::from(value);
|
||||
let inner = <crate::high_level_api::FheUint256>::try_encrypt_trivial(value).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(FheUint256(inner)));
|
||||
})
|
||||
@@ -225,15 +315,15 @@ pub unsafe extern "C" fn fhe_uint256_try_encrypt_trivial_u256(
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_client_key_u256(
|
||||
value: *const U256,
|
||||
value: U256,
|
||||
client_key: *const ClientKey,
|
||||
result: *mut *mut FheUint256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
|
||||
let inner =
|
||||
<crate::high_level_api::FheUint256>::try_encrypt((*value).0, &client_key.0).unwrap();
|
||||
let value = crate::integer::U256::from(value);
|
||||
let inner = <crate::high_level_api::FheUint256>::try_encrypt(value, &client_key.0).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(FheUint256(inner)));
|
||||
})
|
||||
@@ -241,15 +331,16 @@ pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_client_key_u256(
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compressed_fhe_uint256_try_encrypt_with_client_key_u256(
|
||||
value: *const U256,
|
||||
value: U256,
|
||||
client_key: *const ClientKey,
|
||||
result: *mut *mut CompressedFheUint256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
|
||||
let value = crate::integer::U256::from(value);
|
||||
let inner =
|
||||
<crate::high_level_api::CompressedFheUint256>::try_encrypt((*value).0, &client_key.0)
|
||||
<crate::high_level_api::CompressedFheUint256>::try_encrypt(value, &client_key.0)
|
||||
.unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(CompressedFheUint256(inner)));
|
||||
@@ -258,32 +349,72 @@ pub unsafe extern "C" fn compressed_fhe_uint256_try_encrypt_with_client_key_u256
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_public_key_u256(
|
||||
value: *const U256,
|
||||
value: U256,
|
||||
public_key: *const PublicKey,
|
||||
result: *mut *mut FheUint256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
let inner =
|
||||
<crate::high_level_api::FheUint256>::try_encrypt((*value).0, &public_key.0).unwrap();
|
||||
let value = crate::integer::U256::from(value);
|
||||
let inner = <crate::high_level_api::FheUint256>::try_encrypt(value, &public_key.0).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(FheUint256(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_compact_public_key_u256(
|
||||
value: U256,
|
||||
public_key: *const CompactPublicKey,
|
||||
result: *mut *mut FheUint256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
let value = crate::integer::U256::from(value);
|
||||
let inner = <crate::high_level_api::FheUint256>::try_encrypt(value, &public_key.0).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(FheUint256(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compact_fhe_uint256_list_try_encrypt_with_compact_public_key_u256(
|
||||
input: *const U256,
|
||||
input_len: usize,
|
||||
public_key: *const CompactPublicKey,
|
||||
result: *mut *mut CompactFheUint256List,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
let slc = ::std::slice::from_raw_parts(input, input_len);
|
||||
let values = slc
|
||||
.iter()
|
||||
.copied()
|
||||
.map(crate::integer::U256::from)
|
||||
.collect::<Vec<_>>();
|
||||
let inner =
|
||||
<crate::high_level_api::CompactFheUint256List>::try_encrypt(&values, &public_key.0)
|
||||
.unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new(CompactFheUint256List(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn fhe_uint256_decrypt(
|
||||
encrypted_value: *const FheUint256,
|
||||
client_key: *const ClientKey,
|
||||
result: *mut *mut U256,
|
||||
result: *mut U256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
let encrypted_value = get_ref_checked(encrypted_value).unwrap();
|
||||
|
||||
let inner: crate::integer::U256 = encrypted_value.0.decrypt(&client_key.0);
|
||||
*result = Box::into_raw(Box::new(U256(inner)));
|
||||
*result = U256::from(inner);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -3,14 +3,20 @@ use std::os::raw::c_int;
|
||||
|
||||
pub struct ClientKey(pub(crate) crate::high_level_api::ClientKey);
|
||||
pub struct PublicKey(pub(crate) crate::high_level_api::PublicKey);
|
||||
pub struct CompactPublicKey(pub(crate) crate::high_level_api::CompactPublicKey);
|
||||
pub struct CompressedCompactPublicKey(pub(crate) crate::high_level_api::CompressedCompactPublicKey);
|
||||
pub struct ServerKey(pub(crate) crate::high_level_api::ServerKey);
|
||||
|
||||
impl_destroy_on_type!(ClientKey);
|
||||
impl_destroy_on_type!(PublicKey);
|
||||
impl_destroy_on_type!(CompactPublicKey);
|
||||
impl_destroy_on_type!(CompressedCompactPublicKey);
|
||||
impl_destroy_on_type!(ServerKey);
|
||||
|
||||
impl_serialize_deserialize_on_type!(ClientKey);
|
||||
impl_serialize_deserialize_on_type!(PublicKey);
|
||||
impl_serialize_deserialize_on_type!(CompactPublicKey);
|
||||
impl_serialize_deserialize_on_type!(CompressedCompactPublicKey);
|
||||
impl_serialize_deserialize_on_type!(ServerKey);
|
||||
|
||||
#[no_mangle]
|
||||
@@ -69,3 +75,43 @@ pub unsafe extern "C" fn public_key_new(
|
||||
*result_public_key = Box::into_raw(Box::new(PublicKey(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compact_public_key_new(
|
||||
client_key: *const ClientKey,
|
||||
result_public_key: *mut *mut CompactPublicKey,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
let inner = crate::high_level_api::CompactPublicKey::new(&client_key.0);
|
||||
|
||||
*result_public_key = Box::into_raw(Box::new(CompactPublicKey(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compressed_compact_public_key_new(
|
||||
client_key: *const ClientKey,
|
||||
result_public_key: *mut *mut CompressedCompactPublicKey,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let client_key = get_ref_checked(client_key).unwrap();
|
||||
let inner = crate::high_level_api::CompressedCompactPublicKey::new(&client_key.0);
|
||||
|
||||
*result_public_key = Box::into_raw(Box::new(CompressedCompactPublicKey(inner)));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn compressed_compact_public_key_decompress(
|
||||
public_key: *const CompressedCompactPublicKey,
|
||||
result_public_key: *mut *mut CompactPublicKey,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let public_key = get_ref_checked(public_key).unwrap();
|
||||
|
||||
*result_public_key = Box::into_raw(Box::new(CompactPublicKey(
|
||||
public_key.0.clone().decompress(),
|
||||
)));
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,4 +7,6 @@ pub mod config;
|
||||
pub mod integers;
|
||||
pub mod keys;
|
||||
#[cfg(feature = "integer")]
|
||||
pub mod u128;
|
||||
#[cfg(feature = "integer")]
|
||||
pub mod u256;
|
||||
|
||||
20
tfhe/src/c_api/high_level_api/u128.rs
Normal file
20
tfhe/src/c_api/high_level_api/u128.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct U128 {
|
||||
pub w0: u64,
|
||||
pub w1: u64,
|
||||
}
|
||||
|
||||
impl From<u128> for U128 {
|
||||
fn from(value: u128) -> Self {
|
||||
let w0 = (value & (u64::MAX as u128)) as u64;
|
||||
let w1 = (value >> 64) as u64;
|
||||
Self { w0, w1 }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<U128> for u128 {
|
||||
fn from(value: U128) -> Self {
|
||||
((value.w1 as u128) << 64u128) | value.w0 as u128
|
||||
}
|
||||
}
|
||||
@@ -1,47 +1,30 @@
|
||||
use crate::c_api::utils::*;
|
||||
use std::os::raw::c_int;
|
||||
|
||||
pub struct U256(pub(in crate::c_api) crate::integer::U256);
|
||||
|
||||
impl_destroy_on_type!(U256);
|
||||
|
||||
/// w0 is the least significant, w4 is the most significant
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn u256_from_u64_words(
|
||||
w0: u64,
|
||||
w1: u64,
|
||||
w2: u64,
|
||||
w3: u64,
|
||||
result: *mut *mut U256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let inner = crate::integer::U256::from((w0, w1, w2, w3));
|
||||
*result = Box::into_raw(Box::new(U256(inner)));
|
||||
})
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct U256 {
|
||||
pub w0: u64,
|
||||
pub w1: u64,
|
||||
pub w2: u64,
|
||||
pub w3: u64,
|
||||
}
|
||||
|
||||
/// w0 is the least significant, w4 is the most significant
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn u256_to_u64_words(
|
||||
input: *const U256,
|
||||
w0: *mut u64,
|
||||
w1: *mut u64,
|
||||
w2: *mut u64,
|
||||
w3: *mut u64,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let input = get_ref_checked(input).unwrap();
|
||||
impl From<crate::integer::U256> for U256 {
|
||||
fn from(value: crate::integer::U256) -> Self {
|
||||
Self {
|
||||
w0: value.0[0],
|
||||
w1: value.0[1],
|
||||
w2: value.0[2],
|
||||
w3: value.0[3],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
check_ptr_is_non_null_and_aligned(w0).unwrap();
|
||||
check_ptr_is_non_null_and_aligned(w1).unwrap();
|
||||
check_ptr_is_non_null_and_aligned(w2).unwrap();
|
||||
check_ptr_is_non_null_and_aligned(w3).unwrap();
|
||||
|
||||
*w0 = input.0 .0[0];
|
||||
*w1 = input.0 .0[1];
|
||||
*w2 = input.0 .0[2];
|
||||
*w3 = input.0 .0[3];
|
||||
})
|
||||
impl From<U256> for crate::integer::U256 {
|
||||
fn from(value: U256) -> Self {
|
||||
Self([value.w0, value.w1, value.w2, value.w3])
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a U256 from little endian bytes
|
||||
@@ -51,7 +34,7 @@ pub unsafe extern "C" fn u256_to_u64_words(
|
||||
pub unsafe extern "C" fn u256_from_little_endian_bytes(
|
||||
input: *const u8,
|
||||
len: usize,
|
||||
result: *mut *mut U256,
|
||||
result: *mut U256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let mut inner = crate::integer::U256::default();
|
||||
@@ -59,7 +42,7 @@ pub unsafe extern "C" fn u256_from_little_endian_bytes(
|
||||
let input = std::slice::from_raw_parts(input, len);
|
||||
inner.copy_from_le_byte_slice(input);
|
||||
|
||||
*result = Box::into_raw(Box::new(U256(inner)));
|
||||
*result = U256::from(inner)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -70,7 +53,7 @@ pub unsafe extern "C" fn u256_from_little_endian_bytes(
|
||||
pub unsafe extern "C" fn u256_from_big_endian_bytes(
|
||||
input: *const u8,
|
||||
len: usize,
|
||||
result: *mut *mut U256,
|
||||
result: *mut U256,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
let mut inner = crate::integer::U256::default();
|
||||
@@ -78,38 +61,32 @@ pub unsafe extern "C" fn u256_from_big_endian_bytes(
|
||||
let input = std::slice::from_raw_parts(input, len);
|
||||
inner.copy_from_be_byte_slice(input);
|
||||
|
||||
*result = Box::into_raw(Box::new(U256(inner)));
|
||||
*result = U256::from(inner)
|
||||
})
|
||||
}
|
||||
|
||||
/// len must be 32
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn u256_little_endian_bytes(
|
||||
input: *const U256,
|
||||
input: U256,
|
||||
result: *mut u8,
|
||||
len: usize,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
let input = get_ref_checked(input).unwrap();
|
||||
|
||||
let bytes = std::slice::from_raw_parts_mut(result, len);
|
||||
input.0.copy_to_le_byte_slice(bytes);
|
||||
crate::integer::U256::from(input).copy_to_le_byte_slice(bytes);
|
||||
})
|
||||
}
|
||||
|
||||
/// len must be 32
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn u256_big_endian_bytes(
|
||||
input: *const U256,
|
||||
result: *mut u8,
|
||||
len: usize,
|
||||
) -> c_int {
|
||||
pub unsafe extern "C" fn u256_big_endian_bytes(input: U256, result: *mut u8, len: usize) -> c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
let input = get_ref_checked(input).unwrap();
|
||||
|
||||
let bytes = std::slice::from_raw_parts_mut(result, len);
|
||||
input.0.copy_to_be_byte_slice(bytes);
|
||||
crate::integer::U256::from(input).copy_to_be_byte_slice(bytes);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -59,6 +59,49 @@ macro_rules! impl_try_encrypt_with_public_key_on_type {
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_try_encrypt_with_compact_public_key_on_type {
|
||||
($wrapper_type:ty{$wrapped_type:ty}, $input_type:ty) => {
|
||||
::paste::paste! {
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn [<$wrapper_type:snake _try_encrypt_with_compact_public_key_ $input_type:snake>](
|
||||
value: $input_type,
|
||||
public_key: *const $crate::c_api::high_level_api::keys::CompactPublicKey,
|
||||
result: *mut *mut $wrapper_type,
|
||||
) -> ::std::os::raw::c_int {
|
||||
$crate::c_api::utils::catch_panic(|| {
|
||||
let public_key = $crate::c_api::utils::get_ref_checked(public_key).unwrap();
|
||||
|
||||
let inner = <$wrapped_type>::try_encrypt(value, &public_key.0).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new($wrapper_type(inner)));
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_try_encrypt_list_with_compact_public_key_on_type {
|
||||
($wrapper_type:ty{$wrapped_type:ty}, $input_type:ty) => {
|
||||
::paste::paste! {
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn [<$wrapper_type:snake _try_encrypt_with_compact_public_key_ $input_type:snake>](
|
||||
input: *const $input_type,
|
||||
input_len: usize,
|
||||
public_key: *const $crate::c_api::high_level_api::keys::CompactPublicKey,
|
||||
result: *mut *mut $wrapper_type,
|
||||
) -> ::std::os::raw::c_int {
|
||||
$crate::c_api::utils::catch_panic(|| {
|
||||
let public_key = $crate::c_api::utils::get_ref_checked(public_key).unwrap();
|
||||
let slc = ::std::slice::from_raw_parts(input, input_len);
|
||||
let inner = <$wrapped_type>::try_encrypt(slc, &public_key.0).unwrap();
|
||||
|
||||
*result = Box::into_raw(Box::new($wrapper_type(inner)));
|
||||
})
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_try_encrypt_trivial_on_type {
|
||||
($wrapper_type:ty{$wrapped_type:ty}, $input_type:ty) => {
|
||||
::paste::paste! {
|
||||
|
||||
@@ -10,7 +10,7 @@ pub struct ShortintClientKey(pub(in crate::c_api) shortint::client_key::ClientKe
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn shortint_gen_client_key(
|
||||
shortint_parameters: *const super::parameters::ShortintParameters,
|
||||
shortint_parameters: super::parameters::ShortintPBSParameters,
|
||||
result_client_key: *mut *mut ShortintClientKey,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
@@ -20,9 +20,10 @@ pub unsafe extern "C" fn shortint_gen_client_key(
|
||||
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
|
||||
*result_client_key = std::ptr::null_mut();
|
||||
|
||||
let shortint_parameters = get_ref_checked(shortint_parameters).unwrap();
|
||||
let shortint_parameters: crate::shortint::parameters::ClassicPBSParameters =
|
||||
shortint_parameters.into();
|
||||
|
||||
let client_key = shortint::client_key::ClientKey::new(shortint_parameters.0.to_owned());
|
||||
let client_key = shortint::client_key::ClientKey::new(shortint_parameters);
|
||||
|
||||
let heap_allocated_client_key = Box::new(ShortintClientKey(client_key));
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use crate::c_api::utils::*;
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use super::parameters::ShortintParameters;
|
||||
use super::{
|
||||
ShortintBivariatePBSLookupTable, ShortintCiphertext, ShortintClientKey,
|
||||
ShortintCompressedCiphertext, ShortintCompressedPublicKey, ShortintCompressedServerKey,
|
||||
@@ -57,17 +56,6 @@ pub unsafe extern "C" fn destroy_shortint_compressed_public_key(
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn destroy_shortint_parameters(
|
||||
shortint_parameters: *mut ShortintParameters,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(shortint_parameters).unwrap();
|
||||
|
||||
drop(Box::from_raw(shortint_parameters));
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn destroy_shortint_ciphertext(
|
||||
shortint_ciphertext: *mut ShortintCiphertext,
|
||||
|
||||
@@ -19,7 +19,7 @@ pub use server_key::{ShortintCompressedServerKey, ShortintServerKey};
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn shortint_gen_keys_with_parameters(
|
||||
shortint_parameters: *const parameters::ShortintParameters,
|
||||
shortint_parameters: parameters::ShortintPBSParameters,
|
||||
result_client_key: *mut *mut ShortintClientKey,
|
||||
result_server_key: *mut *mut ShortintServerKey,
|
||||
) -> c_int {
|
||||
@@ -32,9 +32,10 @@ pub unsafe extern "C" fn shortint_gen_keys_with_parameters(
|
||||
*result_client_key = std::ptr::null_mut();
|
||||
*result_server_key = std::ptr::null_mut();
|
||||
|
||||
let shortint_parameters = get_ref_checked(shortint_parameters).unwrap();
|
||||
let shortint_parameters: crate::shortint::parameters::ClassicPBSParameters =
|
||||
shortint_parameters.into();
|
||||
|
||||
let client_key = shortint::client_key::ClientKey::new(shortint_parameters.0.to_owned());
|
||||
let client_key = shortint::client_key::ClientKey::new(shortint_parameters);
|
||||
let server_key = shortint::server_key::ServerKey::new(&client_key);
|
||||
|
||||
let heap_allocated_client_key = Box::new(ShortintClientKey(client_key));
|
||||
|
||||
@@ -3,11 +3,14 @@ pub use crate::core_crypto::commons::dispersion::StandardDev;
|
||||
pub use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
|
||||
};
|
||||
pub use crate::shortint::parameters::parameters_compact_pk::*;
|
||||
pub use crate::shortint::parameters::*;
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use crate::shortint;
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub enum ShortintEncryptionKeyChoice {
|
||||
ShortintEncryptionKeyChoiceBig,
|
||||
ShortintEncryptionKeyChoiceSmall,
|
||||
@@ -26,13 +29,185 @@ impl From<ShortintEncryptionKeyChoice> for crate::shortint::parameters::Encrypti
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShortintParameters(pub(in crate::c_api) shortint::parameters::Parameters);
|
||||
#[repr(C)]
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct ShortintPBSParameters {
|
||||
pub lwe_dimension: usize,
|
||||
pub glwe_dimension: usize,
|
||||
pub polynomial_size: usize,
|
||||
pub lwe_modular_std_dev: f64,
|
||||
pub glwe_modular_std_dev: f64,
|
||||
pub pbs_base_log: usize,
|
||||
pub pbs_level: usize,
|
||||
pub ks_base_log: usize,
|
||||
pub ks_level: usize,
|
||||
pub message_modulus: usize,
|
||||
pub carry_modulus: usize,
|
||||
pub modulus_power_of_2_exponent: usize,
|
||||
pub encryption_key_choice: ShortintEncryptionKeyChoice,
|
||||
}
|
||||
|
||||
impl From<ShortintPBSParameters> for crate::shortint::ClassicPBSParameters {
|
||||
fn from(c_params: ShortintPBSParameters) -> Self {
|
||||
Self {
|
||||
lwe_dimension: LweDimension(c_params.lwe_dimension),
|
||||
glwe_dimension: GlweDimension(c_params.glwe_dimension),
|
||||
polynomial_size: PolynomialSize(c_params.polynomial_size),
|
||||
lwe_modular_std_dev: StandardDev(c_params.lwe_modular_std_dev),
|
||||
glwe_modular_std_dev: StandardDev(c_params.glwe_modular_std_dev),
|
||||
pbs_base_log: DecompositionBaseLog(c_params.pbs_base_log),
|
||||
pbs_level: DecompositionLevelCount(c_params.pbs_level),
|
||||
ks_base_log: DecompositionBaseLog(c_params.ks_base_log),
|
||||
ks_level: DecompositionLevelCount(c_params.ks_level),
|
||||
message_modulus: crate::shortint::parameters::MessageModulus(c_params.message_modulus),
|
||||
carry_modulus: crate::shortint::parameters::CarryModulus(c_params.carry_modulus),
|
||||
ciphertext_modulus: crate::shortint::parameters::CiphertextModulus::try_new_power_of_2(
|
||||
c_params.modulus_power_of_2_exponent,
|
||||
)
|
||||
.unwrap(),
|
||||
encryption_key_choice: c_params.encryption_key_choice.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::shortint::ClassicPBSParameters> for ShortintPBSParameters {
|
||||
fn from(rust_params: crate::shortint::ClassicPBSParameters) -> Self {
|
||||
Self::convert(rust_params)
|
||||
}
|
||||
}
|
||||
|
||||
impl ShortintEncryptionKeyChoice {
|
||||
// From::from cannot be marked as const, so we have to have
|
||||
// our own function
|
||||
const fn convert(rust_choice: crate::shortint::EncryptionKeyChoice) -> Self {
|
||||
match rust_choice {
|
||||
crate::shortint::EncryptionKeyChoice::Big => Self::ShortintEncryptionKeyChoiceBig,
|
||||
crate::shortint::EncryptionKeyChoice::Small => Self::ShortintEncryptionKeyChoiceSmall,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const fn convert_modulus(rust_modulus: crate::shortint::CiphertextModulus) -> usize {
|
||||
if rust_modulus.is_native_modulus() {
|
||||
64 // shortints are on 64 bits
|
||||
} else {
|
||||
assert!(rust_modulus.is_power_of_two());
|
||||
let modulus = rust_modulus.get_custom_modulus();
|
||||
let exponent = modulus.ilog2() as usize;
|
||||
assert!(exponent <= 64);
|
||||
exponent
|
||||
}
|
||||
}
|
||||
|
||||
impl ShortintPBSParameters {
|
||||
const fn convert(rust_params: crate::shortint::ClassicPBSParameters) -> Self {
|
||||
Self {
|
||||
lwe_dimension: rust_params.lwe_dimension.0,
|
||||
glwe_dimension: rust_params.glwe_dimension.0,
|
||||
polynomial_size: rust_params.polynomial_size.0,
|
||||
lwe_modular_std_dev: rust_params.lwe_modular_std_dev.0,
|
||||
glwe_modular_std_dev: rust_params.glwe_modular_std_dev.0,
|
||||
pbs_base_log: rust_params.pbs_base_log.0,
|
||||
pbs_level: rust_params.pbs_level.0,
|
||||
ks_base_log: rust_params.ks_base_log.0,
|
||||
ks_level: rust_params.ks_level.0,
|
||||
message_modulus: rust_params.message_modulus.0,
|
||||
carry_modulus: rust_params.carry_modulus.0,
|
||||
modulus_power_of_2_exponent: convert_modulus(rust_params.ciphertext_modulus),
|
||||
encryption_key_choice: ShortintEncryptionKeyChoice::convert(
|
||||
rust_params.encryption_key_choice,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! expose_as_shortint_pbs_parameters(
|
||||
(
|
||||
$(
|
||||
$param_name:ident
|
||||
),*
|
||||
$(,)?
|
||||
) => {
|
||||
::paste::paste!{
|
||||
$(
|
||||
#[no_mangle]
|
||||
pub static [<SHORTINT_ $param_name>]: ShortintPBSParameters =
|
||||
ShortintPBSParameters::convert($param_name);
|
||||
|
||||
)*
|
||||
}
|
||||
|
||||
// Test that converting a param from its rust struct
|
||||
// to the c struct and then to the rust struct again
|
||||
// yields the same values as the original struct
|
||||
//
|
||||
// This is what will essentially happen in the real code
|
||||
#[test]
|
||||
fn test_shortint_pbs_parameters_roundtrip_c_rust() {
|
||||
$(
|
||||
// 1 scope for each parameters
|
||||
{
|
||||
let rust_params = crate::shortint::parameters::$param_name;
|
||||
let c_params = ShortintPBSParameters::from(rust_params);
|
||||
let rust_params_from_c = crate::shortint::parameters::ClassicPBSParameters::from(c_params);
|
||||
assert_eq!(rust_params, rust_params_from_c);
|
||||
}
|
||||
)*
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
expose_as_shortint_pbs_parameters!(
|
||||
PARAM_MESSAGE_1_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_1,
|
||||
PARAM_MESSAGE_2_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_2,
|
||||
PARAM_MESSAGE_2_CARRY_1,
|
||||
PARAM_MESSAGE_3_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_3,
|
||||
PARAM_MESSAGE_2_CARRY_2,
|
||||
PARAM_MESSAGE_3_CARRY_1,
|
||||
PARAM_MESSAGE_4_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_4,
|
||||
PARAM_MESSAGE_2_CARRY_3,
|
||||
PARAM_MESSAGE_3_CARRY_2,
|
||||
PARAM_MESSAGE_4_CARRY_1,
|
||||
PARAM_MESSAGE_5_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_5,
|
||||
PARAM_MESSAGE_2_CARRY_4,
|
||||
PARAM_MESSAGE_3_CARRY_3,
|
||||
PARAM_MESSAGE_4_CARRY_2,
|
||||
PARAM_MESSAGE_5_CARRY_1,
|
||||
PARAM_MESSAGE_6_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_6,
|
||||
PARAM_MESSAGE_2_CARRY_5,
|
||||
PARAM_MESSAGE_3_CARRY_4,
|
||||
PARAM_MESSAGE_4_CARRY_3,
|
||||
PARAM_MESSAGE_5_CARRY_2,
|
||||
PARAM_MESSAGE_6_CARRY_1,
|
||||
PARAM_MESSAGE_7_CARRY_0,
|
||||
PARAM_MESSAGE_1_CARRY_7,
|
||||
PARAM_MESSAGE_2_CARRY_6,
|
||||
PARAM_MESSAGE_3_CARRY_5,
|
||||
PARAM_MESSAGE_4_CARRY_4,
|
||||
PARAM_MESSAGE_5_CARRY_3,
|
||||
PARAM_MESSAGE_6_CARRY_2,
|
||||
PARAM_MESSAGE_7_CARRY_1,
|
||||
PARAM_MESSAGE_8_CARRY_0,
|
||||
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK,
|
||||
// Small params
|
||||
PARAM_SMALL_MESSAGE_1_CARRY_1,
|
||||
PARAM_SMALL_MESSAGE_2_CARRY_2,
|
||||
PARAM_SMALL_MESSAGE_3_CARRY_3,
|
||||
PARAM_SMALL_MESSAGE_4_CARRY_4,
|
||||
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_SMALL,
|
||||
);
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn shortint_get_parameters(
|
||||
message_bits: u32,
|
||||
carry_bits: u32,
|
||||
result: *mut *mut ShortintParameters,
|
||||
result: *mut ShortintPBSParameters,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
@@ -76,12 +251,8 @@ pub unsafe extern "C" fn shortint_get_parameters(
|
||||
_ => None,
|
||||
};
|
||||
|
||||
match params {
|
||||
Some(params) => {
|
||||
let params = Box::new(ShortintParameters(params));
|
||||
*result = Box::into_raw(params);
|
||||
}
|
||||
None => *result = std::ptr::null_mut(),
|
||||
if let Some(params) = params {
|
||||
*result = params.into();
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -90,7 +261,7 @@ pub unsafe extern "C" fn shortint_get_parameters(
|
||||
pub unsafe extern "C" fn shortint_get_parameters_small(
|
||||
message_bits: u32,
|
||||
carry_bits: u32,
|
||||
result: *mut *mut ShortintParameters,
|
||||
result: *mut ShortintPBSParameters,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
@@ -102,71 +273,8 @@ pub unsafe extern "C" fn shortint_get_parameters_small(
|
||||
_ => None,
|
||||
};
|
||||
|
||||
match params {
|
||||
Some(params) => {
|
||||
let params = Box::new(ShortintParameters(params));
|
||||
*result = Box::into_raw(params);
|
||||
}
|
||||
None => *result = std::ptr::null_mut(),
|
||||
if let Some(params) = params {
|
||||
*result = params.into();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn shortint_create_parameters(
|
||||
lwe_dimension: usize,
|
||||
glwe_dimension: usize,
|
||||
polynomial_size: usize,
|
||||
lwe_modular_std_dev: f64,
|
||||
glwe_modular_std_dev: f64,
|
||||
pbs_base_log: usize,
|
||||
pbs_level: usize,
|
||||
ks_base_log: usize,
|
||||
ks_level: usize,
|
||||
pfks_level: usize,
|
||||
pfks_base_log: usize,
|
||||
pfks_modular_std_dev: f64,
|
||||
cbs_level: usize,
|
||||
cbs_base_log: usize,
|
||||
message_modulus: usize,
|
||||
carry_modulus: usize,
|
||||
modulus_power_of_2_exponent: usize,
|
||||
encryption_key_choice: ShortintEncryptionKeyChoice,
|
||||
result: *mut *mut ShortintParameters,
|
||||
) -> c_int {
|
||||
catch_panic(|| {
|
||||
check_ptr_is_non_null_and_aligned(result).unwrap();
|
||||
|
||||
// First fill the result with a null ptr so that if we fail and the return code is not
|
||||
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
|
||||
*result = std::ptr::null_mut();
|
||||
|
||||
let heap_allocated_parameters =
|
||||
Box::new(ShortintParameters(shortint::parameters::Parameters {
|
||||
lwe_dimension: LweDimension(lwe_dimension),
|
||||
glwe_dimension: GlweDimension(glwe_dimension),
|
||||
polynomial_size: PolynomialSize(polynomial_size),
|
||||
lwe_modular_std_dev: StandardDev(lwe_modular_std_dev),
|
||||
glwe_modular_std_dev: StandardDev(glwe_modular_std_dev),
|
||||
pbs_base_log: DecompositionBaseLog(pbs_base_log),
|
||||
pbs_level: DecompositionLevelCount(pbs_level),
|
||||
ks_base_log: DecompositionBaseLog(ks_base_log),
|
||||
ks_level: DecompositionLevelCount(ks_level),
|
||||
pfks_level: DecompositionLevelCount(pfks_level),
|
||||
pfks_base_log: DecompositionBaseLog(pfks_base_log),
|
||||
pfks_modular_std_dev: StandardDev(pfks_modular_std_dev),
|
||||
cbs_level: DecompositionLevelCount(cbs_level),
|
||||
cbs_base_log: DecompositionBaseLog(cbs_base_log),
|
||||
message_modulus: crate::shortint::parameters::MessageModulus(message_modulus),
|
||||
carry_modulus: crate::shortint::parameters::CarryModulus(carry_modulus),
|
||||
ciphertext_modulus:
|
||||
crate::shortint::parameters::CiphertextModulus::try_new_power_of_2(
|
||||
modulus_power_of_2_exponent,
|
||||
)
|
||||
.unwrap(),
|
||||
encryption_key_choice: encryption_key_choice.into(),
|
||||
}));
|
||||
|
||||
*result = Box::into_raw(heap_allocated_parameters);
|
||||
})
|
||||
}
|
||||
|
||||
@@ -111,6 +111,8 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
let decomp_base_log = output.decomposition_base_log();
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
for (level_index, (mut level_matrix, mut generator)) in
|
||||
output.iter_mut().zip(gen_iter).enumerate()
|
||||
{
|
||||
@@ -121,7 +123,7 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = generator
|
||||
@@ -249,6 +251,8 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
let decomp_base_log = output.decomposition_base_log();
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|
||||
|(level_index, (mut level_matrix, mut generator))| {
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
@@ -258,7 +262,7 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special
|
||||
// treatment
|
||||
@@ -367,6 +371,8 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
let decomp_base_log = output.decomposition_base_log();
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
for (level_index, (mut level_matrix, mut loop_generator)) in
|
||||
output.iter_mut().zip(gen_iter).enumerate()
|
||||
{
|
||||
@@ -377,7 +383,7 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = loop_generator
|
||||
@@ -542,6 +548,8 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
let decomp_base_log = output.decomposition_base_log();
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|
||||
|(level_index, (mut level_matrix, mut generator))| {
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
@@ -551,7 +559,7 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = generator
|
||||
@@ -832,13 +840,13 @@ where
|
||||
|
||||
let plaintext_ref = decrypted_plaintext_list.get(0);
|
||||
|
||||
let ciphertext_modulus = ggsw_ciphertext.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// Glwe decryption maps to a smaller torus potentially, map back to the native torus
|
||||
let rounded = decomposer.closest_representable(
|
||||
(*plaintext_ref.0).wrapping_mul(
|
||||
ggsw_ciphertext
|
||||
.ciphertext_modulus()
|
||||
.get_scaling_to_native_torus(),
|
||||
),
|
||||
(*plaintext_ref.0)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
let decoded =
|
||||
rounded.wrapping_div(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)));
|
||||
|
||||
@@ -37,6 +37,8 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
generator.unsigned_torus_slice_wrapping_add_random_noise_custom_mod_assign(
|
||||
output_body.as_mut(),
|
||||
@@ -45,7 +47,7 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
slice_wrapping_scalar_mul_assign(output_body.as_mut(), torus_scaling);
|
||||
}
|
||||
@@ -250,6 +252,8 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
generator.fill_slice_with_random_noise_custom_mod(
|
||||
output_body.as_mut(),
|
||||
@@ -263,7 +267,7 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
slice_wrapping_scalar_mul_assign(output_body.as_mut(), torus_scaling);
|
||||
}
|
||||
@@ -572,6 +576,8 @@ pub fn decrypt_glwe_ciphertext<Scalar, KeyCont, InputCont, OutputCont>(
|
||||
|
||||
let ciphertext_modulus = input_glwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let (mask, body) = input_glwe_ciphertext.get_mask_and_body();
|
||||
output_plaintext_list
|
||||
.as_mut()
|
||||
@@ -585,7 +591,7 @@ pub fn decrypt_glwe_ciphertext<Scalar, KeyCont, InputCont, OutputCont>(
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_div_assign(
|
||||
output_plaintext_list.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -720,10 +726,12 @@ pub fn trivially_encrypt_glwe_ciphertext<Scalar, InputCont, OutputCont>(
|
||||
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
body.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -801,6 +809,8 @@ where
|
||||
Scalar: UnsignedTorus,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let polynomial_size = PolynomialSize(encoded.plaintext_count().0);
|
||||
|
||||
let mut new_ct =
|
||||
@@ -812,7 +822,7 @@ where
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
body.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
//! Module with primitives pertaining to [`LweCompactCiphertextList`] expansion.
|
||||
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::polynomial_wrapping_monic_monomial_mul_assign;
|
||||
use crate::core_crypto::commons::parameters::MonomialDegree;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Expand an [`LweCompactCiphertextList`] into an [`LweCiphertextList`].
|
||||
///
|
||||
/// Consider using [`par_expand_lwe_compact_ciphertext_list`] for better performance.
|
||||
pub fn expand_lwe_compact_ciphertext_list<Scalar, InputCont, OutputCont>(
|
||||
output_lwe_ciphertext_list: &mut LweCiphertextList<OutputCont>,
|
||||
input_lwe_compact_ciphertext_list: &LweCompactCiphertextList<InputCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output_lwe_ciphertext_list.entity_count()
|
||||
== input_lwe_compact_ciphertext_list.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
assert!(output_lwe_ciphertext_list.lwe_size() == input_lwe_compact_ciphertext_list.lwe_size());
|
||||
|
||||
let (input_mask_list, input_body_list) =
|
||||
input_lwe_compact_ciphertext_list.get_mask_and_body_list();
|
||||
|
||||
let lwe_dimension = input_mask_list.lwe_dimension();
|
||||
let max_ciphertext_per_bin = lwe_dimension.0;
|
||||
|
||||
for (input_mask, (mut output_ct_chunk, input_body_chunk)) in input_mask_list.iter().zip(
|
||||
output_lwe_ciphertext_list
|
||||
.chunks_mut(max_ciphertext_per_bin)
|
||||
.zip(input_body_list.chunks(max_ciphertext_per_bin)),
|
||||
) {
|
||||
for (ct_idx, (mut out_ct, input_body)) in output_ct_chunk
|
||||
.iter_mut()
|
||||
.zip(input_body_chunk.iter())
|
||||
.enumerate()
|
||||
{
|
||||
let (mut out_mask, out_body) = out_ct.get_mut_mask_and_body();
|
||||
out_mask.as_mut().copy_from_slice(input_mask.as_ref());
|
||||
|
||||
let mut out_mask_as_polynomial = Polynomial::from_container(out_mask.as_mut());
|
||||
|
||||
// This the Psi_jl from the paper, it's equivalent to a multiplication in the X^N + 1
|
||||
// ring for our choice of i == n
|
||||
polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut out_mask_as_polynomial,
|
||||
MonomialDegree(lwe_dimension.0 - (ct_idx + 1)),
|
||||
);
|
||||
|
||||
*out_body.data = *input_body.data;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel version of [`expand_lwe_compact_ciphertext_list`].
|
||||
pub fn par_expand_lwe_compact_ciphertext_list<Scalar, InputCont, OutputCont>(
|
||||
output_lwe_ciphertext_list: &mut LweCiphertextList<OutputCont>,
|
||||
input_lwe_compact_ciphertext_list: &LweCompactCiphertextList<InputCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger + Send + Sync,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output_lwe_ciphertext_list.entity_count()
|
||||
== input_lwe_compact_ciphertext_list.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
assert!(output_lwe_ciphertext_list.lwe_size() == input_lwe_compact_ciphertext_list.lwe_size());
|
||||
|
||||
let (input_mask_list, input_body_list) =
|
||||
input_lwe_compact_ciphertext_list.get_mask_and_body_list();
|
||||
|
||||
let lwe_dimension = input_mask_list.lwe_dimension();
|
||||
let max_ciphertext_per_bin = lwe_dimension.0;
|
||||
|
||||
input_mask_list
|
||||
.par_iter()
|
||||
.zip(
|
||||
output_lwe_ciphertext_list
|
||||
.par_chunks_mut(max_ciphertext_per_bin)
|
||||
.zip(input_body_list.par_chunks(max_ciphertext_per_bin)),
|
||||
)
|
||||
.for_each(|(input_mask, (mut output_ct_chunk, input_body_chunk))| {
|
||||
output_ct_chunk
|
||||
.par_iter_mut()
|
||||
.zip(input_body_chunk.par_iter())
|
||||
.enumerate()
|
||||
.for_each(|(ct_idx, (mut out_ct, input_body))| {
|
||||
let (mut out_mask, out_body) = out_ct.get_mut_mask_and_body();
|
||||
out_mask.as_mut().copy_from_slice(input_mask.as_ref());
|
||||
|
||||
let mut out_mask_as_polynomial = Polynomial::from_container(out_mask.as_mut());
|
||||
|
||||
// This is the Psi_jl from the paper, it's equivalent to a multiplication in the
|
||||
// X^N + 1 ring for our choice of i == n
|
||||
polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut out_mask_as_polynomial,
|
||||
MonomialDegree(lwe_dimension.0 - (ct_idx + 1)),
|
||||
);
|
||||
|
||||
*out_body.data = *input_body.data;
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
//! Module containing primitives pertaining to [`LWE compact public key
|
||||
//! generation`](`LweCompactPublicKey`).
|
||||
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::prelude::ActivatedRandomGenerator;
|
||||
use slice_algorithms::*;
|
||||
|
||||
/// Fill an [`LWE compact public key`](`LweCompactPublicKey`) with an actual public key constructed
|
||||
/// from a private [`LWE secret key`](`LweSecretKey`).
|
||||
pub fn generate_lwe_compact_public_key<Scalar, InputKeyCont, OutputKeyCont, Gen>(
|
||||
lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
output: &mut LweCompactPublicKey<OutputKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation only supports native moduli"
|
||||
);
|
||||
|
||||
assert!(
|
||||
lwe_secret_key.lwe_dimension() == output.lwe_dimension(),
|
||||
"Mismatched LweDimension between input LweSecretKey {:?} \
|
||||
and ouptut LweCompactPublicKey {:?}",
|
||||
lwe_secret_key.lwe_dimension(),
|
||||
output.lwe_dimension()
|
||||
);
|
||||
|
||||
let (mut mask, mut body) = output.get_mut_mask_and_body();
|
||||
generator.fill_slice_with_random_mask(mask.as_mut());
|
||||
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
body.as_mut(),
|
||||
mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
);
|
||||
|
||||
generator
|
||||
.unsigned_torus_slice_wrapping_add_random_noise_assign(body.as_mut(), noise_parameters);
|
||||
}
|
||||
|
||||
/// Allocate a new [`LWE compact public key`](`LweCompactPublicKey`) and fill it with an actual
|
||||
/// public key constructed from a private [`LWE secret key`](`LweSecretKey`).
|
||||
///
|
||||
/// See [`encrypt_lwe_ciphertext_with_compact_public_key`] for usage.
|
||||
pub fn allocate_and_generate_new_lwe_compact_public_key<Scalar, InputKeyCont, Gen>(
|
||||
lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) -> LweCompactPublicKeyOwned<Scalar>
|
||||
where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
let mut pk = LweCompactPublicKeyOwned::new(
|
||||
Scalar::ZERO,
|
||||
lwe_secret_key.lwe_dimension(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
generate_lwe_compact_public_key(lwe_secret_key, &mut pk, noise_parameters, generator);
|
||||
|
||||
pk
|
||||
}
|
||||
|
||||
/// Fill a [`seeded LWE compact public key`](`LweCompactPublicKey`) with an actual public key
|
||||
/// constructed from a private [`LWE secret key`](`LweSecretKey`).
|
||||
pub fn generate_seeded_lwe_compact_public_key<Scalar, InputKeyCont, OutputKeyCont, NoiseSeeder>(
|
||||
lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
output: &mut SeededLweCompactPublicKey<OutputKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: ContainerMut<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
assert!(
|
||||
output.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation only supports native moduli"
|
||||
);
|
||||
|
||||
assert!(
|
||||
lwe_secret_key.lwe_dimension() == output.lwe_dimension(),
|
||||
"Mismatched LweDimension between input LweSecretKey {:?} \
|
||||
and ouptut LweCompactPublicKey {:?}",
|
||||
lwe_secret_key.lwe_dimension(),
|
||||
output.lwe_dimension()
|
||||
);
|
||||
|
||||
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
output.compression_seed().seed,
|
||||
noise_seeder,
|
||||
);
|
||||
|
||||
let mut tmp_mask = vec![Scalar::ZERO; output.lwe_dimension().0];
|
||||
generator.fill_slice_with_random_mask(tmp_mask.as_mut());
|
||||
|
||||
let mut body = output.get_mut_body();
|
||||
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
body.as_mut(),
|
||||
tmp_mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
);
|
||||
|
||||
generator
|
||||
.unsigned_torus_slice_wrapping_add_random_noise_assign(body.as_mut(), noise_parameters);
|
||||
}
|
||||
|
||||
/// Allocate a new [`seeded LWE compact public key`](`SeededLweCompactPublicKey`) and fill it with
|
||||
/// an actual public key constructed from a private [`LWE secret key`](`LweSecretKey`).
|
||||
pub fn allocate_and_generate_new_seeded_lwe_compact_public_key<Scalar, InputKeyCont, NoiseSeeder>(
|
||||
lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) -> SeededLweCompactPublicKeyOwned<Scalar>
|
||||
where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
let mut pk = SeededLweCompactPublicKey::new(
|
||||
Scalar::ZERO,
|
||||
lwe_secret_key.lwe_dimension(),
|
||||
noise_seeder.seed().into(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
generate_seeded_lwe_compact_public_key(lwe_secret_key, &mut pk, noise_parameters, noise_seeder);
|
||||
|
||||
pk
|
||||
}
|
||||
@@ -36,6 +36,8 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
|
||||
// generate an error from the normal distribution described by std_dev
|
||||
@@ -43,7 +45,7 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
*output_body.data = (*output_body.data).wrapping_add(encoded.0);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
*output_body.data = (*output_body.data).wrapping_mul(torus_scaling);
|
||||
}
|
||||
@@ -296,9 +298,10 @@ pub fn trivially_encrypt_lwe_ciphertext<Scalar, OutputCont>(
|
||||
*output_body.data = encoded.0;
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
*output_body.data =
|
||||
(*output_body.data).wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
*output_body.data = (*output_body.data)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,9 +377,10 @@ where
|
||||
*output_body.data = encoded.0;
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
*output_body.data =
|
||||
(*output_body.data).wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
*output_body.data = (*output_body.data)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
new_ct
|
||||
@@ -409,6 +413,8 @@ where
|
||||
|
||||
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let (mask, body) = lwe_ciphertext.get_mask_and_body();
|
||||
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
@@ -423,7 +429,7 @@ where
|
||||
mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -777,6 +783,8 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
|
||||
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
output.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let mut tmp_zero_encryption =
|
||||
@@ -806,7 +814,7 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
encoded
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -912,6 +920,8 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let mut tmp_zero_encryption =
|
||||
LweCiphertext::new(Scalar::ZERO, lwe_public_key.lwe_size(), ciphertext_modulus);
|
||||
|
||||
@@ -926,7 +936,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*body.data = *public_encryption_of_zero_body.data;
|
||||
@@ -945,7 +955,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
encoded
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1473,3 +1483,791 @@ where
|
||||
|
||||
seeded_ct
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`) using an
|
||||
/// [`LWE compact public key`](`LweCompactPublicKey`). The ciphertext can be decrypted using the
|
||||
/// [`LWE secret key`](`LweSecretKey`) that was used to generate the public key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
///
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for LweCiphertext creation
|
||||
/// let lwe_dimension = LweDimension(2048);
|
||||
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
|
||||
/// let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator =
|
||||
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the LweSecretKey
|
||||
/// let lwe_secret_key =
|
||||
/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
///
|
||||
/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key(
|
||||
/// &lwe_secret_key,
|
||||
/// glwe_modular_std_dev,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// // Create the plaintext
|
||||
/// let msg = 3u64;
|
||||
/// let plaintext = Plaintext(msg << 60);
|
||||
///
|
||||
/// // Create a new LweCiphertext
|
||||
/// let mut lwe = LweCiphertext::new(0u64, lwe_dimension.to_lwe_size(), ciphertext_modulus);
|
||||
///
|
||||
/// encrypt_lwe_ciphertext_with_compact_public_key(
|
||||
/// &lwe_compact_public_key,
|
||||
/// &mut lwe,
|
||||
/// plaintext,
|
||||
/// glwe_modular_std_dev,
|
||||
/// glwe_modular_std_dev,
|
||||
/// &mut secret_generator,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&lwe_secret_key, &lwe);
|
||||
///
|
||||
/// // Round and remove encoding
|
||||
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
|
||||
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
|
||||
///
|
||||
/// // Remove the encoding
|
||||
/// let cleartext = rounded >> 60;
|
||||
///
|
||||
/// // Check we recovered the original message
|
||||
/// assert_eq!(cleartext, msg);
|
||||
/// ```
|
||||
pub fn encrypt_lwe_ciphertext_with_compact_public_key<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
OutputCont,
|
||||
SecretGen,
|
||||
EncryptionGen,
|
||||
>(
|
||||
lwe_compact_public_key: &LweCompactPublicKey<KeyCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
mask_noise_parameters: impl DispersionParameter,
|
||||
body_noise_parameters: impl DispersionParameter,
|
||||
secret_generator: &mut SecretRandomGenerator<SecretGen>,
|
||||
encryption_generator: &mut EncryptionRandomGenerator<EncryptionGen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
SecretGen: ByteRandomGenerator,
|
||||
EncryptionGen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(),
|
||||
"Mismatch between LweDimension of output cipertext and input public key. \
|
||||
Got {:?} in output, and {:?} in public key.",
|
||||
output.lwe_size().to_lwe_dimension(),
|
||||
lwe_compact_public_key.lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(),
|
||||
"Mismatch between CiphertextModulus of output cipertext and input public key. \
|
||||
Got {:?} in output, and {:?} in public key.",
|
||||
output.ciphertext_modulus(),
|
||||
lwe_compact_public_key.ciphertext_modulus()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation only supports native moduli"
|
||||
);
|
||||
|
||||
let mut binary_random_vector = vec![Scalar::ZERO; lwe_compact_public_key.lwe_dimension().0];
|
||||
|
||||
secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector);
|
||||
|
||||
let (mut ct_mask, ct_body) = output.get_mut_mask_and_body();
|
||||
let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body();
|
||||
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
ct_mask.as_mut(),
|
||||
pk_mask.as_ref(),
|
||||
&binary_random_vector,
|
||||
);
|
||||
|
||||
// Noise from Chi_1 for the mask part of the encryption
|
||||
encryption_generator.unsigned_torus_slice_wrapping_add_random_noise_assign(
|
||||
ct_mask.as_mut(),
|
||||
mask_noise_parameters,
|
||||
);
|
||||
|
||||
*ct_body.data = slice_wrapping_dot_product(pk_body.as_ref(), &binary_random_vector);
|
||||
// Noise from Chi_2 for the body part of the encryption
|
||||
*ct_body.data =
|
||||
(*ct_body.data).wrapping_add(encryption_generator.random_noise(body_noise_parameters));
|
||||
*ct_body.data = (*ct_body.data).wrapping_add(encoded.0);
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext list in an output [`LWE compact ciphertext
|
||||
/// list`](`LweCompactCiphertextList`) using an [`LWE compact public key`](`LweCompactPublicKey`).
|
||||
/// The expanded ciphertext list can be decrypted using the [`LWE secret key`](`LweSecretKey`) that
|
||||
/// was used to generate the public key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
///
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for LweCiphertext creation
|
||||
/// let lwe_dimension = LweDimension(2048);
|
||||
/// let lwe_ciphertext_count = LweCiphertextCount(lwe_dimension.0 * 4);
|
||||
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
|
||||
/// let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator =
|
||||
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the LweSecretKey
|
||||
/// let lwe_secret_key =
|
||||
/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
///
|
||||
/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key(
|
||||
/// &lwe_secret_key,
|
||||
/// glwe_modular_std_dev,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut input_plaintext_list = PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0));
|
||||
/// input_plaintext_list
|
||||
/// .iter_mut()
|
||||
/// .enumerate()
|
||||
/// .for_each(|(idx, x)| {
|
||||
/// *x.0 = (idx as u64 % 16) << 60;
|
||||
/// });
|
||||
///
|
||||
/// // Create a new LweCompactCiphertextList
|
||||
/// let mut output_compact_ct_list = LweCompactCiphertextList::new(
|
||||
/// 0u64,
|
||||
/// lwe_dimension.to_lwe_size(),
|
||||
/// lwe_ciphertext_count,
|
||||
/// ciphertext_modulus,
|
||||
/// );
|
||||
///
|
||||
/// encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
|
||||
/// &lwe_compact_public_key,
|
||||
/// &mut output_compact_ct_list,
|
||||
/// &input_plaintext_list,
|
||||
/// glwe_modular_std_dev,
|
||||
/// glwe_modular_std_dev,
|
||||
/// &mut secret_generator,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
/// output_plaintext_list.as_mut().fill(0u64);
|
||||
///
|
||||
/// let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list();
|
||||
///
|
||||
/// decrypt_lwe_ciphertext_list(
|
||||
/// &lwe_secret_key,
|
||||
/// &lwe_ciphertext_list,
|
||||
/// &mut output_plaintext_list,
|
||||
/// );
|
||||
///
|
||||
/// let signed_decomposer =
|
||||
/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// // Round the plaintexts
|
||||
/// output_plaintext_list
|
||||
/// .iter_mut()
|
||||
/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
///
|
||||
/// // Check we recovered the original messages
|
||||
/// assert_eq!(input_plaintext_list, output_plaintext_list);
|
||||
/// ```
|
||||
pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
SecretGen,
|
||||
EncryptionGen,
|
||||
>(
|
||||
lwe_compact_public_key: &LweCompactPublicKey<KeyCont>,
|
||||
output: &mut LweCompactCiphertextList<OutputCont>,
|
||||
encoded: &PlaintextList<InputCont>,
|
||||
mask_noise_parameters: impl DispersionParameter,
|
||||
body_noise_parameters: impl DispersionParameter,
|
||||
secret_generator: &mut SecretRandomGenerator<SecretGen>,
|
||||
encryption_generator: &mut EncryptionRandomGenerator<EncryptionGen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
SecretGen: ByteRandomGenerator,
|
||||
EncryptionGen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(),
|
||||
"Mismatch between LweDimension of output cipertext and input public key. \
|
||||
Got {:?} in output, and {:?} in public key.",
|
||||
output.lwe_size().to_lwe_dimension(),
|
||||
lwe_compact_public_key.lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(),
|
||||
"Mismatch between CiphertextModulus of output cipertext and input public key. \
|
||||
Got {:?} in output, and {:?} in public key.",
|
||||
output.ciphertext_modulus(),
|
||||
lwe_compact_public_key.ciphertext_modulus()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.lwe_ciphertext_count().0 == encoded.plaintext_count().0,
|
||||
"Mismatch between LweCiphertextCount of output cipertext and \
|
||||
PlaintextCount of input list. Got {:?} in output, and {:?} in input plaintext list.",
|
||||
output.lwe_ciphertext_count(),
|
||||
encoded.plaintext_count()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation only supports native moduli"
|
||||
);
|
||||
|
||||
let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list();
|
||||
let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body();
|
||||
|
||||
let lwe_mask_count = output_mask_list.lwe_mask_count();
|
||||
let lwe_dimension = output_mask_list.lwe_dimension();
|
||||
|
||||
let mut binary_random_vector = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()];
|
||||
secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector);
|
||||
|
||||
let max_ciphertext_per_bin = lwe_dimension.0;
|
||||
|
||||
let gen_iter = encryption_generator
|
||||
.fork_lwe_compact_ciphertext_list_to_bin::<Scalar>(lwe_mask_count, lwe_dimension)
|
||||
.expect("Failed to split generator into lwe compact ciphertext bins");
|
||||
|
||||
// Loop over the ciphertext "bins"
|
||||
output_mask_list
|
||||
.iter_mut()
|
||||
.zip(
|
||||
output_body_list
|
||||
.chunks_mut(max_ciphertext_per_bin)
|
||||
.zip(encoded.chunks(max_ciphertext_per_bin))
|
||||
.zip(binary_random_vector.chunks(max_ciphertext_per_bin))
|
||||
.zip(gen_iter),
|
||||
)
|
||||
.for_each(
|
||||
|(
|
||||
mut output_mask,
|
||||
(
|
||||
((mut output_body_chunk, input_plaintext_chunk), binary_random_slice),
|
||||
mut loop_generator,
|
||||
),
|
||||
)| {
|
||||
// output_body_chunk may not be able to fit the full convolution result so we create
|
||||
// a temp buffer to compute the full convolution
|
||||
let mut pk_body_convolved = vec![Scalar::ZERO; lwe_dimension.0];
|
||||
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
output_mask.as_mut(),
|
||||
pk_mask.as_ref(),
|
||||
binary_random_slice,
|
||||
);
|
||||
|
||||
// Fill the temp buffer with b convolved with r
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
pk_body_convolved.as_mut_slice(),
|
||||
pk_body.as_ref(),
|
||||
binary_random_slice,
|
||||
);
|
||||
|
||||
// Noise from Chi_1 for the mask part of the encryption
|
||||
loop_generator.unsigned_torus_slice_wrapping_add_random_noise_assign(
|
||||
output_mask.as_mut(),
|
||||
mask_noise_parameters,
|
||||
);
|
||||
|
||||
// Fill the body chunk afterwards manually as it most likely will be smaller than
|
||||
// the full convolution result. b convolved r + Delta * m + e2
|
||||
// taking noise from Chi_2 for the body part of the encryption
|
||||
output_body_chunk
|
||||
.iter_mut()
|
||||
.zip(pk_body_convolved.iter().zip(input_plaintext_chunk.iter()))
|
||||
.for_each(|(dst, (&src, plaintext))| {
|
||||
*dst.data = src
|
||||
.wrapping_add(loop_generator.random_noise(body_noise_parameters))
|
||||
.wrapping_add(*plaintext.0)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`encrypt_lwe_compact_ciphertext_list_with_compact_public_key`]. Encrypt an
|
||||
/// input plaintext list in an output [`LWE compact ciphertext list`](`LweCompactCiphertextList`)
|
||||
/// using an [`LWE compact public key`](`LweCompactPublicKey`). The expanded ciphertext list can be
|
||||
/// decrypted using the [`LWE secret key`](`LweSecretKey`) that was used to generate the public key.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
///
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for LweCiphertext creation
|
||||
/// let lwe_dimension = LweDimension(2048);
|
||||
/// let lwe_ciphertext_count = LweCiphertextCount(lwe_dimension.0 * 4);
|
||||
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
|
||||
/// let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator =
|
||||
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the LweSecretKey
|
||||
/// let lwe_secret_key =
|
||||
/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
///
|
||||
/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key(
|
||||
/// &lwe_secret_key,
|
||||
/// glwe_modular_std_dev,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut input_plaintext_list = PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0));
|
||||
/// input_plaintext_list
|
||||
/// .iter_mut()
|
||||
/// .enumerate()
|
||||
/// .for_each(|(idx, x)| {
|
||||
/// *x.0 = (idx as u64 % 16) << 60;
|
||||
/// });
|
||||
///
|
||||
/// // Create a new LweCompactCiphertextList
|
||||
/// let mut output_compact_ct_list = LweCompactCiphertextList::new(
|
||||
/// 0u64,
|
||||
/// lwe_dimension.to_lwe_size(),
|
||||
/// lwe_ciphertext_count,
|
||||
/// ciphertext_modulus,
|
||||
/// );
|
||||
///
|
||||
/// par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
|
||||
/// &lwe_compact_public_key,
|
||||
/// &mut output_compact_ct_list,
|
||||
/// &input_plaintext_list,
|
||||
/// glwe_modular_std_dev,
|
||||
/// glwe_modular_std_dev,
|
||||
/// &mut secret_generator,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
/// output_plaintext_list.as_mut().fill(0u64);
|
||||
///
|
||||
/// let lwe_ciphertext_list = output_compact_ct_list.par_expand_into_lwe_ciphertext_list();
|
||||
///
|
||||
/// decrypt_lwe_ciphertext_list(
|
||||
/// &lwe_secret_key,
|
||||
/// &lwe_ciphertext_list,
|
||||
/// &mut output_plaintext_list,
|
||||
/// );
|
||||
///
|
||||
/// let signed_decomposer =
|
||||
/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// // Round the plaintexts
|
||||
/// output_plaintext_list
|
||||
/// .iter_mut()
|
||||
/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
///
|
||||
/// // Check we recovered the original messages
|
||||
/// assert_eq!(input_plaintext_list, output_plaintext_list);
|
||||
/// ```
|
||||
pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
SecretGen,
|
||||
EncryptionGen,
|
||||
>(
|
||||
lwe_compact_public_key: &LweCompactPublicKey<KeyCont>,
|
||||
output: &mut LweCompactCiphertextList<OutputCont>,
|
||||
encoded: &PlaintextList<InputCont>,
|
||||
mask_noise_parameters: impl DispersionParameter + Sync,
|
||||
body_noise_parameters: impl DispersionParameter + Sync,
|
||||
secret_generator: &mut SecretRandomGenerator<SecretGen>,
|
||||
encryption_generator: &mut EncryptionRandomGenerator<EncryptionGen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus + Sync + Send,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
SecretGen: ByteRandomGenerator,
|
||||
EncryptionGen: ParallelByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(),
|
||||
"Mismatch between LweDimension of output cipertext and input public key. \
|
||||
Got {:?} in output, and {:?} in public key.",
|
||||
output.lwe_size().to_lwe_dimension(),
|
||||
lwe_compact_public_key.lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(),
|
||||
"Mismatch between CiphertextModulus of output cipertext and input public key. \
|
||||
Got {:?} in output, and {:?} in public key.",
|
||||
output.ciphertext_modulus(),
|
||||
lwe_compact_public_key.ciphertext_modulus()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.lwe_ciphertext_count().0 == encoded.plaintext_count().0,
|
||||
"Mismatch between LweCiphertextCount of output cipertext and \
|
||||
PlaintextCount of input list. Got {:?} in output, and {:?} in input plaintext list.",
|
||||
output.lwe_ciphertext_count(),
|
||||
encoded.plaintext_count()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.ciphertext_modulus().is_native_modulus(),
|
||||
"This operation only supports native moduli"
|
||||
);
|
||||
|
||||
let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list();
|
||||
let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body();
|
||||
|
||||
let lwe_mask_count = output_mask_list.lwe_mask_count();
|
||||
let lwe_dimension = output_mask_list.lwe_dimension();
|
||||
|
||||
let mut binary_random_vector = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()];
|
||||
secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector);
|
||||
|
||||
let max_ciphertext_per_bin = lwe_dimension.0;
|
||||
|
||||
let gen_iter = encryption_generator
|
||||
.par_fork_lwe_compact_ciphertext_list_to_bin::<Scalar>(lwe_mask_count, lwe_dimension)
|
||||
.expect("Failed to split generator into lwe compact ciphertext bins");
|
||||
|
||||
// Loop over the ciphertext "bins"
|
||||
output_mask_list
|
||||
.par_iter_mut()
|
||||
.zip(
|
||||
output_body_list
|
||||
.par_chunks_mut(max_ciphertext_per_bin)
|
||||
.zip(encoded.par_chunks(max_ciphertext_per_bin))
|
||||
.zip(binary_random_vector.par_chunks(max_ciphertext_per_bin))
|
||||
.zip(gen_iter),
|
||||
)
|
||||
.for_each(
|
||||
|(
|
||||
mut output_mask,
|
||||
(
|
||||
((mut output_body_chunk, input_plaintext_chunk), binary_random_slice),
|
||||
mut loop_generator,
|
||||
),
|
||||
)| {
|
||||
// output_body_chunk may not be able to fit the full convolution result so we create
|
||||
// a temp buffer to compute the full convolution
|
||||
let mut pk_body_convolved = vec![Scalar::ZERO; lwe_dimension.0];
|
||||
|
||||
rayon::join(
|
||||
|| {
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
output_mask.as_mut(),
|
||||
pk_mask.as_ref(),
|
||||
binary_random_slice,
|
||||
);
|
||||
},
|
||||
|| {
|
||||
// Fill the temp buffer with b convolved with r
|
||||
slice_semi_reverse_negacyclic_convolution(
|
||||
pk_body_convolved.as_mut_slice(),
|
||||
pk_body.as_ref(),
|
||||
binary_random_slice,
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
// Noise from Chi_1 for the mask part of the encryption
|
||||
loop_generator.unsigned_torus_slice_wrapping_add_random_noise_assign(
|
||||
output_mask.as_mut(),
|
||||
mask_noise_parameters,
|
||||
);
|
||||
|
||||
// Fill the body chunk afterwards manually as it most likely will be smaller than
|
||||
// the full convolution result. b convolved r + Delta * m + e2
|
||||
// taking noise from Chi_2 for the body part of the encryption
|
||||
output_body_chunk
|
||||
.iter_mut()
|
||||
.zip(pk_body_convolved.iter().zip(input_plaintext_chunk.iter()))
|
||||
.for_each(|(dst, (&src, plaintext))| {
|
||||
*dst.data = src
|
||||
.wrapping_add(loop_generator.random_noise(body_noise_parameters))
|
||||
.wrapping_add(*plaintext.0)
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::core_crypto::commons::test_tools;
|
||||
use crate::core_crypto::prelude::*;
|
||||
|
||||
use crate::core_crypto::commons::generators::{
|
||||
DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
|
||||
|
||||
#[test]
|
||||
fn test_compact_public_key_encryption() {
|
||||
use rand::Rng;
|
||||
|
||||
let lwe_dimension = LweDimension(2048);
|
||||
let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
|
||||
let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
|
||||
let mut secret_random_generator = test_tools::new_secret_random_generator();
|
||||
let mut encryption_random_generator = test_tools::new_encryption_random_generator();
|
||||
|
||||
let mut thread_rng = rand::thread_rng();
|
||||
|
||||
for _ in 0..10_000 {
|
||||
let lwe_sk =
|
||||
LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator);
|
||||
|
||||
let mut compact_lwe_pk =
|
||||
LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus);
|
||||
|
||||
generate_lwe_compact_public_key(
|
||||
&lwe_sk,
|
||||
&mut compact_lwe_pk,
|
||||
glwe_modular_std_dev,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let msg: u64 = thread_rng.gen();
|
||||
let msg = msg % 16;
|
||||
|
||||
let plaintext = Plaintext(msg << 60);
|
||||
|
||||
let mut output_ct = LweCiphertext::new(
|
||||
0u64,
|
||||
lwe_dimension.to_lwe_size(),
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
|
||||
encrypt_lwe_ciphertext_with_compact_public_key(
|
||||
&compact_lwe_pk,
|
||||
&mut output_ct,
|
||||
plaintext,
|
||||
glwe_modular_std_dev,
|
||||
glwe_modular_std_dev,
|
||||
&mut secret_random_generator,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let decrypted_plaintext = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);
|
||||
|
||||
let signed_decomposer =
|
||||
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
|
||||
let cleartext = signed_decomposer.closest_representable(decrypted_plaintext.0) >> 60;
|
||||
|
||||
assert_eq!(cleartext, msg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_par_compact_lwe_list_public_key_encryption_equivalence() {
|
||||
use rand::Rng;
|
||||
|
||||
let lwe_dimension = LweDimension(2048);
|
||||
let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
|
||||
let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
|
||||
let mut thread_rng = rand::thread_rng();
|
||||
|
||||
for _ in 0..100 {
|
||||
// We'll encrypt between 1 and 4 * lwe_dimension ciphertexts
|
||||
let ct_count: usize = thread_rng.gen();
|
||||
let ct_count = ct_count % (lwe_dimension.0 * 4) + 1;
|
||||
let lwe_ciphertext_count = LweCiphertextCount(ct_count);
|
||||
|
||||
println!("{lwe_dimension:?} {ct_count:?}");
|
||||
|
||||
let seed = test_tools::random_seed();
|
||||
let mut input_plaintext_list =
|
||||
PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0));
|
||||
input_plaintext_list.iter_mut().for_each(|x| {
|
||||
let msg: u64 = thread_rng.gen();
|
||||
*x.0 = (msg % 16) << 60;
|
||||
});
|
||||
|
||||
let par_lwe_ct_list = {
|
||||
let mut deterministic_seeder =
|
||||
DeterministicSeeder::<ActivatedRandomGenerator>::new(seed);
|
||||
let mut secret_random_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
deterministic_seeder.seed(),
|
||||
);
|
||||
let mut encryption_random_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
deterministic_seeder.seed(),
|
||||
&mut deterministic_seeder,
|
||||
);
|
||||
|
||||
let lwe_sk =
|
||||
LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator);
|
||||
|
||||
let mut compact_lwe_pk =
|
||||
LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus);
|
||||
|
||||
generate_lwe_compact_public_key(
|
||||
&lwe_sk,
|
||||
&mut compact_lwe_pk,
|
||||
glwe_modular_std_dev,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_compact_ct_list = LweCompactCiphertextList::new(
|
||||
0u64,
|
||||
lwe_dimension.to_lwe_size(),
|
||||
lwe_ciphertext_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
|
||||
&compact_lwe_pk,
|
||||
&mut output_compact_ct_list,
|
||||
&input_plaintext_list,
|
||||
glwe_modular_std_dev,
|
||||
glwe_modular_std_dev,
|
||||
&mut secret_random_generator,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
output_plaintext_list.as_mut().fill(0u64);
|
||||
|
||||
let lwe_ciphertext_list =
|
||||
output_compact_ct_list.par_expand_into_lwe_ciphertext_list();
|
||||
|
||||
decrypt_lwe_ciphertext_list(
|
||||
&lwe_sk,
|
||||
&lwe_ciphertext_list,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
let signed_decomposer =
|
||||
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
|
||||
output_plaintext_list
|
||||
.iter_mut()
|
||||
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
|
||||
assert_eq!(input_plaintext_list, output_plaintext_list);
|
||||
|
||||
lwe_ciphertext_list
|
||||
};
|
||||
|
||||
let ser_lwe_ct_list = {
|
||||
let mut deterministic_seeder =
|
||||
DeterministicSeeder::<ActivatedRandomGenerator>::new(seed);
|
||||
let mut secret_random_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
deterministic_seeder.seed(),
|
||||
);
|
||||
let mut encryption_random_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
deterministic_seeder.seed(),
|
||||
&mut deterministic_seeder,
|
||||
);
|
||||
|
||||
let lwe_sk =
|
||||
LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator);
|
||||
|
||||
let mut compact_lwe_pk =
|
||||
LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus);
|
||||
|
||||
generate_lwe_compact_public_key(
|
||||
&lwe_sk,
|
||||
&mut compact_lwe_pk,
|
||||
glwe_modular_std_dev,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_compact_ct_list = LweCompactCiphertextList::new(
|
||||
0u64,
|
||||
lwe_dimension.to_lwe_size(),
|
||||
lwe_ciphertext_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
|
||||
&compact_lwe_pk,
|
||||
&mut output_compact_ct_list,
|
||||
&input_plaintext_list,
|
||||
glwe_modular_std_dev,
|
||||
glwe_modular_std_dev,
|
||||
&mut secret_random_generator,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
output_plaintext_list.as_mut().fill(0u64);
|
||||
|
||||
let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list();
|
||||
|
||||
decrypt_lwe_ciphertext_list(
|
||||
&lwe_sk,
|
||||
&lwe_ciphertext_list,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
let signed_decomposer =
|
||||
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
|
||||
output_plaintext_list
|
||||
.iter_mut()
|
||||
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
|
||||
|
||||
assert_eq!(input_plaintext_list, output_plaintext_list);
|
||||
|
||||
lwe_ciphertext_list
|
||||
};
|
||||
|
||||
assert_eq!(ser_lwe_ct_list, par_lwe_ct_list);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +93,7 @@ pub fn generate_lwe_keyswitch_key<Scalar, InputKeyCont, OutputKeyCont, KSKeyCont
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
@@ -115,7 +116,7 @@ pub fn generate_lwe_keyswitch_key<Scalar, InputKeyCont, OutputKeyCont, KSKeyCont
|
||||
// of mapping that back to the native torus
|
||||
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
|
||||
.to_recomposition_summand()
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
encrypt_lwe_ciphertext_list(
|
||||
@@ -255,6 +256,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
@@ -282,7 +284,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
|
||||
// of mapping that back to the native torus
|
||||
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
|
||||
.to_recomposition_summand()
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
encrypt_seeded_lwe_ciphertext_list_with_existing_generator(
|
||||
|
||||
@@ -238,12 +238,13 @@ pub fn lwe_ciphertext_plaintext_add_assign<Scalar, InCont>(
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(rhs.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -313,12 +314,13 @@ pub fn lwe_ciphertext_plaintext_sub_assign<Scalar, InCont>(
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_sub(rhs.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_sub(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -450,3 +451,324 @@ where
|
||||
|
||||
bsk
|
||||
}
|
||||
|
||||
/// Fill a [`seeded LWE bootstrap key`](`SeededLweMultiBitBootstrapKey`) with an actual seeded
|
||||
/// bootstrapping key constructed from an input key [`LWE secret key`](`LweSecretKey`) and an output
|
||||
/// key [`GLWE secret key`](`GlweSecretKey`)
|
||||
///
|
||||
/// Consider using [`par_generate_seeded_lwe_multi_bit_bootstrap_key`] for better key generation
|
||||
/// times.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn generate_seeded_lwe_multi_bit_bootstrap_key<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
OutputCont,
|
||||
NoiseSeeder,
|
||||
>(
|
||||
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
|
||||
output: &mut SeededLweMultiBitBootstrapKey<OutputCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) where
|
||||
Scalar: UnsignedTorus + CastFrom<usize>,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
assert!(
|
||||
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
|
||||
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
|
||||
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
|
||||
input_lwe_secret_key.lwe_dimension(),
|
||||
output.input_lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
|
||||
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
output.glwe_size()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
|
||||
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
|
||||
output_glwe_secret_key.polynomial_size(),
|
||||
output.polynomial_size()
|
||||
);
|
||||
|
||||
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
output.compression_seed().seed,
|
||||
noise_seeder,
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
|
||||
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
|
||||
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
|
||||
input_lwe_secret_key.lwe_dimension(),
|
||||
output.input_lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
|
||||
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
output.glwe_size()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
|
||||
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
|
||||
output_glwe_secret_key.polynomial_size(),
|
||||
output.polynomial_size()
|
||||
);
|
||||
|
||||
let gen_iter = generator
|
||||
.fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
|
||||
output.input_lwe_dimension(),
|
||||
output.decomposition_level_count(),
|
||||
output.glwe_size(),
|
||||
output.polynomial_size(),
|
||||
output.grouping_factor(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let grouping_factor = output.grouping_factor();
|
||||
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
|
||||
|
||||
for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
|
||||
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
|
||||
.zip(
|
||||
input_lwe_secret_key
|
||||
.as_ref()
|
||||
.chunks_exact(grouping_factor.0),
|
||||
)
|
||||
.zip(gen_iter)
|
||||
{
|
||||
let gen_iter = loop_generator.fork_n(ggsw_per_multi_bit_element.0).unwrap();
|
||||
for ((bit_inversion_idx, mut ggsw), mut inner_loop_generator) in
|
||||
ggsw_group.iter_mut().enumerate().zip(gen_iter)
|
||||
{
|
||||
// Use the index of the ggsw as a way to know which bit to invert
|
||||
let key_bits_plaintext = combine_key_bits(bit_inversion_idx, input_key_elements);
|
||||
|
||||
encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator(
|
||||
output_glwe_secret_key,
|
||||
&mut ggsw,
|
||||
Plaintext(key_bits_plaintext),
|
||||
noise_parameters,
|
||||
&mut inner_loop_generator,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a new [`seeded LWE bootstrap key`](`SeededLweMultiBitBootstrapKey`) and fill it with an
|
||||
/// actual seeded bootstrapping key constructed from an input key [`LWE secret key`](`LweSecretKey`)
|
||||
/// and an output key [`GLWE secret key`](`GlweSecretKey`)
|
||||
///
|
||||
/// Consider using [`par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key`] for better
|
||||
/// key generation times.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
NoiseSeeder,
|
||||
>(
|
||||
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
|
||||
decomp_base_log: DecompositionBaseLog,
|
||||
decomp_level_count: DecompositionLevelCount,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
|
||||
where
|
||||
Scalar: UnsignedTorus + CastFrom<usize>,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
let mut bsk = SeededLweMultiBitBootstrapKeyOwned::new(
|
||||
Scalar::ZERO,
|
||||
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
output_glwe_secret_key.polynomial_size(),
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
input_lwe_secret_key.lwe_dimension(),
|
||||
grouping_factor,
|
||||
noise_seeder.seed().into(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
generate_seeded_lwe_multi_bit_bootstrap_key(
|
||||
input_lwe_secret_key,
|
||||
output_glwe_secret_key,
|
||||
&mut bsk,
|
||||
noise_parameters,
|
||||
noise_seeder,
|
||||
);
|
||||
|
||||
bsk
|
||||
}
|
||||
|
||||
/// Parallel variant of [`generate_seeded_lwe_multi_bit_bootstrap_key`], it is recommended to use
|
||||
/// this function for better key generation times as LWE bootstrapping keys can be quite large.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn par_generate_seeded_lwe_multi_bit_bootstrap_key<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
OutputCont,
|
||||
NoiseSeeder,
|
||||
>(
|
||||
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
|
||||
output: &mut SeededLweMultiBitBootstrapKey<OutputCont>,
|
||||
noise_parameters: impl DispersionParameter + Sync,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) where
|
||||
Scalar: UnsignedTorus + CastFrom<usize> + Sync + Send,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
assert!(
|
||||
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
|
||||
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
|
||||
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
|
||||
input_lwe_secret_key.lwe_dimension(),
|
||||
output.input_lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
|
||||
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
output.glwe_size()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
|
||||
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
|
||||
output_glwe_secret_key.polynomial_size(),
|
||||
output.polynomial_size()
|
||||
);
|
||||
|
||||
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
output.compression_seed().seed,
|
||||
noise_seeder,
|
||||
);
|
||||
|
||||
let gen_iter = generator
|
||||
.par_fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
|
||||
output.input_lwe_dimension(),
|
||||
output.decomposition_level_count(),
|
||||
output.glwe_size(),
|
||||
output.polynomial_size(),
|
||||
output.grouping_factor(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let grouping_factor = output.grouping_factor();
|
||||
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
|
||||
|
||||
output
|
||||
.par_iter_mut()
|
||||
.chunks(ggsw_per_multi_bit_element.0)
|
||||
.zip(
|
||||
input_lwe_secret_key
|
||||
.as_ref()
|
||||
.par_chunks_exact(grouping_factor.0),
|
||||
)
|
||||
.zip(gen_iter)
|
||||
.for_each(
|
||||
|((mut ggsw_group, input_key_elements), mut loop_generator)| {
|
||||
let gen_iter = loop_generator
|
||||
.par_fork_n(ggsw_per_multi_bit_element.0)
|
||||
.unwrap();
|
||||
ggsw_group
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.zip(gen_iter)
|
||||
.for_each(|((bit_inversion_idx, ggsw), mut inner_loop_generator)| {
|
||||
// Use the index of the ggsw as a way to know which bit to invert
|
||||
let key_bits_plaintext =
|
||||
combine_key_bits(bit_inversion_idx, input_key_elements);
|
||||
|
||||
par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator(
|
||||
output_glwe_secret_key,
|
||||
ggsw,
|
||||
Plaintext(key_bits_plaintext),
|
||||
noise_parameters,
|
||||
&mut inner_loop_generator,
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key`], it is
|
||||
/// recommended to use this function for better key generation times as LWE bootstrapping keys can
|
||||
/// be quite large.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
NoiseSeeder,
|
||||
>(
|
||||
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
|
||||
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
|
||||
decomp_base_log: DecompositionBaseLog,
|
||||
decomp_level_count: DecompositionLevelCount,
|
||||
noise_parameters: impl DispersionParameter + Sync,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
|
||||
where
|
||||
Scalar: UnsignedTorus + CastFrom<usize> + Sync + Send,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar> + Sync,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
let mut bsk = SeededLweMultiBitBootstrapKeyOwned::new(
|
||||
Scalar::ZERO,
|
||||
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
output_glwe_secret_key.polynomial_size(),
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
input_lwe_secret_key.lwe_dimension(),
|
||||
grouping_factor,
|
||||
noise_seeder.seed().into(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
par_generate_seeded_lwe_multi_bit_bootstrap_key(
|
||||
input_lwe_secret_key,
|
||||
output_glwe_secret_key,
|
||||
&mut bsk,
|
||||
noise_parameters,
|
||||
noise_seeder,
|
||||
);
|
||||
|
||||
bsk
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext;
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::*;
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
@@ -9,11 +10,89 @@ use crate::core_crypto::fft_impl::common::pbs_modulus_switch;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
|
||||
add_external_product_assign, add_external_product_assign_scratch, update_with_fmadd,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
use concrete_fft::c64;
|
||||
use std::sync::{mpsc, Condvar, Mutex};
|
||||
use std::thread;
|
||||
|
||||
pub fn prepare_multi_bit_ggsw_mem_optimized<
|
||||
Scalar,
|
||||
GgswBufferCont,
|
||||
GgswGroupCont,
|
||||
PolyCont,
|
||||
FourierPolyCont,
|
||||
>(
|
||||
fourier_ggsw_buffer: &mut FourierGgswCiphertext<GgswBufferCont>,
|
||||
ggsw_group: &[FourierGgswCiphertext<GgswGroupCont>],
|
||||
lwe_mask_elements: &[Scalar],
|
||||
a_monomial: &mut Polynomial<PolyCont>,
|
||||
fourier_a_monomial: &mut FourierPolynomial<FourierPolyCont>,
|
||||
fft: FftView<'_>,
|
||||
buffers: &mut ComputationBuffers,
|
||||
) where
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
|
||||
GgswBufferCont: ContainerMut<Element = c64>,
|
||||
GgswGroupCont: Container<Element = c64>,
|
||||
PolyCont: ContainerMut<Element = Scalar>,
|
||||
FourierPolyCont: ContainerMut<Element = c64>,
|
||||
{
|
||||
let mut ggsw_group_iter = ggsw_group.iter();
|
||||
|
||||
// Keygen guarantees the first term is a constant term of the polynomial, no
|
||||
// polynomial multiplication required
|
||||
let ggsw_a_none = ggsw_group_iter.next().unwrap();
|
||||
|
||||
fourier_ggsw_buffer
|
||||
.as_mut_view()
|
||||
.data()
|
||||
.copy_from_slice(ggsw_a_none.as_view().data());
|
||||
|
||||
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
|
||||
|
||||
let polynomial_size = a_monomial.polynomial_size();
|
||||
|
||||
for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() {
|
||||
// We already processed the first ggsw, advance the index by 1
|
||||
let ggsw_idx = ggsw_idx + 1;
|
||||
|
||||
// Select the proper mask elements to build the monomial degree depending on
|
||||
// the order the GGSW were generated in, using the bits from mask_idx and
|
||||
// ggsw_idx as selector bits
|
||||
let mut monomial_degree = Scalar::ZERO;
|
||||
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
|
||||
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
|
||||
let selection_bit: Scalar = Scalar::cast_from((ggsw_idx >> mask_position) & 1);
|
||||
monomial_degree =
|
||||
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
|
||||
}
|
||||
|
||||
let switched_degree = pbs_modulus_switch(
|
||||
monomial_degree,
|
||||
polynomial_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
a_monomial.as_mut()[0] = Scalar::ONE;
|
||||
a_monomial.as_mut()[1..].fill(Scalar::ZERO);
|
||||
polynomial_wrapping_monic_monomial_mul_assign(a_monomial, MonomialDegree(switched_degree));
|
||||
|
||||
fft.forward_as_integer(
|
||||
fourier_a_monomial.as_mut_view(),
|
||||
a_monomial.as_view(),
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
update_with_fmadd(
|
||||
multi_bit_fourier_ggsw,
|
||||
fourier_ggsw.as_view().data(),
|
||||
fourier_a_monomial.as_view().data,
|
||||
false,
|
||||
polynomial_size.to_fourier_polynomial_size().0,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
|
||||
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
|
||||
/// key`](`LweMultiBitBootstrapKey`) in the fourier domain.
|
||||
@@ -270,6 +349,15 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert!(
|
||||
thread_count.0 != 0,
|
||||
"Got thread_count == 0, this is not supported"
|
||||
);
|
||||
|
||||
assert!(accumulator
|
||||
.ciphertext_modulus()
|
||||
.is_compatible_with_native_modulus());
|
||||
|
||||
let (lwe_mask, lwe_body) = input.get_mask_and_body();
|
||||
|
||||
// No way to chunk the result of ggsw_iter at the moment
|
||||
@@ -338,10 +426,7 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
|
||||
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
|
||||
|
||||
let mut unit_polynomial =
|
||||
Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
|
||||
unit_polynomial.as_mut()[0] = Scalar::ONE;
|
||||
let mut a_monomial = unit_polynomial.clone();
|
||||
let mut a_monomial = Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
|
||||
let mut fourier_a_monomial = FourierPolynomial::new(multi_bit_bsk.polynomial_size());
|
||||
|
||||
let work_queue = &work_queue;
|
||||
@@ -367,64 +452,15 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
|
||||
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
|
||||
|
||||
let mut ggsw_group_iter = ggsw_group.iter();
|
||||
|
||||
// Keygen guarantees the first term is a constant term of the polynomial, no
|
||||
// polynomial multiplication required
|
||||
let ggsw_a_none = ggsw_group_iter.next().unwrap();
|
||||
|
||||
fourier_ggsw_buffer
|
||||
.as_mut_view()
|
||||
.data()
|
||||
.copy_from_slice(ggsw_a_none.as_view().data());
|
||||
|
||||
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
|
||||
|
||||
for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() {
|
||||
// We already processed the first ggsw, advance the index by 1
|
||||
let ggsw_idx = ggsw_idx + 1;
|
||||
|
||||
// Select the proper mask elements to build the monomial degree depending on
|
||||
// the order the GGSW were generated in, using the bits from mask_idx and
|
||||
// ggsw_idx as selector bits
|
||||
let mut monomial_degree = Scalar::ZERO;
|
||||
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
|
||||
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
|
||||
let selection_bit: Scalar =
|
||||
Scalar::cast_from((ggsw_idx >> mask_position) & 1);
|
||||
monomial_degree =
|
||||
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
|
||||
}
|
||||
|
||||
let switched_degree = pbs_modulus_switch(
|
||||
monomial_degree,
|
||||
lut_poly_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
a_monomial
|
||||
.as_mut()
|
||||
.copy_from_slice(unit_polynomial.as_ref());
|
||||
polynomial_wrapping_monic_monomial_mul_assign(
|
||||
&mut a_monomial,
|
||||
MonomialDegree(switched_degree),
|
||||
);
|
||||
|
||||
fft.forward_as_integer(
|
||||
fourier_a_monomial.as_mut_view(),
|
||||
a_monomial.as_view(),
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
update_with_fmadd(
|
||||
multi_bit_fourier_ggsw,
|
||||
fourier_ggsw.as_view().data(),
|
||||
fourier_a_monomial.as_view().data,
|
||||
false,
|
||||
lut_poly_size.to_fourier_polynomial_size().0,
|
||||
);
|
||||
}
|
||||
prepare_multi_bit_ggsw_mem_optimized(
|
||||
&mut fourier_ggsw_buffer,
|
||||
ggsw_group,
|
||||
lwe_mask_elements,
|
||||
&mut a_monomial,
|
||||
&mut fourier_a_monomial,
|
||||
fft,
|
||||
&mut buffers,
|
||||
);
|
||||
|
||||
// Drop the lock before we wake other threads
|
||||
drop(fourier_ggsw_buffer);
|
||||
@@ -801,6 +837,11 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert!(
|
||||
thread_count.0 != 0,
|
||||
"Got thread_count == 0, this is not supported"
|
||||
);
|
||||
|
||||
let mut local_accumulator = GlweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
accumulator.glwe_size(),
|
||||
@@ -815,3 +856,430 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
|
||||
pub fn std_prepare_multi_bit_ggsw<Scalar, GgswBufferCont, TmpGgswBufferCont, GgswGroupCont>(
|
||||
multi_bit_ggsw: &mut GgswCiphertext<GgswBufferCont>,
|
||||
tmp_ggsw_buffer: &mut GgswCiphertext<TmpGgswBufferCont>,
|
||||
ggsw_group: &GgswCiphertextList<GgswGroupCont>,
|
||||
lwe_mask_elements: &[Scalar],
|
||||
) where
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
|
||||
GgswBufferCont: ContainerMut<Element = Scalar>,
|
||||
TmpGgswBufferCont: ContainerMut<Element = Scalar>,
|
||||
GgswGroupCont: Container<Element = Scalar>,
|
||||
{
|
||||
let mut ggsw_group_iter = ggsw_group.iter();
|
||||
|
||||
// Keygen guarantees the first term is a constant term of the polynomial, no
|
||||
// polynomial multiplication required
|
||||
let ggsw_a_none = ggsw_group_iter.next().unwrap();
|
||||
|
||||
multi_bit_ggsw
|
||||
.as_mut()
|
||||
.copy_from_slice(ggsw_a_none.as_ref());
|
||||
|
||||
let polynomial_size = multi_bit_ggsw.polynomial_size();
|
||||
|
||||
for (ggsw_idx, std_ggsw) in ggsw_group_iter.enumerate() {
|
||||
// We already processed the first ggsw, advance the index by 1
|
||||
let ggsw_idx = ggsw_idx + 1;
|
||||
|
||||
// Select the proper mask elements to build the monomial degree depending on
|
||||
// the order the GGSW were generated in, using the bits from mask_idx and
|
||||
// ggsw_idx as selector bits
|
||||
let mut monomial_degree = Scalar::ZERO;
|
||||
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
|
||||
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
|
||||
let selection_bit: Scalar = Scalar::cast_from((ggsw_idx >> mask_position) & 1);
|
||||
monomial_degree =
|
||||
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
|
||||
}
|
||||
|
||||
let switched_degree = pbs_modulus_switch(
|
||||
monomial_degree,
|
||||
polynomial_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
tmp_ggsw_buffer
|
||||
.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.zip(std_ggsw.as_polynomial_list().iter())
|
||||
.for_each(|(mut tmp_polynomial, input_polynomial)| {
|
||||
polynomial_wrapping_monic_monomial_mul(
|
||||
&mut tmp_polynomial,
|
||||
&input_polynomial,
|
||||
MonomialDegree(switched_degree),
|
||||
);
|
||||
});
|
||||
|
||||
slice_wrapping_add_assign(multi_bit_ggsw.as_mut(), tmp_ggsw_buffer.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn std_multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
accumulator: &mut GlweCiphertext<OutputCont>,
|
||||
multi_bit_bsk: &LweMultiBitBootstrapKey<KeyCont>,
|
||||
thread_count: ThreadCount,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync + Send,
|
||||
InputCont: Container<Element = Scalar> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
KeyCont: Container<Element = Scalar> + Sync,
|
||||
{
|
||||
assert_eq!(
|
||||
input.lwe_size().to_lwe_dimension(),
|
||||
multi_bit_bsk.input_lwe_dimension(),
|
||||
"Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \
|
||||
FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
|
||||
input.lwe_size().to_lwe_dimension(),
|
||||
multi_bit_bsk.input_lwe_dimension(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.glwe_size(),
|
||||
multi_bit_bsk.glwe_size(),
|
||||
"Mimatched GlweSize. Accumulator GlweSize {:?}. \
|
||||
FourierLweMultiBitBootstrapKey GlweSize {:?}.",
|
||||
accumulator.glwe_size(),
|
||||
multi_bit_bsk.glwe_size(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.polynomial_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
"Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \
|
||||
FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
|
||||
accumulator.polynomial_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
"Mimatched CiphertextModulus. Accumulator CiphertextModulus {:?}. \
|
||||
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
|
||||
accumulator.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
"Mimatched CiphertextModulus. LweCiphertext CiphertextModulus {:?}. \
|
||||
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
|
||||
input.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let (lwe_mask, lwe_body) = input.get_mask_and_body();
|
||||
|
||||
// No way to chunk the result of ggsw_iter at the moment
|
||||
// let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();
|
||||
let mut work_queue = Vec::with_capacity(multi_bit_bsk.multi_bit_input_lwe_dimension().0);
|
||||
|
||||
let grouping_factor = multi_bit_bsk.grouping_factor();
|
||||
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
|
||||
|
||||
for (lwe_mask_elements, ggsw_group) in lwe_mask
|
||||
.as_ref()
|
||||
.chunks_exact(grouping_factor.0)
|
||||
.zip(multi_bit_bsk.chunks_exact(ggsw_per_multi_bit_element.0))
|
||||
{
|
||||
work_queue.push((lwe_mask_elements, ggsw_group));
|
||||
}
|
||||
|
||||
assert!(work_queue.len() == lwe_mask.lwe_dimension().0 / grouping_factor.0);
|
||||
|
||||
let work_queue = Mutex::new(work_queue);
|
||||
|
||||
// Each producer thread works in a dedicated slot of the buffer
|
||||
let thread_buffers: usize = thread_count.0;
|
||||
|
||||
let lut_poly_size = accumulator.polynomial_size();
|
||||
let monomial_degree = pbs_modulus_switch(
|
||||
*lwe_body.data,
|
||||
lut_poly_size,
|
||||
ModulusSwitchOffset(0),
|
||||
LutCountLog(0),
|
||||
);
|
||||
|
||||
// Modulus switching
|
||||
accumulator
|
||||
.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.for_each(|mut poly| {
|
||||
polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut poly,
|
||||
MonomialDegree(monomial_degree),
|
||||
)
|
||||
});
|
||||
|
||||
let fourier_multi_bit_ggsw_buffers = (0..thread_buffers)
|
||||
.map(|_| {
|
||||
(
|
||||
Mutex::new(false),
|
||||
Condvar::new(),
|
||||
Mutex::new(FourierGgswCiphertext::new(
|
||||
multi_bit_bsk.glwe_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
multi_bit_bsk.decomposition_base_log(),
|
||||
multi_bit_bsk.decomposition_level_count(),
|
||||
)),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let (tx, rx) = mpsc::channel::<usize>();
|
||||
|
||||
let fft = Fft::new(multi_bit_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
thread::scope(|s| {
|
||||
let produce_multi_bit_fourier_ggsw = |thread_id: usize, tx: mpsc::Sender<usize>| {
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
|
||||
|
||||
let mut std_ggsw_buffer = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
multi_bit_bsk.glwe_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
multi_bit_bsk.decomposition_base_log(),
|
||||
multi_bit_bsk.decomposition_level_count(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut tmp_ggsw_buffer = GgswCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
multi_bit_bsk.glwe_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
multi_bit_bsk.decomposition_base_log(),
|
||||
multi_bit_bsk.decomposition_level_count(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let work_queue = &work_queue;
|
||||
|
||||
let dest_idx = thread_id;
|
||||
let (ready_for_consumer_lock, condvar, fourier_ggsw_buffer) =
|
||||
&fourier_multi_bit_ggsw_buffers[dest_idx];
|
||||
|
||||
loop {
|
||||
let maybe_work = {
|
||||
let mut queue_lock = work_queue.lock().unwrap();
|
||||
queue_lock.pop()
|
||||
};
|
||||
|
||||
let Some((lwe_mask_elements, ggsw_group)) = maybe_work else {break};
|
||||
let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();
|
||||
|
||||
// Wait while the buffer is not ready for processing and wait on the condvar
|
||||
// to get notified when we can start processing again
|
||||
while *ready_for_consumer {
|
||||
ready_for_consumer = condvar.wait(ready_for_consumer).unwrap();
|
||||
}
|
||||
|
||||
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
|
||||
|
||||
std_prepare_multi_bit_ggsw(
|
||||
&mut std_ggsw_buffer,
|
||||
&mut tmp_ggsw_buffer,
|
||||
&ggsw_group,
|
||||
lwe_mask_elements,
|
||||
);
|
||||
|
||||
fourier_ggsw_buffer.as_mut_view().fill_with_forward_fourier(
|
||||
std_ggsw_buffer.as_view(),
|
||||
fft,
|
||||
buffers.stack(),
|
||||
);
|
||||
|
||||
// Drop the lock before we wake other threads
|
||||
drop(fourier_ggsw_buffer);
|
||||
|
||||
*ready_for_consumer = true;
|
||||
tx.send(dest_idx).unwrap();
|
||||
|
||||
// Wake threads waiting on the condvar
|
||||
condvar.notify_all();
|
||||
}
|
||||
};
|
||||
|
||||
let threads: Vec<_> = (0..thread_count.0)
|
||||
.map(|id| {
|
||||
let tx = tx.clone();
|
||||
s.spawn(move || produce_multi_bit_fourier_ggsw(id, tx))
|
||||
})
|
||||
.collect();
|
||||
|
||||
// We initialize ct0 for the successive external products
|
||||
let ct0 = accumulator;
|
||||
let mut ct1 = GlweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
ct0.glwe_size(),
|
||||
ct0.polynomial_size(),
|
||||
ct0.ciphertext_modulus(),
|
||||
);
|
||||
let ct1 = &mut ct1;
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
buffers.resize(
|
||||
add_external_product_assign_scratch::<Scalar>(
|
||||
multi_bit_bsk.glwe_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
let mut src_idx = 1usize;
|
||||
|
||||
for _ in 0..multi_bit_bsk.multi_bit_input_lwe_dimension().0 {
|
||||
src_idx ^= 1;
|
||||
let idx = rx.recv().unwrap();
|
||||
let (ready_lock, condvar, multi_bit_fourier_ggsw) =
|
||||
&fourier_multi_bit_ggsw_buffers[idx];
|
||||
|
||||
let (src_ct, mut dst_ct) = if src_idx == 0 {
|
||||
(ct0.as_view(), ct1.as_mut_view())
|
||||
} else {
|
||||
(ct1.as_view(), ct0.as_mut_view())
|
||||
};
|
||||
|
||||
dst_ct.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let mut ready = ready_lock.lock().unwrap();
|
||||
assert!(*ready);
|
||||
|
||||
let multi_bit_fourier_ggsw = multi_bit_fourier_ggsw.lock().unwrap();
|
||||
add_external_product_assign(
|
||||
dst_ct,
|
||||
multi_bit_fourier_ggsw.as_view(),
|
||||
src_ct,
|
||||
fft,
|
||||
buffers.stack(),
|
||||
);
|
||||
drop(multi_bit_fourier_ggsw);
|
||||
|
||||
*ready = false;
|
||||
// Wake a single producer thread sleeping on the condvar (only one will get to work
|
||||
// anyways)
|
||||
condvar.notify_one();
|
||||
}
|
||||
|
||||
if src_idx == 0 {
|
||||
ct0.as_mut().copy_from_slice(ct1.as_ref());
|
||||
}
|
||||
|
||||
let ciphertext_modulus = ct0.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
ct0.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
|
||||
threads.into_iter().for_each(|t| t.join().unwrap());
|
||||
});
|
||||
}
|
||||
|
||||
pub fn std_multi_bit_programmable_bootstrap_lwe_ciphertext<
|
||||
Scalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
AccCont,
|
||||
KeyCont,
|
||||
>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
multi_bit_bsk: &LweMultiBitBootstrapKey<KeyCont>,
|
||||
thread_count: ThreadCount,
|
||||
) where
|
||||
// CastInto required for PBS modulus switch which returns a usize
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync + Send,
|
||||
InputCont: Container<Element = Scalar> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
AccCont: Container<Element = Scalar>,
|
||||
KeyCont: Container<Element = Scalar> + Sync,
|
||||
{
|
||||
assert_eq!(
|
||||
input.lwe_size().to_lwe_dimension(),
|
||||
multi_bit_bsk.input_lwe_dimension(),
|
||||
"Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \
|
||||
FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
|
||||
input.lwe_size().to_lwe_dimension(),
|
||||
multi_bit_bsk.input_lwe_dimension(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
output.lwe_size().to_lwe_dimension(),
|
||||
multi_bit_bsk.output_lwe_dimension(),
|
||||
"Mimatched output LweDimension. LweCiphertext output LweDimension {:?}. \
|
||||
FourierLweMultiBitBootstrapKey output LweDimension {:?}.",
|
||||
output.lwe_size().to_lwe_dimension(),
|
||||
multi_bit_bsk.output_lwe_dimension(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.glwe_size(),
|
||||
multi_bit_bsk.glwe_size(),
|
||||
"Mimatched GlweSize. Accumulator GlweSize {:?}. \
|
||||
FourierLweMultiBitBootstrapKey GlweSize {:?}.",
|
||||
accumulator.glwe_size(),
|
||||
multi_bit_bsk.glwe_size(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.polynomial_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
"Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \
|
||||
FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
|
||||
accumulator.polynomial_size(),
|
||||
multi_bit_bsk.polynomial_size(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
input.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
"Mimatched CiphertextModulus. LweCiphertext CiphertextModulus {:?}. \
|
||||
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
|
||||
input.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
accumulator.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
"Mimatched CiphertextModulus. Accumulator CiphertextModulus {:?}. \
|
||||
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
|
||||
accumulator.ciphertext_modulus(),
|
||||
multi_bit_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut local_accumulator = GlweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
accumulator.glwe_size(),
|
||||
accumulator.polynomial_size(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
local_accumulator
|
||||
.as_mut()
|
||||
.copy_from_slice(accumulator.as_ref());
|
||||
|
||||
std_multi_bit_blind_rotate_assign(input, &mut local_accumulator, multi_bit_bsk, thread_count);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
|
||||
@@ -349,8 +349,6 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
@@ -487,6 +485,8 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
|
||||
InputGlweCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(out.ciphertext_modulus(), glwe.ciphertext_modulus());
|
||||
let ciphertext_modulus = out.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
impl_add_external_product_assign(
|
||||
out.as_mut_view(),
|
||||
@@ -496,7 +496,6 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
|
||||
stack,
|
||||
);
|
||||
|
||||
let ciphertext_modulus = out.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
@@ -776,6 +775,8 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
|
||||
GgswCont: Container<Element = c64>,
|
||||
{
|
||||
assert_eq!(ct0.ciphertext_modulus(), ct1.ciphertext_modulus());
|
||||
let ciphertext_modulus = ct0.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
cmux(
|
||||
ct0.as_mut_view(),
|
||||
@@ -785,7 +786,6 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
|
||||
stack,
|
||||
);
|
||||
|
||||
let ciphertext_modulus = ct0.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
|
||||
86
tfhe/src/core_crypto/algorithms/misc.rs
Normal file
86
tfhe/src/core_crypto/algorithms/misc.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
//! Miscellaneous algorithms.
|
||||
|
||||
use crate::core_crypto::prelude::*;
|
||||
|
||||
#[inline]
|
||||
pub fn divide_round_to_u128<Scalar>(numerator: Scalar, denominator: Scalar) -> u128
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
let numerator_128: u128 = numerator.cast_into();
|
||||
let half_denominator: u128 = (denominator / Scalar::TWO).cast_into();
|
||||
let denominator_128: u128 = denominator.cast_into();
|
||||
// That's the rounding
|
||||
(numerator_128 + half_denominator) / denominator_128
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn divide_round_to_u128_custom_mod<Scalar>(
|
||||
numerator: Scalar,
|
||||
denominator: Scalar,
|
||||
modulus: u128,
|
||||
) -> u128
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
let numerator_128: u128 = numerator.cast_into();
|
||||
let half_denominator: u128 = (denominator / Scalar::TWO).cast_into();
|
||||
let denominator_128: u128 = denominator.cast_into();
|
||||
// That's the rounding
|
||||
((numerator_128 + half_denominator) % modulus) / denominator_128
|
||||
}
|
||||
|
||||
pub fn odd_modular_inverse_pow_2<Scalar>(odd_value_to_invert: Scalar, log2_modulo: usize) -> Scalar
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
let t = log2_modulo.ilog2() + if log2_modulo.is_power_of_two() { 0 } else { 1 };
|
||||
let mut y = Scalar::ONE;
|
||||
let e = odd_value_to_invert;
|
||||
|
||||
for i in 1..=t {
|
||||
// 1 << (1 << i) == 2 ^ {2 ^ i}
|
||||
let curr_mod = Scalar::ONE.shl(1 << i);
|
||||
// y = y * (2 - y * e) mod 2 ^ {2 ^ i}
|
||||
// Here using wrapping ops is ok as the modulus used is a power of 2, as long as 2 ^ {2 ^ i}
|
||||
// is smaller than Scalar::BITS, we are good to go, the discarded values would not have been
|
||||
// Used anyways, and 2 ^ {2 ^ i} is compatible with a native modulus
|
||||
y = (y.wrapping_mul(Scalar::TWO.wrapping_sub(y.wrapping_mul(e)))).wrapping_rem(curr_mod);
|
||||
}
|
||||
|
||||
y.wrapping_rem(Scalar::ONE.shl(log2_modulo))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_divide_round() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
const NB_TESTS: usize = 1_000_000_000;
|
||||
const SCALING: f64 = u64::MAX as f64;
|
||||
for _ in 0..NB_TESTS {
|
||||
let num: f64 = rng.gen();
|
||||
let mut denom = 0.0f64;
|
||||
while denom == 0.0f64 {
|
||||
denom = rng.gen();
|
||||
}
|
||||
|
||||
let num = (num * SCALING).round();
|
||||
let denom = (denom * SCALING).round();
|
||||
|
||||
let rounded = (num / denom).round();
|
||||
let expected_rounded_u64: u64 = rounded as u64;
|
||||
|
||||
let num_u64: u64 = num as u64;
|
||||
let denom_u64: u64 = denom as u64;
|
||||
|
||||
// sanity check
|
||||
assert_eq!(num, num_u64 as f64);
|
||||
assert_eq!(denom, denom_u64 as f64);
|
||||
|
||||
let rounded_u128 = divide_round_to_u128(num_u64, denom_u64);
|
||||
|
||||
assert_eq!(expected_rounded_u64, rounded_u128 as u64);
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,8 @@ pub mod glwe_sample_extraction;
|
||||
pub mod glwe_secret_key_generation;
|
||||
pub mod lwe_bootstrap_key_conversion;
|
||||
pub mod lwe_bootstrap_key_generation;
|
||||
pub mod lwe_compact_ciphertext_list_expansion;
|
||||
pub mod lwe_compact_public_key_generation;
|
||||
pub mod lwe_encryption;
|
||||
pub mod lwe_keyswitch;
|
||||
pub mod lwe_keyswitch_key_generation;
|
||||
@@ -22,6 +24,7 @@ pub mod lwe_programmable_bootstrapping;
|
||||
pub mod lwe_public_key_generation;
|
||||
pub mod lwe_secret_key_generation;
|
||||
pub mod lwe_wopbs;
|
||||
pub mod misc;
|
||||
pub mod polynomial_algorithms;
|
||||
pub mod seeded_ggsw_ciphertext_decompression;
|
||||
pub mod seeded_ggsw_ciphertext_list_decompression;
|
||||
@@ -30,7 +33,9 @@ pub mod seeded_glwe_ciphertext_list_decompression;
|
||||
pub mod seeded_lwe_bootstrap_key_decompression;
|
||||
pub mod seeded_lwe_ciphertext_decompression;
|
||||
pub mod seeded_lwe_ciphertext_list_decompression;
|
||||
pub mod seeded_lwe_compact_public_key_decompression;
|
||||
pub mod seeded_lwe_keyswitch_key_decompression;
|
||||
pub mod seeded_lwe_multi_bit_bootstrap_key_decompression;
|
||||
pub mod seeded_lwe_public_key_decompression;
|
||||
pub mod slice_algorithms;
|
||||
|
||||
@@ -46,6 +51,8 @@ pub use glwe_sample_extraction::*;
|
||||
pub use glwe_secret_key_generation::*;
|
||||
pub use lwe_bootstrap_key_conversion::*;
|
||||
pub use lwe_bootstrap_key_generation::*;
|
||||
pub use lwe_compact_ciphertext_list_expansion::*;
|
||||
pub use lwe_compact_public_key_generation::*;
|
||||
pub use lwe_encryption::*;
|
||||
pub use lwe_keyswitch::*;
|
||||
pub use lwe_keyswitch_key_generation::*;
|
||||
@@ -66,5 +73,7 @@ pub use seeded_glwe_ciphertext_list_decompression::*;
|
||||
pub use seeded_lwe_bootstrap_key_decompression::*;
|
||||
pub use seeded_lwe_ciphertext_decompression::*;
|
||||
pub use seeded_lwe_ciphertext_list_decompression::*;
|
||||
pub use seeded_lwe_compact_public_key_decompression::*;
|
||||
pub use seeded_lwe_keyswitch_key_decompression::*;
|
||||
pub use seeded_lwe_multi_bit_bootstrap_key_decompression::*;
|
||||
pub use seeded_lwe_public_key_decompression::*;
|
||||
|
||||
@@ -259,6 +259,124 @@ pub fn polynomial_wrapping_monic_monomial_mul_assign<Scalar, OutputCont>(
|
||||
.for_each(|a| *a = a.wrapping_neg());
|
||||
}
|
||||
|
||||
/// Divides (mod $(X^{N}+1)$), the input polynomial with a monic monomial of a given degree i.e.
|
||||
/// $X^{degree}$.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// Computations wrap around (similar to computing modulo $2^{n\_{bits}}$) when exceeding the
|
||||
/// unsigned integer capacity.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*;
|
||||
/// use tfhe::core_crypto::commons::parameters::*;
|
||||
/// use tfhe::core_crypto::entities::*;
|
||||
/// let input = Polynomial::from_container(vec![1u8, 2, 3]);
|
||||
/// let mut output = Polynomial::from_container(vec![0, 0, 0]);
|
||||
/// polynomial_wrapping_monic_monomial_div(&mut output, &input, MonomialDegree(2));
|
||||
/// assert_eq!(output.as_ref(), &[3, 255, 254]);
|
||||
/// ```
|
||||
pub fn polynomial_wrapping_monic_monomial_div<Scalar, OutputCont, InputCont>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
input: &Polynomial<InputCont>,
|
||||
monomial_degree: MonomialDegree,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == input.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input polynomial size {:?}.",
|
||||
output.polynomial_size(),
|
||||
input.polynomial_size(),
|
||||
);
|
||||
|
||||
let polynomial_len = output.container_len();
|
||||
let remaining_degree = monomial_degree.0 % output.as_ref().container_len();
|
||||
|
||||
let src_slice = &input[remaining_degree..];
|
||||
let src_slice_len = src_slice.len();
|
||||
let dst_slice = &mut output[..src_slice_len];
|
||||
dst_slice.copy_from_slice(src_slice);
|
||||
|
||||
for (dst, &src) in output[polynomial_len - remaining_degree..]
|
||||
.iter_mut()
|
||||
.zip(input[..remaining_degree].iter())
|
||||
{
|
||||
*dst = src.wrapping_neg();
|
||||
}
|
||||
|
||||
let full_cycles_count = monomial_degree.0 / polynomial_len;
|
||||
if full_cycles_count % 2 != 0 {
|
||||
output
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|a| *a = a.wrapping_neg());
|
||||
}
|
||||
}
|
||||
|
||||
/// Multiply (mod $(X^{N}+1)$), the input polynomial with a monic monomial of a given degree i.e.
|
||||
/// $X^{degree}$.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// Computations wrap around (similar to computing modulo $2^{n\_{bits}}$) when exceeding the
|
||||
/// unsigned integer capacity.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*;
|
||||
/// use tfhe::core_crypto::commons::parameters::*;
|
||||
/// use tfhe::core_crypto::entities::*;
|
||||
/// let input = Polynomial::from_container(vec![1u8, 2, 3]);
|
||||
/// let mut output = Polynomial::from_container(vec![0, 0, 0]);
|
||||
/// polynomial_wrapping_monic_monomial_mul(&mut output, &input, MonomialDegree(2));
|
||||
/// assert_eq!(output.as_ref(), &[254, 253, 1]);
|
||||
/// ```
|
||||
pub fn polynomial_wrapping_monic_monomial_mul<Scalar, OutputCont, InputCont>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
input: &Polynomial<InputCont>,
|
||||
monomial_degree: MonomialDegree,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == input.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input polynomial size {:?}.",
|
||||
output.polynomial_size(),
|
||||
input.polynomial_size(),
|
||||
);
|
||||
|
||||
let polynomial_len = output.container_len();
|
||||
let remaining_degree = monomial_degree.0 % output.as_ref().container_len();
|
||||
|
||||
for (dst, &src) in output[..remaining_degree]
|
||||
.iter_mut()
|
||||
.zip(input[polynomial_len - remaining_degree..].iter())
|
||||
{
|
||||
*dst = src.wrapping_neg();
|
||||
}
|
||||
|
||||
let dst_slice = &mut output[remaining_degree..];
|
||||
let dst_slice_len = dst_slice.len();
|
||||
let src_slice = &input[..dst_slice_len];
|
||||
dst_slice.copy_from_slice(src_slice);
|
||||
|
||||
let full_cycles_count = monomial_degree.0 / polynomial_len;
|
||||
if full_cycles_count % 2 != 0 {
|
||||
output
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|a| *a = a.wrapping_neg());
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtract the sum of the element-wise product between two lists of polynomials, to the output
|
||||
/// polynomial.
|
||||
///
|
||||
|
||||
@@ -34,13 +34,14 @@ pub fn decompress_seeded_glwe_ciphertext_with_existing_generator<
|
||||
let (mut output_mask, mut output_body) = output_glwe.get_mut_mask_and_body();
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// generate a uniformly random mask
|
||||
generator.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
output_body
|
||||
|
||||
@@ -32,6 +32,7 @@ pub fn decompress_seeded_glwe_ciphertext_list_with_existing_generator<
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_list.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
for (mut glwe_out, body_in) in output_list.iter_mut().zip(input_seeded_list.iter()) {
|
||||
let (mut output_mask, mut output_body) = glwe_out.get_mut_mask_and_body();
|
||||
@@ -42,7 +43,7 @@ pub fn decompress_seeded_glwe_ciphertext_list_with_existing_generator<
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
output_body.as_mut().copy_from_slice(body_in.as_ref());
|
||||
|
||||
@@ -26,6 +26,7 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_lwe.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
let (mut output_mask, output_body) = output_lwe.get_mut_mask_and_body();
|
||||
|
||||
// generate a uniformly random mask
|
||||
@@ -33,7 +34,7 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*output_body.data = *input_seeded_lwe.get_body().data;
|
||||
|
||||
@@ -5,6 +5,8 @@ use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Convenience function to share the core logic of the decompression algorithm for
|
||||
/// [`SeededLweCiphertextList`] between all functions needing it.
|
||||
pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
@@ -32,6 +34,7 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_list.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
for (mut lwe_out, body_in) in output_list.iter_mut().zip(input_seeded_list.iter()) {
|
||||
let (mut output_mask, output_body) = lwe_out.get_mut_mask_and_body();
|
||||
@@ -42,13 +45,96 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*output_body.data = *body_in.data;
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience function to share the core logic of the decompression algorithm for
|
||||
/// [`SeededLweCiphertextList`] between all functions needing it.
|
||||
pub fn par_decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
Scalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
Gen,
|
||||
>(
|
||||
output_list: &mut LweCiphertextList<OutputCont>,
|
||||
input_seeded_list: &SeededLweCiphertextList<InputCont>,
|
||||
generator: &mut RandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus + Send + Sync,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ParallelByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_list.ciphertext_modulus(),
|
||||
input_seeded_list.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus \
|
||||
between input SeededLweCiphertextList ({:?}) and output LweCiphertextList ({:?})",
|
||||
input_seeded_list.ciphertext_modulus(),
|
||||
output_list.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_list.ciphertext_modulus();
|
||||
|
||||
let bytes_per_child = (Scalar::BITS / 8) * (output_list.lwe_size().to_lwe_dimension().0);
|
||||
let generators = generator
|
||||
.par_try_fork(output_list.lwe_ciphertext_count().0, bytes_per_child)
|
||||
.expect("Failed to fork generator");
|
||||
|
||||
output_list
|
||||
.par_iter_mut()
|
||||
.zip(input_seeded_list.par_iter())
|
||||
.zip(generators)
|
||||
.for_each(|((mut lwe_out, body_in), mut generator)| {
|
||||
let (mut output_mask, output_body) = lwe_out.get_mut_mask_and_body();
|
||||
|
||||
// generate a uniformly random mask
|
||||
generator.fill_slice_with_random_uniform_custom_mod(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*output_body.data = *body_in.data;
|
||||
});
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLweCiphertextList`], without consuming it, into a standard
|
||||
/// [`LweCiphertextList`].
|
||||
pub fn par_decompress_seeded_lwe_ciphertext_list<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_list: &mut LweCiphertextList<OutputCont>,
|
||||
input_seeded_list: &SeededLweCiphertextList<InputCont>,
|
||||
) where
|
||||
Scalar: UnsignedTorus + Send + Sync,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ParallelByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_list.ciphertext_modulus(),
|
||||
input_seeded_list.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus \
|
||||
between input SeededLweCiphertextList ({:?}) and output LweCiphertextList ({:?})",
|
||||
input_seeded_list.ciphertext_modulus(),
|
||||
output_list.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut generator = RandomGenerator::<Gen>::new(input_seeded_list.compression_seed().seed);
|
||||
par_decompress_seeded_lwe_ciphertext_list_with_existing_generator::<_, _, _, Gen>(
|
||||
output_list,
|
||||
input_seeded_list,
|
||||
&mut generator,
|
||||
)
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLweCiphertextList`], without consuming it, into a standard
|
||||
/// [`LweCiphertextList`].
|
||||
pub fn decompress_seeded_lwe_ciphertext_list<Scalar, InputCont, OutputCont, Gen>(
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
//! Module with primitives pertaining to [`SeededLweCompactPublicKey`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
|
||||
/// Convenience function to share the core logic of the decompression algorithm for
|
||||
/// [`SeededLweCompactPublicKey`] between all functions needing it.
|
||||
pub fn decompress_seeded_lwe_compact_public_key_with_existing_generator<
|
||||
Scalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
Gen,
|
||||
>(
|
||||
output_cpk: &mut LweCompactPublicKey<OutputCont>,
|
||||
input_seeded_cpk: &SeededLweCompactPublicKey<InputCont>,
|
||||
generator: &mut RandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
decompress_seeded_glwe_ciphertext_with_existing_generator(
|
||||
&mut output_cpk.as_mut_glwe_ciphertext(),
|
||||
&input_seeded_cpk.as_seeded_glwe_ciphertext(),
|
||||
generator,
|
||||
);
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLweCompactPublicKey`], without consuming it, into a standard
|
||||
/// [`LweCompactPublicKey`].
|
||||
pub fn decompress_seeded_lwe_compact_public_key<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_cpk: &mut LweCompactPublicKey<OutputCont>,
|
||||
input_seeded_cpk: &SeededLweCompactPublicKey<InputCont>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
let mut generator = RandomGenerator::<Gen>::new(input_seeded_cpk.compression_seed().seed);
|
||||
decompress_seeded_lwe_compact_public_key_with_existing_generator::<_, _, _, Gen>(
|
||||
output_cpk,
|
||||
input_seeded_cpk,
|
||||
&mut generator,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
//! Module with primitives pertaining to [`SeededLweMultiBitBootstrapKey`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
|
||||
/// Convenience function to share the core logic of the decompression algorithm for
|
||||
/// [`SeededLweMultiBitBootstrapKey`] between all functions needing it.
|
||||
pub fn decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator<
|
||||
Scalar,
|
||||
InputCont,
|
||||
OutputCont,
|
||||
Gen,
|
||||
>(
|
||||
output_bsk: &mut LweMultiBitBootstrapKey<OutputCont>,
|
||||
input_bsk: &SeededLweMultiBitBootstrapKey<InputCont>,
|
||||
generator: &mut RandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_bsk.ciphertext_modulus(),
|
||||
input_bsk.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus \
|
||||
between input SeededLweMultiBitBootstrapKey ({:?}) and output LweMultiBitBootstrapKey ({:?})",
|
||||
input_bsk.ciphertext_modulus(),
|
||||
output_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
decompress_seeded_ggsw_ciphertext_list_with_existing_generator(output_bsk, input_bsk, generator)
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLweMultiBitBootstrapKey`], without consuming it, into a standard
|
||||
/// [`LweMultiBitBootstrapKey`].
|
||||
pub fn decompress_seeded_lwe_multi_bit_bootstrap_key<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_bsk: &mut LweMultiBitBootstrapKey<OutputCont>,
|
||||
input_bsk: &SeededLweMultiBitBootstrapKey<InputCont>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_bsk.ciphertext_modulus(),
|
||||
input_bsk.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus \
|
||||
between input SeededLweMultiBitBootstrapKey ({:?}) and output LweMultiBitBootstrapKey ({:?})",
|
||||
input_bsk.ciphertext_modulus(),
|
||||
output_bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut generator = RandomGenerator::<Gen>::new(input_bsk.compression_seed().seed);
|
||||
decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator::<_, _, _, Gen>(
|
||||
output_bsk,
|
||||
input_bsk,
|
||||
&mut generator,
|
||||
)
|
||||
}
|
||||
@@ -32,3 +32,31 @@ pub fn decompress_seeded_lwe_public_key<Scalar, InputCont, OutputCont, Gen>(
|
||||
&mut generator,
|
||||
);
|
||||
}
|
||||
|
||||
/// Decompress a [`SeededLwePublicKey`], without consuming it, into a standard
|
||||
/// [`LwePublicKey`] using mutliple threads.
|
||||
pub fn par_decompress_seeded_lwe_public_key<Scalar, InputCont, OutputCont, Gen>(
|
||||
output_pk: &mut LwePublicKey<OutputCont>,
|
||||
input_pk: &SeededLwePublicKey<InputCont>,
|
||||
) where
|
||||
Scalar: UnsignedTorus + Send + Sync,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ParallelByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_pk.ciphertext_modulus(),
|
||||
input_pk.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus \
|
||||
between input SeededLwePublicKey ({:?}) and output LwePublicKey ({:?})",
|
||||
output_pk.ciphertext_modulus(),
|
||||
input_pk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut generator = RandomGenerator::<Gen>::new(input_pk.compression_seed().seed);
|
||||
par_decompress_seeded_lwe_ciphertext_list_with_existing_generator::<_, _, _, Gen>(
|
||||
output_pk,
|
||||
input_pk,
|
||||
&mut generator,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
//! Module providing algorithms to perform computations on raw slices.
|
||||
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::polynomial_wrapping_add_mul_assign;
|
||||
use crate::core_crypto::commons::numeric::UnsignedInteger;
|
||||
use crate::core_crypto::entities::Polynomial;
|
||||
|
||||
/// Compute a dot product between two slices containing unsigned integers.
|
||||
///
|
||||
@@ -309,3 +311,54 @@ where
|
||||
lhs.iter_mut()
|
||||
.for_each(|lhs| *lhs = (*lhs).wrapping_div(rhs));
|
||||
}
|
||||
|
||||
/// Primitive for compact LWE public key
|
||||
///
|
||||
/// Here $i$ from section 3 of <https://eprint.iacr.org/2023/603> is taken equal to $n$.
|
||||
/// ```
|
||||
/// use tfhe::core_crypto::algorithms::slice_algorithms::*;
|
||||
/// let lhs = vec![1u8, 2u8, 3u8];
|
||||
/// let rhs = vec![4u8, 5u8, 6u8];
|
||||
/// let mut output = vec![0u8; 3];
|
||||
/// slice_semi_reverse_negacyclic_convolution(&mut output, &lhs, &rhs);
|
||||
/// assert_eq!(&output, &[(-17i8) as u8, 5, 32]);
|
||||
/// ```
|
||||
pub fn slice_semi_reverse_negacyclic_convolution<Scalar>(
|
||||
output: &mut [Scalar],
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
assert!(
|
||||
output.len() == lhs.len(),
|
||||
"output (len: {}) and lhs (len: {}) must have the same length",
|
||||
output.len(),
|
||||
lhs.len()
|
||||
);
|
||||
|
||||
// Apply phi_1 to the rhs term
|
||||
let mut phi_1_rhs: Vec<_> = rhs.to_vec();
|
||||
phi_1_rhs.reverse();
|
||||
|
||||
let phi_1_rhs_as_polynomial = Polynomial::from_container(phi_1_rhs.as_slice());
|
||||
|
||||
// Clear output as we'll add the multiplication result
|
||||
output.fill(Scalar::ZERO);
|
||||
let mut output_as_polynomial = Polynomial::from_container(output);
|
||||
let lhs_as_polynomial = Polynomial::from_container(lhs);
|
||||
|
||||
// Apply the classic negacyclic convolution via polynomial mul in the X^N + 1 ring, with the
|
||||
// phi_1 rhs it is equivalent to the operator we need
|
||||
polynomial_wrapping_add_mul_assign(
|
||||
&mut output_as_polynomial,
|
||||
&lhs_as_polynomial,
|
||||
&phi_1_rhs_as_polynomial,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
use super::*;
|
||||
|
||||
use crate::core_crypto::commons::generators::{
|
||||
DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
|
||||
|
||||
fn test_seeded_lwe_cpk_gen_equivalence<Scalar: UnsignedTorus>(
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) {
|
||||
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
// computations
|
||||
// Define parameters for LweCompactPublicKey creation
|
||||
let lwe_dimension = LweDimension(1024);
|
||||
let lwe_modular_std_dev = StandardDev(0.00000004990272175010415);
|
||||
|
||||
// Create the PRNG
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mask_seed = seeder.seed();
|
||||
let deterministic_seeder_seed = seeder.seed();
|
||||
let mut secret_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
|
||||
const NB_TEST: usize = 10;
|
||||
|
||||
for _ in 0..NB_TEST {
|
||||
// Create the LweSecretKey
|
||||
let input_lwe_secret_key =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
|
||||
|
||||
let mut cpk = LweCompactPublicKey::new(Scalar::ZERO, lwe_dimension, ciphertext_modulus);
|
||||
|
||||
let mut deterministic_seeder =
|
||||
DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed);
|
||||
let mut encryption_generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
mask_seed,
|
||||
&mut deterministic_seeder,
|
||||
);
|
||||
|
||||
generate_lwe_compact_public_key(
|
||||
&input_lwe_secret_key,
|
||||
&mut cpk,
|
||||
lwe_modular_std_dev,
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(&cpk, ciphertext_modulus));
|
||||
|
||||
let mut seeded_cpk = SeededLweCompactPublicKey::new(
|
||||
Scalar::ZERO,
|
||||
lwe_dimension,
|
||||
mask_seed.into(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut deterministic_seeder =
|
||||
DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed);
|
||||
|
||||
generate_seeded_lwe_compact_public_key(
|
||||
&input_lwe_secret_key,
|
||||
&mut seeded_cpk,
|
||||
lwe_modular_std_dev,
|
||||
&mut deterministic_seeder,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(&seeded_cpk, ciphertext_modulus));
|
||||
|
||||
let decompressed_cpk = seeded_cpk.decompress_into_lwe_compact_public_key();
|
||||
|
||||
assert_eq!(cpk, decompressed_cpk);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_seeded_lwe_cpk_gen_equivalence_u32_native_mod() {
|
||||
test_seeded_lwe_cpk_gen_equivalence::<u32>(CiphertextModulus::new_native())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_seeded_lwe_cpk_gen_equivalence_u64_naive_mod() {
|
||||
test_seeded_lwe_cpk_gen_equivalence::<u64>(CiphertextModulus::new_native())
|
||||
}
|
||||
@@ -827,3 +827,67 @@ fn test_u128_encryption() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lwe_compact_public_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
|
||||
params: TestParams<Scalar>,
|
||||
) {
|
||||
let lwe_dimension = LweDimension(params.polynomial_size.0);
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
const NB_TESTS: usize = 10;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let mut msg = msg_modulus;
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
msg = msg.wrapping_sub(Scalar::ONE);
|
||||
for _ in 0..NB_TESTS {
|
||||
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
|
||||
let pk = allocate_and_generate_new_lwe_compact_public_key(
|
||||
&lwe_sk,
|
||||
glwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
lwe_dimension.to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
|
||||
encrypt_lwe_ciphertext_with_compact_public_key(
|
||||
&pk,
|
||||
&mut ct,
|
||||
plaintext,
|
||||
glwe_modular_std_dev,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.secret_random_generator,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(&ct, ciphertext_modulus));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
|
||||
|
||||
assert_eq!(msg, decoded);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(lwe_compact_public_encrypt_decrypt_custom_mod {
|
||||
TEST_PARAMS_4_BITS_NATIVE_U64
|
||||
});
|
||||
|
||||
@@ -0,0 +1,176 @@
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::dispersion::StandardDev;
|
||||
use crate::core_crypto::commons::generators::{DeterministicSeeder, EncryptionRandomGenerator};
|
||||
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, Seed};
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension,
|
||||
LweBskGroupingFactor, LweDimension, PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::test_tools::new_secret_random_generator;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::prelude::CastFrom;
|
||||
|
||||
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence<
|
||||
T: UnsignedTorus + CastFrom<usize> + Sync + Send,
|
||||
>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) {
|
||||
for _ in 0..10 {
|
||||
let mut lwe_dim =
|
||||
LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
|
||||
let glwe_dim =
|
||||
GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
|
||||
let poly_size =
|
||||
PolynomialSize(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
|
||||
let level = DecompositionLevelCount(
|
||||
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
|
||||
);
|
||||
let base_log = DecompositionBaseLog(
|
||||
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
|
||||
);
|
||||
let grouping_factor = LweBskGroupingFactor(
|
||||
crate::core_crypto::commons::test_tools::random_usize_between(2..4),
|
||||
);
|
||||
let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
|
||||
let deterministic_seeder_seed =
|
||||
Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
|
||||
|
||||
while lwe_dim.0 % grouping_factor.0 != 0 {
|
||||
lwe_dim = LweDimension(lwe_dim.0 + 1);
|
||||
}
|
||||
|
||||
let mut secret_generator = new_secret_random_generator();
|
||||
let lwe_sk =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(lwe_dim, &mut secret_generator);
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dim,
|
||||
poly_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
|
||||
let mut parallel_multi_bit_bsk = LweMultiBitBootstrapKeyOwned::new(
|
||||
T::ZERO,
|
||||
glwe_dim.to_glwe_size(),
|
||||
poly_size,
|
||||
base_log,
|
||||
level,
|
||||
lwe_dim,
|
||||
grouping_factor,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut encryption_generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
mask_seed,
|
||||
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
|
||||
);
|
||||
|
||||
par_generate_lwe_multi_bit_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&mut parallel_multi_bit_bsk,
|
||||
StandardDev::from_standard_dev(10.),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let mut sequential_multi_bit_bsk = LweMultiBitBootstrapKeyOwned::new(
|
||||
T::ZERO,
|
||||
glwe_dim.to_glwe_size(),
|
||||
poly_size,
|
||||
base_log,
|
||||
level,
|
||||
lwe_dim,
|
||||
grouping_factor,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut encryption_generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
mask_seed,
|
||||
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
|
||||
);
|
||||
|
||||
generate_lwe_multi_bit_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&mut sequential_multi_bit_bsk,
|
||||
StandardDev::from_standard_dev(10.),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
assert_eq!(parallel_multi_bit_bsk, sequential_multi_bit_bsk);
|
||||
|
||||
let mut sequential_seeded_multi_bit_bsk = SeededLweMultiBitBootstrapKey::new(
|
||||
T::ZERO,
|
||||
glwe_dim.to_glwe_size(),
|
||||
poly_size,
|
||||
base_log,
|
||||
level,
|
||||
lwe_dim,
|
||||
grouping_factor,
|
||||
mask_seed.into(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
generate_seeded_lwe_multi_bit_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&mut sequential_seeded_multi_bit_bsk,
|
||||
StandardDev::from_standard_dev(10.),
|
||||
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
|
||||
);
|
||||
|
||||
let mut parallel_seeded_multi_bit_bsk = SeededLweMultiBitBootstrapKey::new(
|
||||
T::ZERO,
|
||||
glwe_dim.to_glwe_size(),
|
||||
poly_size,
|
||||
base_log,
|
||||
level,
|
||||
lwe_dim,
|
||||
grouping_factor,
|
||||
mask_seed.into(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
par_generate_seeded_lwe_multi_bit_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&mut parallel_seeded_multi_bit_bsk,
|
||||
StandardDev::from_standard_dev(10.),
|
||||
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
sequential_seeded_multi_bit_bsk,
|
||||
parallel_seeded_multi_bit_bsk
|
||||
);
|
||||
|
||||
let decompressed_multi_bit_bsk =
|
||||
sequential_seeded_multi_bit_bsk.decompress_into_lwe_multi_bit_bootstrap_key();
|
||||
|
||||
assert_eq!(decompressed_multi_bit_bsk, sequential_multi_bit_bsk);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u32_native_mod() {
|
||||
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u32>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u32_custom_mod() {
|
||||
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u32>(
|
||||
CiphertextModulus::try_new_power_of_2(31).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u64_native_mod() {
|
||||
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u64>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u64_custom_mod() {
|
||||
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u64>(
|
||||
CiphertextModulus::try_new_power_of_2(63).unwrap(),
|
||||
);
|
||||
}
|
||||
@@ -146,6 +146,125 @@ fn lwe_encrypt_multi_bit_pbs_decrypt_custom_mod<
|
||||
}
|
||||
}
|
||||
|
||||
fn lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod<
|
||||
Scalar: UnsignedTorus + Sync + Send + CastFrom<usize> + CastInto<usize>,
|
||||
>(
|
||||
params: MultiBitParams<Scalar>,
|
||||
) {
|
||||
let input_lwe_dimension = params.input_lwe_dimension;
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev;
|
||||
let glwe_modular_std_dev = params.glwe_modular_std_dev;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
let decomp_base_log = params.decomp_base_log;
|
||||
let decomp_level_count = params.decomp_level_count;
|
||||
let grouping_factor = params.grouping_factor;
|
||||
let thread_count = params.thread_count;
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let f = |x: Scalar| {
|
||||
x.wrapping_mul(Scalar::TWO)
|
||||
.wrapping_sub(Scalar::ONE)
|
||||
.wrapping_rem(msg_modulus)
|
||||
};
|
||||
|
||||
let delta: Scalar = encoding_with_padding / msg_modulus;
|
||||
let mut msg = msg_modulus;
|
||||
const NB_TESTS: usize = 10;
|
||||
|
||||
let accumulator = generate_accumulator(
|
||||
polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
msg_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(&accumulator, ciphertext_modulus));
|
||||
|
||||
// Keygen is a bit slow on this one so we keep it out of the testing loop
|
||||
// Create the LweSecretKey
|
||||
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
input_lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
&mut rsc.secret_random_generator,
|
||||
);
|
||||
let output_lwe_secret_key = output_glwe_secret_key.clone().into_lwe_secret_key();
|
||||
|
||||
let mut bsk = LweMultiBitBootstrapKey::new(
|
||||
Scalar::ZERO,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
polynomial_size,
|
||||
decomp_base_log,
|
||||
decomp_level_count,
|
||||
input_lwe_dimension,
|
||||
grouping_factor,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
par_generate_lwe_multi_bit_bootstrap_key(
|
||||
&input_lwe_secret_key,
|
||||
&output_glwe_secret_key,
|
||||
&mut bsk,
|
||||
glwe_modular_std_dev,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(&*bsk, ciphertext_modulus));
|
||||
|
||||
while msg != Scalar::ZERO {
|
||||
msg = msg.wrapping_sub(Scalar::ONE);
|
||||
for _ in 0..NB_TESTS {
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
|
||||
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&input_lwe_secret_key,
|
||||
plaintext,
|
||||
lwe_modular_std_dev,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(
|
||||
&lwe_ciphertext_in,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut out_pbs_ct = LweCiphertext::new(
|
||||
Scalar::ZERO,
|
||||
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
std_multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
&lwe_ciphertext_in,
|
||||
&mut out_pbs_ct,
|
||||
&accumulator,
|
||||
&bsk,
|
||||
thread_count,
|
||||
);
|
||||
|
||||
assert!(check_content_respects_mod(&out_pbs_ct, ciphertext_modulus));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
|
||||
|
||||
assert_eq!(decoded, f(msg));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_lwe_encrypt_multi_bit_pbs_decrypt_factor_2_thread_5_native_mod() {
|
||||
lwe_encrypt_multi_bit_pbs_decrypt_custom_mod::<u64>(
|
||||
@@ -229,3 +348,87 @@ pub fn test_lwe_encrypt_multi_bit_pbs_decrypt_factor_3_thread_12_custom_mod() {
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_2_thread_5_native_mod() {
|
||||
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
|
||||
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
|
||||
// correct computations
|
||||
MultiBitParams {
|
||||
input_lwe_dimension: LweDimension(786),
|
||||
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
|
||||
decomp_base_log: DecompositionBaseLog(23),
|
||||
decomp_level_count: DecompositionLevelCount(1),
|
||||
glwe_dimension: GlweDimension(2),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
|
||||
message_modulus_log: CiphertextModulusLog(4),
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
grouping_factor: LweBskGroupingFactor(2),
|
||||
thread_count: ThreadCount(5),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_3_thread_12_native_mod() {
|
||||
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
|
||||
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
|
||||
// correct computations
|
||||
MultiBitParams {
|
||||
input_lwe_dimension: LweDimension(786),
|
||||
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
|
||||
decomp_base_log: DecompositionBaseLog(22),
|
||||
decomp_level_count: DecompositionLevelCount(1),
|
||||
glwe_dimension: GlweDimension(2),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
|
||||
message_modulus_log: CiphertextModulusLog(4),
|
||||
ciphertext_modulus: CiphertextModulus::new_native(),
|
||||
grouping_factor: LweBskGroupingFactor(3),
|
||||
thread_count: ThreadCount(12),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_2_thread_5_custom_mod() {
|
||||
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
|
||||
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
|
||||
// correct computations
|
||||
MultiBitParams {
|
||||
input_lwe_dimension: LweDimension(786),
|
||||
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
|
||||
decomp_base_log: DecompositionBaseLog(23),
|
||||
decomp_level_count: DecompositionLevelCount(1),
|
||||
glwe_dimension: GlweDimension(2),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
|
||||
message_modulus_log: CiphertextModulusLog(3),
|
||||
ciphertext_modulus: CiphertextModulus::try_new_power_of_2(63).unwrap(),
|
||||
grouping_factor: LweBskGroupingFactor(2),
|
||||
thread_count: ThreadCount(5),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_3_thread_12_custom_mod() {
|
||||
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
|
||||
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
|
||||
// correct computations
|
||||
MultiBitParams {
|
||||
input_lwe_dimension: LweDimension(786),
|
||||
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
|
||||
decomp_base_log: DecompositionBaseLog(22),
|
||||
decomp_level_count: DecompositionLevelCount(1),
|
||||
glwe_dimension: GlweDimension(2),
|
||||
polynomial_size: PolynomialSize(1024),
|
||||
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
|
||||
message_modulus_log: CiphertextModulusLog(3),
|
||||
ciphertext_modulus: CiphertextModulus::try_new_power_of_2(63).unwrap(),
|
||||
grouping_factor: LweBskGroupingFactor(3),
|
||||
thread_count: ThreadCount(12),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,10 +4,12 @@ use paste::paste;
|
||||
mod ggsw_encryption;
|
||||
mod glwe_encryption;
|
||||
mod lwe_bootstrap_key_generation;
|
||||
mod lwe_compact_public_key_generation;
|
||||
mod lwe_encryption;
|
||||
mod lwe_keyswitch;
|
||||
mod lwe_keyswitch_key_generation;
|
||||
mod lwe_linear_algebra;
|
||||
mod lwe_multi_bit_bootstrap_key_generation;
|
||||
mod lwe_multi_bit_programmable_bootstrapping;
|
||||
mod lwe_programmable_bootstrapping;
|
||||
|
||||
@@ -137,7 +139,7 @@ pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]
|
||||
if !modulus.is_native_modulus() {
|
||||
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
|
||||
// we want the bits under the mask to be 0
|
||||
let power_2_diff_mask = modulus.get_scaling_to_native_torus() - Scalar::ONE;
|
||||
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
|
||||
return input
|
||||
.as_ref()
|
||||
.iter()
|
||||
@@ -153,7 +155,7 @@ pub fn check_scalar_respects_mod<Scalar: UnsignedInteger>(
|
||||
modulus: CiphertextModulus<Scalar>,
|
||||
) -> bool {
|
||||
if !modulus.is_native_modulus() {
|
||||
let power_2_diff_mask = modulus.get_scaling_to_native_torus() - Scalar::ONE;
|
||||
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
|
||||
return (input & power_2_diff_mask) == Scalar::ZERO;
|
||||
}
|
||||
|
||||
|
||||
@@ -121,6 +121,26 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn try_new(modulus: u128) -> Result<Self, &'static str> {
|
||||
if Scalar::BITS < 128 && modulus > (1 << Scalar::BITS) {
|
||||
Err("Modulus is bigger than the maximum value of the associated Scalar type")
|
||||
} else {
|
||||
let res = match modulus {
|
||||
0 => CiphertextModulus::new_native(),
|
||||
modulus => {
|
||||
let Some(non_zero_modulus) = NonZeroU128::new(modulus) else {
|
||||
panic!("Got zero modulus for CiphertextModulusInner::Custom variant",)
|
||||
};
|
||||
CiphertextModulus {
|
||||
inner: CiphertextModulusInner::Custom(non_zero_modulus),
|
||||
_scalar: PhantomData,
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(res.canonicalize())
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn canonicalize(self) -> Self {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => self,
|
||||
@@ -151,10 +171,14 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
|
||||
res.canonicalize()
|
||||
}
|
||||
|
||||
pub fn get_scaling_to_native_torus(&self) -> Scalar {
|
||||
pub fn get_power_of_two_scaling_to_native_torus(&self) -> Scalar {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => Scalar::ONE,
|
||||
CiphertextModulusInner::Custom(modulus) => {
|
||||
assert!(
|
||||
modulus.is_power_of_two(),
|
||||
"Cannot get scaling for non power of two modulus {modulus:}"
|
||||
);
|
||||
Scalar::ONE.wrapping_shl(Scalar::BITS as u32 - modulus.ilog2())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
CiphertextModulus, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, GlweDimension,
|
||||
GlweSize, LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweSize, PolynomialSize,
|
||||
GlweSize, LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweMaskCount, LweSize,
|
||||
PolynomialSize,
|
||||
};
|
||||
use concrete_csprng::generators::ForkError;
|
||||
use rayon::prelude::*;
|
||||
@@ -169,6 +170,16 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
|
||||
self.try_fork(lwe_size.0, mask_bytes, noise_bytes)
|
||||
}
|
||||
|
||||
pub(crate) fn fork_lwe_compact_ciphertext_list_to_bin<T: UnsignedInteger>(
|
||||
&mut self,
|
||||
lwe_mask_count: LweMaskCount,
|
||||
lwe_dimension: LweDimension,
|
||||
) -> Result<impl Iterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
|
||||
let mask_bytes = mask_bytes_per_lwe_compact_ciphertext_bin::<T>(lwe_dimension);
|
||||
let noise_bytes = noise_bytes_per_lwe_compact_ciphertext_bin(lwe_dimension);
|
||||
self.try_fork(lwe_mask_count.0, mask_bytes, noise_bytes)
|
||||
}
|
||||
|
||||
// Forks both generators into an iterator
|
||||
fn try_fork(
|
||||
&mut self,
|
||||
@@ -431,6 +442,16 @@ impl<G: ParallelByteRandomGenerator> EncryptionRandomGenerator<G> {
|
||||
self.par_try_fork(lwe_size.0, mask_bytes, noise_bytes)
|
||||
}
|
||||
|
||||
pub(crate) fn par_fork_lwe_compact_ciphertext_list_to_bin<T: UnsignedInteger>(
|
||||
&mut self,
|
||||
lwe_mask_count: LweMaskCount,
|
||||
lwe_dimension: LweDimension,
|
||||
) -> Result<impl IndexedParallelIterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
|
||||
let mask_bytes = mask_bytes_per_lwe_compact_ciphertext_bin::<T>(lwe_dimension);
|
||||
let noise_bytes = noise_bytes_per_lwe_compact_ciphertext_bin(lwe_dimension);
|
||||
self.par_try_fork(lwe_mask_count.0, mask_bytes, noise_bytes)
|
||||
}
|
||||
|
||||
// Forks both generators into a parallel iterator.
|
||||
fn par_try_fork(
|
||||
&mut self,
|
||||
@@ -504,6 +525,12 @@ fn mask_bytes_per_pfpksk<T: UnsignedInteger>(
|
||||
lwe_size.0 * mask_bytes_per_pfpksk_chunk::<T>(level, glwe_size, poly_size)
|
||||
}
|
||||
|
||||
fn mask_bytes_per_lwe_compact_ciphertext_bin<T: UnsignedInteger>(
|
||||
lwe_dimension: LweDimension,
|
||||
) -> usize {
|
||||
lwe_dimension.0 * mask_bytes_per_coef::<T>()
|
||||
}
|
||||
|
||||
fn noise_bytes_per_coef() -> usize {
|
||||
// We use f64 to sample the noise for every precision, and we need 4/pi inputs to generate
|
||||
// such an output (here we take 32 to keep a safety margin).
|
||||
@@ -553,6 +580,10 @@ fn noise_bytes_per_pfpksk(
|
||||
lwe_size.0 * noise_bytes_per_pfpksk_chunk(level, poly_size)
|
||||
}
|
||||
|
||||
fn noise_bytes_per_lwe_compact_ciphertext_bin(lwe_dimension: LweDimension) -> usize {
|
||||
lwe_dimension.0 * noise_bytes_per_coef()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::core_crypto::algorithms::*;
|
||||
@@ -562,7 +593,7 @@ mod test {
|
||||
PolynomialSize,
|
||||
};
|
||||
use crate::core_crypto::commons::test_tools::{
|
||||
new_encryption_random_generator, new_secret_random_generator,
|
||||
new_encryption_random_generator, new_secret_random_generator, normality_test_f64,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::UnsignedTorus;
|
||||
|
||||
@@ -728,6 +759,119 @@ mod test {
|
||||
noise_gen_slice_native::<u128>();
|
||||
}
|
||||
|
||||
fn test_normal_random_encryption_native<Scalar: UnsignedTorus>() {
|
||||
const RUNS: usize = 10000;
|
||||
const SAMPLES_PER_RUN: usize = 1000;
|
||||
let mut rng = new_encryption_random_generator();
|
||||
let failures: f64 = (0..RUNS)
|
||||
.map(|_| {
|
||||
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
|
||||
|
||||
rng.fill_slice_with_random_noise(&mut samples, StandardDev(f64::powi(2., -20)));
|
||||
|
||||
let samples: Vec<f64> = samples
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| {
|
||||
let torus = x.into_torus();
|
||||
// The upper half of the torus corresponds to the negative domain when
|
||||
// mapping unsigned integer back to float (MSB or
|
||||
// sign bit is set)
|
||||
if torus > 0.5 {
|
||||
torus - 1.0
|
||||
} else {
|
||||
torus
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
|
||||
// If we are normal return 0, it's not a failure
|
||||
0.0
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
})
|
||||
.sum::<f64>();
|
||||
let failure_rate = failures / (RUNS as f64);
|
||||
println!("failure_rate: {failure_rate}");
|
||||
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
|
||||
assert!(failure_rate <= 0.065);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_native_u32() {
|
||||
test_normal_random_encryption_native::<u32>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_native_u64() {
|
||||
test_normal_random_encryption_native::<u64>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_native_u128() {
|
||||
test_normal_random_encryption_native::<u128>();
|
||||
}
|
||||
|
||||
fn test_normal_random_encryption_add_assign_native<Scalar: UnsignedTorus>() {
|
||||
const RUNS: usize = 10000;
|
||||
const SAMPLES_PER_RUN: usize = 1000;
|
||||
let mut rng = new_encryption_random_generator();
|
||||
let failures: f64 = (0..RUNS)
|
||||
.map(|_| {
|
||||
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
|
||||
|
||||
rng.unsigned_torus_slice_wrapping_add_random_noise_assign(
|
||||
&mut samples,
|
||||
StandardDev(f64::powi(2., -20)),
|
||||
);
|
||||
|
||||
let samples: Vec<f64> = samples
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| {
|
||||
let torus = x.into_torus();
|
||||
// The upper half of the torus corresponds to the negative domain when
|
||||
// mapping unsigned integer back to float (MSB or
|
||||
// sign bit is set)
|
||||
if torus > 0.5 {
|
||||
torus - 1.0
|
||||
} else {
|
||||
torus
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
|
||||
// If we are normal return 0, it's not a failure
|
||||
0.0
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
})
|
||||
.sum::<f64>();
|
||||
let failure_rate = failures / (RUNS as f64);
|
||||
println!("failure_rate: {failure_rate}");
|
||||
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
|
||||
assert!(failure_rate <= 0.065);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_native_u32() {
|
||||
test_normal_random_encryption_add_assign_native::<u32>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_native_u64() {
|
||||
test_normal_random_encryption_add_assign_native::<u64>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_native_u128() {
|
||||
test_normal_random_encryption_add_assign_native::<u128>();
|
||||
}
|
||||
|
||||
fn noise_gen_slice_custom_mod<Scalar: UnsignedTorus>(
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) {
|
||||
@@ -783,6 +927,170 @@ mod test {
|
||||
noise_gen_slice_custom_mod::<u128>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
fn test_normal_random_encryption_custom_mod<Scalar: UnsignedTorus>(
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) {
|
||||
const RUNS: usize = 10000;
|
||||
const SAMPLES_PER_RUN: usize = 1000;
|
||||
let mut rng = new_encryption_random_generator();
|
||||
let failures: f64 = (0..RUNS)
|
||||
.map(|_| {
|
||||
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
|
||||
|
||||
rng.fill_slice_with_random_noise_custom_mod(
|
||||
&mut samples,
|
||||
StandardDev(f64::powi(2., -20)),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let samples: Vec<f64> = samples
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| {
|
||||
let torus = x.into_torus();
|
||||
// The upper half of the torus corresponds to the negative domain when
|
||||
// mapping unsigned integer back to float (MSB or
|
||||
// sign bit is set)
|
||||
if torus > 0.5 {
|
||||
torus - 1.0
|
||||
} else {
|
||||
torus
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
|
||||
// If we are normal return 0, it's not a failure
|
||||
0.0
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
})
|
||||
.sum::<f64>();
|
||||
let failure_rate = failures / (RUNS as f64);
|
||||
println!("failure_rate: {failure_rate}");
|
||||
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
|
||||
assert!(failure_rate <= 0.065);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_custom_mod_u32() {
|
||||
test_normal_random_encryption_custom_mod::<u32>(
|
||||
CiphertextModulus::try_new_power_of_2(31).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_custom_mod_u64() {
|
||||
test_normal_random_encryption_custom_mod::<u64>(
|
||||
CiphertextModulus::try_new_power_of_2(63).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_custom_mod_u128() {
|
||||
test_normal_random_encryption_custom_mod::<u128>(
|
||||
CiphertextModulus::try_new_power_of_2(127).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_native_custom_mod_u32() {
|
||||
test_normal_random_encryption_custom_mod::<u32>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_native_custom_mod_u64() {
|
||||
test_normal_random_encryption_custom_mod::<u64>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_native_custom_mod_u128() {
|
||||
test_normal_random_encryption_custom_mod::<u128>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
fn test_normal_random_encryption_add_assign_custom_mod<Scalar: UnsignedTorus>(
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) {
|
||||
const RUNS: usize = 10000;
|
||||
const SAMPLES_PER_RUN: usize = 1000;
|
||||
let mut rng = new_encryption_random_generator();
|
||||
let failures: f64 = (0..RUNS)
|
||||
.map(|_| {
|
||||
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
|
||||
|
||||
rng.unsigned_torus_slice_wrapping_add_random_noise_custom_mod_assign(
|
||||
&mut samples,
|
||||
StandardDev(f64::powi(2., -20)),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let samples: Vec<f64> = samples
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| {
|
||||
let torus = x.into_torus();
|
||||
// The upper half of the torus corresponds to the negative domain when
|
||||
// mapping unsigned integer back to float (MSB or
|
||||
// sign bit is set)
|
||||
if torus > 0.5 {
|
||||
torus - 1.0
|
||||
} else {
|
||||
torus
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
|
||||
// If we are normal return 0, it's not a failure
|
||||
0.0
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
})
|
||||
.sum::<f64>();
|
||||
let failure_rate = failures / (RUNS as f64);
|
||||
println!("failure_rate: {failure_rate}");
|
||||
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
|
||||
assert!(failure_rate <= 0.065);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_custom_mod_u32() {
|
||||
test_normal_random_encryption_add_assign_custom_mod::<u32>(
|
||||
CiphertextModulus::try_new_power_of_2(31).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_custom_mod_u64() {
|
||||
test_normal_random_encryption_add_assign_custom_mod::<u64>(
|
||||
CiphertextModulus::try_new_power_of_2(63).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_custom_mod_u128() {
|
||||
test_normal_random_encryption_add_assign_custom_mod::<u128>(
|
||||
CiphertextModulus::try_new_power_of_2(127).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_native_custom_mod_u32() {
|
||||
test_normal_random_encryption_add_assign_custom_mod::<u32>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_native_custom_mod_u64() {
|
||||
test_normal_random_encryption_add_assign_custom_mod::<u64>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normal_random_encryption_add_assign_native_custom_mod_u128() {
|
||||
test_normal_random_encryption_add_assign_custom_mod::<u128>(CiphertextModulus::new_native());
|
||||
}
|
||||
|
||||
fn mask_gen_slice_native<Scalar: UnsignedTorus>() {
|
||||
let mut gen = new_encryption_random_generator();
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecompositionIter;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
SignedDecompositionIter, SignedDecompositionIterNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
use crate::core_crypto::prelude::misc::divide_round_to_u128_custom_mod;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A structure which allows to decompose unsigned integers into a set of smaller terms.
|
||||
@@ -174,3 +178,215 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A structure which allows to decompose unsigned integers into a set of smaller terms for moduli
|
||||
/// which are non power of 2.
|
||||
///
|
||||
/// See the [module level](super) documentation for a description of the signed decomposition.
|
||||
#[derive(Debug)]
|
||||
pub struct SignedDecomposerNonNative<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
pub(crate) base_log: usize,
|
||||
pub(crate) level_count: usize,
|
||||
pub(crate) ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
}
|
||||
|
||||
impl<Scalar> SignedDecomposerNonNative<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
/// Create a new decomposer.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3));
|
||||
/// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4));
|
||||
/// ```
|
||||
pub fn new(
|
||||
base_log: DecompositionBaseLog,
|
||||
level_count: DecompositionLevelCount,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) -> SignedDecomposerNonNative<Scalar> {
|
||||
debug_assert!(
|
||||
Scalar::BITS > base_log.0 * level_count.0,
|
||||
"Decomposed bits exceeds the size of the integer to be decomposed"
|
||||
);
|
||||
SignedDecomposerNonNative {
|
||||
base_log: base_log.0,
|
||||
level_count: level_count.0,
|
||||
ciphertext_modulus,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the logarithm in base two of the base of this decomposer.
|
||||
///
|
||||
/// If the decomposer uses a base $B=2^b$, this returns $b$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4));
|
||||
/// ```
|
||||
pub fn base_log(&self) -> DecompositionBaseLog {
|
||||
DecompositionBaseLog(self.base_log)
|
||||
}
|
||||
|
||||
/// Return the number of levels of this decomposer.
|
||||
///
|
||||
/// If the decomposer uses $l$ levels, this returns $l$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3));
|
||||
/// ```
|
||||
pub fn level_count(&self) -> DecompositionLevelCount {
|
||||
DecompositionLevelCount(self.level_count)
|
||||
}
|
||||
|
||||
/// Return the ciphertext modulus of this decomposer.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(
|
||||
/// decomposer.ciphertext_modulus(),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap()
|
||||
/// );
|
||||
/// ```
|
||||
pub fn ciphertext_modulus(&self) -> CiphertextModulus<Scalar> {
|
||||
self.ciphertext_modulus
|
||||
}
|
||||
|
||||
/// Return the closet value representable by the decomposition.
|
||||
///
|
||||
/// For some input integer `k`, decomposition base `B`, decomposition level count `l` and given
|
||||
/// ciphertext modulus `q` the performed operation is the following:
|
||||
///
|
||||
/// $$
|
||||
/// \lfloor \frac{k\cdot q}{B^{l}} \rceil \cdot B^{l}
|
||||
/// $$
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// let closest = decomposer.closest_representable(16982820785129133100u64);
|
||||
/// assert_eq!(closest, 16983074190859960320u64);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn closest_representable(&self, input: Scalar) -> Scalar {
|
||||
let ciphertext_modulus = self.ciphertext_modulus.get_custom_modulus();
|
||||
// Floored approach
|
||||
// B^l
|
||||
let base_to_level_count = 1 << (self.base_log * self.level_count);
|
||||
// sr = floor(q/(B^l))
|
||||
let smallest_representable = ciphertext_modulus / base_to_level_count;
|
||||
|
||||
let input_128: u128 = input.cast_into();
|
||||
// rounded = round(input/sr)
|
||||
let rounded =
|
||||
divide_round_to_u128_custom_mod(input_128, smallest_representable, ciphertext_modulus);
|
||||
// rounded * sr
|
||||
let closest_representable = rounded * smallest_representable;
|
||||
Scalar::cast_from(closest_representable)
|
||||
}
|
||||
|
||||
/// Generate an iterator over the terms of the decomposition of the input.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::numeric::UnsignedInteger;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
///
|
||||
/// // These two values allow to take each arm of the half basis check below
|
||||
/// for value in [1u64 << 63, 16982820785129133100u64] {
|
||||
/// for term in decomposer.decompose(value) {
|
||||
/// assert!(1 <= term.level().0);
|
||||
/// assert!(term.level().0 <= 3);
|
||||
/// let term = term.value();
|
||||
/// let abs_term = if term < decomposer.ciphertext_modulus().get_custom_modulus() as u64 / 2
|
||||
/// {
|
||||
/// term
|
||||
/// } else {
|
||||
/// decomposer.ciphertext_modulus().get_custom_modulus() as u64 - term
|
||||
/// };
|
||||
/// println!("abs_term: {abs_term}");
|
||||
/// let half_basis = 2u64.pow(4) / 2u64;
|
||||
/// println!("half_basis: {half_basis}");
|
||||
/// assert!(abs_term <= half_basis);
|
||||
/// }
|
||||
/// assert_eq!(decomposer.decompose(1).count(), 3);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn decompose(&self, input: Scalar) -> SignedDecompositionIterNonNative<Scalar> {
|
||||
// Note that there would be no sense of making the decomposition on an input which was
|
||||
// not rounded to the closest representable first. We then perform it before decomposing.
|
||||
SignedDecompositionIterNonNative::new(
|
||||
self.closest_representable(input),
|
||||
DecompositionBaseLog(self.base_log),
|
||||
DecompositionLevelCount(self.level_count),
|
||||
self.ciphertext_modulus,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, DecompositionTerm};
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
DecompositionLevel, DecompositionTerm, DecompositionTermNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::numeric::UnsignedInteger;
|
||||
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
|
||||
@@ -122,3 +125,158 @@ fn decompose_one_level<S: UnsignedInteger>(base_log: usize, state: &mut S, mod_b
|
||||
*state += carry;
|
||||
res.wrapping_sub(carry << base_log)
|
||||
}
|
||||
|
||||
/// An iterator that yields the terms of the signed decomposition of an integer.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This iterator yields the decomposition in reverse order. That means that the highest level
|
||||
/// will be yielded first.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SignedDecompositionIterNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
// The base log of the decomposition
|
||||
base_log: usize,
|
||||
// The number of levels of the decomposition
|
||||
level_count: usize,
|
||||
// The internal state of the decomposition
|
||||
state: T,
|
||||
// The current level
|
||||
current_level: usize,
|
||||
// A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form:
|
||||
// ...0001111
|
||||
mod_b_mask: T,
|
||||
// Ciphertext modulus
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
// A flag which store whether the iterator is a fresh one (for the recompose method)
|
||||
fresh: bool,
|
||||
}
|
||||
|
||||
impl<T> SignedDecompositionIterNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
input: T,
|
||||
base_log: DecompositionBaseLog,
|
||||
level: DecompositionLevelCount,
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) -> SignedDecompositionIterNonNative<T> {
|
||||
let base_to_the_level = 1 << (base_log.0 * level.0);
|
||||
let smallest_representable = ciphertext_modulus.get_custom_modulus() / base_to_the_level;
|
||||
|
||||
let input_128: u128 = input.cast_into();
|
||||
let state = T::cast_from(input_128 / smallest_representable);
|
||||
|
||||
SignedDecompositionIterNonNative {
|
||||
base_log: base_log.0,
|
||||
level_count: level.0,
|
||||
state,
|
||||
current_level: level.0,
|
||||
mod_b_mask: (T::ONE << base_log.0) - T::ONE,
|
||||
ciphertext_modulus,
|
||||
fresh: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_fresh(&self) -> bool {
|
||||
self.fresh
|
||||
}
|
||||
|
||||
/// Return the logarithm in base two of the base of this decomposition.
|
||||
///
|
||||
/// If the decomposition uses a base $B=2^b$, this returns $b$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// let val = 9_223_372_036_854_775_808u64;
|
||||
/// let decomp = decomposer.decompose(val);
|
||||
/// assert_eq!(decomp.base_log(), DecompositionBaseLog(4));
|
||||
/// ```
|
||||
pub fn base_log(&self) -> DecompositionBaseLog {
|
||||
DecompositionBaseLog(self.base_log)
|
||||
}
|
||||
|
||||
/// Return the number of levels of this decomposition.
|
||||
///
|
||||
/// If the decomposition uses $l$ levels, this returns $l$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// let val = 9_223_372_036_854_775_808u64;
|
||||
/// let decomp = decomposer.decompose(val);
|
||||
/// assert_eq!(decomp.level_count(), DecompositionLevelCount(3));
|
||||
/// ```
|
||||
pub fn level_count(&self) -> DecompositionLevelCount {
|
||||
DecompositionLevelCount(self.level_count)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Iterator for SignedDecompositionIterNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
type Item = DecompositionTermNonNative<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// The iterator is not fresh anymore
|
||||
self.fresh = false;
|
||||
// We check if the decomposition is over
|
||||
if self.current_level == 0 {
|
||||
return None;
|
||||
}
|
||||
// We decompose the current level
|
||||
let output = decompose_one_level_non_native(
|
||||
self.base_log,
|
||||
&mut self.state,
|
||||
self.mod_b_mask,
|
||||
T::cast_from(self.ciphertext_modulus.get_custom_modulus()),
|
||||
);
|
||||
self.current_level -= 1;
|
||||
// We return the output for this level
|
||||
Some(DecompositionTermNonNative::new(
|
||||
DecompositionLevel(self.current_level + 1),
|
||||
DecompositionBaseLog(self.base_log),
|
||||
output,
|
||||
self.ciphertext_modulus,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn decompose_one_level_non_native<S: UnsignedInteger>(
|
||||
base_log: usize,
|
||||
state: &mut S,
|
||||
mod_b_mask: S,
|
||||
ciphertext_modulus: S,
|
||||
) -> S {
|
||||
let res = *state & mod_b_mask;
|
||||
*state >>= base_log;
|
||||
let mut carry = (res.wrapping_sub(S::ONE) | *state) & res;
|
||||
carry >>= base_log - 1;
|
||||
*state += carry;
|
||||
res.wrapping_add(ciphertext_modulus)
|
||||
.wrapping_sub(carry << base_log)
|
||||
.wrapping_rem(ciphertext_modulus)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
|
||||
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::DecompositionBaseLog;
|
||||
@@ -91,3 +92,117 @@ where
|
||||
DecompositionLevel(self.level)
|
||||
}
|
||||
}
|
||||
|
||||
/// A member of the decomposition.
|
||||
///
|
||||
/// If we decompose a value $\theta$ as a sum $\sum\_{i=1}^l\tilde{\theta}\_i\frac{q}{B^i}$, this
|
||||
/// represents a $\tilde{\theta}\_i$.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct DecompositionTermNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
level: usize,
|
||||
base_log: usize,
|
||||
value: T,
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
}
|
||||
|
||||
impl<T> DecompositionTermNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
// Creates a new decomposition term.
|
||||
pub(crate) fn new(
|
||||
level: DecompositionLevel,
|
||||
base_log: DecompositionBaseLog,
|
||||
value: T,
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) -> DecompositionTermNonNative<T> {
|
||||
DecompositionTermNonNative {
|
||||
level: level.0,
|
||||
base_log: base_log.0,
|
||||
value,
|
||||
ciphertext_modulus,
|
||||
}
|
||||
}
|
||||
|
||||
/// Turn this term into a summand.
|
||||
///
|
||||
/// If our member represents one $\tilde{\theta}\_i$ of the decomposition, this method returns
|
||||
/// $\tilde{\theta}\_i\frac{q}{B^i}$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new(1 << 32).unwrap(),
|
||||
/// );
|
||||
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
|
||||
/// assert_eq!(output.to_recomposition_summand(), 1048576);
|
||||
/// ```
|
||||
pub fn to_recomposition_summand(&self) -> T {
|
||||
// Floored approach
|
||||
// * floor(q / B^j)
|
||||
let base_to_the_level = 1 << (self.base_log * self.level);
|
||||
let digit_radix = self.ciphertext_modulus.get_custom_modulus() / base_to_the_level;
|
||||
|
||||
let value_u128: u128 = self.value.cast_into();
|
||||
let summand_u128 = value_u128 * digit_radix;
|
||||
T::cast_from(summand_u128)
|
||||
}
|
||||
|
||||
/// Return the value of the term.
|
||||
///
|
||||
/// If our member represents one $\tilde{\theta}\_i$, this returns its actual value.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new(1 << 32).unwrap(),
|
||||
/// );
|
||||
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
|
||||
/// assert_eq!(output.value(), 1);
|
||||
/// ```
|
||||
pub fn value(&self) -> T {
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Return the level of the term.
|
||||
///
|
||||
/// If our member represents one $\tilde{\theta}\_i$, this returns the value of $i$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::{
|
||||
/// DecompositionLevel, SignedDecomposerNonNative,
|
||||
/// };
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new(1 << 32).unwrap(),
|
||||
/// );
|
||||
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
|
||||
/// assert_eq!(output.level(), DecompositionLevel(3));
|
||||
/// ```
|
||||
pub fn level(&self) -> DecompositionLevel {
|
||||
DecompositionLevel(self.level)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
SignedDecomposer, SignedDecomposerNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::{RandomGenerable, Uniform};
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{Numeric, SignedInteger, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
use crate::core_crypto::commons::test_tools::{any_uint, any_usize, random_usize_between};
|
||||
use crate::core_crypto::commons::traits::CastInto;
|
||||
use std::fmt::Debug;
|
||||
|
||||
// Return a random decomposition valid for the size of the T type.
|
||||
@@ -69,7 +73,7 @@ fn test_round_to_closest_representable<T: UnsignedTorus>() {
|
||||
let bit: usize = log_b * level_max;
|
||||
|
||||
let val = val << (bits - bit);
|
||||
let delta = delta >> (bits - (bits - bit - 1));
|
||||
let delta = delta >> (bit + 1);
|
||||
|
||||
let decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(log_b),
|
||||
@@ -117,3 +121,171 @@ fn test_round_to_closest_twice_u32() {
|
||||
fn test_round_to_closest_twice_u64() {
|
||||
test_round_to_closest_twice::<u64>();
|
||||
}
|
||||
|
||||
// Return a random decomposition valid for the size of the T type.
|
||||
fn random_decomp_non_native<T: UnsignedInteger>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) -> SignedDecomposerNonNative<T> {
|
||||
let mut base_log;
|
||||
let mut level_count;
|
||||
loop {
|
||||
base_log = random_usize_between(2..T::BITS);
|
||||
level_count = random_usize_between(2..T::BITS);
|
||||
if base_log * level_count < T::BITS {
|
||||
break;
|
||||
}
|
||||
}
|
||||
SignedDecomposerNonNative::new(
|
||||
DecompositionBaseLog(base_log),
|
||||
DecompositionLevelCount(level_count),
|
||||
ciphertext_modulus,
|
||||
)
|
||||
}
|
||||
|
||||
fn test_round_to_closest_representable_non_native<T: UnsignedTorus>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) {
|
||||
// Manage limit cases
|
||||
{
|
||||
let log_b = any_usize();
|
||||
let level_max = any_usize();
|
||||
let bits = T::BITS;
|
||||
let log_b = (log_b % ((bits / 4) - 1)) + 1;
|
||||
let level_count = (level_max % 4) + 1;
|
||||
let rep_bits: usize = log_b * level_count;
|
||||
|
||||
let base_to_the_level_u128 = 1u128 << rep_bits;
|
||||
let smallest_representable_u128 =
|
||||
ciphertext_modulus.get_custom_modulus() / base_to_the_level_u128;
|
||||
let sub_smallest_representable_u128 = smallest_representable_u128 / 2;
|
||||
// Compute an epsilon that should not change the result of a closest representable
|
||||
let epsilon_u128 = any_uint::<u128>() % sub_smallest_representable_u128;
|
||||
|
||||
// Around 0
|
||||
let val = T::ZERO;
|
||||
|
||||
let decomposer = SignedDecomposerNonNative::new(
|
||||
DecompositionBaseLog(log_b),
|
||||
DecompositionLevelCount(level_count),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let val_u128: u128 = val.cast_into();
|
||||
let val_plus_epsilon: T = val_u128
|
||||
.wrapping_add(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_plus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_plus_epsilon: {val_plus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
|
||||
let val_minus_epsilon: T = val_u128
|
||||
.wrapping_add(ciphertext_modulus.get_custom_modulus())
|
||||
.wrapping_sub(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_minus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_minus_epsilon: {val_minus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
let log_b = any_usize();
|
||||
let level_max = any_usize();
|
||||
let bits = T::BITS;
|
||||
let log_b = (log_b % ((bits / 4) - 1)) + 1;
|
||||
let level_count = (level_max % 4) + 1;
|
||||
let rep_bits: usize = log_b * level_count;
|
||||
|
||||
let base_to_the_level_u128 = 1u128 << rep_bits;
|
||||
let base_to_the_level = T::ONE << rep_bits;
|
||||
let smallest_representable_u128 =
|
||||
ciphertext_modulus.get_custom_modulus() / base_to_the_level_u128;
|
||||
let smallest_representable: T = smallest_representable_u128.cast_into();
|
||||
let sub_smallest_representable_u128 = smallest_representable_u128 / 2;
|
||||
// Compute an epsilon that should not change the result of a closest representable
|
||||
let epsilon_u128 = any_uint::<u128>() % sub_smallest_representable_u128;
|
||||
|
||||
let multiple_of_smallest_representable = any_uint::<T>() % base_to_the_level;
|
||||
let val = multiple_of_smallest_representable * smallest_representable;
|
||||
|
||||
let decomposer = SignedDecomposerNonNative::new(
|
||||
DecompositionBaseLog(log_b),
|
||||
DecompositionLevelCount(level_count),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let val_u128: u128 = val.cast_into();
|
||||
let val_plus_epsilon: T = val_u128
|
||||
.wrapping_add(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_plus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_plus_epsilon: {val_plus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
|
||||
let val_minus_epsilon: T = val_u128
|
||||
.wrapping_add(ciphertext_modulus.get_custom_modulus())
|
||||
.wrapping_sub(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_minus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_minus_epsilon: {val_minus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_to_closest_representable_non_native_u64() {
|
||||
test_round_to_closest_representable_non_native::<u64>(
|
||||
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
fn test_round_to_closest_twice_non_native<T: UnsignedTorus + Debug>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) {
|
||||
for _ in 0..1000 {
|
||||
let decomp = random_decomp_non_native(ciphertext_modulus);
|
||||
let input: T = any_uint();
|
||||
|
||||
let rounded_once = decomp.closest_representable(input);
|
||||
let rounded_twice = decomp.closest_representable(rounded_once);
|
||||
assert_eq!(rounded_once, rounded_twice);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_to_closest_twice_non_native_u64() {
|
||||
test_round_to_closest_twice_non_native::<u64>(
|
||||
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user