Compare commits

...

52 Commits

Author SHA1 Message Date
J-B Orfila
00fc67e84c chore(shortint): updates param compact key 2023-06-13 13:44:46 +02:00
aquint-zama
a5906bb7cb chore(tfhe): add a Code of Conduct 2023-06-08 14:06:29 +02:00
Jeremy Shulman
90b7494acd chore(doc): attach tutorials to doc 2023-06-08 14:05:46 +02:00
Arthur Meyre
3508019cd2 feat(core): Add Compact Public Key
- Based on "TFHE Public-Key Encryption Revisited "
  https://eprint.iacr.org/2023/603.pdf

Co-authored-by: tmontaigu <thomas.montaigu@laposte.net>
2023-06-07 19:47:50 +02:00
Arthur Meyre
200c8a177a feat(core): add std multi-bit bootstrapping 2023-06-07 16:12:37 +02:00
Arthur Meyre
2f6c1cf0b5 chore(ci): add docs alias make target for doc 2023-06-07 14:18:49 +02:00
tmontaigu
b96027f417 feat(integer): improve default sub latency 2023-06-07 11:04:11 +02:00
tmontaigu
90c850ca0d feat(integer): improve scalar add,sub and negation
- scalar_add now uses the same parallel carry propagation algorithm
  as the add function.

- scalar_sub now uses the same parallel carry propagation algorithm
  as the sub function.

- the 'default' negation function uses the now improved scalar_add
  to be faster

- unchecked_scalar_add, smart_scalar_add, checked_scalar_add, scalar_add
  have been updated to work on generic scalar type so it should work
  on u32, u64, u128, U256, etc

- unchecked_scalar_sub, smart_scalar_sub, checked_scalar_sub, scalar_sub
  have been updated to work on generic scalar type so it should work
  on u32, u64, u128.
  As U256 does not yet implement the UnsignedInteger trait, its not
  usable yet as a scalar type for the sub operation.

- The HLAPI is still locked to u64 scalars, it will be updated
  when most / all scalar ops are ready
2023-06-06 19:56:56 +02:00
Arthur Meyre
c8d3008a8d chore(shortint): proper ThreadCount serialization for bootstrapping key
- skip thread_count on serialization, deserialize using the function to
properly populate thread_count
2023-06-06 16:58:23 +02:00
David Testé
08c264f193 chore(ci): put wasm tests in their own workflow
This is mostly done to avoid failure on AWS tests (core, boolean,
shortint, ...) workflow due to flaky tests in WASM.
2023-06-06 14:02:52 +02:00
twiby
4ae202d8a4 refactor(tfhe): provide CiphertextBase with functions to convert from a generic type OpOrder to a specific struct.
This allows removing all calls to std::mem::transmute in shortint/engine/server_side/mod.rs, isolating unsafe blocks in the conversion functions. This makes the code safer and more likely to panic! in case of an error.
2023-06-06 12:19:56 +02:00
dependabot[bot]
7eb8601540 chore(deps): bump JS-DevTools/npm-publish from 2.1.0 to 2.2.0
Bumps [JS-DevTools/npm-publish](https://github.com/JS-DevTools/npm-publish) from 2.1.0 to 2.2.0.
- [Release notes](https://github.com/JS-DevTools/npm-publish/releases)
- [Changelog](https://github.com/JS-DevTools/npm-publish/blob/main/CHANGELOG.md)
- [Commits](541aa6b21b...a25b4180b7)

---
updated-dependencies:
- dependency-name: JS-DevTools/npm-publish
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-06 10:23:30 +02:00
tmontaigu
8a1691c536 chore(wasm): remove serialization in web test
In the web wasm test we serialize the public key
to print its size (38_931_6265 bytes) this
means we hold the public key twice in ram.

I suspect this causes frequent out of
memory errors which then result in the
test timing out.

So we remove that hoping it has a positive impact
2023-06-02 17:19:04 +02:00
Arthur Meyre
d1cb55ba24 chore(tfhe): add multi bit shortint and integer tests
- default tests do not run multi bit PBS as it's not yet deterministic
- only radix parallel currently use multi bit pbs in integer
- remove determinism checks for some unchecked ops
- 4_4 multi bit parameters are disabled for now as they seem to introduce
too much noise
2023-06-02 16:00:28 +02:00
Arthur Meyre
2b9a49db87 chore(tfhe): switch to using Into for PBS parameters conversion
- it seems generally better for some "Self conversion" i.e. Into<A> for A
seems to work better than From<A> for A
2023-06-02 16:00:28 +02:00
Arthur Meyre
62ddb24f00 chore(ci): add multibit to key cache generation 2023-06-02 16:00:28 +02:00
Arthur Meyre
c6ae463b41 feat(shortint): add the possibility to use multi bit PBS 2023-06-02 16:00:28 +02:00
tmontaigu
4947eefad4 fix(u256): align with rust for shift behaviours 2023-06-02 12:00:42 +02:00
tmontaigu
71209e3927 feat(integer): make scalar shift match rust when shift >= bit size
When the scalar value denoting the shift was bigger or equal to
the total  number of bits in the ciphertext we would return zeros.

To match more the rust behaviour as well as the behaviour of
non scalar shift / rotate, the scalar shift will now remove
any higher bits of the clear shift value
2023-06-02 11:35:54 +02:00
tmontaigu
2a66ea3d16 feat(intger): add shifts and rotates on encrypted values
This implemantation is base on barrel shifters
which are used un hardware
2023-06-02 11:35:54 +02:00
tmontaigu
d4ff1f5595 feat(wasm): add parralellism in wasm API and add wasm for HLAPI
Co-authored-by: David Testé <david.teste@zama.ai>
2023-06-02 11:13:12 +02:00
Arthur Meyre
8ae92a960d chore(ci): add multibit workflow 2023-06-02 08:55:42 +02:00
tmontaigu
b042c2f7d6 refactor(integer): improve decomposition/recomposition into blocks
This new implementation should hopefully be a little bit easier to understand.

But more importantly it is more general/generic,
the previous implementation required the input type to be able to be described as u64 words,
the new one works for any type (as long as needed trait are implemented)

Also the new implementation is separated from the encryption code,
meaning it will be usable by scalar operation, which will allow us
to deduplicate code and start making scalar ops support scalar values
that are on more than 64-bits.
2023-06-01 18:13:34 +02:00
tmontaigu
e307da5c7f feat(integer): make eq (==) faster and add ne (!=) 2023-05-31 19:03:02 +02:00
Arthur Meyre
3d5b88d608 chore(core): encode the proper expectation wrt to ciphertext modulus
- we don't manage any non native moduli but rather native-compatible moduli
so update the asserts accordingly
2023-05-30 15:39:14 +02:00
Arthur Meyre
4fbf0691c5 chore(core): rename get_scaling_to_native_torus
- function now named get_power_of_two_scaling_to_native_torus to emphasize
it's reserved to power of 2 moduli
2023-05-30 15:39:14 +02:00
Arthur Meyre
5d277e85b9 feat(core): add non native decomposer 2023-05-30 15:39:14 +02:00
Arthur Meyre
778eea30e9 chore(tfhe): remove anyhow, just use Box<dyn std::error::Error> 2023-05-30 11:55:43 +02:00
tmontaigu
63247fa227 chore(sha256_example): use array_fn 2023-05-25 00:22:01 +02:00
David Testé
799291a1f0 docs(tfhe): format sha256_bool and add make recipes to run it 2023-05-25 00:22:01 +02:00
Sexosexosexo
509fe7a63e docs(tfhe): add boolean sha256 tutorial
Clap dev dependency added
2023-05-25 00:22:01 +02:00
tmontaigu
4eac45f0c6 fix(dark_market): fix change cwd logic 2023-05-24 23:30:26 +02:00
David Testé
ddb3451087 docs(tfhe): format dark market example add make recipe to run it 2023-05-24 23:30:26 +02:00
Yagiz Senal
e66a329e33 docs(tfhe): add dark market tutorial 2023-05-24 23:30:26 +02:00
David Testé
d79b1d9b19 docs(tfhe): format regex_engine and add make recipes to run it 2023-05-24 22:11:53 +02:00
Rick Klomp
b501cc078a docs(tfhe): add FHE Regex Pattern Matching Engine
this includes a tutorial and an example implementation for the regex bounty
2023-05-24 22:11:53 +02:00
tmontaigu
800878d89e feat(hlapi): add CompressedPublicKey decompression 2023-05-23 14:19:35 +02:00
tmontaigu
20d0e81bae feat(boolean): add CompressedPublicKey 2023-05-19 19:07:16 +02:00
tmontaigu
d3dbf4ecc9 feat(integer): allow decompressing CompressedPublicKey 2023-05-19 15:32:25 +02:00
tmontaigu
c20ca07cd3 chore(ci): reduce number of test-threads
Reduce number of test-threads being spawned
to reduce propability if tests getting killed due
to out of memory
2023-05-17 15:58:27 +02:00
tmontaigu
9f6c7e9139 feat(hlapi): add CompressedServerKey
Now that WopPBS key are optional in the hlapi
we can have a CompressedServerKey.
If a user tries to create a CompressedServerKey
but has enabled function evaluation on integers
(WopPBS) then it will panic as WopPBS are not yet compressible.
And 'stuffing' the non-compressed wopbs-key in the
compressed server key, would defeat the purpose of
compressed server key, as WopPBS key makes of for
the vast majority of the space used.

Also having CompressedServerKey is required to
be able to have wasm API of the hlapi
as wasm cannot generate normal server key.
2023-05-17 11:15:37 +02:00
David Testé
3c8d6a6f8b chore(ci): handle aws tests in pull request from forked repository 2023-05-17 08:42:19 +02:00
Arthur Meyre
1c837fa6f0 test(core): add normality test based on Shapiro-Francia 2023-05-16 10:12:28 +02:00
tmontaigu
1ec7e4762a feat(integer): make wopbs compile on wasm
The goal here is just to make the code compile
and not allow js api to generate wopbs key yet.
2023-05-15 22:06:36 +02:00
tmontaigu
20fb697d57 refactor(hlapi): disable WopPBS by default in hlapi
In the HLAPI, the WopPBS is enabled by default,
meaning the WopPBS key is generated when integers
are enabled.

This is not really good as the wopbs key is huge
(~700MB with PARAM_2_2) and only used for function evaluation
which does not scale for all types exposed by the halpi
and is still a bit experimental so not really advertised in the docs.

Also keys for wopbs are not compressible yet
(that is why the HLAPI does not yet have a CompressedServerKey).

So disabling wopbs by default will enable to have a compressed server
key that actually compresse things.
2023-05-15 19:01:53 +02:00
tmontaigu
0429d56cf3 chore(U256): add small tests 2023-05-15 11:40:44 +02:00
tmontaigu
509bf3e284 docs(bench): update results of benchmarks in the docs 2023-05-12 21:58:47 +02:00
Arthur Meyre
b2fc1d5266 refactor(shortint): make a difference between PBS and Wopbs parameters
- preparatory work to manage several PBS implementations and harmonize
parameters management

BREAKING CHANGE:
- parameters structures changed
- gen_keys for integer now takes parameters by value to uniformize with
shortint
2023-05-12 17:20:05 +02:00
Arthur Meyre
62d94dbee8 chore(tfhe): fix double Example heading in docstring 2023-05-12 17:20:05 +02:00
Agnes Leroy
fbe911d7db chore(tfhe): hard set number of threads to 10 for the multi-bit PBS
It's the optimal value measured on an m6i.metal instance where we run the benchmarks
2023-05-12 15:12:11 +02:00
tmontaigu
ba72faf828 chore(readme): remove non-needed mut in boolean example 2023-05-11 22:25:12 +02:00
tmontaigu
c387b9340f feat(integer): improve mul and scalar mul
This improves the mul and scalar_mul algorithms
to be faster

The improvement is made within the code
that was responsible for summing up all
the terms by making better use of carries
and avoiding uncessary propagations.

The scalar mul forwards the call to a right shift
when the scalar is a power of two as it just cost one
PBS so it will always be faster.

For 64-bits, target-cpu=native + avx512:
- mul before: 3.4s
- mul after: 900ms
2023-05-11 14:07:11 +02:00
274 changed files with 31229 additions and 4765 deletions

View File

@@ -25,26 +25,35 @@ on:
request_id:
description: 'Slab request ID'
type: string
matrix_item:
description: 'Build matrix item'
fork_repo:
description: 'Name of forked repo as user/repo'
type: string
fork_git_sha:
description: 'Git SHA to checkout from fork'
type: string
jobs:
integer-tests:
concurrency:
group: ${{ github.workflow }}_${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }}
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
cancel-in-progress: true
runs-on: ${{ github.event.inputs.runner_name }}
runs-on: ${{ inputs.runner_name }}
steps:
# Step used for log purpose.
- name: Instance configuration used
run: |
echo "ID: ${{ github.event.inputs.instance_id }}"
echo "AMI: ${{ github.event.inputs.instance_image_id }}"
echo "Type: ${{ github.event.inputs.instance_type }}"
echo "Request ID: ${{ github.event.inputs.request_id }}"
echo "ID: ${{ inputs.instance_id }}"
echo "AMI: ${{ inputs.instance_image_id }}"
echo "Type: ${{ inputs.instance_type }}"
echo "Request ID: ${{ inputs.request_id }}"
echo "Fork repo: ${{ inputs.fork_repo }}"
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
- uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
- name: Checkout tfhe-rs
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
with:
repository: ${{ inputs.fork_repo }}
ref: ${{ inputs.fork_git_sha }}
- name: Set up home
run: |

View File

@@ -0,0 +1,90 @@
name: AWS Multi Bit Tests on CPU
env:
CARGO_TERM_COLOR: always
ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
RUSTFLAGS: "-C target-cpu=native"
on:
# Allows you to run this workflow manually from the Actions tab as an alternative.
workflow_dispatch:
# All the inputs are provided by Slab
inputs:
instance_id:
description: "AWS instance ID"
type: string
instance_image_id:
description: "AWS instance AMI ID"
type: string
instance_type:
description: "AWS instance product type"
type: string
runner_name:
description: "Action runner name"
type: string
request_id:
description: 'Slab request ID'
type: string
fork_repo:
description: 'Name of forked repo as user/repo'
type: string
fork_git_sha:
description: 'Git SHA to checkout from fork'
type: string
jobs:
shortint-tests:
concurrency:
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
cancel-in-progress: true
runs-on: ${{ inputs.runner_name }}
steps:
# Step used for log purpose.
- name: Instance configuration used
run: |
echo "ID: ${{ inputs.instance_id }}"
echo "AMI: ${{ inputs.instance_image_id }}"
echo "Type: ${{ inputs.instance_type }}"
echo "Request ID: ${{ inputs.request_id }}"
echo "Fork repo: ${{ inputs.fork_repo }}"
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
- name: Checkout tfhe-rs
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
with:
repository: ${{ inputs.fork_repo }}
ref: ${{ inputs.fork_git_sha }}
- name: Set up home
run: |
echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}"
- name: Install latest stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af
with:
toolchain: stable
default: true
- name: Gen Keys if required
run: |
MULTI_BIT_ONLY=TRUE make gen_key_cache
- name: Run shortint multi-bit tests
run: |
make test_shortint_multi_bit_ci
- name: Run integer multi-bit tests
run: |
make test_integer_multi_bit_ci
- name: Slack Notification
if: ${{ always() }}
continue-on-error: true
uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7
env:
SLACK_COLOR: ${{ job.status }}
SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }}
SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png
SLACK_MESSAGE: "Shortint tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})"
SLACK_USERNAME: ${{ secrets.BOT_USERNAME }}
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}

View File

@@ -25,26 +25,35 @@ on:
request_id:
description: 'Slab request ID'
type: string
matrix_item:
description: 'Build matrix item'
fork_repo:
description: 'Name of forked repo as user/repo'
type: string
fork_git_sha:
description: 'Git SHA to checkout from fork'
type: string
jobs:
shortint-tests:
concurrency:
group: ${{ github.workflow }}_${{ github.ref }}_${{ github.event.inputs.instance_image_id }}_${{ github.event.inputs.instance_type }}
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
cancel-in-progress: true
runs-on: ${{ github.event.inputs.runner_name }}
runs-on: ${{ inputs.runner_name }}
steps:
# Step used for log purpose.
- name: Instance configuration used
run: |
echo "ID: ${{ github.event.inputs.instance_id }}"
echo "AMI: ${{ github.event.inputs.instance_image_id }}"
echo "Type: ${{ github.event.inputs.instance_type }}"
echo "Request ID: ${{ github.event.inputs.request_id }}"
echo "ID: ${{ inputs.instance_id }}"
echo "AMI: ${{ inputs.instance_image_id }}"
echo "Type: ${{ inputs.instance_type }}"
echo "Request ID: ${{ inputs.request_id }}"
echo "Fork repo: ${{ inputs.fork_repo }}"
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
- uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
- name: Checkout tfhe-rs
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
with:
repository: ${{ inputs.fork_repo }}
ref: ${{ inputs.fork_git_sha }}
- name: Set up home
run: |
@@ -72,10 +81,6 @@ jobs:
run: |
make test_user_doc
- name: Run js on wasm API tests
run: |
make test_nodejs_wasm_api_in_docker
- name: Gen Keys if required
run: |
make gen_key_cache

View File

@@ -0,0 +1,87 @@
name: AWS WASM Tests on CPU
env:
CARGO_TERM_COLOR: always
ACTION_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
RUSTFLAGS: "-C target-cpu=native"
on:
# Allows you to run this workflow manually from the Actions tab as an alternative.
workflow_dispatch:
# All the inputs are provided by Slab
inputs:
instance_id:
description: "AWS instance ID"
type: string
instance_image_id:
description: "AWS instance AMI ID"
type: string
instance_type:
description: "AWS instance product type"
type: string
runner_name:
description: "Action runner name"
type: string
request_id:
description: 'Slab request ID'
type: string
fork_repo:
description: 'Name of forked repo as user/repo'
type: string
fork_git_sha:
description: 'Git SHA to checkout from fork'
type: string
jobs:
wasm-tests:
concurrency:
group: ${{ github.workflow }}_${{ github.ref }}_${{ inputs.instance_image_id }}_${{ inputs.instance_type }}
cancel-in-progress: true
runs-on: ${{ inputs.runner_name }}
steps:
# Step used for log purpose.
- name: Instance configuration used
run: |
echo "ID: ${{ inputs.instance_id }}"
echo "AMI: ${{ inputs.instance_image_id }}"
echo "Type: ${{ inputs.instance_type }}"
echo "Request ID: ${{ inputs.request_id }}"
echo "Fork repo: ${{ inputs.fork_repo }}"
echo "Fork git sha: ${{ inputs.fork_git_sha }}"
- name: Checkout tfhe-rs
uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab
with:
repository: ${{ inputs.fork_repo }}
ref: ${{ inputs.fork_git_sha }}
- name: Set up home
run: |
echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}"
- name: Install latest stable
uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af
with:
toolchain: stable
default: true
- name: Run js on wasm API tests
run: |
make test_nodejs_wasm_api_in_docker
- name: Run parallel wasm tests
run: |
make install_node
make ci_test_web_js_api_parallel
- name: Slack Notification
if: ${{ always() }}
continue-on-error: true
uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7
env:
SLACK_COLOR: ${{ job.status }}
SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }}
SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png
SLACK_MESSAGE: "WASM tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})"
SLACK_USERNAME: ${{ secrets.BOT_USERNAME }}
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}

View File

@@ -31,7 +31,7 @@ jobs:
make build_web_js_api
- name: Publish web package
uses: JS-DevTools/npm-publish@541aa6b21b4a1e9990c95a92c21adc16b35e9551
uses: JS-DevTools/npm-publish@a25b4180b728b0279fca97d4e5bccf391685aead
with:
token: ${{ secrets.NPM_TOKEN }}
package: tfhe/pkg/package.json
@@ -45,7 +45,7 @@ jobs:
sed -i 's/"tfhe"/"node-tfhe"/g' tfhe/pkg/package.json
- name: Publish Node package
uses: JS-DevTools/npm-publish@541aa6b21b4a1e9990c95a92c21adc16b35e9551
uses: JS-DevTools/npm-publish@a25b4180b728b0279fca97d4e5bccf391685aead
with:
token: ${{ secrets.NPM_TOKEN }}
package: tfhe/pkg/package.json

View File

@@ -16,3 +16,5 @@ jobs:
message: |
@slab-ci cpu_test
@slab-ci cpu_integer_test
@slab-ci cpu_multi_bit_test
@slab-ci cpu_wasm_test

1
.gitignore vendored
View File

@@ -12,3 +12,4 @@ target/
# Some of our bench outputs
/tfhe/benchmarks_parameters
/tfhe/shortint_key_sizes.csv

131
CODE_OF_CONDUCT.md Normal file
View File

@@ -0,0 +1,131 @@
# Contributor Covenant Code of Conduct
## Our pledge
We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.
We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.
## Our standards
Examples of behavior that contributes to a positive environment for our
community include:
- Demonstrating empathy and kindness toward other people
- Being respectful of differing opinions, viewpoints, and experiences
- Giving and gracefully accepting constructive feedback
- Accepting responsibility and apologizing to those affected by our mistakes,
and learning from the experience
- Focusing on what is best not just for us as individuals, but for the overall
community
Examples of unacceptable behavior include:
- The use of sexualized language or imagery, and sexual attention or advances of
any kind
- Trolling, insulting or derogatory comments, and personal or political attacks
- Public or private harassment
- Publishing others' private information, such as a physical or email address,
without their explicit permission
- Other conduct which could reasonably be considered inappropriate in a
professional setting
## Enforcement responsibilities
Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.
Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.
## Scope
This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting us anonymously through [this form](https://forms.gle/569j3cZqGRFgrR3u9).
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the
reporter of any incident.
## Enforcement guidelines
Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of
actions.
**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.
### 3. Temporary ban
**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.
### 4. Permanent ban
**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the
community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][mozilla coc].
For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][faq]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].
[faq]: https://www.contributor-covenant.org/faq
[homepage]: https://www.contributor-covenant.org
[mozilla coc]: https://github.com/mozilla/diversity
[translations]: https://www.contributor-covenant.org/translations
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html

119
Makefile
View File

@@ -21,6 +21,10 @@ else
AVX512_FEATURE=
endif
# Variables used only for regex_engine example
REGEX_STRING?=''
REGEX_PATTERN?=''
.PHONY: rs_check_toolchain # Echo the rust toolchain used for checks
rs_check_toolchain:
@echo $(RS_CHECK_TOOLCHAIN)
@@ -58,6 +62,13 @@ install_wasm_pack: install_rs_build_toolchain
cargo $(CARGO_RS_BUILD_TOOLCHAIN) install wasm-pack || \
( echo "Unable to install cargo wasm-pack, unknown error." && exit 1 )
.PHONY: install_node # Install last version of NodeJS via nvm
install_node:
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh | $(SHELL)
source ~/.bashrc
$(SHELL) -i -c 'nvm install node' || \
( echo "Unable to install node, unknown error." && exit 1 )
.PHONY: fmt # Format rust code
fmt: install_rs_check_toolchain
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt
@@ -105,10 +116,10 @@ clippy_c_api: install_rs_check_toolchain
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint and the js wasm API
.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint, integer and the js wasm API
clippy_js_wasm_api: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api \
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_tasks # Run clippy lints on helper tasks crate.
@@ -131,9 +142,13 @@ clippy_fast: clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_ta
.PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests
gen_key_cache: install_rs_build_toolchain
if [[ "$${MULTI_BIT_ONLY}" == TRUE ]]; then \
multi_bit_flag="--multi-bit-only"; \
fi && \
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
--example generates_test_keys \
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe
--features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe -- \
$${multi_bit_flag:+"$${multi_bit_flag}"}
.PHONY: build_core # Build core_crypto without experimental features
build_core: install_rs_build_toolchain install_rs_check_toolchain
@@ -184,14 +199,23 @@ build_web_js_api: install_rs_build_toolchain install_wasm_pack
cd tfhe && \
RUSTFLAGS="$(WASM_RUSTFLAGS)" rustup run "$(RS_BUILD_TOOLCHAIN)" \
wasm-pack build --release --target=web \
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api
.PHONY: build_web_js_api_parallel # Build the js API targeting the web browser with parallelism support
build_web_js_api_parallel: install_rs_check_toolchain install_wasm_pack
cd tfhe && \
rustup component add rust-src --toolchain $(RS_CHECK_TOOLCHAIN) && \
RUSTFLAGS="$(WASM_RUSTFLAGS) -C target-feature=+atomics,+bulk-memory,+mutable-globals" rustup run $(RS_CHECK_TOOLCHAIN) \
wasm-pack build --release --target=web \
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,parallel-wasm-api \
-Z build-std=panic_abort,std
.PHONY: build_node_js_api # Build the js API targeting nodejs
build_node_js_api: install_rs_build_toolchain install_wasm_pack
cd tfhe && \
RUSTFLAGS="$(WASM_RUSTFLAGS)" rustup run "$(RS_BUILD_TOOLCHAIN)" \
wasm-pack build --release --target=nodejs \
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api
-- --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api
.PHONY: test_core_crypto # Run the tests of the core_crypto module including experimental ones
test_core_crypto: install_rs_build_toolchain install_rs_check_toolchain
@@ -208,13 +232,24 @@ test_boolean: install_rs_build_toolchain
--features=$(TARGET_ARCH_FEATURE),boolean -p tfhe -- boolean::
.PHONY: test_c_api # Run the tests for the C API
test_c_api: build_c_api
test_c_api: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api \
-p tfhe \
c_api
"$(MAKE)" build_c_api
./scripts/c_api_tests.sh
.PHONY: test_shortint_ci # Run the tests for shortint ci
test_shortint_ci: install_rs_build_toolchain install_cargo_nextest
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
./scripts/shortint-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN)
./scripts/shortint-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN)
.PHONY: test_shortint_multi_bit_ci # Run the tests for shortint ci running only multibit tests
test_shortint_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
./scripts/shortint-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) --multi-bit
.PHONY: test_shortint # Run all the tests for shortint
test_shortint: install_rs_build_toolchain
@@ -224,7 +259,12 @@ test_shortint: install_rs_build_toolchain
.PHONY: test_integer_ci # Run the tests for integer ci
test_integer_ci: install_rs_build_toolchain install_cargo_nextest
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
./scripts/integer-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN)
./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN)
.PHONY: test_integer_multi_bit_ci # Run the tests for integer ci running only multibit tests
test_integer_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest
BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \
./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) --multi-bit
.PHONY: test_integer # Run all the tests for integer
test_integer: install_rs_build_toolchain
@@ -242,12 +282,27 @@ test_user_doc: install_rs_build_toolchain
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache -p tfhe \
-- test_user_docs::
.PHONY: test_regex_engine # Run tests for regex_engine example
test_regex_engine: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--example regex_engine \
--features=$(TARGET_ARCH_FEATURE),integer
.PHONY: test_sha256_bool # Run tests for sha256_bool example
test_sha256_bool: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--example sha256_bool \
--features=$(TARGET_ARCH_FEATURE),boolean
.PHONY: doc # Build rust doc
doc: install_rs_check_toolchain
RUSTDOCFLAGS="--html-in-header katex-header.html -Dwarnings" \
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer --no-deps
.PHONY: docs # Build rust doc alias for doc
docs: doc
.PHONY: format_doc_latex # Format the documentation latex equations to avoid broken rendering.
format_doc_latex:
cargo xtask format_latex_doc
@@ -258,13 +313,15 @@ format_doc_latex:
@printf "\n===============================\n"
.PHONY: check_compile_tests # Build tests in debug without running them
check_compile_tests: build_c_api
check_compile_tests:
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --no-run \
--features=$(TARGET_ARCH_FEATURE),experimental,boolean,shortint,integer,internal-keycache \
-p tfhe
@if [[ "$(OS)" == "Linux" || "$(OS)" == "Darwin" ]]; then \
./scripts/c_api_tests.sh --build-only; \
fi
@if [[ "$(OS)" == "Linux" || "$(OS)" == "Darwin" ]]; then \
"$(MAKE)" build_c_api; \
./scripts/c_api_tests.sh --build-only; \
fi
.PHONY: build_nodejs_test_docker # Build a docker image with tools to run nodejs tests for wasm API
build_nodejs_test_docker:
@@ -286,6 +343,24 @@ test_nodejs_wasm_api_in_docker: build_nodejs_test_docker
test_nodejs_wasm_api: build_node_js_api
cd tfhe && node --test js_on_wasm_tests
.PHONY: test_web_js_api_parallel # Run tests for the web wasm api
test_web_js_api_parallel: build_web_js_api_parallel
$(MAKE) -C tfhe/web_wasm_parallel_tests test
.PHONY: ci_test_web_js_api_parallel # Run tests for the web wasm api
ci_test_web_js_api_parallel: build_web_js_api_parallel
# Auto-retry since WASM tests can be flaky
@for i in 1 2 3 ; do \
source ~/.nvm/nvm.sh && \
nvm use node && \
$(MAKE) -C tfhe/web_wasm_parallel_tests test-ci | tee web_js_tests_output; \
if grep -q -i "timeout" web_js_tests_output; then \
echo "Timeout occurred starting attempt #${i}"; \
else \
break; \
fi; \
done
.PHONY: no_tfhe_typo # Check we did not invert the h and f in tfhe
no_tfhe_typo:
@./scripts/no_tfhe_typo.sh
@@ -326,6 +401,26 @@ measure_boolean_key_sizes: install_rs_check_toolchain
--example boolean_key_sizes \
--features=$(TARGET_ARCH_FEATURE),boolean,internal-keycache
.PHONY: regex_engine # Run regex_engine example
regex_engine: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
--example regex_engine \
--features=$(TARGET_ARCH_FEATURE),integer \
-- $(REGEX_STRING) $(REGEX_PATTERN)
.PHONY: dark_market # Run dark market example
dark_market: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
--example dark_market \
--features=$(TARGET_ARCH_FEATURE),integer,internal-keycache \
-- fhe-modified fhe-parallel plain fhe
.PHONY: sha256_bool # Run sha256_bool example
sha256_bool: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) run --profile $(CARGO_PROFILE) \
--example sha256_bool \
--features=$(TARGET_ARCH_FEATURE),boolean
.PHONY: pcc # pcc stands for pre commit checks
pcc: no_tfhe_typo check_fmt doc clippy_all check_compile_tests

View File

@@ -68,7 +68,7 @@ use tfhe::boolean::prelude::*;
fn main() {
// We generate a set of client/server keys, using the default parameters:
let (mut client_key, mut server_key) = gen_keys();
let (client_key, server_key) = gen_keys();
// We use the client secret key to encrypt two messages:
let ct_1 = client_key.encrypt(true);
@@ -132,7 +132,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We create keys to create 16 bits integers
// using 8 blocks of 2 bits
let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, 8);
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, 8);
let clear_a = 2382u16;
let clear_b = 29374u16;

View File

@@ -3,6 +3,11 @@ region = "eu-west-3"
image_id = "ami-04deffe45b5b236fd"
instance_type = "m6i.32xlarge"
[profile.cpu-small]
region = "eu-west-3"
image_id = "ami-04deffe45b5b236fd"
instance_type = "m6i.4xlarge"
[profile.bench]
region = "eu-west-3"
image_id = "ami-04deffe45b5b236fd"
@@ -18,6 +23,16 @@ workflow = "aws_tfhe_integer_tests.yml"
profile = "cpu-big"
check_run_name = "CPU Integer AWS Tests"
[command.cpu_multi_bit_test]
workflow = "aws_tfhe_multi_bit_tests.yml"
profile = "cpu-big"
check_run_name = "CPU AWS Multi Bit Tests"
[command.cpu_wasm_test]
workflow = "aws_tfhe_wasm_tests.yml"
profile = "cpu-small"
check_run_name = "CPU AWS WASM Tests"
[command.integer_bench]
workflow = "integer_benchmark.yml"
profile = "bench"

View File

@@ -2,6 +2,49 @@
set -e
function usage() {
echo "$0: shortint test runner"
echo
echo "--help Print this message"
echo "--rust-toolchain The toolchain to run the tests with default: stable"
echo "--multi-bit Run multi-bit tests only: default off"
echo
}
RUST_TOOLCHAIN="+stable"
multi_bit=""
not_multi_bit="_multi_bit"
while [ -n "$1" ]
do
case "$1" in
"--help" | "-h" )
usage
exit 0
;;
"--rust-toolchain" )
shift
RUST_TOOLCHAIN="$1"
;;
"--multi-bit" )
multi_bit="_multi_bit"
not_multi_bit=""
;;
*)
echo "Unknown param : $1"
exit 1
;;
esac
shift
done
if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then
RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}"
fi
CURR_DIR="$(dirname "$0")"
ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")"
@@ -29,14 +72,15 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
# mul_crt_4_4 is extremely flaky (~80% failure)
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
# test_integer_smart_mul_param_message_4_carry_4 is too slow
filter_expression=''\
'test(/^integer::.*$/)'\
'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\
'and not test(~mul_crt_param_message_4_carry_4)'\
'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\
'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)'
filter_expression="""\
test(/^integer::.*${multi_bit}/) \
${not_multi_bit:+"and not test(~${not_multi_bit})"} \
and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/) \
and not test(~mul_crt_param_message_4_carry_4) \
and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/) \
and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)"""
cargo ${1:+"${1}"} nextest run \
cargo "${RUST_TOOLCHAIN}" nextest run \
--tests \
--release \
--package tfhe \
@@ -45,39 +89,46 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
--test-threads "${n_threads}" \
-E "$filter_expression"
cargo ${1:+"${1}"} test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",integer,internal-keycache \
--doc \
integer::
if [[ "${multi_bit}" == "" ]]; then
cargo "${RUST_TOOLCHAIN}" test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",integer,internal-keycache \
--doc \
integer::
fi
else
# block pbs are too slow for high params
# mul_crt_4_4 is extremely flaky (~80% failure)
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
# test_integer_smart_mul_param_message_4_carry_4 is too slow
filter_expression=''\
'test(/^integer::.*$/)'\
'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\
'and not test(~mul_crt_param_message_4_carry_4)'\
'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\
'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)'
filter_expression="""\
test(/^integer::.*${multi_bit}/) \
${not_multi_bit:+"and not test(~${not_multi_bit})"} \
and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/) \
and not test(~mul_crt_param_message_4_carry_4) \
and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/) \
and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)"""
cargo ${1:+"${1}"} nextest run \
num_cpu_threads="$(${nproc_bin})"
num_threads=$((num_cpu_threads * 2 / 3))
cargo "${RUST_TOOLCHAIN}" nextest run \
--tests \
--release \
--package tfhe \
--profile ci \
--features="${ARCH_FEATURE}",integer,internal-keycache \
--test-threads "$(${nproc_bin})" \
--test-threads $num_threads \
-E "$filter_expression"
cargo ${1:+"${1}"} test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",integer,internal-keycache \
--doc \
integer:: -- --test-threads="$(${nproc_bin})"
if [[ "${multi_bit}" == "" ]]; then
cargo "${RUST_TOOLCHAIN}" test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",integer,internal-keycache \
--doc \
integer:: -- --test-threads="$(${nproc_bin})"
fi
fi
echo "Test ran in $SECONDS seconds"

View File

@@ -2,6 +2,47 @@
set -e
function usage() {
echo "$0: shortint test runner"
echo
echo "--help Print this message"
echo "--rust-toolchain The toolchain to run the tests with default: stable"
echo "--multi-bit Run multi-bit tests only: default off"
echo
}
RUST_TOOLCHAIN="+stable"
multi_bit=""
while [ -n "$1" ]
do
case "$1" in
"--help" | "-h" )
usage
exit 0
;;
"--rust-toolchain" )
shift
RUST_TOOLCHAIN="$1"
;;
"--multi-bit" )
multi_bit="_multi_bit"
;;
*)
echo "Unknown param : $1"
exit 1
;;
esac
shift
done
if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then
RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}"
fi
CURR_DIR="$(dirname "$0")"
ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")"
@@ -31,25 +72,25 @@ else
fi
if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
filter_expression_small_params=''\
'('\
' test(/^shortint::.*_param_message_1_carry_1$/)'\
'or test(/^shortint::.*_param_message_1_carry_2$/)'\
'or test(/^shortint::.*_param_message_1_carry_3$/)'\
'or test(/^shortint::.*_param_message_1_carry_4$/)'\
'or test(/^shortint::.*_param_message_1_carry_5$/)'\
'or test(/^shortint::.*_param_message_1_carry_6$/)'\
'or test(/^shortint::.*_param_message_2_carry_1$/)'\
'or test(/^shortint::.*_param_message_2_carry_2$/)'\
'or test(/^shortint::.*_param_message_2_carry_3$/)'\
'or test(/^shortint::.*_param_message_3_carry_1$/)'\
'or test(/^shortint::.*_param_message_3_carry_2$/)'\
'or test(/^shortint::.*_param_message_3_carry_3$/)'\
')'\
'and not test(~smart_add_and_mul)' # This test is too slow
filter_expression_small_params="""\
(\
test(/^shortint::.*_param${multi_bit}_message_1_carry_1/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_2/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_3/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_4/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_5/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_6/) \
or test(/^shortint::.*_param${multi_bit}_message_2_carry_1/) \
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2/) \
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3/) \
or test(/^shortint::.*_param${multi_bit}_message_3_carry_1/) \
or test(/^shortint::.*_param${multi_bit}_message_3_carry_2/) \
or test(/^shortint::.*_param${multi_bit}_message_3_carry_3/) \
) \
and not test(~smart_add_and_mul)""" # This test is too slow
# Run tests only no examples or benches with small params and more threads
cargo ${1:+"${1}"} nextest run \
cargo "${RUST_TOOLCHAIN}" nextest run \
--tests \
--release \
--package tfhe \
@@ -58,14 +99,14 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
--test-threads "${n_threads_small}" \
-E "${filter_expression_small_params}"
filter_expression_big_params=''\
'('\
' test(/^shortint::.*_param_message_4_carry_4$/)'\
')'\
'and not test(~smart_add_and_mul)'
filter_expression_big_params="""\
(\
test(/^shortint::.*_param${multi_bit}_message_4_carry_4/) \
) \
and not test(~smart_add_and_mul)"""
# Run tests only no examples or benches with big params and less threads
cargo ${1:+"${1}"} nextest run \
cargo "${RUST_TOOLCHAIN}" nextest run \
--tests \
--release \
--package tfhe \
@@ -74,33 +115,35 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
--test-threads "${n_threads_big}" \
-E "${filter_expression_big_params}"
cargo ${1:+"${1}"} test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--doc \
shortint::
if [[ "${multi_bit}" == "" ]]; then
cargo "${RUST_TOOLCHAIN}" test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--doc \
shortint::
fi
else
filter_expression=''\
'('\
' test(/^shortint::.*_param_message_1_carry_1$/)'\
'or test(/^shortint::.*_param_message_1_carry_2$/)'\
'or test(/^shortint::.*_param_message_1_carry_3$/)'\
'or test(/^shortint::.*_param_message_1_carry_4$/)'\
'or test(/^shortint::.*_param_message_1_carry_5$/)'\
'or test(/^shortint::.*_param_message_1_carry_6$/)'\
'or test(/^shortint::.*_param_message_2_carry_1$/)'\
'or test(/^shortint::.*_param_message_2_carry_2$/)'\
'or test(/^shortint::.*_param_message_2_carry_3$/)'\
'or test(/^shortint::.*_param_message_3_carry_1$/)'\
'or test(/^shortint::.*_param_message_3_carry_2$/)'\
'or test(/^shortint::.*_param_message_3_carry_3$/)'\
'or test(/^shortint::.*_param_message_4_carry_4$/)'\
')'\
'and not test(~smart_add_and_mul)' # This test is too slow
filter_expression="""\
(\
test(/^shortint::.*_param${multi_bit}_message_1_carry_1/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_2/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_3/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_4/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_5/) \
or test(/^shortint::.*_param${multi_bit}_message_1_carry_6/) \
or test(/^shortint::.*_param${multi_bit}_message_2_carry_1/) \
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2/) \
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3/) \
or test(/^shortint::.*_param${multi_bit}_message_3_carry_1/) \
or test(/^shortint::.*_param${multi_bit}_message_3_carry_2/) \
or test(/^shortint::.*_param${multi_bit}_message_3_carry_3/) \
or test(/^shortint::.*_param${multi_bit}_message_4_carry_4/) \
)\
and not test(~smart_add_and_mul)""" # This test is too slow
# Run tests only no examples or benches with small params and more threads
cargo ${1:+"${1}"} nextest run \
cargo "${RUST_TOOLCHAIN}" nextest run \
--tests \
--release \
--package tfhe \
@@ -109,12 +152,14 @@ else
--test-threads "$(${nproc_bin})" \
-E "${filter_expression}"
cargo ${1:+"${1}"} test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--doc \
shortint:: -- --test-threads="$(${nproc_bin})"
if [[ "${multi_bit}" == "" ]]; then
cargo "${RUST_TOOLCHAIN}" test \
--release \
--package tfhe \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--doc \
shortint:: -- --test-threads="$(${nproc_bin})"
fi
fi
echo "Test ran in $SECONDS seconds"

View File

@@ -18,17 +18,23 @@ rust-version = "1.67"
[dev-dependencies]
rand = "0.8.5"
rand_distr = "0.4.3"
kolmogorov_smirnov = "1.1.0"
paste = "1.0.7"
lazy_static = { version = "1.4.0" }
criterion = "0.4.0"
doc-comment = "0.3.3"
serde_json = "1.0.94"
clap = "4.2.7"
# Used in user documentation
bincode = "1.3.3"
fs2 = { version = "0.4.3" }
itertools = "0.10.5"
num_cpus = "1.15"
# For erf and normality test
libm = "0.2.6"
test-case = "*"
combine = "*"
env_logger = "*"
log = "*"
[build-dependencies]
cbindgen = { version = "0.24.3", optional = true }
@@ -53,9 +59,10 @@ fs2 = { version = "0.4.3", optional = true }
itertools = "0.10.5"
# wasm deps
wasm-bindgen = { version = "0.2.63", features = [
wasm-bindgen = { version = "0.2.86", features = [
"serde-serialize",
], optional = true }
wasm-bindgen-rayon = { version = "1.0", optional = true }
js-sys = { version = "0.3", optional = true }
console_error_panic_hook = { version = "0.1.7", optional = true }
serde-wasm-bindgen = { version = "0.4", optional = true }
@@ -76,7 +83,12 @@ experimental-force_fft_algo_dif4 = []
__c_api = ["cbindgen", "bincode"]
boolean-c-api = ["boolean", "__c_api"]
shortint-c-api = ["shortint", "__c_api"]
high-level-c-api = ["boolean", "shortint", "integer", "__c_api"]
high-level-c-api = [
"boolean-c-api",
"shortint-c-api",
"integer",
"__c_api"
]
__wasm_api = [
"wasm-bindgen",
@@ -89,6 +101,9 @@ __wasm_api = [
]
boolean-client-js-wasm-api = ["boolean", "__wasm_api"]
shortint-client-js-wasm-api = ["shortint", "__wasm_api"]
integer-client-js-wasm-api = ["integer", "__wasm_api"]
high-level-client-js-wasm-api = ["boolean", "shortint", "integer", "__wasm_api"]
parallel-wasm-api = ["wasm-bindgen-rayon"]
nightly-avx512 = ["concrete-fft/nightly", "pulp/nightly"]
@@ -182,13 +197,29 @@ required-features = ["shortint", "internal-keycache"]
name = "boolean_key_sizes"
required-features = ["boolean", "internal-keycache"]
[[example]]
name = "dark_market"
required-features = ["integer", "internal-keycache"]
[[example]]
name = "shortint_key_sizes"
required-features = ["shortint", "internal-keycache"]
[[example]]
name = "integer_compact_pk_ct_sizes"
required-features = ["integer", "internal-keycache"]
[[example]]
name = "micro_bench_and"
required-features = ["boolean"]
[[example]]
name = "regex_engine"
required-features = ["integer"]
[[example]]
name = "sha256_bool"
required-features = ["boolean"]
[lib]
crate-type = ["lib", "staticlib", "cdylib"]

View File

@@ -7,9 +7,9 @@ use tfhe::boolean::parameters::{BooleanParameters, DEFAULT_PARAMETERS, TFHE_LIB_
use tfhe::core_crypto::prelude::*;
use tfhe::shortint::keycache::NamedParam;
use tfhe::shortint::parameters::*;
use tfhe::shortint::Parameters;
use tfhe::shortint::ClassicPBSParameters;
const SHORTINT_BENCH_PARAMS: [Parameters; 15] = [
const SHORTINT_BENCH_PARAMS: [ClassicPBSParameters; 15] = [
PARAM_MESSAGE_1_CARRY_0,
PARAM_MESSAGE_1_CARRY_1,
PARAM_MESSAGE_2_CARRY_0,
@@ -125,7 +125,7 @@ fn multi_bit_benchmark_parameters<Scalar: Numeric>(
(
CryptoParametersRecord {
lwe_dimension: Some(LweDimension(888)),
lwe_modular_std_dev: Some(StandardDev(0.000002226459789930014)),
lwe_modular_std_dev: Some(StandardDev(0.0000006125031601933181)),
pbs_base_log: Some(DecompositionBaseLog(21)),
pbs_level: Some(DecompositionLevelCount(1)),
glwe_dimension: Some(GlweDimension(1)),
@@ -400,7 +400,7 @@ fn multi_bit_pbs<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Syn
&mut out_pbs_ct,
&accumulator.as_view(),
&multi_bit_bsk,
ThreadCount(num_cpus::get()),
ThreadCount(10),
);
black_box(&mut out_pbs_ct);
})

View File

@@ -23,13 +23,13 @@ use tfhe::shortint::parameters::{
/// in radix decomposition
struct ParamsAndNumBlocksIter {
params_and_bit_sizes:
itertools::Product<IntoIter<tfhe::shortint::Parameters, 1>, IntoIter<usize, 7>>,
itertools::Product<IntoIter<tfhe::shortint::ClassicPBSParameters, 1>, IntoIter<usize, 7>>,
}
impl Default for ParamsAndNumBlocksIter {
fn default() -> Self {
// FIXME One set of parameter is tested since we want to benchmark only quickest operations.
const PARAMS: [tfhe::shortint::Parameters; 1] = [
const PARAMS: [tfhe::shortint::ClassicPBSParameters; 1] = [
PARAM_MESSAGE_2_CARRY_2,
// PARAM_MESSAGE_3_CARRY_3,
// PARAM_MESSAGE_4_CARRY_4,
@@ -42,7 +42,7 @@ impl Default for ParamsAndNumBlocksIter {
}
}
impl Iterator for ParamsAndNumBlocksIter {
type Item = (tfhe::shortint::Parameters, usize, usize);
type Item = (tfhe::shortint::ClassicPBSParameters, usize, usize);
fn next(&mut self) -> Option<Self::Item> {
let (param, bit_size) = self.params_and_bit_sizes.next()?;
@@ -664,11 +664,29 @@ define_server_key_bench_fn!(
define_server_key_bench_default_fn!(method_name: max_parallelized, display_name: max);
define_server_key_bench_default_fn!(method_name: min_parallelized, display_name: min);
define_server_key_bench_default_fn!(method_name: eq_parallelized, display_name: equal);
define_server_key_bench_default_fn!(method_name: ne_parallelized, display_name: not_equal);
define_server_key_bench_default_fn!(method_name: lt_parallelized, display_name: less_than);
define_server_key_bench_default_fn!(method_name: le_parallelized, display_name: less_or_equal);
define_server_key_bench_default_fn!(method_name: gt_parallelized, display_name: greater_than);
define_server_key_bench_default_fn!(method_name: ge_parallelized, display_name: greater_or_equal);
define_server_key_bench_default_fn!(
method_name: left_shift_parallelized,
display_name: left_shift
);
define_server_key_bench_default_fn!(
method_name: right_shift_parallelized,
display_name: right_shift
);
define_server_key_bench_default_fn!(
method_name: rotate_left_parallelized,
display_name: rotate_left
);
define_server_key_bench_default_fn!(
method_name: rotate_right_parallelized,
display_name: rotate_right
);
criterion_group!(
smart_arithmetic_operation,
smart_neg,
@@ -793,10 +811,15 @@ criterion_group!(
min_parallelized,
max_parallelized,
eq_parallelized,
ne_parallelized,
lt_parallelized,
le_parallelized,
gt_parallelized,
ge_parallelized,
left_shift_parallelized,
right_shift_parallelized,
rotate_left_parallelized,
rotate_right_parallelized,
scalar_add_parallelized,
scalar_sub_parallelized,
scalar_mul_parallelized,

View File

@@ -5,7 +5,7 @@ use crate::utilities::{write_to_json, OperatorType};
use criterion::{criterion_group, criterion_main, Criterion};
use tfhe::shortint::keycache::NamedParam;
use tfhe::shortint::parameters::*;
use tfhe::shortint::{CiphertextBig, Parameters, ServerKey};
use tfhe::shortint::{CiphertextBig, ClassicPBSParameters, ServerKey, ShortintParameterSet};
use rand::Rng;
use tfhe::shortint::keycache::KEY_CACHE;
@@ -13,14 +13,14 @@ use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::keycache::KEY_CACHE_WOPBS;
use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_4_NORM2_6;
const SERVER_KEY_BENCH_PARAMS: [Parameters; 4] = [
const SERVER_KEY_BENCH_PARAMS: [ClassicPBSParameters; 4] = [
PARAM_MESSAGE_1_CARRY_1,
PARAM_MESSAGE_2_CARRY_2,
PARAM_MESSAGE_3_CARRY_3,
PARAM_MESSAGE_4_CARRY_4,
];
const SERVER_KEY_BENCH_PARAMS_EXTENDED: [Parameters; 15] = [
const SERVER_KEY_BENCH_PARAMS_EXTENDED: [ClassicPBSParameters; 15] = [
PARAM_MESSAGE_1_CARRY_0,
PARAM_MESSAGE_1_CARRY_1,
PARAM_MESSAGE_2_CARRY_0,
@@ -43,7 +43,7 @@ fn bench_server_key_unary_function<F>(
bench_name: &str,
display_name: &str,
unary_op: F,
params: &[Parameters],
params: &[ClassicPBSParameters],
) where
F: Fn(&ServerKey, &mut CiphertextBig),
{
@@ -55,7 +55,7 @@ fn bench_server_key_unary_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
let modulus = cks.parameters.message_modulus().0 as u64;
let clear_text = rng.gen::<u64>() % modulus;
@@ -87,7 +87,7 @@ fn bench_server_key_binary_function<F>(
bench_name: &str,
display_name: &str,
binary_op: F,
params: &[Parameters],
params: &[ClassicPBSParameters],
) where
F: Fn(&ServerKey, &mut CiphertextBig, &mut CiphertextBig),
{
@@ -99,7 +99,7 @@ fn bench_server_key_binary_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
let modulus = cks.parameters.message_modulus().0 as u64;
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
@@ -133,7 +133,7 @@ fn bench_server_key_binary_scalar_function<F>(
bench_name: &str,
display_name: &str,
binary_op: F,
params: &[Parameters],
params: &[ClassicPBSParameters],
) where
F: Fn(&ServerKey, &mut CiphertextBig, u8),
{
@@ -145,7 +145,7 @@ fn bench_server_key_binary_scalar_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
let modulus = cks.parameters.message_modulus().0 as u64;
let clear_0 = rng.gen::<u64>() % modulus;
let clear_1 = rng.gen::<u64>() % modulus;
@@ -178,7 +178,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
bench_name: &str,
display_name: &str,
binary_op: F,
params: &[Parameters],
params: &[ClassicPBSParameters],
) where
F: Fn(&ServerKey, &mut CiphertextBig, u8),
{
@@ -190,7 +190,7 @@ fn bench_server_key_binary_scalar_division_function<F>(
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
let modulus = cks.parameters.message_modulus().0 as u64;
assert_ne!(modulus, 1);
let clear_0 = rng.gen::<u64>() % modulus;
@@ -231,7 +231,7 @@ fn carry_extract(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
let modulus = cks.parameters.message_modulus().0 as u64;
let clear_0 = rng.gen::<u64>() % modulus;
@@ -267,7 +267,7 @@ fn programmable_bootstrapping(c: &mut Criterion) {
let mut rng = rand::thread_rng();
let modulus = cks.parameters.message_modulus.0 as u64;
let modulus = cks.parameters.message_modulus().0 as u64;
let acc = sks.generate_accumulator(|x| x);
@@ -297,12 +297,15 @@ fn programmable_bootstrapping(c: &mut Criterion) {
bench_group.finish();
}
// TODO: remove?
fn _bench_wopbs_param_message_8_norm2_5(c: &mut Criterion) {
let mut bench_group = c.benchmark_group("programmable_bootstrap");
let param = WOPBS_PARAM_MESSAGE_4_NORM2_6;
let param_set: ShortintParameterSet = param.try_into().unwrap();
let pbs_params = param_set.pbs_parameters().unwrap();
let keys = KEY_CACHE_WOPBS.get_from_param((param, param));
let keys = KEY_CACHE_WOPBS.get_from_param((pbs_params, param));
let (cks, _, wopbs_key) = (keys.client_key(), keys.server_key(), keys.wopbs_key());
let mut rng = rand::thread_rng();

View File

@@ -5,7 +5,7 @@ use std::path::PathBuf;
use tfhe::boolean::parameters::BooleanParameters;
use tfhe::core_crypto::prelude::*;
#[cfg(feature = "shortint")]
use tfhe::shortint::Parameters;
use tfhe::shortint::ClassicPBSParameters;
#[derive(Clone, Copy, Default, Serialize)]
pub struct CryptoParametersRecord {
@@ -52,8 +52,8 @@ impl From<BooleanParameters> for CryptoParametersRecord {
}
#[cfg(feature = "shortint")]
impl From<Parameters> for CryptoParametersRecord {
fn from(params: Parameters) -> Self {
impl From<ClassicPBSParameters> for CryptoParametersRecord {
fn from(params: ClassicPBSParameters) -> Self {
CryptoParametersRecord {
lwe_dimension: Some(params.lwe_dimension),
glwe_dimension: Some(params.glwe_dimension),
@@ -64,11 +64,11 @@ impl From<Parameters> for CryptoParametersRecord {
pbs_level: Some(params.pbs_level),
ks_base_log: Some(params.ks_base_log),
ks_level: Some(params.ks_level),
pfks_level: Some(params.pfks_level),
pfks_base_log: Some(params.pfks_base_log),
pfks_modular_std_dev: Some(params.pfks_modular_std_dev),
cbs_level: Some(params.cbs_level),
cbs_base_log: Some(params.cbs_base_log),
pfks_level: None,
pfks_base_log: None,
pfks_modular_std_dev: None,
cbs_level: None,
cbs_base_log: None,
message_modulus: Some(params.message_modulus.0),
carry_modulus: Some(params.carry_modulus.0),
}

View File

@@ -9,22 +9,24 @@ int uint128_client_key(const ClientKey *client_key) {
FheUint128 *lhs = NULL;
FheUint128 *rhs = NULL;
FheUint128 *result = NULL;
U128 lhs_clear = {10, 20};
U128 rhs_clear = {1, 2};
U128 result_clear = {0};
ok = fhe_uint128_try_encrypt_with_client_key_u128(10, 20, client_key, &lhs);
ok = fhe_uint128_try_encrypt_with_client_key_u128(lhs_clear, client_key, &lhs);
assert(ok == 0);
ok = fhe_uint128_try_encrypt_with_client_key_u128(1, 2, client_key, &rhs);
ok = fhe_uint128_try_encrypt_with_client_key_u128(rhs_clear, client_key, &rhs);
assert(ok == 0);
ok = fhe_uint128_sub(lhs, rhs, &result);
assert(ok == 0);
uint64_t w0, w1;
ok = fhe_uint128_decrypt(result, client_key, &w0, &w1);
ok = fhe_uint128_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(w0 == 9);
assert(w1 == 18);
assert(result_clear.w0 == 9);
assert(result_clear.w1 == 18);
fhe_uint128_destroy(lhs);
fhe_uint128_destroy(rhs);
@@ -37,22 +39,24 @@ int uint128_encrypt_trivial(const ClientKey *client_key) {
FheUint128 *lhs = NULL;
FheUint128 *rhs = NULL;
FheUint128 *result = NULL;
ok = fhe_uint128_try_encrypt_trivial_u128(10, 20, &lhs);
U128 lhs_clear = {10, 20};
U128 rhs_clear = {1, 2};
U128 result_clear = {0};
ok = fhe_uint128_try_encrypt_trivial_u128(lhs_clear, &lhs);
assert(ok == 0);
ok = fhe_uint128_try_encrypt_trivial_u128(1, 2, &rhs);
ok = fhe_uint128_try_encrypt_trivial_u128(rhs_clear, &rhs);
assert(ok == 0);
ok = fhe_uint128_sub(lhs, rhs, &result);
assert(ok == 0);
uint64_t w0, w1;
ok = fhe_uint128_decrypt(result, client_key, &w0, &w1);
ok = fhe_uint128_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(w0 == 9);
assert(w1 == 18);
assert(result_clear.w0 == 9);
assert(result_clear.w1 == 18);
fhe_uint128_destroy(lhs);
fhe_uint128_destroy(rhs);
@@ -65,22 +69,24 @@ int uint128_public_key(const ClientKey *client_key, const PublicKey *public_key)
FheUint128 *lhs = NULL;
FheUint128 *rhs = NULL;
FheUint128 *result = NULL;
U128 lhs_clear = {10, 20};
U128 rhs_clear = {1, 2};
U128 result_clear = {0};
ok = fhe_uint128_try_encrypt_with_public_key_u128(1, 2, public_key, &lhs);
ok = fhe_uint128_try_encrypt_with_public_key_u128(lhs_clear, public_key, &lhs);
assert(ok == 0);
ok = fhe_uint128_try_encrypt_with_public_key_u128(10, 20, public_key, &rhs);
ok = fhe_uint128_try_encrypt_with_public_key_u128(rhs_clear, public_key, &rhs);
assert(ok == 0);
ok = fhe_uint128_add(lhs, rhs, &result);
assert(ok == 0);
uint64_t w0, w1;
ok = fhe_uint128_decrypt(result, client_key, &w0, &w1);
ok = fhe_uint128_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(w0 == 11);
assert(w1 == 22);
assert(result_clear.w0 == 11);
assert(result_clear.w1 == 22);
fhe_uint128_destroy(lhs);
fhe_uint128_destroy(rhs);

View File

@@ -10,14 +10,9 @@ int uint256_client_key(const ClientKey *client_key) {
FheUint256 *rhs = NULL;
FheUint256 *result = NULL;
FheUint64 *cast_result = NULL;
U256 *lhs_clear = NULL;
U256 *rhs_clear = NULL;
U256 *result_clear = NULL;
ok = u256_from_u64_words(1, 2, 3, 4, &lhs_clear);
assert(ok == 0);
ok = u256_from_u64_words(5, 6, 7, 8, &rhs_clear);
assert(ok == 0);
U256 lhs_clear = {1 , 2, 3, 4};
U256 rhs_clear = {5, 6, 7, 8};
U256 result_clear = { 0 };
ok = fhe_uint256_try_encrypt_with_client_key_u256(lhs_clear, client_key, &lhs);
assert(ok == 0);
@@ -31,14 +26,10 @@ int uint256_client_key(const ClientKey *client_key) {
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
uint64_t w0, w1, w2, w3;
ok = u256_to_u64_words(result_clear, &w0, &w1, &w2, &w3);
assert(ok == 0);
assert(w0 == 6);
assert(w1 == 8);
assert(w2 == 10);
assert(w3 == 12);
assert(result_clear.w0 == 6);
assert(result_clear.w1 == 8);
assert(result_clear.w2 == 10);
assert(result_clear.w3 == 12);
// try some casting
ok = fhe_uint256_cast_into_fhe_uint64(result, &cast_result);
@@ -48,9 +39,6 @@ int uint256_client_key(const ClientKey *client_key) {
assert(ok == 0);
assert(u64_clear == 6);
u256_destroy(lhs_clear);
u256_destroy(rhs_clear);
u256_destroy(result_clear);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);
@@ -63,14 +51,9 @@ int uint256_encrypt_trivial(const ClientKey *client_key) {
FheUint256 *lhs = NULL;
FheUint256 *rhs = NULL;
FheUint256 *result = NULL;
U256 *lhs_clear = NULL;
U256 *rhs_clear = NULL;
U256 *result_clear = NULL;
ok = u256_from_u64_words(1, 2, 3, 4, &lhs_clear);
assert(ok == 0);
ok = u256_from_u64_words(5, 6, 7, 8, &rhs_clear);
assert(ok == 0);
U256 lhs_clear = {1 , 2, 3, 4};
U256 rhs_clear = {5, 6, 7, 8};
U256 result_clear = { 0 };
ok = fhe_uint256_try_encrypt_trivial_u256(lhs_clear, &lhs);
assert(ok == 0);
@@ -84,18 +67,11 @@ int uint256_encrypt_trivial(const ClientKey *client_key) {
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
uint64_t w0, w1, w2, w3;
ok = u256_to_u64_words(result_clear, &w0, &w1, &w2, &w3);
assert(ok == 0);
assert(result_clear.w0 == 6);
assert(result_clear.w1 == 8);
assert(result_clear.w2 == 10);
assert(result_clear.w3 == 12);
assert(w0 == 6);
assert(w1 == 8);
assert(w2 == 10);
assert(w3 == 12);
u256_destroy(lhs_clear);
u256_destroy(rhs_clear);
u256_destroy(result_clear);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);
@@ -107,14 +83,9 @@ int uint256_public_key(const ClientKey *client_key, const PublicKey *public_key)
FheUint256 *lhs = NULL;
FheUint256 *rhs = NULL;
FheUint256 *result = NULL;
U256 *lhs_clear = NULL;
U256 *rhs_clear = NULL;
U256 *result_clear = NULL;
ok = u256_from_u64_words(5, 6, 7, 8, &lhs_clear);
assert(ok == 0);
ok = u256_from_u64_words(1, 2, 3, 4, &rhs_clear);
assert(ok == 0);
U256 lhs_clear = {5, 6, 7, 8};
U256 rhs_clear = {1 , 2, 3, 4};
U256 result_clear = { 0 };
ok = fhe_uint256_try_encrypt_with_public_key_u256(lhs_clear, public_key, &lhs);
assert(ok == 0);
@@ -128,18 +99,11 @@ int uint256_public_key(const ClientKey *client_key, const PublicKey *public_key)
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
uint64_t w0, w1, w2, w3;
ok = u256_to_u64_words(result_clear, &w0, &w1, &w2, &w3);
assert(ok == 0);
assert(result_clear.w0 == 4);
assert(result_clear.w1 == 4);
assert(result_clear.w2 == 4);
assert(result_clear.w3 == 4);
assert(w0 == 4);
assert(w1 == 4);
assert(w2 == 4);
assert(w3 == 4);
u256_destroy(lhs_clear);
u256_destroy(rhs_clear);
u256_destroy(result_clear);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);

View File

@@ -0,0 +1,213 @@
#include <tfhe.h>
#include <assert.h>
#include <inttypes.h>
#include <stdio.h>
int uint256_client_key(const ClientKey *client_key) {
int ok;
FheUint256 *lhs = NULL;
FheUint256 *rhs = NULL;
FheUint256 *result = NULL;
FheUint64 *cast_result = NULL;
U256 lhs_clear = {1, 2, 3, 4};
U256 rhs_clear = {5, 6, 7, 8};
U256 result_clear = {0};
ok = fhe_uint256_try_encrypt_with_client_key_u256(lhs_clear, client_key, &lhs);
assert(ok == 0);
ok = fhe_uint256_try_encrypt_with_client_key_u256(rhs_clear, client_key, &rhs);
assert(ok == 0);
ok = fhe_uint256_add(lhs, rhs, &result);
assert(ok == 0);
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(result_clear.w0 == 6);
assert(result_clear.w1 == 8);
assert(result_clear.w2 == 10);
assert(result_clear.w3 == 12);
// try some casting
ok = fhe_uint256_cast_into_fhe_uint64(result, &cast_result);
assert(ok == 0);
uint64_t u64_clear;
ok = fhe_uint64_decrypt(cast_result, client_key, &u64_clear);
assert(ok == 0);
assert(u64_clear == 6);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);
fhe_uint64_destroy(cast_result);
return ok;
}
int uint256_encrypt_trivial(const ClientKey *client_key) {
int ok;
FheUint256 *lhs = NULL;
FheUint256 *rhs = NULL;
FheUint256 *result = NULL;
U256 lhs_clear = {1, 2, 3, 4};
U256 rhs_clear = {5, 6, 7, 8};
U256 result_clear = {0};
ok = fhe_uint256_try_encrypt_trivial_u256(lhs_clear, &lhs);
assert(ok == 0);
ok = fhe_uint256_try_encrypt_trivial_u256(rhs_clear, &rhs);
assert(ok == 0);
ok = fhe_uint256_add(lhs, rhs, &result);
assert(ok == 0);
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(result_clear.w0 == 6);
assert(result_clear.w1 == 8);
assert(result_clear.w2 == 10);
assert(result_clear.w3 == 12);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);
return ok;
}
int uint256_public_key(const ClientKey *client_key,
const CompressedCompactPublicKey *compressed_public_key) {
int ok;
CompactPublicKey *public_key = NULL;
FheUint256 *lhs = NULL;
FheUint256 *rhs = NULL;
FheUint256 *result = NULL;
CompactFheUint256List *list = NULL;
U256 result_clear = {0};
U256 clears[2] = {{5, 6, 7, 8}, {1, 2, 3, 4}};
ok = compressed_compact_public_key_decompress(compressed_public_key, &public_key);
assert(ok == 0);
// Compact list example
{
ok = compact_fhe_uint256_list_try_encrypt_with_compact_public_key_u256(
&clears[0], 2, public_key, &list);
assert(ok == 0);
size_t len = 0;
ok = compact_fhe_uint256_list_len(list, &len);
assert(ok == 0);
assert(len == 2);
FheUint256 *expand_output[2] = {NULL};
ok = compact_fhe_uint256_list_expand(list, &expand_output[0], 2);
assert(ok == 0);
// transfer ownership
lhs = expand_output[0];
rhs = expand_output[1];
ok = fhe_uint256_sub(lhs, rhs, &result);
assert(ok == 0);
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(result_clear.w0 == 4);
assert(result_clear.w1 == 4);
assert(result_clear.w2 == 4);
assert(result_clear.w3 == 4);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);
}
{
ok = fhe_uint256_try_encrypt_with_compact_public_key_u256(clears[0], public_key, &lhs);
assert(ok == 0);
ok = fhe_uint256_try_encrypt_with_compact_public_key_u256(clears[1], public_key, &rhs);
assert(ok == 0);
ok = fhe_uint256_sub(lhs, rhs, &result);
assert(ok == 0);
ok = fhe_uint256_decrypt(result, client_key, &result_clear);
assert(ok == 0);
assert(result_clear.w0 == 4);
assert(result_clear.w1 == 4);
assert(result_clear.w2 == 4);
assert(result_clear.w3 == 4);
fhe_uint256_destroy(lhs);
fhe_uint256_destroy(rhs);
fhe_uint256_destroy(result);
}
compact_public_key_destroy(public_key);
return ok;
}
int main(void) {
int ok = 0;
{
ConfigBuilder *builder;
Config *config;
config_builder_all_disabled(&builder);
config_builder_enable_custom_integers(&builder, SHORTINT_PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
config_builder_build(builder, &config);
ClientKey *client_key = NULL;
ServerKey *server_key = NULL;
CompressedCompactPublicKey *compressed_public_key = NULL;
generate_keys(config, &client_key, &server_key);
compressed_compact_public_key_new(client_key, &compressed_public_key);
set_server_key(server_key);
uint256_client_key(client_key);
uint256_encrypt_trivial(client_key);
uint256_public_key(client_key, compressed_public_key);
client_key_destroy(client_key);
compressed_compact_public_key_destroy(compressed_public_key);
server_key_destroy(server_key);
}
{
ConfigBuilder *builder;
Config *config;
config_builder_all_disabled(&builder);
config_builder_enable_custom_integers(&builder,
SHORTINT_PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
config_builder_build(builder, &config);
ClientKey *client_key = NULL;
ServerKey *server_key = NULL;
CompressedCompactPublicKey *compressed_public_key = NULL;
generate_keys(config, &client_key, &server_key);
compressed_compact_public_key_new(client_key, &compressed_public_key);
set_server_key(server_key);
uint256_client_key(client_key);
uint256_encrypt_trivial(client_key);
uint256_public_key(client_key, compressed_public_key);
client_key_destroy(client_key);
compressed_compact_public_key_destroy(compressed_public_key);
server_key_destroy(server_key);
}
return ok;
}

View File

@@ -43,6 +43,8 @@ void micro_bench_and() {
destroy_boolean_client_key(cks);
destroy_boolean_server_key(sks);
destroy_boolean_ciphertext(ct_left);
destroy_boolean_ciphertext(ct_right);
}
int main(void) {

View File

@@ -8,16 +8,13 @@
void test_predefined_keygen_w_serde(void) {
ShortintClientKey *cks = NULL;
ShortintServerKey *sks = NULL;
ShortintParameters *params = NULL;
ShortintCiphertext *ct = NULL;
Buffer ct_ser_buffer = {.pointer = NULL, .length = 0};
ShortintCiphertext *deser_ct = NULL;
ShortintCompressedCiphertext *cct = NULL;
ShortintCompressedCiphertext *deser_cct = NULL;
ShortintCiphertext *decompressed_ct = NULL;
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
assert(gen_keys_ok == 0);
@@ -67,7 +64,6 @@ void test_predefined_keygen_w_serde(void) {
destroy_shortint_client_key(cks);
destroy_shortint_server_key(sks);
destroy_shortint_parameters(params);
destroy_shortint_ciphertext(ct);
destroy_shortint_ciphertext(deser_ct);
destroy_shortint_compressed_ciphertext(cct);
@@ -79,11 +75,8 @@ void test_predefined_keygen_w_serde(void) {
void test_server_key_trivial_encrypt(void) {
ShortintClientKey *cks = NULL;
ShortintServerKey *sks = NULL;
ShortintParameters *params = NULL;
ShortintCiphertext *ct = NULL;
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
assert(gen_keys_ok == 0);
@@ -99,25 +92,32 @@ void test_server_key_trivial_encrypt(void) {
destroy_shortint_client_key(cks);
destroy_shortint_server_key(sks);
destroy_shortint_parameters(params);
destroy_shortint_ciphertext(ct);
}
void test_custom_keygen(void) {
ShortintClientKey *cks = NULL;
ShortintServerKey *sks = NULL;
ShortintParameters *params = NULL;
int params_ok =
shortint_create_parameters(10, 1, 1024, 10e-100, 10e-100, 2, 3, 2, 3, 2, 3, 10e-100, 2, 3, 2,
2, 64, ShortintEncryptionKeyChoiceBig, &params);
assert(params_ok == 0);
ShortintPBSParameters params = {
.lwe_dimension = 10,
.glwe_dimension = 1,
.polynomial_size = 1024,
.lwe_modular_std_dev = 10e-100,
.glwe_modular_std_dev = 10e-100,
.pbs_base_log = 2,
.pbs_level = 3,
.ks_base_log = 2,
.ks_level = 3,
.message_modulus = 2,
.carry_modulus = 2,
.modulus_power_of_2_exponent = 64,
.encryption_key_choice = ShortintEncryptionKeyChoiceBig,
};
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
assert(gen_keys_ok == 0);
destroy_shortint_parameters(params);
destroy_shortint_client_key(cks);
destroy_shortint_server_key(sks);
}
@@ -126,12 +126,9 @@ void test_public_keygen(ShortintPublicKeyKind pk_kind) {
ShortintClientKey *cks = NULL;
ShortintPublicKey *pks = NULL;
ShortintPublicKey *pks_deser = NULL;
ShortintParameters *params = NULL;
ShortintCiphertext *ct = NULL;
Buffer pks_ser_buff = {.pointer = NULL, .length = 0};
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
int gen_keys_ok = shortint_gen_client_key(params, &cks);
assert(gen_keys_ok == 0);
@@ -157,7 +154,6 @@ void test_public_keygen(ShortintPublicKeyKind pk_kind) {
assert(result == 2);
destroy_shortint_parameters(params);
destroy_shortint_client_key(cks);
destroy_shortint_public_key(pks);
destroy_shortint_public_key(pks_deser);
@@ -169,11 +165,8 @@ void test_compressed_public_keygen(ShortintPublicKeyKind pk_kind) {
ShortintClientKey *cks = NULL;
ShortintCompressedPublicKey *cpks = NULL;
ShortintPublicKey *pks = NULL;
ShortintParameters *params = NULL;
ShortintCiphertext *ct = NULL;
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
int gen_keys_ok = shortint_gen_client_key(params, &cks);
assert(gen_keys_ok == 0);
@@ -204,7 +197,6 @@ void test_compressed_public_keygen(ShortintPublicKeyKind pk_kind) {
assert(result == 2);
destroy_shortint_parameters(params);
destroy_shortint_client_key(cks);
destroy_shortint_compressed_public_key(cpks);
destroy_shortint_public_key(pks);

View File

@@ -41,10 +41,7 @@ void test_shortint_pbs_2_bits_message(void) {
ShortintPBSLookupTable *accumulator = NULL;
ShortintClientKey *cks = NULL;
ShortintServerKey *sks = NULL;
ShortintParameters *params = NULL;
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
assert(gen_keys_ok == 0);
@@ -111,17 +108,13 @@ void test_shortint_pbs_2_bits_message(void) {
destroy_shortint_pbs_accumulator(accumulator);
destroy_shortint_client_key(cks);
destroy_shortint_server_key(sks);
destroy_shortint_parameters(params);
}
void test_shortint_bivariate_pbs_2_bits_message(void) {
ShortintBivariatePBSLookupTable *accumulator = NULL;
ShortintClientKey *cks = NULL;
ShortintServerKey *sks = NULL;
ShortintParameters *params = NULL;
int get_params_ok = shortint_get_parameters(2, 2, &params);
assert(get_params_ok == 0);
ShortintPBSParameters params = SHORTINT_PARAM_MESSAGE_2_CARRY_2;
int gen_keys_ok = shortint_gen_keys_with_parameters(params, &cks, &sks);
assert(gen_keys_ok == 0);
@@ -187,7 +180,6 @@ void test_shortint_bivariate_pbs_2_bits_message(void) {
destroy_shortint_bivariate_pbs_accumulator(accumulator);
destroy_shortint_client_key(cks);
destroy_shortint_server_key(sks);
destroy_shortint_parameters(params);
}
int main(void) {

View File

@@ -419,10 +419,10 @@ void test_server_key(void) {
ShortintCompressedServerKey *deser_csks = NULL;
Buffer sks_ser_buffer = {.pointer = NULL, .length = 0};
ShortintServerKey *deser_sks = NULL;
ShortintParameters *params = NULL;
ShortintClientKey *cks_small = NULL;
ShortintServerKey *sks_small = NULL;
ShortintParameters *params_small = NULL;
ShortintPBSParameters params = { 0 };
ShortintPBSParameters params_small = { 0 };
const uint32_t message_bits = 2;
const uint32_t carry_bits = 2;
@@ -736,8 +736,6 @@ void test_server_key(void) {
destroy_shortint_client_key(deser_cks);
destroy_shortint_compressed_server_key(deser_csks);
destroy_shortint_server_key(deser_sks);
destroy_shortint_parameters(params);
destroy_shortint_parameters(params_small);
destroy_buffer(&cks_ser_buffer);
destroy_buffer(&csks_ser_buffer);
destroy_buffer(&sks_ser_buffer);

View File

@@ -3,7 +3,6 @@
* [What is TFHE-rs?](README.md)
## Getting Started
* [Installation](getting\_started/installation.md)
* [Quick Start](getting\_started/quick\_start.md)
* [Supported Operations](getting\_started/operations.md)
@@ -33,6 +32,12 @@
* [Cryptographic Parameters](integer/parameters.md)
* [Serialization/Deserialization](integer/serialization.md)
## Tutorials for real-life applications
* [Dark Market](tutorial/dark_market.md)
* [SHA256](tutorial/sha256_bool.md)
* [Homomorphic Regular Expressions](tutorial/regex/tutorial.md)
## C API
* [High-Level API](c_api/high-level-api.md)
* [Shortint API](c_api/shortint-api.md)

View File

@@ -58,11 +58,11 @@ All timings are related to parallelized Radix-based integer operations, where ea
To ensure predictable timings, the operation flavor is the `default` one: a carry propagation is computed after each operation. Operation cost could be reduced by using `unchecked`, `checked`, or `smart`.
| Plaintext size | add | mul | greater\_than (gt) | min |
| -------------------| -------------- | ------------------- | --------- | ------- |
| 8 bits | 129.0 ms | 227.2 ms | 111.9 ms | 186.8 ms |
| 16 bits | 256.3 ms | 756.0 ms | 145.3 ms | 233.1 ms |
| 32 bits | 469.4 ms | 2.10 s | 192.0 ms | 282.9 ms |
| 40 bits | 608.0 ms | 3.37 s | 228.4 ms | 318.6 ms |
| 64 bits | 959.9 ms | 5.53 s | 249.0 ms | 336.5 ms |
| 128 bits | 1.88 s | 14.1 s | 294.7 ms | 398.6 ms |
| 256 bits | 3.66 s | 29.2 s | 361.8 ms | 509.1 ms |
| -------------------| ---------------| --------------------| ---------------------| -------------|
| 8 bits | 129.0 ms | 227 ms | 111.9 ms | 186.8 ms |
| 16 bits | 195 ms | 369 ms | 145.3 ms | 233.1 ms |
| 32 bits | 238 ms | 519 ms | 192.0 ms | 282.9 ms |
| 40 bits | 283 ms | 754 ms | 228.4 ms | 318.6 ms |
| 64 bits | 297 ms | 1.18 s | 249.0 ms | 336.5 ms |
| 128 bits | 424 ms | 3.13 s | 294.7 ms | 398.6 ms |
| 256 bits | 500 ms | 11 s | 361.8 ms | 509.1 ms |

View File

@@ -115,7 +115,7 @@ fn main() {
let msg1 = 1;
let msg2 = 0;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -143,7 +143,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We create keys for radix represention to create 16 bits integers
// using 8 blocks of 2 bits
let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, 8);
let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, 8);
let clear_a = 2382u16;
let clear_b = 29374u16;

View File

@@ -26,7 +26,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We generate a set of client/server keys, using the default parameters:
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
}
```
@@ -105,7 +105,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 12u64;
let msg2 = 11u64;
@@ -113,7 +113,7 @@ fn main() {
let scalar = 3u64;
// message_modulus^vec_length
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -143,7 +143,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
let num_block = 2;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 12u64;
let msg2 = 11u64;
@@ -151,7 +151,7 @@ fn main() {
let scalar = 3u64;
// message_modulus^vec_length
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -181,7 +181,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 12;
let msg2 = 11;
@@ -189,7 +189,7 @@ fn main() {
let scalar = 3;
// message_modulus^vec_length
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -220,7 +220,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 12;
let msg2 = 11;
@@ -228,7 +228,7 @@ fn main() {
let scalar = 3;
// message_modulus^vec_length
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);

View File

@@ -27,13 +27,13 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() -> Result<(), Box<dyn std::error::Error>> {
// We generate a set of client/server keys, using the default parameters:
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 201;
let msg2 = 12;
// message_modulus^vec_length
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
let ct_1 = client_key.encrypt(msg1);
let ct_2 = client_key.encrypt(msg2);

View File

@@ -34,7 +34,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We generate a set of client/server keys, using the default parameters:
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
}
```
@@ -49,7 +49,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We generate a set of client/server keys, using the default parameters:
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 128u64;
let msg2 = 13u64;
@@ -72,7 +72,7 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We generate a set of client/server keys, using the default parameters:
let num_block = 4;
let (client_key, _) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, _) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
//We generate the public key from the secret client key:
let public_key = PublicKeyBig::new(&client_key);
@@ -98,13 +98,13 @@ use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
fn main() {
// We generate a set of client/server keys, using the default parameters:
let num_block = 4;
let (client_key, server_key) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, num_block);
let (client_key, server_key) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block);
let msg1 = 128;
let msg2 = 13;
// message_modulus^vec_length
let modulus = client_key.parameters().message_modulus.0.pow(num_block as u32) as u64;
let modulus = client_key.parameters().message_modulus().0.pow(num_block as u32) as u64;
// We use the client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);

View File

@@ -55,7 +55,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -88,7 +88,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -131,7 +131,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -165,7 +165,7 @@ fn main() {
let msg2 = 3;
let scalar = 4;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the client key to encrypt two messages:
let mut ct_1 = client_key.encrypt(msg1);
@@ -241,7 +241,7 @@ fn main() {
let msg1 = 2;
let msg2 = 1;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -272,7 +272,7 @@ fn main() {
let msg1 = 2;
let msg2 = 1;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -303,7 +303,7 @@ fn main() {
let msg1 = 2;
let msg2 = 1;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -331,7 +331,7 @@ fn main() {
let msg1 = 3;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);
@@ -365,7 +365,7 @@ fn main() {
let msg1 = 3;
let msg2 = 2;
let modulus = client_key.parameters.message_modulus.0 as u64;
let modulus = client_key.parameters.message_modulus().0 as u64;
// We use the private client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);

View File

@@ -44,7 +44,7 @@ In the case of multiplication, two algorithms are implemented: the first one rel
## User-defined parameter sets
It is possible to define new parameter sets. To do so, it is sufficient to use the function `unsecure_parameters()` or to manually fill the `Parameter` structure fields.
It is possible to define new parameter sets. To do so, it is sufficient to use the function `unsecure_parameters()` or to manually fill the `ClassicPBSParameters` structure fields.
For instance:
@@ -53,7 +53,7 @@ use tfhe::shortint::prelude::*;
fn main() {
let param = unsafe {
Parameters::new(
ClassicPBSParameters::new(
LweDimension(656),
GlweDimension(2),
PolynomialSize(512),
@@ -63,11 +63,6 @@ fn main() {
DecompositionLevelCount(2),
DecompositionBaseLog(3),
DecompositionLevelCount(4),
StandardDev(0.00000000037411618952047216),
DecompositionBaseLog(15),
DecompositionLevelCount(1),
DecompositionLevelCount(0),
DecompositionBaseLog(0),
MessageModulus(4),
CarryModulus(1),
CiphertextModulus::new_native(),

View File

@@ -82,7 +82,7 @@ fn main() {
let msg1 = 1;
let msg2 = 0;
let modulus = client_key.parameters.message_modulus.0;
let modulus = client_key.parameters.message_modulus().0;
// We use the client key to encrypt two messages:
let ct_1 = client_key.encrypt(msg1);

View File

@@ -0,0 +1,479 @@
# Dark Market Tutorial
In this tutorial, we are going to build a dark market application using TFHE-rs. A dark market is a marketplace where
buy and sell orders are not visible to the public before they are filled. Different algorithms aim to
solve this problem, we are going to implement the algorithm defined [in this paper](https://eprint.iacr.org/2022/923.pdf) with TFHE-rs.
We will first implement the algorithm in plain Rust and then we will see how to use TFHE-rs to
implement the same algorithm with FHE.
In addition, we will also implement a modified version of the algorithm that allows for more concurrent operations which
improves the performance in hardware where there are multiple cores.
## Specifications
#### Inputs:
* A list of sell orders where each sell order is only defined in volume terms, it is assumed that the price is fetched
from a different source.
* A list of buy orders where each buy order is only defined in volume terms, it is assumed that the price is fetched
from a different source.
#### Input constraints:
* The sell and buy orders are within the range [1,100].
* The maximum number of sell and buy orders is 500, respectively.
#### Outputs:
There is no output returned at the end of the algorithm. Instead, the algorithm makes changes on the given input lists.
The number of filled orders is written over the original order count in the respective lists. If it is not possible to
fill the orders, the order count is set to zero.
#### Example input and output:
##### Example 1:
| | Sell | Buy |
|--------|--------------------|-----------|
| Input | [ 5, 12, 7, 4, 3 ] | [ 19, 2 ] |
| Output | [ 5, 12, 4, 0, 0 ] | [ 19, 2 ] |
Last three indices of the filled sell orders are zero because there is no buy orders to match them.
##### Example 2:
| | Sell | Buy |
|--------|-------------------|----------------------|
| Input | [ 3, 1, 1, 4, 2 ] | [ 5, 3, 3, 2, 4, 1 ] |
| Output | [ 3, 1, 1, 4, 2 ] | [ 5, 3, 3, 0, 0, 0 ] |
Last three indices of the filled buy orders are zero because there is no sell orders to match them.
## Plain Implementation
1. Calculate the total sell volume and the total buy volume.
```rust
let total_sell_volume: u16 = sell_orders.iter().sum();
let total_buy_volume: u16 = buy_orders.iter().sum();
```
2. Find the total volume that will be transacted. In the paper, this amount is calculated with the formula:
```
(total_sell_volume > total_buy_volume) * (total_buy_volume total_sell_volume) + total_sell_volume
```
When closely observed, we can see that this formula can be replaced with the `min` function. Therefore, we calculate this
value by taking the minimum of the total sell volume and the total buy volume.
```rust
let total_volume = std::cmp::min(total_buy_volume, total_sell_volume);
```
3. Beginning with the first item, start filling the sell orders one by one. We apply the `min` function replacement also
here.
```rust
let mut volume_left_to_transact = total_volume;
for sell_order in sell_orders.iter_mut() {
let filled_amount = std::cmp::min(volume_left_to_transact, *sell_order);
*sell_order = filled_amount;
volume_left_to_transact -= filled_amount;
}
```
The number of orders that are filled is indicated by modifying the input list. For example, if the first sell order is
1000 and the total volume is 500, then the first sell order will be modified to 500 and the second sell order will be
modified to 0.
4. Do the fill operation also for the buy orders.
```rust
let mut volume_left_to_transact = total_volume;
for buy_order in buy_orders.iter_mut() {
let filled_amount = std::cmp::min(volume_left_to_transact, *buy_order);
*buy_order = filled_amount;
volume_left_to_transact -= filled_amount;
}
```
#### The complete algorithm in plain Rust:
```rust
fn volume_match_plain(sell_orders: &mut Vec<u16>, buy_orders: &mut Vec<u16>) {
let total_sell_volume: u16 = sell_orders.iter().sum();
let total_buy_volume: u16 = buy_orders.iter().sum();
let total_volume = std::cmp::min(total_buy_volume, total_sell_volume);
let mut volume_left_to_transact = total_volume;
for sell_order in sell_orders.iter_mut() {
let filled_amount = std::cmp::min(volume_left_to_transact, *sell_order);
*sell_order = filled_amount;
volume_left_to_transact -= filled_amount;
}
let mut volume_left_to_transact = total_volume;
for buy_order in buy_orders.iter_mut() {
let filled_amount = std::cmp::min(volume_left_to_transact, *buy_order);
*buy_order = filled_amount;
volume_left_to_transact -= filled_amount;
}
}
```
## FHE Implementation
For the FHE implementation, we first start with finding the right bit size for our algorithm to work without
overflows.
The variables that are declared in the algorithm and their maximum values are described in the table below:
| Variable | Maximum Value | Bit Size |
|-------------------------|---------------|----------|
| total_sell_volume | 50000 | 16 |
| total_buy_volume | 50000 | 16 |
| total_volume | 50000 | 16 |
| volume_left_to_transact | 50000 | 16 |
| sell_order | 100 | 7 |
| buy_order | 100 | 7 |
As we can observe from the table, we need **16 bits of message space** to be able to run the algorithm without
overflows. TFHE-rs provides different presets for the different bit sizes. Since we need 16 bits of message, we are
going to use the `integer` module to implement the algorithm.
Here are the input types of our algorithm:
* `sell_orders` is of type `Vec<tfhe::integer::RadixCipherText>`
* `buy_orders` is of type `Vec<tfhe::integer::RadixCipherText>`
* `server_key` is of type `tfhe::integer::ServerKey`
Now, we can start implementing the algorithm with FHE:
1. Calculate the total sell volume and the total buy volume.
```rust
let mut total_sell_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for sell_order in sell_orders.iter_mut() {
server_key.smart_add_assign(&mut total_sell_volume, sell_order);
}
let mut total_buy_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for buy_order in buy_orders.iter_mut() {
server_key.smart_add_assign(&mut total_buy_volume, buy_order);
}
```
2. Find the total volume that will be transacted by taking the minimum of the total sell volume and the total buy
volume.
```rust
let total_volume = server_key.smart_min(&mut total_sell_volume, &mut total_buy_volume);
```
3. Beginning with the first item, start filling the sell and buy orders one by one. We can create `fill_orders` closure to
reduce code duplication since the code for filling buy orders and sell orders are the same.
```rust
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
let mut volume_left_to_transact = total_volume.clone();
for mut order in orders.iter_mut() {
let mut filled_amount = server_key.smart_min(&mut volume_left_to_transact, &mut order);
server_key.smart_sub_assign(&mut volume_left_to_transact, &mut filled_amount);
*order = filled_amount;
}
};
fill_orders(sell_orders);
fill_orders(buy_orders);
```
#### The complete algorithm in TFHE-rs:
```rust
const NUMBER_OF_BLOCKS: usize = 8;
fn volume_match_fhe(
sell_orders: &mut [RadixCiphertextBig],
buy_orders: &mut [RadixCiphertextBig],
server_key: &ServerKey,
) {
let mut total_sell_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for sell_order in sell_orders.iter_mut() {
server_key.smart_add_assign(&mut total_sell_volume, sell_order);
}
let mut total_buy_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for buy_order in buy_orders.iter_mut() {
server_key.smart_add_assign(&mut total_buy_volume, buy_order);
}
let total_volume = server_key.smart_min(&mut total_sell_volume, &mut total_buy_volume);
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
let mut volume_left_to_transact = total_volume.clone();
for mut order in orders.iter_mut() {
let mut filled_amount = server_key.smart_min(&mut volume_left_to_transact, &mut order);
server_key.smart_sub_assign(&mut volume_left_to_transact, &mut filled_amount);
*order = filled_amount;
}
};
fill_orders(sell_orders);
fill_orders(buy_orders);
}
```
### Optimizing the implementation
* TFHE-rs provides parallelized implementations of the operations. We can use these parallelized
implementations to speed up the algorithm. For example, we can use `smart_add_assign_parallelized` instead of
`smart_add_assign`.
* We can parallelize vector sum with Rayon and `reduce` operation.
```rust
let parallel_vector_sum = |vec: &mut [RadixCiphertextBig]| {
vec.to_vec().into_par_iter().reduce(
|| server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|mut acc: RadixCiphertextBig, mut ele: RadixCiphertextBig| {
server_key.smart_add_parallelized(&mut acc, &mut ele)
},
)
};
```
* We can run vector summation on `buy_orders` and `sell_orders` in parallel since these operations do not depend on each other.
```rust
let (mut total_sell_volume, mut total_buy_volume) =
rayon::join(|| vector_sum(sell_orders), || vector_sum(buy_orders));
```
* We can match sell and buy orders in parallel since the matching does not depend on each other.
```rust
rayon::join(|| fill_orders(sell_orders), || fill_orders(buy_orders));
```
#### Optimized algorithm
```rust
fn volume_match_fhe_parallelized(
sell_orders: &mut [RadixCiphertextBig],
buy_orders: &mut [RadixCiphertextBig],
server_key: &ServerKey,
) {
let parallel_vector_sum = |vec: &mut [RadixCiphertextBig]| {
vec.to_vec().into_par_iter().reduce(
|| server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|mut acc: RadixCiphertextBig, mut ele: RadixCiphertextBig| {
server_key.smart_add_parallelized(&mut acc, &mut ele)
},
)
};
let (mut total_sell_volume, mut total_buy_volume) = rayon::join(
|| parallel_vector_sum(sell_orders),
|| parallel_vector_sum(buy_orders),
);
let total_volume =
server_key.smart_min_parallelized(&mut total_sell_volume, &mut total_buy_volume);
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
let mut volume_left_to_transact = total_volume.clone();
for mut order in orders.iter_mut() {
let mut filled_amount =
server_key.smart_min_parallelized(&mut volume_left_to_transact, &mut order);
server_key
.smart_sub_assign_parallelized(&mut volume_left_to_transact, &mut filled_amount);
*order = filled_amount;
}
};
rayon::join(|| fill_orders(sell_orders), || fill_orders(buy_orders));
}
```
## Modified Algorithm
When observed closely, there is only a small amount of concurrency introduced in the `fill_orders` part of the algorithm.
The reason is that the `volume_left_to_transact` is shared between all the orders and should be modified sequentially.
This means that the orders cannot be filled in parallel. If we can somehow remove this dependency, we can fill the orders in parallel.
In order to do so, we closely observe the function of `volume_left_to_transact` variable in the algorithm. We can see that it is being used to check whether we can fill the current order or not.
Instead of subtracting the current order value from `volume_left_to_transact` in each loop, we can add this value to the next order
index and check the availability by comparing the current order value with the total volume. If the current order value
(now representing the sum of values before this order plus this order) is smaller than the total number of matching orders,
we can safely fill all the orders and continue the loop. If not, we should partially fill the orders with what is left from
matching orders.
We will call the new list the "prefix sum" of the array.
The new version for the plain `fill_orders` is as follows:
```rust
let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64|{
orders.iter().for_each(|order : &mut u64| {
if (total_orders >= prefix_sum[i]) {
continue;
} else if total_orders >= prefix_sum.get(i-1).unwrap_or(0) {
*order = total_orders - prefix_sum.get(i-1).unwrap_or(0);
} else {
*order = 0;
}
});
};
```
To write this new function we need transform the conditional code into a mathematical expression since FHE does not support conditional operations.
```rust
let fill_orders = |orders: &mut [u64], prefix_sum: &[u64], total_orders: u64| {
orders.iter().for_each(|order| : &mut){
*order = *order + ((total_orders >= prefix_sum - std::cmp::min(total_orders, prefix_sum.get(i - 1).unwrap_or(&0).clone()) - *order);
}
};
```
New `fill_order` function requires a prefix sum array. We are going to calculate this prefix sum array in parallel
with the algorithm described [here](https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda).
The sample code in the paper is written in CUDA. When we try to implement the algorithm in Rust we see that the compiler does not allow us to do so.
The reason for that is while the algorithm does not access the same array element in any of the threads(the index calculations using `d` and `k` values never overlap),
Rust compiler cannot understand this and does not let us share the same array between threads.
So we modify how the algorithm is implemented, but we don't change the algorithm itself.
Here is the modified version of the algorithm in TFHE-rs:
```rust
fn volume_match_fhe_modified(
sell_orders: &mut [RadixCiphertextBig],
buy_orders: &mut [RadixCiphertextBig],
server_key: &ServerKey,
) {
let compute_prefix_sum = |arr: &[RadixCiphertextBig]| {
if arr.is_empty() {
return arr.to_vec();
}
let mut prefix_sum: Vec<RadixCiphertextBig> = (0..arr.len().next_power_of_two())
.into_par_iter()
.map(|i| {
if i < arr.len() {
arr[i].clone()
} else {
server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS)
}
})
.collect();
// Up sweep
for d in 0..(prefix_sum.len().ilog2() as u32) {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left)
});
}
// Down sweep
let last = prefix_sum.last().unwrap().clone();
*prefix_sum.last_mut().unwrap() = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for d in (0..(prefix_sum.len().ilog2() as u32)).rev() {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let t = chunk.last().unwrap().clone();
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left);
chunk[(length - 1) / 2] = t;
});
}
prefix_sum.push(last);
prefix_sum[1..=arr.len()].to_vec()
};
println!("Creating prefix sum arrays...");
let time = Instant::now();
let (prefix_sum_sell_orders, prefix_sum_buy_orders) = rayon::join(
|| compute_prefix_sum(sell_orders),
|| compute_prefix_sum(buy_orders),
);
println!("Created prefix sum arrays in {:?}", time.elapsed());
let fill_orders = |total_orders: &RadixCiphertextBig,
orders: &mut [RadixCiphertextBig],
prefix_sum_arr: &[RadixCiphertextBig]| {
orders
.into_par_iter()
.enumerate()
.for_each(move |(i, order)| {
server_key.smart_add_assign_parallelized(
order,
&mut server_key.smart_mul_parallelized(
&mut server_key
.smart_ge_parallelized(&mut order.clone(), &mut total_orders.clone()),
&mut server_key.smart_sub_parallelized(
&mut server_key.smart_sub_parallelized(
&mut total_orders.clone(),
&mut server_key.smart_min_parallelized(
&mut total_orders.clone(),
&mut prefix_sum_arr
.get(i - 1)
.unwrap_or(
&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
)
.clone(),
),
),
&mut order.clone(),
),
),
);
});
};
let total_buy_orders = &mut prefix_sum_buy_orders
.last()
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
.clone();
let total_sell_orders = &mut prefix_sum_sell_orders
.last()
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
.clone();
println!("Matching orders...");
let time = Instant::now();
rayon::join(
|| fill_orders(total_sell_orders, buy_orders, &prefix_sum_buy_orders),
|| fill_orders(total_buy_orders, sell_orders, &prefix_sum_sell_orders),
);
println!("Matched orders in {:?}", time.elapsed());
}
```
## Running the tutorial
The plain, FHE and parallel FHE implementations can be run by providing respective arguments as described below.
```bash
# Runs FHE implementation
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- fhe
# Runs parallelized FHE implementation
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- fhe-parallel
# Runs modified FHE implementation
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- fhe-modified
# Runs plain implementation
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- plain
# Multiple implementations can be run within same instance
cargo run --release --package tfhe --example dark_market --features="integer internal-keycache" -- plain fhe-parallel
```
## Conclusion
In this tutorial, we've learned how to implement the volume matching algorithm described [in this paper](https://eprint.iacr.org/2022/923.pdf) in plain Rust and in TFHE-rs.
We've identified the right bit size for our problem at hand, used operations defined in `TFHE-rs`, and introduced concurrency to the algorithm to increase its performance.

View File

@@ -0,0 +1,512 @@
# FHE Regex Pattern Matching Tutorial
This tutorial explains how to build a regex Pattern Matching Engine (PME) where ciphertext is the
content that is evaluated.
A regex PME is an essential tool for programmers. It allows you to perform complex searches on content.
A less powerful simple search on string can only find matches of the exact given sequence of
characters (e.g., your browser's default search function). Regex PMEs
are more powerful, allowing searches on certain structures of text, where a
structure may take any form in multiple possible sequences of characters. The
structure to be searched is defined with the regex, a very concise
language.
Here are some example regexes to give you an idea of what is possible:
Regex | Semantics
--- | ---
/abc/ | Searches for the sequence `abc` (equivalent to a simple text search)
/^abc/ | Searches for the sequence `abc` at the beginning of the content
/a?bc/ | Searches for sequences `abc`, `bc`
/ab\|c+d/ | Searches for sequences of `ab`, `c` repeated 1 or more times, followed by `d`
Regexes are powerful enough to be able to express structures like email address
formats. This capability is what makes regexes useful for many programming
solutions.
There are two main components identifiable in a PME:
1. The pattern that is to be matched has to be parsed, translated from a
textual representation into a recursively structured object (an Abstract
Syntax Tree, or AST).
2. This AST must then be applied to the text that it is to be matched against,
resulting in a 'yes' or 'no' to whether the pattern has matched (in the case of
our FHE implementation, this result is an encrypted 'yes' or an encrypted 'no').
Parsing is a well understood problem. There are a couple of different
approaches possible here. Regardless of the approach chosen, it starts with
figuring out what language we want to support. That is, what are
the kinds of sentences we want our regex language to include? A few
example sentences we definitely want to support are, for example: `/a/`,
`/a?bc/`, `/^ab$/`, `/ab|cd/`, however example sentences don't suffice as
a specification because they can never be exhaustive (they're endless). We need
something to specify _exactly_ the full set of sentences our language supports.
There exists a language that can help us describe our own language's structure exactly:
Grammar.
## The Grammar and datastructure
It is useful to start with defining the Grammar before starting to write
code for the parser because the code structure follows directly from the
Grammar. A Grammar consists of a generally small set of rules. For example,
a very basic Grammar could look like this:
```
Start := 'a'
```
This describes a language that only contains the sentence "a". Not a very interesting language.
We can make it more interesting though by introducing choice into the Grammar
with \| (called a 'pipe') operators. If we want the above Grammar to accept
either "a" or "b":
```
Start := 'a' | 'b'
```
So far, only Grammars with a single rule have been shown. However, a Grammar can
consist of multiple rules. Most languages require it. So let's consider a more meaningful language,
one that accepts sentences consisting of one or more digits. We could describe such a language
with the following Grammar:
```
Start := Digit+
Digit := '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9'
```
The `+` after `Digit` is another Grammar operator. With it, we specify that
Digit must be matched one or more times. Here are all the Grammar operators that
are relevant for this tutorial:
Operator | Example | Semantics
--- | --- | ---
`\|` | a \| b | we first try matching on 'a' - if no match, we try to match on 'b'
`+` | a+ | match 'a' one or more times
`*` | a* | match 'a' any amount of times (including zero times)
`?` | a? | optionally match 'a' (match zero or one time)
`.` | . | match any character
`..` | a .. b | match on a range of alphabetically ordered characters from 'a', up to and including 'b'
` ` | a b | sequencing; match on 'a' and then on 'b'
In the case of the example PME, the Grammar is as follows (notice the unquoted ? and quoted ?, etc. The unquoted characters are Grammar operators, and the quoted are characters we are matching in the parsing).
```
Start := '/' '^'? Regex '$'? '/' Modifier?
Regex := Term '|' Term
| Term
Term := Factor*
Factor := Atom '?'
| Repeated
| Atom
Repeated := Atom '*'
| Atom '+'
| Atom '{' Digit* ','? '}'
| Atom '{' Digit+ ',' Digit* '}'
Atom := '.'
| '\' .
| Character
| '[' Range ']'
| '(' Regex ')'
Range := '^' Range
| AlphaNum '-' AlphaNum
| AlphaNum+
Digit := '0' .. '9'
Character := AlphaNum
| '&' | ';' | ':' | ',' | '`' | '~' | '-' | '_' | '!' | '@' | '#' | '%' | '\'' | '\"'
AlphaNum := 'a' .. 'z'
| 'A' .. 'Z'
| '0' .. '9'
Modifier := 'i'
```
We will refer occasionally to specific parts in the Grammar listed above by \<rule name\>.\<variant index\> (where the first rule variant has index 1).
With the Grammar defined, we can start defining a type to parse into. In Rust, we
have the `enum` kind of type that is perfect for this, as it allows you to define
multiple variants that may recurse. I prefer to start by defining variants that
do not recurse (i.e., that don't contain nested regex expressions):
```rust
enum RegExpr {
Char { c: char }, // matching against a single character (Atom.2 and Atom.3)
AnyChar, // matching _any_ character (Atom.1)
SOF, // matching only at the beginning of the content ('^' in Start.1)
EOF, // matching only at the end of the content (the '$' in Start.1)
Range { cs: Vec<char> }, // matching on a list of characters (Range.3, eg '[acd]')
Between { from: char, to: char }, // matching between 2 characters based on ascii ordering (Range.2, eg '[a-g]')
}
```
With this, we can translate the following basic regexes:
Pattern | RegExpr value
--- | ---
`/a/` | `RegExpr::Char { c: 'a' }`
`/\\^/` | `RegExpr::Char { c: '^' }`
`/./` | `RegExpr::AnyChar`
`/^/` | `RegExpr::SOF`
`/$/` | `RegExpr::EOF`
`/[acd]/` | `RegExpr::Range { vec!['a', 'c', 'd'] }`
`/[a-g]/` | `RegExpr::Between { from: 'a', to: 'g' }`
Notice we're not yet able to sequence multiple components together. Let's define
the first variant that captures recursive RegExpr for this:
```rust
enum RegExpr {
...
Seq { re_xs: Vec<RegExpr> }, // matching sequences of RegExpr components (Term.1)
}
```
With this Seq (short for sequence) variant, we allow translating patterns that
contain multiple components:
Pattern | RegExpr value
--- | ---
`/ab/` | `RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Char { c: 'b' }] }`
`/^a.$/` | `RegExpr::Seq { re_xs: vec![RegExpr::SOF, RexExpr::Char { 'a' }, RegExpr::AnyChar, RegExpr::EOF] }`
`/a[f-l]/` | `RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Between { from: 'f', to: 'l' }] }`
Let's finish the RegExpr datastructure by adding variants for 'Optional' matching,
'Not' logic in a range, and 'Either' left or right matching:
```rust
enum RegExpr {
...
Optional { opt_re: Box<RegExpr> }, // matching optionally (Factor.1)
Not { not_re: Box<RegExpr> }, // matching inversely on a range (Range.1)
Either { l_re: Box<RegExpr>, r_re: Box<RegExpr> }, // matching the left or right regex (Regex.1)
}
```
Some features may make the most sense being implemented during post-processing of
the parsed datastructure. For example, the case insensitivity feature (the `i`
Modifier) is implemented in the example implementation by taking the parsed
RegExpr and mutating every character mentioned inside to cover both the lower
case as well as the upper case variant (see function `case_insensitive` in
`parser.rs` for the example implementation).
The modifier `i` in our Grammar (for enabling case insensitivity) was easiest
to implement by applying a post-processing step to the parser.
We are now able to translate any complex regex into a RegExpr value. For example:
Pattern | RegExpr value
--- | ---
`/a?/` | `RegExpr::Optional { opt_re: Box::new(RegExpr::Char { c: 'a' }) }`
`/[a-d]?/` | `RegExpr::Optional { opt_re: Box::new(RegExpr::Between { from: 'a', to: 'd' }) }`
`/[^ab]/` | `RegExpr::Not { not_re: Box::new(RegExpr::Range { cs: vec!['a', 'b'] }) }`
`/av\|d?/` | `RegExpr::Either { l_re: Box::new(RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Char { c: 'v' }] }), r_re: Box::new(RegExpr::Optional { opt_re: Box::new(RegExpr::Char { c: 'd' }) }) }`
`/(av\|d)?/` | `RegExpr::Optional { opt_re: Box::new(RegExpr::Either { l_re: Box::new(RegExpr::Seq { re_xs: vec![RegExpr::Char { c: 'a' }, RegExpr::Char { c: 'v' }] }), r_re: Box::new(RegExpr::Char { c: 'd' }) }) }`
With both the Grammar and the datastructure to parse into defined, we can now
start implementing the actual parsing logic. There are multiple ways this can
be done. For example, there exist tools that can automatically generate parser
code by giving it the Grammar definition (these are called parser generators).
However, you might prefer to write parsers with a parser combinator library.
This may be the better option for you because the behavior in runtime is easier to understand
for parsers constructed with a parser combinator library than of parsers that were
generated with a parser generator tool.
Rust offers a number of popular parser combinator libraries. This tutorial used
`combine`, but any other library would work just as well. Choose whichever appeals
the most to you (including any parser generator tool). The implementation of
our regex parser will differ significantly depending on the approach you choose,
so we will not cover this in detail here. You may look at the parser code in the example
implementation to get an idea of how this could be done. In general though, the Grammar and the
datastructure are the important components, while the parser code follows directly from these.
## Matching the RegExpr to encrypted content
The next challenge is to build the execution engine, where we take a RegExpr
value and recurse into it to apply the necessary actions on the encrypted
content. We first have to define how we actually encode our content into an
encrypted state. Once that is defined, we can start working on how we will
execute our RegExpr onto the encrypted content.
### Encoding and encrypting the content.
It is not possible to encrypt the entire content into a single encrypted value.
We can only encrypt numbers and preform operations on those encrypted numbers with
FHE. Therefore, we have to find a scheme where we encode the content into a
sequence of numbers that are then encrypted individually to form a sequence of
encrypted numbers.
We recommend the following two strategies:
1. to map each character of the content into the u8 ascii value, and then encrypt
each bit of these u8 values individually.
2. to, instead of encrypting each bit individually, encrypt each u8 ascii value in
its entirety.
Strategy 1 requires more high-level TFHE-rs operations to check for
a simple character match (we have to check each bit individually for
equality as opposed to checking the entire byte in one, high-level TFHE-rs
operation), though some experimentation did show that both options performed
equally well on a regex like `/a/`. This is likely because bitwise FHE
operations are relatively cheap compared to u8 FHE operations. However,
option 1 falls apart as soon as you introduce '[a-z]' regex logic.
With option 2, it is possible to complete this match with just three TFHE-rs
operations: `ge`, `le`, and `bitand`.
```rust
// note: this is pseudocode
c = <the encrypted character under inspection>;
sk = <the server key, aka the public key>
ge_from = sk.ge(c, 'a');
le_to = sk.le(c, 'z');
result = sk.bitand(ge_from, le_to);
```
If, on the other hand, we had encrypted the content with the first strategy,
there would be no way to test for `greater/equal than from` and `less/equal
than to`. We'd have to check for the potential equality of each character between
`from` and `to`, and then join the results together with a sequence of
`sk.bitor`; that would require far more cryptographic operations than in strategy 2.
Because FHE operations are computationally expensive, and strategy 1 requires
significantly more FHE operations for matching on `[a-z]` regex logic, we
should opt for strategy 2.
### Matching with the AST versus matching with a derived DFA.
There are a lot of regex PMEs. It's been built many times and it's been
researched thoroughly. There are different strategies possible here.
A straight forward strategy is to directly recurse into our RegExpr
value and apply the necessary matching operations onto the content. In a way,
this is nice because it allows us to link the RegExpr structure directly to
the matching semantics, resulting in code that is easier to
understand, maintain, etc.
Alternatively, there exists an algorithm that transforms the AST (i.e., the
RegExpr, in our case) into a Deterministic Finite Automata (DFA). Normally, this
is a favorable approach in terms of efficiency because the derived DFA can be
walked over without needing to backtrack (whereas the former strategy cannot
prevent backtracking). This means that the content can be walked over from
character to character, and depending on what the character is at this
cursor, the DFA is conjunctively traveled in a definite direction which
ultimately leads us to the `yes, there is a match` or the `no, there is no
match`. There is a small upfront cost of having to translate the AST into the
DFA, but the lack of backtracking during matching generally makes up for
this, especially if the content that it is matched against is significantly big.
In our case though, we are matching on encrypted content. We have no way to know
what the character at our cursor is, and therefore no way to find this definite
direction to go forward in the DFA. Therefore, translating the AST into the DFA does
not help us as it does in normal regex PMEs. For this reason, consider opting for the
former strategy because it allows for matching logic that is easier to understand.
### Matching.
In the previous section, we decided we'll match by traversing into the RegExpr
value. This section will explain exactly how to do that. Similarly to defining
the Grammar, it is often best to start with working out the non-recursive
RegExpr variants.
We'll start by defining the function that will recursively traverse into the RegExpr value:
```rust
type StringCiphertext = Vec<RadixCiphertextBig>;
type ResultCiphertext = RadixCiphertextBig;
fn match(
sk: &ServerKey,
content: &StringCipherText,
re: &RegExpr,
content_pos: usize,
) -> Vec<(ResultCiphertext, usize)> {
let content_char = &content[c_pos];
match re {
...
}
}
```
`sk` is the server key (aka, public key),`content` is what we'll be matching
against, `re` is the RegExpr value we built when parsing the regex, and `c_pos`
is the cursor position (the index in content we are currently matching
against).
The result is a vector of tuples, with the first value of the tuple being the computed
ciphertext result, and the second value being the content position after the
regex components were applied. It's a vector because certain RegExpr variants
require the consideration of a list of possible execution paths. For example,
RegExpr::Optional might succeed by applying _or_ and *not* applying the optional
regex (notice that in the former case, `c_pos` moves forward whereas in the
latter case it stays put).
On first call, a `match` of the entire regex pattern starts with `c_pos=0`.
Then `match` is called again for the entire regex pattern with `c_pos=1`, etc. until
`c_pos` exceeds the length of the content. Each of these alternative match results
are then joined together with `sk.bitor` operations (this works because if one of them results
in 'true' then, in general, our matching algorithm should return 'true').
The `...` within the match statement above is what we will be working out for
some of the RegExpr variants now. Starting with `RegExpr::Char`:
```rust
case RegExpr::Char { c } => {
vec![(sk.eq(content_char, c), c_pos + 1)]
},
```
Let's consider an example of the variant above. If we apply `/a/` to content
`bac`, we'll have the following list of `match` calls `re` and `c_pos` values
(for simplicity, `re` is denoted in regex pattern instead of in RegExpr value):
re | c\_pos | Ciphertext operation
--- | --- | ---
/a/ | 0 | sk.eq(content[0], a)
/a/ | 1 | sk.eq(content[1], a)
/a/ | 2 | sk.eq(content[2], a)
And we would arrive at the following sequence of ciphertext operations:
```
sk.bitor(sk.eq(content[0], a), sk.bitor(sk.eq(content[1], a), sk.eq(content[2], a)))
```
AnyChar is a no operation:
```rust
case RegExpr::AnyChar => {
// note: ct_true is just some constant representing True that is trivially encoded into ciphertext
return vec![(ct_true, c_pos + 1)];
}
```
The sequence iterates over its `re_xs`, increasing the content position
accordingly, and joins the results with `bitand` operations:
```rust
case RegExpr::Seq { re_xs } => {
re_xs.iter().fold(|prev_results, re_x| {
prev_results.iter().flat_map(|(prev_res, prev_c_pos)| {
(x_res, new_c_pos) = match(sk, content, re_x, prev_c_pos);
(sk.bitand(prev_res, x_res), new_c_pos)
})
}, (ct_true, c_pos))
},
```
Other variants are similar, as they recurse and manipulate `re` and `c_pos`
accordingly. Hopefully, the general idea is already clear.
Ultimately the entire pattern-matching logic unfolds into a sequence of
the following set of FHE operations:
1. eq (tests for an exact character match)
2. ge (tests for 'greater than' or 'equal to' a character)
3. le (tests for 'less than' or 'equal to' a character)
4. bitand (bitwise AND, used for sequencing multiple regex components)
5. bitor (bitwise OR, used for folding multiple possible execution variants'
results into a single result)
6. bitxor (bitwise XOR, used for the 'not' logic in ranges)
### Optimizations.
Generally, the included example PME follows the approach outlined above. However, there were
two additional optimizations applied. Both of these optimizations involved
reducing the number of unnecessary FHE operations. Given how computationally expensive
these operations are, it makes sense to optimize for this (and to ignore any suboptimal
memory usage of our PME, etc.).
The first optimization involved delaying the execution of FHE operations to _after_
the generation of all possible execution paths to be considered. This optimization
allows us to prune execution paths during execution path construction that are provably
going to result in an encrypted false value, without having already performed the FHE
operations up to the point of pruning. Consider the regex `/^a+b$/`, and we are applying
this to a content of size 4. If we are executing execution paths naively, we would go ahead
and check for all possible amounts of `a` repetitions: `ab`, `aab`, `aaab`.
However, while building the execution paths, we can use the fact that `a+` must
begin at the beginning of the content, and that `b` must be the final character
of the content. From this follows that we only have to check for the following
sentence: `aaab`. Delaying execution of the FHE operations until after we've
built the possible execution paths in this example reduced the number of FHE
operations applied by approximately half.
The second optimization involved preventing the same FHE conditions to be
re-evaluated. Consider the regex `/^a?ab/`. This would give us the following
possible execution paths to consider:
1. `content[0] == a && content[1] == a && content[2] == b` (we match the `a` in
`a?`)
2. `content[0] == a && content[1] == b` (we don't match the `a` in `a?`)
Notice that, for both execution paths, we are checking for `content[0] == a`.
Even though we cannot see what the encrypted result is, we do know that it's
either going to be an encrypted false for both cases or an encrypted true for
both cases. Therefore, we can skip the re-evaluation of `content[0] == a` and
simply copy the result from the first evaluation over. This optimization
involves maintaining a cache of known expression evaluation results and
reusing those where possible.
## Trying out the example implementation
The implementation that guided the writing of this tutorial can be found
under `tfhe/examples/regex_engine`.
When compiling with `--example regex_engine`, a binary is produced that serves
as a basic demo. Simply call it with the content string as a first argument and
the pattern string as a second argument. For example,
`cargo run --release --features=x86_64-unix,integer --example regex_engine -- 'this is the content' '/^pattern$/'`;
note it's advised to compile the executable with `--release` flag as the key
generation and homomorphic operations otherwise seem to experience a heavy
performance penalty.
On execution, a private and public key pair are created. Then, the content is
encrypted with the client key, and the regex pattern is applied onto the
encrypted content string - with access given only to the server key. Finally, it
decrypts the resulting encrypted result using the client key and prints the
verdict to the console.
To get more information on exact computations and performance, set the `RUST_LOG`
environment variable to `debug` or to `trace`.
### Supported regex patterns
This section specifies the supported set of regex patterns in the regex engine.
#### Components
A regex is described by a sequence of components surrounded by `/`, the
following components are supported:
Name | Notation | Examples
--- | --- | ---
Character | Simply the character itself | `/a/`, `/b/`, `/Z/`, `/5/`
Character range | `[<character>-<character]` | `/[a-d]/`, `/[C-H]`/
Any character | `.` | `/a.c/`
Escaped symbol | `\<symbol>` | `/\^/`, `/\$/`
Parenthesis | `(<regex>)` | `/(abc)*/`, `/d(ab)?/`
Optional | `<regex>?` | `/a?/`, `/(az)?/`
Zero or more | `<regex>*` | `/a*/`, `/ab*c/`
One or more | `<regex>+` | `/a+/`, `/ab+c/`
Exact repeat | `<regex{<number>}>` | `/ab{2}c/`
At least repeat | `<regex{<number>,}>` | `/ab{2,}c/`
At most repeat | `<regex{,<number>}>` | `/ab{,2}c/`
Repeat between | `<regex{<number>,<number>}>` | `/ab{2,4}c/`
Either | `<regex>\|<regex>` | `/a\|b/`, `/ab\|cd/`
Start matching | `/^<regex>` | `/^abc/`
End matching | `<regex>$/` | `/abc$/`
#### Modifiers
Modifiers are mode selectors that affect the entire regex behavior. One modifier is
currently supported:
- Case insensitive matching, by appending an `i` after the regex pattern. For example: `/abc/i`
#### General examples
These components and modifiers can be combined to form any desired regex
pattern. To give some idea of what is possible, here is a non-exhaustive list of
supported regex patterns:
Pattern | Description
--- | ---
`/^abc$/` | Matches with content that equals exactly `abc` (case sensitive)
`/^abc$/i` | Matches with content that equals `abc` (case insensitive)
`/abc/` | Matches with content that contains somewhere `abc`
`/ab?c/` | Matches with content that contains somewhere `abc` or somwhere `ab`
`/^ab*c$/` | For example, matches with: `ac`, `abc`, `abbbbc`
`/^[a-c]b\|cd$/` | Matches with: `ab`, `bb`, `cb`, `cd`
`/^[a-c]b\|cd$/i` | Matches with: `ab`, `Ab`, `aB`, ..., `cD`, `CD`
`/^d(abc)+d$/` | For example, matches with: `dabcd`, `dabcabcd`, `dabcabcabcd`
`/^a.*d$/` | Matches with any content that starts with `a` and ends with `d`

View File

@@ -0,0 +1,322 @@
# Tutorial
## Intro
In this tutorial we will go through the steps to turn a regular sha256 implementation into its homomorphic version. We explain the basics of the sha256 function first, and then how to implement it homomorphically with performance considerations.
## Sha256
The first step in this experiment is actually implementing the sha256 function. We can find the specification [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf), but let's summarize the three main sections of the document.
#### Padding
The sha256 function processes the input data in blocks or chunks of 512 bits. Before actually performing the hash computations we have to pad the input in the following way:
* Append a single "1" bit
* Append a number of "0" bits such that exactly 64 bits are left to make the message length a multiple of 512
* Append the last 64 bits as a binary encoding of the original input length
Or visually:
```
0 L L+1 L+1+k L+1+k+64
|-----------------------------------|---|--------------------------------|----------------------|
Original input (L bits) "1" bit "0" bits Encoding of the number L
```
Where the numbers on the top represent the length of the padded input at each position, and L+1+k+64 is a multiple of 512 (the length of the padded input).
#### Operations and functions
Let's take a look at the operations that we will use as building blocks for functions inside the sha256 computation. These are bitwise AND, XOR, NOT, addition modulo 2^32 and the Rotate Right (ROTR) and Shift Right (SHR) operations, all working with 32-bit words and producing a new word.
We combine these operations inside the sigma (with 4 variations), Ch and Maj functions. At the end of the day, when we change the sha256 to be computed homomorphically, we will mainly change the isolated code of each operation.
Here is the definition of each function:
```
Ch(x, y, z) = (x AND y) XOR ((NOT x) AND z)
Maj(x, y, z) = (x AND y) XOR (x AND z) XOR (y AND z)
Σ0(x) = ROTR-2(x) XOR ROTR-13(x) XOR ROTR-22(x)
Σ1(x) = ROTR-6(x) XOR ROTR-11(x) XOR ROTR-25(x)
σ0(x) = ROTR-7(x) XOR ROTR-18(x) XOR SHR-3(x)
σ1(x) = ROTR-17(x) XOR ROTR-19(x) XOR SHR-10(x)
```
There are some things to note about the functions. Firstly we see that Maj can be simplified by applying the boolean distributive law (x AND y) XOR (x AND z) = x AND (y XOR z). So the new Maj function looks like this:
```
Maj(x, y, z) = (x AND (y XOR z)) XOR (y AND z)
```
Next we can also see that Ch can be simplified by using a single bitwise multiplexer. Let's take a look at the truth table of the Ch expression.
| x | y | z | Result |
| - | - | - | ------ |
| 0 | 0 | 0 | 0 |
| 0 | 0 | 1 | 1 |
| 0 | 1 | 0 | 0 |
| 0 | 1 | 1 | 1 |
| 1 | 0 | 0 | 0 |
| 1 | 0 | 1 | 0 |
| 1 | 1 | 0 | 1 |
| 1 | 1 | 1 | 1 |
When ```x = 0``` the result is identical to ```z```, but when ```x = 1``` the result is identical to ```y```. This is the same as saying ```if x {y} else {z}```. Hence we can replace the 4 bitwise operations of Ch by a single bitwise multiplexer.
Note that all these operations can be evaluated homomorphically. ROTR and SHR can be evaluated by changing the index of each individual bit of the word, even if each bit is encrypted, without using any homomorphic operation. Bitwise AND, XOR and multiplexer can be computed homomorphically and addition modulo 2^32 can be broken down into boolean homomorphic operations as well.
#### Sha256 computation
As we have mentioned, the sha256 function works with chunks of 512 bits. For each chunk, we will compute 64 32-bit words. 16 will come from the 512 bits and the rest will be computed using the previous functions. After computing the 64 words, and still within the same chunk iteration, a compression loop will compute a hash value (8 32-bit words), again using the previous functions and some constants to mix everything up. When we finish the last chunk iteration, the resulting hash values will be the output of the sha256 function.
Here is how this function looks like using arrays of 32 bools to represent words:
```rust
fn sha256(padded_input: Vec<bool>) -> [bool; 256] {
// Initialize hash values with constant values
let mut hash: [[bool; 32]; 8] = [
hex_to_bools(0x6a09e667), hex_to_bools(0xbb67ae85),
hex_to_bools(0x3c6ef372), hex_to_bools(0xa54ff53a),
hex_to_bools(0x510e527f), hex_to_bools(0x9b05688c),
hex_to_bools(0x1f83d9ab), hex_to_bools(0x5be0cd19),
];
let chunks = padded_input.chunks(512);
for chunk in chunks {
let mut w = [[false; 32]; 64];
// Copy first 16 words from current chunk
for i in 0..16 {
w[i].copy_from_slice(&chunk[i * 32..(i + 1) * 32]);
}
// Compute the other 48 words
for i in 16..64 {
w[i] = add(add(add(sigma1(&w[i - 2]), w[i - 7]), sigma0(&w[i - 15])), w[i - 16]);
}
let mut a = hash[0];
let mut b = hash[1];
let mut c = hash[2];
let mut d = hash[3];
let mut e = hash[4];
let mut f = hash[5];
let mut g = hash[6];
let mut h = hash[7];
// Compression loop, each iteration uses a specific constant from K
for i in 0..64 {
let temp1 = add(add(add(add(h, ch(&e, &f, &g)), w[i]), hex_to_bools(K[i])), sigma_upper_case_1(&e));
let temp2 = add(sigma_upper_case_0(&a), maj(&a, &b, &c));
h = g;
g = f;
f = e;
e = add(d, temp1);
d = c;
c = b;
b = a;
a = add(temp1, temp2);
}
hash[0] = add(hash[0], a);
hash[1] = add(hash[1], b);
hash[2] = add(hash[2], c);
hash[3] = add(hash[3], d);
hash[4] = add(hash[4], e);
hash[5] = add(hash[5], f);
hash[6] = add(hash[6], g);
hash[7] = add(hash[7], h);
}
// Concatenate the final hash values to produce a 256-bit hash
let mut output = [false; 256];
for i in 0..8 {
output[i * 32..(i + 1) * 32].copy_from_slice(&hash[i]);
}
output
}
```
## Making it homomorphic
The key idea is that we can replace each bit of ```padded_input``` with a Fully Homomorphic Encryption of the same bit value, and operate over the encrypted values using homomorphic operations. To achieve this we need to change the function signatures and deal with the borrowing rules of the Ciphertext type (which represents an encrypted bit) but the structure of the sha256 function remains the same. The part of the code that requires more consideration is the implementation of the sha256 operations, since they will use homomorphic boolean operations internally.
Homomorphic operations are really expensive, so we have to remove their unnecessary use and maximize parallelization in order to speed up the program. To simplify our code we use the Rayon crate which provides parallel iterators and efficiently manages threads. Let's now take a look at each sha256 operation!
#### Rotate Right and Shift Right
As we have highlighted, these two operations can be evaluated by changing the position of each encrypted bit in the word, thereby requiring 0 homomorphic operations. Here is our implementation:
```rust
fn rotate_right(x: &[Ciphertext; 32], n: usize) -> [Ciphertext; 32] {
let mut result = x.clone();
result.rotate_right(n);
result
}
fn shift_right(x: &[Ciphertext; 32], n: usize, sk: &ServerKey) -> [Ciphertext; 32] {
let mut result = x.clone();
result.rotate_right(n);
result[..n].fill_with(|| sk.trivial_encrypt(false));
result
}
```
#### Bitwise XOR, AND, Multiplexer
To implement these operations we will use the ```xor```, ```and``` and ```mux``` methods provided by the tfhe library to evaluate each boolean operation homomorphically. It's important to note that, since we will operate bitwise, we can parallelize the homomorphic computations. In other words, we can homomorphically XOR the bits at index 0 of two words using a thread, while XORing the bits at index 1 using another thread, and so on. This means we could compute these bitwise operations using up to 32 concurrent threads (since we work with 32-bit words).
Here is our implementation of the bitwise homomorphic XOR operation. The ```par_iter``` and ```par_iter_mut``` methods create a parallel iterator that we use to compute each individual XOR efficiently. The other two bitwise operations are implemented in the same way.
```rust
fn xor(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let mut result = a.clone();
result.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(dst, (lhs, rhs))| *dst = sk.xor(lhs, rhs));
result
}
```
#### Addition modulo 2^32
This is perhaps the trickiest operation to efficiently implement in a homomorphic fashion. A naive implementation could use the Ripple Carry Adder algorithm, which is straightforward but cannot be parallelized because each step depends on the previous one.
A better choice would be the Carry Lookahead Adder, which allows us to use the parallelized AND and XOR bitwise operations. With this design, our adder is around 50% faster than the Ripple Carry Adder.
```rust
pub fn add(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let propagate = xor(a, b, sk); // Parallelized bitwise XOR
let generate = and(a, b, sk); // Parallelized bitwise AND
let carry = compute_carry(&propagate, &generate, sk);
let sum = xor(&propagate, &carry, sk); // Parallelized bitwise XOR
sum
}
fn compute_carry(propagate: &[Ciphertext; 32], generate: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let mut carry = trivial_bools(&[false; 32], sk);
carry[31] = sk.trivial_encrypt(false);
for i in (0..31).rev() {
carry[i] = sk.or(&generate[i + 1], &sk.and(&propagate[i + 1], &carry[i + 1]));
}
carry
}
```
To even improve performance more, the function that computes the carry signals can also be parallelized using parallel prefix algorithms. These algorithms involve more boolean operations (so homomorphic operations for us) but may be faster because of their parallel nature. We have implemented the Brent-Kung and Ladner-Fischer algorithms, which entail different tradeoffs.
Brent-Kung has the least amount of boolean operations we could find (140 when using grey cells, for 32-bit numbers), which makes it suitable when we can't process many operations concurrently and fast. Our results confirm that it's indeed faster than both the sequential algorithm and Ladner-Fischer when run on regular computers.
On the other hand, Ladner-Fischer performs more boolean operations (209 using grey cells) than Brent-Kung, but they are performed in larger batches. Hence we can compute more operations in parallel and finish earlier, but we need more fast threads available or they will slow down the carry signals computation. Ladner-Fischer can be suitable when using cloud-based computing services, which offer many high-speed threads.
Our implementation uses Brent-Kung by default, but Ladner-Fischer can be enabled when needed by using the ```--ladner-fischer``` command line argument.
For more information about parallel prefix adders you can read [this paper](https://www.iosrjournals.org/iosr-jece/papers/Vol6-Issue1/A0610106.pdf) or [this other paper](https://www.ijert.org/research/design-and-implementation-of-parallel-prefix-adder-for-improving-the-performance-of-carry-lookahead-adder-IJERTV4IS120608.pdf).
Finally, with all these sha256 operations working homomorphically, our functions will be homomomorphic as well along with the whole sha256 function (after adapting the code to work with the Ciphertext type). Let's talk about other performance improvements we can make before we finish.
### More parallel processing
If we inspect the main ```sha256_fhe``` function, we will find operations that can be performed in parallel. For instance, within the compression loop, ```temp1``` and ```temp2``` can be computed concurrently. An efficient way to parallelize computations here is using the ```rayon::join()``` function, which uses parallel processing only when there are available CPUs. Recall that the two temporary values in the compression loop are the result of several additions, so we can use nested calls to ```rayon::join()``` to potentially parallelize more operations.
Another way to speed up consecutive additions would be using the Carry Save Adder, a very efficient adder that takes 3 numbers and returns a sum and carry sequence. If our inputs are A, B and C, we can construct a CSA with our previously implemented Maj function and the bitwise XOR operation as follows:
```
Carry = Maj(A, B, C)
Sum = A XOR B XOR C
```
By chaining CSAs, we can input the sum and carry from a preceding stage along with another number into a new CSA. Finally, to get the result of the additions we add the sum and carry sequences using a conventional adder. At the end we are performing the same number of additions, but some of them are now CSAs, speeding up the process. Let's see all this together in the ```temp1``` and ```temp2``` computations.
```rust
let (temp1, temp2) = rayon::join(
|| {
let ((sum, carry), s1) = rayon::join(
|| {
let ((sum, carry), ch) = rayon::join(
|| csa(&h, &w[i], &trivial_bools(&hex_to_bools(K[i]), sk), sk),
|| ch(&e, &f, &g, sk),
);
csa(&sum, &carry, &ch, sk)
},
|| sigma_upper_case_1(&e, sk)
);
let (sum, carry) = csa(&sum, &carry, &s1, sk);
add(&sum, &carry, sk)
},
|| {
add(&sigma_upper_case_0(&a, sk), &maj(&a, &b, &c, sk), sk)
},
);
```
The first closure of the outer call to join will return ```temp1``` and the second ```temp2```. Inside the first outer closure we call join recursively until we reach the addition of the value ```h```, the current word ```w[i]``` and the current constant ```K[i]``` by using the CSA, while potentially computing in parallel the ```ch``` function. Then we take the sum, carry and ch values and add them again using the CSA.
All this is done while potentially computing the ```sigma_upper_case_1``` function. Finally we input the previous sum, carry and sigma values to the CSA and perform the final addition with ```add```. Once again, this is done while potentially computing ```sigma_upper_case_0``` and ```maj``` and adding them to get ```temp2```, in the second outer closure.
With some changes of this type, we finally get a homomorphic sha256 function that doesn't leave unused computational resources.
## How to use sha256_bool
First of all, the most important thing when running the program is using the ```--release``` flag. The use of sha256_bool would look like this, given the implementation of ```encrypt_bools``` and ```decrypt_bools```:
```rust
fn main() {
let matches = Command::new("Homomorphic sha256")
.arg(Arg::new("ladner_fischer")
.long("ladner-fischer")
.help("Use the Ladner Fischer parallel prefix algorithm for additions")
.action(ArgAction::SetTrue))
.get_matches();
// If set using the command line flag "--ladner-fischer" this algorithm will be used in additions
let ladner_fischer: bool = matches.get_flag("ladner_fischer");
// INTRODUCE INPUT FROM STDIN
let mut input = String::new();
println!("Write input to hash:");
io::stdin()
.read_line(&mut input)
.expect("Failed to read line");
input = input.trim_end_matches('\n').to_string();
println!("You entered: \"{}\"", input);
// CLIENT PADS DATA AND ENCRYPTS IT
let (ck, sk) = gen_keys();
let padded_input = pad_sha256_input(&input);
let encrypted_input = encrypt_bools(&padded_input, &ck);
// SERVER COMPUTES OVER THE ENCRYPTED PADDED DATA
println!("Computing the hash");
let encrypted_output = sha256_fhe(encrypted_input, ladner_fischer, &sk);
// CLIENT DECRYPTS THE OUTPUT
let output = decrypt_bools(&encrypted_output, &ck);
let outhex = bools_to_hex(output);
println!("{}", outhex);
}
```
By using ```stdin``` we can supply the data to hash using a file instead of the command line. For example, if our file ```input.txt``` is in the same directory as the project, we can use the following shell command after building with ```cargo build --release```:
```sh
./target/release/examples/sha256_bool < input.txt
```
Our implementation also accepts hexadecimal inputs. To be considered as such, the input must start with "0x" and contain only valid hex digits (otherwise it's interpreted as text).
Finally see that padding is executed on the client side. This has the advantage of hiding the exact length of the input to the server, who already doesn't know anything about the contents of it but may extract information from the length.
Another option would be to perform padding on the server side. The padding function would receive the encrypted input and pad it with trivial bit encryptions. We could then integrate the padding function inside the ```sha256_fhe``` function computed by the server.

View File

@@ -0,0 +1,418 @@
use std::time::Instant;
use rayon::prelude::*;
use tfhe::integer::ciphertext::RadixCiphertextBig;
use tfhe::integer::keycache::IntegerKeyCache;
use tfhe::integer::ServerKey;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
/// The number of blocks to be used in the Radix.
const NUMBER_OF_BLOCKS: usize = 8;
/// Plain implementation of the volume matching algorithm.
///
/// Matches the given [sell_orders] with [buy_orders].
/// The amount of the orders that are successfully filled is written over the original order count.
fn volume_match_plain(sell_orders: &mut [u16], buy_orders: &mut [u16]) {
let total_sell_volume: u16 = sell_orders.iter().sum();
let total_buy_volume: u16 = buy_orders.iter().sum();
let total_volume = std::cmp::min(total_buy_volume, total_sell_volume);
let mut volume_left_to_transact = total_volume;
for sell_order in sell_orders.iter_mut() {
let filled_amount = std::cmp::min(volume_left_to_transact, *sell_order);
*sell_order = filled_amount;
volume_left_to_transact -= filled_amount;
}
let mut volume_left_to_transact = total_volume;
for buy_order in buy_orders.iter_mut() {
let filled_amount = std::cmp::min(volume_left_to_transact, *buy_order);
*buy_order = filled_amount;
volume_left_to_transact -= filled_amount;
}
}
/// FHE implementation of the volume matching algorithm.
///
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
/// [server_key]. The amount of the orders that are successfully filled is written over the original
/// order count.
fn volume_match_fhe(
sell_orders: &mut [RadixCiphertextBig],
buy_orders: &mut [RadixCiphertextBig],
server_key: &ServerKey,
) {
println!("Calculating total sell and buy volumes...");
let time = Instant::now();
let mut total_sell_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for sell_order in sell_orders.iter_mut() {
server_key.smart_add_assign(&mut total_sell_volume, sell_order);
}
let mut total_buy_volume = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for buy_order in buy_orders.iter_mut() {
server_key.smart_add_assign(&mut total_buy_volume, buy_order);
}
println!(
"Total sell and buy volumes are calculated in {:?}",
time.elapsed()
);
println!("Calculating total volume to be matched...");
let time = Instant::now();
let total_volume = server_key.smart_min(&mut total_sell_volume, &mut total_buy_volume);
println!(
"Calculated total volume to be matched in {:?}",
time.elapsed()
);
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
let mut volume_left_to_transact = total_volume.clone();
for order in orders.iter_mut() {
let mut filled_amount = server_key.smart_min(&mut volume_left_to_transact, order);
server_key.smart_sub_assign(&mut volume_left_to_transact, &mut filled_amount);
*order = filled_amount;
}
};
println!("Filling orders...");
let time = Instant::now();
fill_orders(sell_orders);
fill_orders(buy_orders);
println!("Filled orders in {:?}", time.elapsed());
}
/// FHE implementation of the volume matching algorithm.
///
/// This version of the algorithm utilizes parallelization to speed up the computation.
///
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
/// [server_key]. The amount of the orders that are successfully filled is written over the original
/// order count.
fn volume_match_fhe_parallelized(
sell_orders: &mut [RadixCiphertextBig],
buy_orders: &mut [RadixCiphertextBig],
server_key: &ServerKey,
) {
// Calculate the element sum of the given vector in parallel
let parallel_vector_sum = |vec: &mut [RadixCiphertextBig]| {
vec.to_vec().into_par_iter().reduce(
|| server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
|mut acc: RadixCiphertextBig, mut ele: RadixCiphertextBig| {
server_key.smart_add_parallelized(&mut acc, &mut ele)
},
)
};
println!("Calculating total sell and buy volumes...");
let time = Instant::now();
// Total sell and buy volumes can be calculated in parallel because they have no dependency on
// each other.
let (mut total_sell_volume, mut total_buy_volume) = rayon::join(
|| parallel_vector_sum(sell_orders),
|| parallel_vector_sum(buy_orders),
);
println!(
"Total sell and buy volumes are calculated in {:?}",
time.elapsed()
);
println!("Calculating total volume to be matched...");
let time = Instant::now();
let total_volume =
server_key.smart_min_parallelized(&mut total_sell_volume, &mut total_buy_volume);
println!(
"Calculated total volume to be matched in {:?}",
time.elapsed()
);
let fill_orders = |orders: &mut [RadixCiphertextBig]| {
let mut volume_left_to_transact = total_volume.clone();
for order in orders.iter_mut() {
let mut filled_amount =
server_key.smart_min_parallelized(&mut volume_left_to_transact, order);
server_key
.smart_sub_assign_parallelized(&mut volume_left_to_transact, &mut filled_amount);
*order = filled_amount;
}
};
println!("Filling orders...");
let time = Instant::now();
rayon::join(|| fill_orders(sell_orders), || fill_orders(buy_orders));
println!("Filled orders in {:?}", time.elapsed());
}
/// FHE implementation of the volume matching algorithm.
///
/// In this function, the implemented algorithm is modified to utilize more concurrency.
///
/// Matches the given encrypted [sell_orders] with encrypted [buy_orders] using the given
/// [server_key]. The amount of the orders that are successfully filled is written over the original
/// order count.
fn volume_match_fhe_modified(
sell_orders: &mut [RadixCiphertextBig],
buy_orders: &mut [RadixCiphertextBig],
server_key: &ServerKey,
) {
let compute_prefix_sum = |arr: &[RadixCiphertextBig]| {
if arr.is_empty() {
return arr.to_vec();
}
let mut prefix_sum: Vec<RadixCiphertextBig> = (0..arr.len().next_power_of_two())
.into_par_iter()
.map(|i| {
if i < arr.len() {
arr[i].clone()
} else {
server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS)
}
})
.collect();
for d in 0..prefix_sum.len().ilog2() {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let mut left = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut left)
});
}
let last = prefix_sum.last().unwrap().clone();
*prefix_sum.last_mut().unwrap() = server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS);
for d in (0..prefix_sum.len().ilog2()).rev() {
prefix_sum
.par_chunks_exact_mut(2_usize.pow(d + 1))
.for_each(move |chunk| {
let length = chunk.len();
let temp = chunk.last().unwrap().clone();
let mut mid = chunk.get((length - 1) / 2).unwrap().clone();
server_key.smart_add_assign_parallelized(chunk.last_mut().unwrap(), &mut mid);
chunk[(length - 1) / 2] = temp;
});
}
prefix_sum.push(last);
prefix_sum[1..=arr.len()].to_vec()
};
println!("Creating prefix sum arrays...");
let time = Instant::now();
let (prefix_sum_sell_orders, prefix_sum_buy_orders) = rayon::join(
|| compute_prefix_sum(sell_orders),
|| compute_prefix_sum(buy_orders),
);
println!("Created prefix sum arrays in {:?}", time.elapsed());
let fill_orders = |total_orders: &RadixCiphertextBig,
orders: &mut [RadixCiphertextBig],
prefix_sum_arr: &[RadixCiphertextBig]| {
orders
.into_par_iter()
.enumerate()
.for_each(move |(i, order)| {
server_key.smart_add_assign_parallelized(
order,
&mut server_key.smart_mul_parallelized(
&mut server_key
.smart_ge_parallelized(&mut order.clone(), &mut total_orders.clone()),
&mut server_key.smart_sub_parallelized(
&mut server_key.smart_sub_parallelized(
&mut total_orders.clone(),
&mut server_key.smart_min_parallelized(
&mut total_orders.clone(),
&mut prefix_sum_arr
.get(i - 1)
.unwrap_or(
&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS),
)
.clone(),
),
),
&mut order.clone(),
),
),
);
});
};
let total_buy_orders = &mut prefix_sum_buy_orders
.last()
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
.clone();
let total_sell_orders = &mut prefix_sum_sell_orders
.last()
.unwrap_or(&server_key.create_trivial_zero_radix(NUMBER_OF_BLOCKS))
.clone();
println!("Matching orders...");
let time = Instant::now();
rayon::join(
|| fill_orders(total_sell_orders, buy_orders, &prefix_sum_buy_orders),
|| fill_orders(total_buy_orders, sell_orders, &prefix_sum_sell_orders),
);
println!("Matched orders in {:?}", time.elapsed());
}
/// Runs the given [tester] function with the test cases for volume matching algorithm.
fn run_test_cases<F: Fn(&[u16], &[u16], &[u16], &[u16])>(tester: F) {
println!("Testing empty sell orders...");
tester(
&[],
&(1..11).collect::<Vec<_>>(),
&[],
&(1..11).map(|_| 0).collect::<Vec<_>>(),
);
println!();
println!("Testing empty buy orders...");
tester(
&(1..11).collect::<Vec<_>>(),
&[],
&(1..11).map(|_| 0).collect::<Vec<_>>(),
&[],
);
println!();
println!("Testing exact matching of sell and buy orders...");
tester(
&(1..11).collect::<Vec<_>>(),
&(1..11).collect::<Vec<_>>(),
&(1..11).collect::<Vec<_>>(),
&(1..11).collect::<Vec<_>>(),
);
println!();
println!("Testing the case where there are more buy orders than sell orders...");
tester(
&(1..11).map(|_| 10).collect::<Vec<_>>(),
&[200],
&(1..11).map(|_| 10).collect::<Vec<_>>(),
&[100],
);
println!();
println!("Testing the case where there are more sell orders than buy orders...");
tester(
&[200],
&(1..11).map(|_| 10).collect::<Vec<_>>(),
&[100],
&(1..11).map(|_| 10).collect::<Vec<_>>(),
);
println!();
println!("Testing maximum input size for sell and buy orders...");
tester(
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
&(1..=500).map(|_| 100).collect::<Vec<_>>(),
);
println!();
}
/// Runs the test cases for the plain implementation of the volume matching algorithm.
fn test_volume_match_plain() {
let tester = |input_sell_orders: &[u16],
input_buy_orders: &[u16],
expected_filled_sells: &[u16],
expected_filled_buys: &[u16]| {
let mut sell_orders = input_sell_orders.to_vec();
let mut buy_orders = input_buy_orders.to_vec();
println!("Running plain implementation...");
let time = Instant::now();
volume_match_plain(&mut sell_orders, &mut buy_orders);
println!("Ran plain implementation in {:?}", time.elapsed());
assert_eq!(sell_orders, expected_filled_sells);
assert_eq!(buy_orders, expected_filled_buys);
};
println!("Running test cases for the plain implementation");
run_test_cases(tester);
}
/// Runs the test cases for the fhe implementation of the volume matching algorithm.
///
/// [parallelized] indicates whether the fhe implementation should be run in parallel.
fn test_volume_match_fhe(
fhe_function: fn(&mut [RadixCiphertextBig], &mut [RadixCiphertextBig], &ServerKey),
) {
let working_dir = std::env::current_dir().unwrap();
if working_dir.file_name().unwrap() != std::path::Path::new("tfhe") {
std::env::set_current_dir(working_dir.join("tfhe")).unwrap();
}
println!("Generating keys...");
let time = Instant::now();
let (client_key, server_key) = IntegerKeyCache.get_from_params(PARAM_MESSAGE_2_CARRY_2);
println!("Keys generated in {:?}", time.elapsed());
let tester = |input_sell_orders: &[u16],
input_buy_orders: &[u16],
expected_filled_sells: &[u16],
expected_filled_buys: &[u16]| {
let mut encrypted_sell_orders = input_sell_orders
.iter()
.cloned()
.map(|pt| client_key.encrypt_radix(pt as u64, NUMBER_OF_BLOCKS))
.collect::<Vec<RadixCiphertextBig>>();
let mut encrypted_buy_orders = input_buy_orders
.iter()
.cloned()
.map(|pt| client_key.encrypt_radix(pt as u64, NUMBER_OF_BLOCKS))
.collect::<Vec<RadixCiphertextBig>>();
println!("Running FHE implementation...");
let time = Instant::now();
fhe_function(
&mut encrypted_sell_orders,
&mut encrypted_buy_orders,
&server_key,
);
println!("Ran FHE implementation in {:?}", time.elapsed());
let decrypted_filled_sells = encrypted_sell_orders
.iter()
.map(|ct| client_key.decrypt_radix::<u64, _>(ct) as u16)
.collect::<Vec<u16>>();
let decrypted_filled_buys = encrypted_buy_orders
.iter()
.map(|ct| client_key.decrypt_radix::<u64, _>(ct) as u16)
.collect::<Vec<u16>>();
assert_eq!(decrypted_filled_sells, expected_filled_sells);
assert_eq!(decrypted_filled_buys, expected_filled_buys);
};
println!("Running test cases for the FHE implementation");
run_test_cases(tester);
}
fn main() {
for argument in std::env::args() {
if argument == "fhe-modified" {
println!("Running modified fhe version");
test_volume_match_fhe(volume_match_fhe_modified);
println!();
}
if argument == "fhe-parallel" {
println!("Running parallelized fhe version");
test_volume_match_fhe(volume_match_fhe_parallelized);
println!();
}
if argument == "plain" {
println!("Running plain version");
test_volume_match_plain();
println!();
}
if argument == "fhe" {
println!("Running fhe version");
test_volume_match_fhe(volume_match_fhe);
println!();
}
}
}

View File

@@ -1,62 +1,100 @@
use clap::{Arg, ArgAction, Command};
use tfhe::shortint::keycache::{NamedParam, KEY_CACHE, KEY_CACHE_WOPBS};
use tfhe::shortint::parameters::parameters_wopbs_message_carry::{
WOPBS_PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_3_CARRY_3,
WOPBS_PARAM_MESSAGE_4_CARRY_4,
};
use tfhe::shortint::parameters::{
Parameters, ALL_PARAMETER_VEC, PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2,
PARAM_MESSAGE_3_CARRY_3, PARAM_MESSAGE_4_CARRY_4,
ClassicPBSParameters, WopbsParameters, ALL_MULTI_BIT_PARAMETER_VEC, ALL_PARAMETER_VEC,
PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_3_CARRY_3,
PARAM_MESSAGE_4_CARRY_4,
};
fn client_server_keys() {
println!("Generating shortint (ClientKey, ServerKey)");
for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() {
println!(
"Generating [{} / {}] : {}",
i + 1,
ALL_PARAMETER_VEC.len(),
params.name()
);
let matches = Command::new("test key gen")
.arg(
Arg::new("multi_bit_only")
.long("multi-bit-only")
.help("Set to generate only multi bit keys, otherwise only PBS and WoPBS keys are generated")
.action(ArgAction::SetTrue),
)
.get_matches();
let start = std::time::Instant::now();
// If set using the command line flag "--ladner-fischer" this algorithm will be used in
// additions
let multi_bit_only: bool = matches.get_flag("multi_bit_only");
let _ = KEY_CACHE.get_from_param(params);
if multi_bit_only {
println!("Generating shortint multibit (ClientKey, ServerKey)");
for (i, params) in ALL_MULTI_BIT_PARAMETER_VEC.iter().copied().enumerate() {
println!(
"Generating [{} / {}] : {}",
i + 1,
ALL_MULTI_BIT_PARAMETER_VEC.len(),
params.name()
);
let stop = start.elapsed().as_secs();
let start = std::time::Instant::now();
println!("Generation took {stop} seconds");
let _ = KEY_CACHE.get_from_param(params);
// Clear keys as we go to avoid filling the RAM
KEY_CACHE.clear_in_memory_cache()
}
let stop = start.elapsed().as_secs();
const WOPBS_PARAMS: [(Parameters, Parameters); 4] = [
(PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1),
(PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2),
(PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3),
(PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4),
];
println!("Generation took {stop} seconds");
println!("Generating woPBS keys");
for (i, (params_shortint, params_wopbs)) in WOPBS_PARAMS.iter().copied().enumerate() {
println!(
"Generating [{} / {}] : {}, {}",
i + 1,
WOPBS_PARAMS.len(),
params_shortint.name(),
params_wopbs.name(),
);
// Clear keys as we go to avoid filling the RAM
KEY_CACHE.clear_in_memory_cache()
}
} else {
println!("Generating shortint (ClientKey, ServerKey)");
for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() {
println!(
"Generating [{} / {}] : {}",
i + 1,
ALL_PARAMETER_VEC.len(),
params.name()
);
let start = std::time::Instant::now();
let start = std::time::Instant::now();
let _ = KEY_CACHE_WOPBS.get_from_param((params_shortint, params_wopbs));
let _ = KEY_CACHE.get_from_param(params);
let stop = start.elapsed().as_secs();
let stop = start.elapsed().as_secs();
println!("Generation took {stop} seconds");
println!("Generation took {stop} seconds");
// Clear keys as we go to avoid filling the RAM
KEY_CACHE_WOPBS.clear_in_memory_cache()
// Clear keys as we go to avoid filling the RAM
KEY_CACHE.clear_in_memory_cache()
}
const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 4] = [
(PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1),
(PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2),
(PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3),
(PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4),
];
println!("Generating woPBS keys");
for (i, (params_shortint, params_wopbs)) in WOPBS_PARAMS.iter().copied().enumerate() {
println!(
"Generating [{} / {}] : {}, {}",
i + 1,
WOPBS_PARAMS.len(),
params_shortint.name(),
params_wopbs.name(),
);
let start = std::time::Instant::now();
let _ = KEY_CACHE_WOPBS.get_from_param((params_shortint, params_wopbs));
let stop = start.elapsed().as_secs();
println!("Generation took {stop} seconds");
// Clear keys as we go to avoid filling the RAM
KEY_CACHE_WOPBS.clear_in_memory_cache()
}
}
}

View File

@@ -0,0 +1,92 @@
use rand::Rng;
use tfhe::core_crypto::commons::numeric::Numeric;
use tfhe::integer::block_decomposition::{DecomposableInto, RecomposableFrom};
use tfhe::integer::public_key::{CompactPublicKeyBig, CompactPublicKeySmall};
use tfhe::integer::{gen_keys, U256};
use tfhe::shortint::keycache::NamedParam;
use tfhe::shortint::parameters::parameters_compact_pk::*;
pub fn main() {
fn size_func<Scalar: Numeric + DecomposableInto<u64> + RecomposableFrom<u64> + From<u32>>() {
let mut rng = rand::thread_rng();
let num_bits = Scalar::BITS;
let params = PARAM_MESSAGE_2_CARRY_2_COMPACT_PK;
{
println!("Sizes for: {} and {num_bits} bits", params.name());
let (cks, _) = gen_keys(params);
let pk = CompactPublicKeyBig::new(&cks);
println!("PK size: {} bytes", bincode::serialize(&pk).unwrap().len());
let num_block =
(num_bits as f64 / (params.message_modulus.0 as f64).log(2.0)).ceil() as usize;
const MAX_CT: usize = 20;
let mut clear_vec = Vec::with_capacity(MAX_CT);
// 5 inputs to a smart contract
let num_ct_for_this_iter = 5;
clear_vec.truncate(0);
for _ in 0..num_ct_for_this_iter {
let clear = rng.gen::<u32>();
clear_vec.push(Scalar::from(clear));
}
let compact_encrypted_list = pk.encrypt_slice_radix_compact(&clear_vec, num_block);
println!(
"Compact CT list for {num_ct_for_this_iter} CTs: {} bytes",
bincode::serialize(&compact_encrypted_list).unwrap().len()
);
let ciphertext_vec = compact_encrypted_list.expand();
for (ciphertext, clear) in ciphertext_vec.iter().zip(clear_vec.iter().copied()) {
let decrypted: Scalar = cks.decrypt_radix(ciphertext);
assert_eq!(decrypted, clear);
}
}
let params = PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_SMALL;
{
println!("Sizes for: {} and {num_bits} bits", params.name());
let (cks, _) = gen_keys(params);
let pk = CompactPublicKeySmall::new(&cks);
println!("PK size: {} bytes", bincode::serialize(&pk).unwrap().len());
let num_block =
(num_bits as f64 / (params.message_modulus.0 as f64).log(2.0)).ceil() as usize;
const MAX_CT: usize = 20;
let mut clear_vec = Vec::with_capacity(MAX_CT);
// 5 inputs to a smart contract
let num_ct_for_this_iter = 5;
clear_vec.truncate(0);
for _ in 0..num_ct_for_this_iter {
let clear = rng.gen::<u32>();
clear_vec.push(Scalar::from(clear));
}
let compact_encrypted_list = pk.encrypt_slice_radix_compact(&clear_vec, num_block);
println!(
"Compact CT list for {num_ct_for_this_iter} CTs: {} bytes",
bincode::serialize(&compact_encrypted_list).unwrap().len()
);
let ciphertext_vec = compact_encrypted_list.expand();
for (ciphertext, clear) in ciphertext_vec.iter().zip(clear_vec.iter().copied()) {
let decrypted: Scalar = cks.decrypt_radix(ciphertext);
assert_eq!(decrypted, clear);
}
}
}
size_func::<u32>();
size_func::<U256>();
}

View File

@@ -0,0 +1,22 @@
use tfhe::integer::{gen_keys_radix, RadixCiphertextBig, RadixClientKey, ServerKey};
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
pub type StringCiphertext = Vec<RadixCiphertextBig>;
pub fn encrypt_str(
client_key: &RadixClientKey,
s: &str,
) -> Result<StringCiphertext, Box<dyn std::error::Error>> {
if !s.is_ascii() {
return Err("content contains non-ascii characters".into());
}
Ok(s.as_bytes()
.iter()
.map(|byte| client_key.encrypt(*byte as u64))
.collect())
}
pub fn gen_keys() -> (RadixClientKey, ServerKey) {
let num_block = 4;
gen_keys_radix(PARAM_MESSAGE_2_CARRY_2, num_block)
}

View File

@@ -0,0 +1,263 @@
use crate::execution::{Executed, Execution, LazyExecution};
use crate::parser::{parse, RegExpr};
use std::rc::Rc;
use tfhe::integer::{RadixCiphertextBig, ServerKey};
pub fn has_match(
sk: &ServerKey,
content: &[RadixCiphertextBig],
pattern: &str,
) -> Result<RadixCiphertextBig, Box<dyn std::error::Error>> {
let re = parse(pattern)?;
let branches: Vec<LazyExecution> = (0..content.len())
.flat_map(|i| build_branches(content, &re, i))
.map(|(lazy_branch_res, _)| lazy_branch_res)
.collect();
let mut exec = Execution::new(sk.clone());
let res = if branches.len() <= 1 {
branches
.get(0)
.map_or(exec.ct_false(), |branch| branch(&mut exec))
.0
} else {
branches[1..]
.iter()
.fold(branches[0](&mut exec), |res, branch| {
let branch_res = branch(&mut exec);
exec.ct_or(res, branch_res)
})
.0
};
info!(
"{} ciphertext operations, {} cache hits",
exec.ct_operations_count(),
exec.cache_hits(),
);
Ok(res)
}
fn build_branches(
content: &[RadixCiphertextBig],
re: &RegExpr,
c_pos: usize,
) -> Vec<(LazyExecution, usize)> {
trace!("program pointer: regex={:?}, content pos={}", re, c_pos);
match re {
RegExpr::Sof => {
if c_pos == 0 {
return vec![(Rc::new(|exec| exec.ct_true()), c_pos)];
} else {
return vec![];
}
}
RegExpr::Eof => {
if c_pos == content.len() {
return vec![(Rc::new(|exec| exec.ct_true()), c_pos)];
} else {
return vec![];
}
}
_ => (),
};
if c_pos >= content.len() {
return vec![];
}
match re.clone() {
RegExpr::Char { c } => {
let c_char = (content[c_pos].clone(), Executed::ct_pos(c_pos));
vec![(
Rc::new(move |exec| exec.ct_eq(c_char.clone(), exec.ct_constant(c))),
c_pos + 1,
)]
}
RegExpr::AnyChar => vec![(Rc::new(|exec| exec.ct_true()), c_pos + 1)],
RegExpr::Not { not_re } => build_branches(content, &not_re, c_pos)
.into_iter()
.map(|(branch, c_pos)| {
(
Rc::new(move |exec: &mut Execution| {
let branch_res = branch(exec);
exec.ct_not(branch_res)
}) as LazyExecution,
c_pos,
)
})
.collect(),
RegExpr::Either { l_re, r_re } => {
let mut res = build_branches(content, &l_re, c_pos);
res.append(&mut build_branches(content, &r_re, c_pos));
res
}
RegExpr::Between { from, to } => {
let c_char = (content[c_pos].clone(), Executed::ct_pos(c_pos));
vec![(
Rc::new(move |exec| {
let ct_from = exec.ct_constant(from);
let ct_to = exec.ct_constant(to);
let ge_from = exec.ct_ge(c_char.clone(), ct_from);
let le_to = exec.ct_le(c_char.clone(), ct_to);
exec.ct_and(ge_from, le_to)
}),
c_pos + 1,
)]
}
RegExpr::Range { cs } => {
let c_char = (content[c_pos].clone(), Executed::ct_pos(c_pos));
vec![(
Rc::new(move |exec| {
cs[1..].iter().fold(
exec.ct_eq(c_char.clone(), exec.ct_constant(cs[0])),
|res, c| {
let ct_c_char_eq = exec.ct_eq(c_char.clone(), exec.ct_constant(*c));
exec.ct_or(res, ct_c_char_eq)
},
)
}),
c_pos + 1,
)]
}
RegExpr::Repeated {
repeat_re,
at_least,
at_most,
} => {
let at_least = at_least.unwrap_or(0);
let at_most = at_most.unwrap_or(content.len() - c_pos);
if at_least > at_most {
return vec![];
}
let mut res = vec![
if at_least == 0 {
vec![(
Rc::new(|exec: &mut Execution| exec.ct_true()) as LazyExecution,
c_pos,
)]
} else {
vec![]
},
build_branches(
content,
&(RegExpr::Seq {
re_xs: std::iter::repeat(*repeat_re.clone())
.take(std::cmp::max(1, at_least))
.collect(),
}),
c_pos,
),
];
for _ in (at_least + 1)..(at_most + 1) {
res.push(
res.last()
.unwrap()
.iter()
.flat_map(|(branch_prev, branch_c_pos)| {
build_branches(content, &repeat_re, *branch_c_pos)
.into_iter()
.map(move |(branch_x, branch_x_c_pos)| {
let branch_prev = branch_prev.clone();
(
Rc::new(move |exec: &mut Execution| {
let res_prev = branch_prev(exec);
let res_x = branch_x(exec);
exec.ct_and(res_prev, res_x)
}) as LazyExecution,
branch_x_c_pos,
)
})
})
.collect(),
);
}
res.into_iter().flatten().collect()
}
RegExpr::Optional { opt_re } => {
let mut res = build_branches(content, &opt_re, c_pos);
res.push((Rc::new(|exec| exec.ct_true()), c_pos));
res
}
RegExpr::Seq { re_xs } => re_xs[1..].iter().fold(
build_branches(content, &re_xs[0], c_pos),
|continuations, re_x| {
continuations
.into_iter()
.flat_map(|(branch_prev, branch_prev_c_pos)| {
build_branches(content, re_x, branch_prev_c_pos)
.into_iter()
.map(move |(branch_x, branch_x_c_pos)| {
let branch_prev = branch_prev.clone();
(
Rc::new(move |exec: &mut Execution| {
let res_prev = branch_prev(exec);
let res_x = branch_x(exec);
exec.ct_and(res_prev, res_x)
}) as LazyExecution,
branch_x_c_pos,
)
})
})
.collect()
},
),
_ => panic!("unmatched regex variant"),
}
}
#[cfg(test)]
mod tests {
use crate::engine::has_match;
use test_case::test_case;
use crate::ciphertext::{encrypt_str, gen_keys, StringCiphertext};
use lazy_static::lazy_static;
use tfhe::integer::{RadixClientKey, ServerKey};
lazy_static! {
pub static ref KEYS: (RadixClientKey, ServerKey) = gen_keys();
}
#[test_case("ab", "/ab/", 1)]
#[test_case("b", "/ab/", 0)]
#[test_case("ab", "/a?b/", 1)]
#[test_case("b", "/a?b/", 1)]
#[test_case("ab", "/^ab|cd$/", 1)]
#[test_case(" ab", "/^ab|cd$/", 0)]
#[test_case(" cd", "/^ab|cd$/", 0)]
#[test_case("cd", "/^ab|cd$/", 1)]
#[test_case("abcd", "/^ab|cd$/", 0)]
#[test_case("abcd", "/ab|cd$/", 1)]
#[test_case("abc", "/abc/", 1)]
#[test_case("123abc", "/abc/", 1)]
#[test_case("123abc456", "/abc/", 1)]
#[test_case("123abdc456", "/abc/", 0)]
#[test_case("abc456", "/abc/", 1)]
#[test_case("bc", "/a*bc/", 1)]
#[test_case("cdaabc", "/a*bc/", 1)]
#[test_case("cdbc", "/a+bc/", 0)]
#[test_case("bc", "/a+bc/", 0)]
#[test_case("Ab", "/ab/i", 1 ; "ab case insensitive")]
#[test_case("Ab", "/ab/", 0 ; "ab case sensitive")]
#[test_case("cD", "/ab|cd/i", 1)]
#[test_case("cD", "/cD/", 1)]
#[test_case("test a num 8", "/8/", 1)]
#[test_case("test a num 8", "/^8/", 0)]
#[test_case("4453", "/^[0-9]*$/", 1)]
#[test_case("4453", "/^[09]*$/", 0)]
#[test_case("09009", "/^[09]*$/", 1)]
#[test_case("de", "/^ab|cd|de$/", 1 ; "multiple or")]
#[test_case(" de", "/^ab|cd|de$/", 0 ; "multiple or nests below ^")]
fn test_has_match(content: &str, pattern: &str, exp: u64) {
let ct_content: StringCiphertext = encrypt_str(&KEYS.0, content).unwrap();
let ct_res = has_match(&KEYS.1, &ct_content, pattern).unwrap();
let got = KEYS.0.decrypt(&ct_res);
assert_eq!(exp, got);
}
}

View File

@@ -0,0 +1,272 @@
use std::collections::HashMap;
use std::rc::Rc;
use tfhe::integer::{RadixCiphertextBig, ServerKey};
use crate::parser::u8_to_char;
#[derive(Clone, PartialEq, Eq, Hash)]
pub(crate) enum Executed {
Constant { c: u8 },
CtPos { at: usize },
And { a: Box<Executed>, b: Box<Executed> },
Or { a: Box<Executed>, b: Box<Executed> },
Equal { a: Box<Executed>, b: Box<Executed> },
GreaterOrEqual { a: Box<Executed>, b: Box<Executed> },
LessOrEqual { a: Box<Executed>, b: Box<Executed> },
Not { a: Box<Executed> },
}
type ExecutedResult = (RadixCiphertextBig, Executed);
impl Executed {
pub(crate) fn ct_pos(at: usize) -> Self {
Executed::CtPos { at }
}
fn get_trivial_constant(&self) -> Option<u8> {
match self {
Self::Constant { c } => Some(*c),
_ => None,
}
}
}
const CT_FALSE: u8 = 0;
const CT_TRUE: u8 = 1;
pub(crate) struct Execution {
sk: ServerKey,
cache: HashMap<Executed, RadixCiphertextBig>,
ct_ops: usize,
cache_hits: usize,
}
pub(crate) type LazyExecution = Rc<dyn Fn(&mut Execution) -> ExecutedResult>;
impl Execution {
pub(crate) fn new(sk: ServerKey) -> Self {
Self {
sk,
cache: HashMap::new(),
ct_ops: 0,
cache_hits: 0,
}
}
pub(crate) fn ct_operations_count(&self) -> usize {
self.ct_ops
}
pub(crate) fn cache_hits(&self) -> usize {
self.cache_hits
}
pub(crate) fn ct_eq(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
let ctx = Executed::Equal {
a: Box::new(a.1.clone()),
b: Box::new(b.1.clone()),
};
self.with_cache(
ctx.clone(),
Rc::new(move |exec: &mut Execution| {
exec.ct_ops += 1;
let mut ct_a = a.0.clone();
let mut ct_b = b.0.clone();
(exec.sk.smart_eq(&mut ct_a, &mut ct_b), ctx.clone())
}),
)
}
pub(crate) fn ct_ge(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
let ctx = Executed::GreaterOrEqual {
a: Box::new(a.1.clone()),
b: Box::new(b.1.clone()),
};
self.with_cache(
ctx.clone(),
Rc::new(move |exec| {
exec.ct_ops += 1;
let mut ct_a = a.0.clone();
let mut ct_b = b.0.clone();
(exec.sk.smart_gt(&mut ct_a, &mut ct_b), ctx.clone())
}),
)
}
pub(crate) fn ct_le(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
let ctx = Executed::LessOrEqual {
a: Box::new(a.1.clone()),
b: Box::new(b.1.clone()),
};
self.with_cache(
ctx.clone(),
Rc::new(move |exec| {
exec.ct_ops += 1;
let mut ct_a = a.0.clone();
let mut ct_b = b.0.clone();
(exec.sk.smart_le(&mut ct_a, &mut ct_b), ctx.clone())
}),
)
}
pub(crate) fn ct_and(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
let ctx = Executed::And {
a: Box::new(a.1.clone()),
b: Box::new(b.1.clone()),
};
let c_a = a.1.get_trivial_constant();
let c_b = b.1.get_trivial_constant();
if c_a == Some(CT_TRUE) {
return (b.0, ctx);
}
if c_a == Some(CT_FALSE) {
return (a.0, ctx);
}
if c_b == Some(CT_TRUE) {
return (a.0, ctx);
}
if c_b == Some(CT_FALSE) {
return (b.0, ctx);
}
self.with_cache(
ctx.clone(),
Rc::new(move |exec| {
exec.ct_ops += 1;
let mut ct_a = a.0.clone();
let mut ct_b = b.0.clone();
(exec.sk.smart_bitand(&mut ct_a, &mut ct_b), ctx.clone())
}),
)
}
pub(crate) fn ct_or(&mut self, a: ExecutedResult, b: ExecutedResult) -> ExecutedResult {
let ctx = Executed::Or {
a: Box::new(a.1.clone()),
b: Box::new(b.1.clone()),
};
let c_a = a.1.get_trivial_constant();
let c_b = b.1.get_trivial_constant();
if c_a == Some(CT_TRUE) {
return (a.0, ctx);
}
if c_b == Some(CT_TRUE) {
return (b.0, ctx);
}
if c_a == Some(CT_FALSE) && c_b == Some(CT_FALSE) {
return (a.0, ctx);
}
self.with_cache(
ctx.clone(),
Rc::new(move |exec| {
exec.ct_ops += 1;
let mut ct_a = a.0.clone();
let mut ct_b = b.0.clone();
(exec.sk.smart_bitor(&mut ct_a, &mut ct_b), ctx.clone())
}),
)
}
pub(crate) fn ct_not(&mut self, a: ExecutedResult) -> ExecutedResult {
let ctx = Executed::Not {
a: Box::new(a.1.clone()),
};
self.with_cache(
ctx.clone(),
Rc::new(move |exec| {
exec.ct_ops += 1;
let mut ct_a = a.0.clone();
let mut ct_b = exec.ct_constant(1).0;
(exec.sk.smart_bitxor(&mut ct_a, &mut ct_b), ctx.clone())
}),
)
}
pub(crate) fn ct_false(&self) -> ExecutedResult {
self.ct_constant(CT_FALSE)
}
pub(crate) fn ct_true(&self) -> ExecutedResult {
self.ct_constant(CT_TRUE)
}
pub(crate) fn ct_constant(&self, c: u8) -> ExecutedResult {
(
self.sk.create_trivial_radix(c as u64, 4),
Executed::Constant { c },
)
}
fn with_cache(&mut self, ctx: Executed, f: LazyExecution) -> ExecutedResult {
if let Some(res) = self.cache.get(&ctx) {
trace!("cache hit: {:?}", &ctx);
self.cache_hits += 1;
return (res.clone(), ctx);
}
debug!("evaluation for: {:?}", &ctx);
let res = f(self);
self.cache.insert(ctx, res.0.clone());
res
}
}
impl std::fmt::Debug for Executed {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Constant { c } => match c {
0 => write!(f, "f"),
1 => write!(f, "t"),
_ => write!(f, "{}", u8_to_char(*c)),
},
Self::CtPos { at } => write!(f, "ct_{}", at),
Self::And { a, b } => {
write!(f, "(")?;
a.fmt(f)?;
write!(f, "/\\")?;
b.fmt(f)?;
write!(f, ")")
}
Self::Or { a, b } => {
write!(f, "(")?;
a.fmt(f)?;
write!(f, "\\/")?;
b.fmt(f)?;
write!(f, ")")
}
Self::Equal { a, b } => {
write!(f, "(")?;
a.fmt(f)?;
write!(f, "==")?;
b.fmt(f)?;
write!(f, ")")
}
Self::GreaterOrEqual { a, b } => {
write!(f, "(")?;
a.fmt(f)?;
write!(f, ">=")?;
b.fmt(f)?;
write!(f, ")")
}
Self::LessOrEqual { a, b } => {
write!(f, "(")?;
a.fmt(f)?;
write!(f, "<=")?;
b.fmt(f)?;
write!(f, ")")
}
Self::Not { a } => {
write!(f, "(!")?;
a.fmt(f)?;
write!(f, ")")
}
}
}
}

View File

@@ -0,0 +1,30 @@
#[macro_use]
extern crate log;
mod ciphertext;
mod engine;
mod execution;
mod parser;
use env_logger::Env;
use std::env;
fn main() {
let env = Env::default().filter_or("RUST_LOG", "info");
env_logger::init_from_env(env);
let args: Vec<String> = env::args().collect();
let content = &args[1];
let pattern = &args[2];
let (client_key, server_key) = ciphertext::gen_keys();
let ct_content = ciphertext::encrypt_str(&client_key, content).unwrap();
let ct_res = engine::has_match(&server_key, &ct_content, pattern).unwrap();
let res: u64 = client_key.decrypt(&ct_res);
if res == 0 {
println!("no match");
} else {
println!("match");
}
}

View File

@@ -0,0 +1,701 @@
use combine::parser::byte;
use combine::parser::byte::byte;
use combine::*;
use std::fmt;
#[derive(Clone, PartialEq, Eq, Hash)]
pub(crate) enum RegExpr {
Sof,
Eof,
Char {
c: u8,
},
AnyChar,
Between {
from: u8,
to: u8,
},
Range {
cs: Vec<u8>,
},
Not {
not_re: Box<RegExpr>,
},
Either {
l_re: Box<RegExpr>,
r_re: Box<RegExpr>,
},
Optional {
opt_re: Box<RegExpr>,
},
Repeated {
repeat_re: Box<RegExpr>,
at_least: Option<usize>, // if None: no least limit, aka 0 times
at_most: Option<usize>, // if None: no most limit
},
Seq {
re_xs: Vec<RegExpr>,
},
}
impl RegExpr {
fn case_insensitive(self) -> Self {
match self {
Self::Char { c } => Self::Range {
cs: case_insensitive(c),
},
Self::Not { not_re } => Self::Not {
not_re: Box::new(not_re.case_insensitive()),
},
Self::Either { l_re, r_re } => Self::Either {
l_re: Box::new(l_re.case_insensitive()),
r_re: Box::new(r_re.case_insensitive()),
},
Self::Optional { opt_re } => Self::Optional {
opt_re: Box::new(opt_re.case_insensitive()),
},
Self::Repeated {
repeat_re,
at_least,
at_most,
} => Self::Repeated {
repeat_re: Box::new(repeat_re.case_insensitive()),
at_least,
at_most,
},
Self::Seq { re_xs } => Self::Seq {
re_xs: re_xs.into_iter().map(|re| re.case_insensitive()).collect(),
},
_ => self,
}
}
}
fn case_insensitive(x: u8) -> Vec<u8> {
let c = u8_to_char(x);
if c.is_ascii_lowercase() {
return vec![x, c.to_ascii_uppercase() as u8];
}
if c.is_ascii_uppercase() {
return vec![x, c.to_ascii_lowercase() as u8];
}
vec![x]
}
pub(crate) fn u8_to_char(c: u8) -> char {
char::from_u32(c as u32).unwrap()
}
impl fmt::Debug for RegExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Sof => write!(f, "^"),
Self::Eof => write!(f, "$"),
Self::Char { c } => write!(f, "{}", u8_to_char(*c)),
Self::AnyChar => write!(f, "."),
Self::Not { not_re } => {
write!(f, "[^")?;
not_re.fmt(f)?;
write!(f, "]")
}
Self::Between { from, to } => {
write!(f, "[{}->{}]", u8_to_char(*from), u8_to_char(*to),)
}
Self::Range { cs } => write!(
f,
"[{}]",
cs.iter().map(|c| u8_to_char(*c)).collect::<String>(),
),
Self::Either { l_re, r_re } => {
write!(f, "(")?;
l_re.fmt(f)?;
write!(f, "|")?;
r_re.fmt(f)?;
write!(f, ")")
}
Self::Repeated {
repeat_re,
at_least,
at_most,
} => {
let stringify_opt_n = |opt_n: &Option<usize>| -> String {
opt_n.map_or("*".to_string(), |n| format!("{:?}", n))
};
repeat_re.fmt(f)?;
write!(
f,
"{{{},{}}}",
stringify_opt_n(at_least),
stringify_opt_n(at_most)
)
}
Self::Optional { opt_re } => {
opt_re.fmt(f)?;
write!(f, "?")
}
Self::Seq { re_xs } => {
write!(f, "<")?;
for re_x in re_xs {
re_x.fmt(f)?;
}
write!(f, ">")?;
Ok(())
}
}
}
}
pub(crate) fn parse(pattern: &str) -> Result<RegExpr, Box<dyn std::error::Error>> {
let (parsed, unparsed) = (
between(
byte(b'/'),
byte(b'/'),
(optional(byte(b'^')), regex(), optional(byte(b'$'))),
)
.map(|(sof, re, eof)| {
if sof.is_none() && eof.is_none() {
return re;
}
let mut re_xs = vec![];
if sof.is_some() {
re_xs.push(RegExpr::Sof);
}
re_xs.push(re);
if eof.is_some() {
re_xs.push(RegExpr::Eof);
}
RegExpr::Seq { re_xs }
}),
optional(byte(b'i')),
)
.map(|(re, case_insensitive)| {
if case_insensitive.is_some() {
re.case_insensitive()
} else {
re
}
})
.parse(pattern.as_bytes())?;
if !unparsed.is_empty() {
return Err(format!(
"failed to parse regular expression, unexpected token at start of: {}",
std::str::from_utf8(unparsed).unwrap()
)
.into());
}
Ok(parsed)
}
// based on grammar from: https://matt.might.net/articles/parsing-regex-with-recursive-descent/
//
// <regex> ::= <term> '|' <regex>
// | <term>
//
// <term> ::= { <factor> }
//
// <factor> ::= <base> { '*' }
//
// <base> ::= <char>
// | '\' <char>
// | '(' <regex> ')'
parser! {
fn regex[Input]()(Input) -> RegExpr
where [Input: Stream<Token = u8>]
{
regex_()
}
}
fn regex_<Input>() -> impl Parser<Input, Output = RegExpr>
where
Input: Stream<Token = u8>,
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
{
choice((
attempt(
(term(), byte(b'|'), regex()).map(|(l_re, _, r_re)| RegExpr::Either {
l_re: Box::new(l_re),
r_re: Box::new(r_re),
}),
),
term(),
))
}
fn term<Input>() -> impl Parser<Input, Output = RegExpr>
where
Input: Stream<Token = u8>,
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
{
many(factor()).map(|re_xs: Vec<RegExpr>| {
if re_xs.len() == 1 {
re_xs[0].clone()
} else {
RegExpr::Seq { re_xs }
}
})
}
fn factor<Input>() -> impl Parser<Input, Output = RegExpr>
where
Input: Stream<Token = u8>,
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
{
choice((
attempt((atom(), byte(b'?'))).map(|(re, _)| RegExpr::Optional {
opt_re: Box::new(re),
}),
attempt(repeated()),
atom(),
))
}
const NON_ESCAPABLE_SYMBOLS: [u8; 14] = [
b'&', b';', b':', b',', b'`', b'~', b'-', b'_', b'!', b'@', b'#', b'%', b'\'', b'\"',
];
fn atom<Input>() -> impl Parser<Input, Output = RegExpr>
where
Input: Stream<Token = u8>,
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
{
choice((
byte(b'.').map(|_| RegExpr::AnyChar),
attempt(byte(b'\\').with(parser::token::any())).map(|c| RegExpr::Char { c }),
choice((
byte::alpha_num(),
parser::token::one_of(NON_ESCAPABLE_SYMBOLS),
))
.map(|c| RegExpr::Char { c }),
between(byte(b'['), byte(b']'), range()),
between(byte(b'('), byte(b')'), regex()),
))
}
parser! {
fn range[Input]()(Input) -> RegExpr
where [Input: Stream<Token = u8>]
{
range_()
}
}
fn range_<Input>() -> impl Parser<Input, Output = RegExpr>
where
Input: Stream<Token = u8>,
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
{
choice((
byte(b'^').with(range()).map(|re| RegExpr::Not {
not_re: Box::new(re),
}),
attempt(
(byte::alpha_num(), byte(b'-'), byte::alpha_num())
.map(|(from, _, to)| RegExpr::Between { from, to }),
),
many1(byte::alpha_num()).map(|cs| RegExpr::Range { cs }),
))
}
fn repeated<Input>() -> impl Parser<Input, Output = RegExpr>
where
Input: Stream<Token = u8>,
Input::Error: ParseError<Input::Token, Input::Range, Input::Position>,
{
choice((
attempt((atom(), choice((byte(b'*'), byte(b'+'))))).map(|(re, c)| RegExpr::Repeated {
repeat_re: Box::new(re),
at_least: if c == b'*' { None } else { Some(1) },
at_most: None,
}),
attempt((
atom(),
between(byte(b'{'), byte(b'}'), many::<Vec<u8>, _, _>(byte::digit())),
))
.map(|(re, repeat_digits)| {
let repeat = parse_digits(&repeat_digits);
RegExpr::Repeated {
repeat_re: Box::new(re),
at_least: Some(repeat),
at_most: Some(repeat),
}
}),
(
atom(),
between(
byte(b'{'),
byte(b'}'),
(
many::<Vec<u8>, _, _>(byte::digit()),
byte(b','),
many::<Vec<u8>, _, _>(byte::digit()),
),
),
)
.map(
|(re, (at_least_digits, _, at_most_digits))| RegExpr::Repeated {
repeat_re: Box::new(re),
at_least: if at_least_digits.is_empty() {
None
} else {
Some(parse_digits(&at_least_digits))
},
at_most: if at_most_digits.is_empty() {
None
} else {
Some(parse_digits(&at_most_digits))
},
},
),
))
}
fn parse_digits(digits: &[u8]) -> usize {
std::str::from_utf8(digits).unwrap().parse().unwrap()
}
#[cfg(test)]
mod tests {
use crate::parser::{parse, RegExpr};
use test_case::test_case;
#[test_case("/h/", RegExpr::Char { c: b'h' }; "char")]
#[test_case("/&/", RegExpr::Char { c: b'&' }; "not necessary to escape ampersand")]
#[test_case("/;/", RegExpr::Char { c: b';' }; "not necessary to escape semicolon")]
#[test_case("/:/", RegExpr::Char { c: b':' }; "not necessary to escape colon")]
#[test_case("/,/", RegExpr::Char { c: b',' }; "not necessary to escape comma")]
#[test_case("/`/", RegExpr::Char { c: b'`' }; "not necessary to escape backtick")]
#[test_case("/~/", RegExpr::Char { c: b'~' }; "not necessary to escape tilde")]
#[test_case("/-/", RegExpr::Char { c: b'-' }; "not necessary to escape minus")]
#[test_case("/_/", RegExpr::Char { c: b'_' }; "not necessary to escape underscore")]
#[test_case("/%/", RegExpr::Char { c: b'%' }; "not necessary to escape percentage")]
#[test_case("/#/", RegExpr::Char { c: b'#' }; "not necessary to escape hashtag")]
#[test_case("/@/", RegExpr::Char { c: b'@' }; "not necessary to escape at")]
#[test_case("/!/", RegExpr::Char { c: b'!' }; "not necessary to escape exclamation")]
#[test_case("/'/", RegExpr::Char { c: b'\'' }; "not necessary to escape single quote")]
#[test_case("/\"/", RegExpr::Char { c: b'\"' }; "not necessary to escape double quote")]
#[test_case("/\\h/", RegExpr::Char { c: b'h' }; "anything can be escaped")]
#[test_case("/./", RegExpr::AnyChar; "any")]
#[test_case("/abc/",
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
RegExpr::Char { c: b'c' },
]};
"abc")]
#[test_case("/^abc/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
RegExpr::Char { c: b'c' },
]},
]};
"<sof>abc")]
#[test_case("/abc$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"abc<eof>")]
#[test_case("/^abc$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>abc<eof>")]
#[test_case("/^ab?c$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Optional { opt_re: Box::new(RegExpr::Char { c: b'b' }) },
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>ab<question>c<eof>")]
#[test_case("/^ab*c$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
at_least: None,
at_most: None,
},
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>ab<star>c<eof>")]
#[test_case("/^ab+c$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
at_least: Some(1),
at_most: None,
},
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>ab<plus>c<eof>")]
#[test_case("/^ab{2}c$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
at_least: Some(2),
at_most: Some(2),
},
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>ab<twice>c<eof>")]
#[test_case("/^ab{3,}c$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
at_least: Some(3),
at_most: None,
},
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>ab<atleast 3>c<eof>")]
#[test_case("/^ab{2,4}c$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'b' }),
at_least: Some(2),
at_most: Some(4),
},
RegExpr::Char { c: b'c' },
]},
RegExpr::Eof,
]};
"<sof>ab<between 2 and 4>c<eof>")]
#[test_case("/^.$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::AnyChar,
RegExpr::Eof,
]};
"<sof><any><eof>")]
#[test_case("/^[abc]$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Range { cs: vec![b'a', b'b', b'c'] },
RegExpr::Eof,
]};
"<sof><a or b or c><eof>")]
#[test_case("/^[a-d]$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Between { from: b'a', to: b'd' },
RegExpr::Eof,
]};
"<sof><between a and d><eof>")]
#[test_case("/^[^abc]$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Not { not_re: Box::new(RegExpr::Range { cs: vec![b'a', b'b', b'c'] })},
RegExpr::Eof,
]};
"<sof><not <a or b or c>><eof>")]
#[test_case("/^[^a-d]$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Not { not_re: Box::new(RegExpr::Between { from: b'a', to: b'd' }) },
RegExpr::Eof,
]};
"<sof><not <between a and d>><eof>")]
#[test_case("/^abc$/i",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Range { cs: vec![b'a', b'A'] },
RegExpr::Range { cs: vec![b'b', b'B'] },
RegExpr::Range { cs: vec![b'c', b'C'] },
]},
RegExpr::Eof,
]};
"<sof>abc<eof> (case insensitive)")]
#[test_case("/^/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq { re_xs: vec![] }
]};
"sof")]
#[test_case("/$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Seq { re_xs: vec![] },
RegExpr::Eof
]};
"eof")]
#[test_case("/a*/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
at_least: None,
at_most: None,
};
"repeat unbounded (w/ *)")]
#[test_case("/a+/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
at_least: Some(1),
at_most: None,
};
"repeat bounded at least (w/ +)")]
#[test_case("/a{104,}/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
at_least: Some(104),
at_most: None,
};
"repeat bounded at least (w/ {x,}")]
#[test_case("/a{,15}/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
at_least: None,
at_most: Some(15),
};
"repeat bounded at most (w/ {,x}")]
#[test_case("/a{12,15}/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Char { c: b'a' }),
at_least: Some(12),
at_most: Some(15),
};
"repeat bounded at least and at most (w/ {x,y}")]
#[test_case("/(a|b)*/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Either {
l_re: Box::new(RegExpr::Char { c: b'a' }),
r_re: Box::new(RegExpr::Char { c: b'b' }),
}),
at_least: None,
at_most: None,
};
"repeat complex unbounded")]
#[test_case("/(a|b){3,7}/",
RegExpr::Repeated {
repeat_re: Box::new(RegExpr::Either {
l_re: Box::new(RegExpr::Char { c: b'a' }),
r_re: Box::new(RegExpr::Char { c: b'b' }),
}),
at_least: Some(3),
at_most: Some(7),
};
"repeat complex bounded")]
#[test_case("/^ab|cd/",
RegExpr::Seq { re_xs: vec![
RegExpr::Sof,
RegExpr::Either {
l_re: Box::new(RegExpr::Seq { re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
] }),
r_re: Box::new(RegExpr::Seq { re_xs: vec![
RegExpr::Char { c: b'c' },
RegExpr::Char { c: b'd' },
]}),
},
]};
"Sof encapsulates full RHS")]
#[test_case("/ab|cd$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Either {
l_re: Box::new(RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
]}),
r_re: Box::new(RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'c' },
RegExpr::Char { c: b'd' },
]}),
},
RegExpr::Eof,
]};
"Eof encapsulates full RHS" )]
#[test_case("/^ab|cd$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Either {
l_re: Box::new(RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'b' },
]}),
r_re: Box::new(RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'c' },
RegExpr::Char { c: b'd' },
]}),
},
RegExpr::Eof,
]};
"Sof + Eof both encapsulate full center")]
#[test_case("/\\^/",
RegExpr::Char { c: b'^' };
"escaping sof symbol")]
#[test_case("/\\./",
RegExpr::Char { c: b'.' };
"escaping period symbol")]
#[test_case("/\\*/",
RegExpr::Char { c: b'*' };
"escaping star symbol")]
#[test_case("/^ca\\^b$/",
RegExpr::Seq {re_xs: vec![
RegExpr::Sof,
RegExpr::Seq {re_xs: vec![
RegExpr::Char { c: b'c' },
RegExpr::Char { c: b'a' },
RegExpr::Char { c: b'^' },
RegExpr::Char { c: b'b' },
]},
RegExpr::Eof,
]};
"escaping, more realistic")]
#[test_case("/8/",
RegExpr::Char { c: b'8' };
"able to match numbers")]
#[test_case("/[7-9]/",
RegExpr::Between { from: b'7', to: b'9' };
"able to match a number range")]
#[test_case("/[79]/",
RegExpr::Range { cs: vec![b'7', b'9'] };
"able to match a number range (part 2)")]
fn test_parser(pattern: &str, exp: RegExpr) {
match parse(pattern) {
Ok(got) => assert_eq!(exp, got),
Err(e) => panic!("got err: {}", e),
}
}
}

View File

@@ -0,0 +1,469 @@
// This module contains all the operations and functions used in the sha256 function, implemented
// with homomorphic boolean operations. Both the bitwise operations, which serve as the building
// blocks for other functions, and the adders employ parallel processing techniques.
use rayon::prelude::*;
use std::array;
use tfhe::boolean::prelude::{BinaryBooleanGates, Ciphertext, ServerKey};
// Implementation of a Carry Save Adder, which computes sum and carry sequences very efficiently. We
// then add the final sum and carry values to obtain the result. CSAs are useful to speed up
// sequential additions
pub fn csa(
a: &[Ciphertext; 32],
b: &[Ciphertext; 32],
c: &[Ciphertext; 32],
sk: &ServerKey,
) -> ([Ciphertext; 32], [Ciphertext; 32]) {
let (carry, sum) = rayon::join(|| maj(a, b, c, sk), || xor(a, &xor(b, c, sk), sk));
// perform a left shift by one to discard the carry-out and set the carry-in to 0
let mut shifted_carry = trivial_bools(&[false; 32], sk);
for (i, elem) in carry.into_iter().enumerate() {
if i == 0 {
continue;
} else {
shifted_carry[i - 1] = elem;
}
}
(sum, shifted_carry)
}
pub fn add(
a: &[Ciphertext; 32],
b: &[Ciphertext; 32],
ladner_fischer_opt: bool,
sk: &ServerKey,
) -> [Ciphertext; 32] {
let (propagate, generate) = rayon::join(|| xor(a, b, sk), || and(a, b, sk));
let carry = if ladner_fischer_opt {
ladner_fischer(&propagate, &generate, sk)
} else {
brent_kung(&propagate, &generate, sk)
};
xor(&propagate, &carry, sk)
}
// Implementation of the Brent Kung parallel prefix algorithm
// This function computes the carry signals in parallel while minimizing the number of homomorphic
// operations
fn brent_kung(
propagate: &[Ciphertext; 32],
generate: &[Ciphertext; 32],
sk: &ServerKey,
) -> [Ciphertext; 32] {
let mut propagate = propagate.clone();
let mut generate = generate.clone();
for d in 0..5 {
// first 5 stages
let stride = 1 << d;
let indices: Vec<(usize, usize)> = (0..32 - stride)
.rev()
.step_by(2 * stride)
.map(|i| i + 1 - stride)
.enumerate()
.collect();
let updates: Vec<(usize, Ciphertext, Ciphertext)> = indices
.into_par_iter()
.map(|(n, index)| {
let new_p;
let new_g;
if n == 0 {
// grey cell
new_p = propagate[index].clone();
new_g = sk.or(
&generate[index],
&sk.and(&generate[index + stride], &propagate[index]),
);
} else {
// black cell
new_p = sk.and(&propagate[index], &propagate[index + stride]);
new_g = sk.or(
&generate[index],
&sk.and(&generate[index + stride], &propagate[index]),
);
}
(index, new_p, new_g)
})
.collect();
for (index, new_p, new_g) in updates {
propagate[index] = new_p;
generate[index] = new_g;
}
if d == 4 {
let mut cells = 0;
for d_2 in 0..4 {
// last 4 stages
let stride = 1 << (4 - d_2 - 1);
cells += 1 << d_2;
let indices: Vec<(usize, usize)> = (0..cells)
.map(|cell| (cell, stride + 2 * stride * cell))
.collect();
let updates: Vec<(usize, Ciphertext)> = indices
.into_par_iter()
.map(|(_, index)| {
let new_g = sk.or(
&generate[index],
&sk.and(&generate[index + stride], &propagate[index]),
);
(index, new_g)
})
.collect();
for (index, new_g) in updates {
generate[index] = new_g;
}
}
}
}
let mut carry = trivial_bools(&[false; 32], sk);
carry[..31].clone_from_slice(&generate[1..(31 + 1)]);
carry
}
// Implementation of the Ladner Fischer parallel prefix algorithm
// This function may perform better than the previous one when many threads are available as it has
// less stages
fn ladner_fischer(
propagate: &[Ciphertext; 32],
generate: &[Ciphertext; 32],
sk: &ServerKey,
) -> [Ciphertext; 32] {
let mut propagate = propagate.clone();
let mut generate = generate.clone();
for d in 0..5 {
let stride = 1 << d;
let indices: Vec<(usize, usize)> = (0..32 - stride)
.rev()
.step_by(2 * stride)
.flat_map(|i| (0..stride).map(move |count| (i, count)))
.collect();
let updates: Vec<(usize, Ciphertext, Ciphertext)> = indices
.into_par_iter()
.map(|(i, count)| {
let index = i - count; // current column
let p = propagate[i + 1].clone(); // propagate from a previous column
let g = generate[i + 1].clone(); // generate from a previous column
let new_p;
let new_g;
if index < 32 - (2 * stride) {
// black cell
new_p = sk.and(&propagate[index], &p);
new_g = sk.or(&generate[index], &sk.and(&g, &propagate[index]));
} else {
// grey cell
new_p = propagate[index].clone();
new_g = sk.or(&generate[index], &sk.and(&g, &propagate[index]));
}
(index, new_p, new_g)
})
.collect();
for (index, new_p, new_g) in updates {
propagate[index] = new_p;
generate[index] = new_g;
}
}
let mut carry = trivial_bools(&[false; 32], sk);
carry[..31].clone_from_slice(&generate[1..(31 + 1)]);
carry
}
// 2 (homomorphic) bitwise ops
pub fn sigma0(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let a = rotate_right(x, 7);
let b = rotate_right(x, 18);
let c = shift_right(x, 3, sk);
xor(&xor(&a, &b, sk), &c, sk)
}
pub fn sigma1(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let a = rotate_right(x, 17);
let b = rotate_right(x, 19);
let c = shift_right(x, 10, sk);
xor(&xor(&a, &b, sk), &c, sk)
}
pub fn sigma_upper_case_0(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let a = rotate_right(x, 2);
let b = rotate_right(x, 13);
let c = rotate_right(x, 22);
xor(&xor(&a, &b, sk), &c, sk)
}
pub fn sigma_upper_case_1(x: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let a = rotate_right(x, 6);
let b = rotate_right(x, 11);
let c = rotate_right(x, 25);
xor(&xor(&a, &b, sk), &c, sk)
}
// 0 bitwise ops
fn rotate_right(x: &[Ciphertext; 32], n: usize) -> [Ciphertext; 32] {
let mut result = x.clone();
result.rotate_right(n);
result
}
fn shift_right(x: &[Ciphertext; 32], n: usize, sk: &ServerKey) -> [Ciphertext; 32] {
let mut result = x.clone();
result.rotate_right(n);
result[..n].fill_with(|| sk.trivial_encrypt(false));
result
}
// 1 bitwise op
pub fn ch(
x: &[Ciphertext; 32],
y: &[Ciphertext; 32],
z: &[Ciphertext; 32],
sk: &ServerKey,
) -> [Ciphertext; 32] {
mux(x, y, z, sk)
}
// 4 bitwise ops
pub fn maj(
x: &[Ciphertext; 32],
y: &[Ciphertext; 32],
z: &[Ciphertext; 32],
sk: &ServerKey,
) -> [Ciphertext; 32] {
let (lhs, rhs) = rayon::join(|| and(x, &xor(y, z, sk), sk), || and(y, z, sk));
xor(&lhs, &rhs, sk)
}
// Parallelized homomorphic bitwise ops
// Building block for most of the previous functions
fn xor(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let mut result = a.clone();
result
.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(dst, (lhs, rhs))| *dst = sk.xor(lhs, rhs));
result
}
fn and(a: &[Ciphertext; 32], b: &[Ciphertext; 32], sk: &ServerKey) -> [Ciphertext; 32] {
let mut result = a.clone();
result
.par_iter_mut()
.zip(a.par_iter().zip(b.par_iter()))
.for_each(|(dst, (lhs, rhs))| *dst = sk.and(lhs, rhs));
result
}
fn mux(
condition: &[Ciphertext; 32],
then: &[Ciphertext; 32],
otherwise: &[Ciphertext; 32],
sk: &ServerKey,
) -> [Ciphertext; 32] {
let mut result = condition.clone();
result
.par_iter_mut()
.zip(
condition
.par_iter()
.zip(then.par_iter().zip(otherwise.par_iter())),
)
.for_each(|(dst, (condition, (then, other)))| *dst = sk.mux(condition, then, other));
result
}
// Trivial encryption of 32 bools
pub fn trivial_bools(bools: &[bool; 32], sk: &ServerKey) -> [Ciphertext; 32] {
array::from_fn(|i| sk.trivial_encrypt(bools[i]))
}
#[cfg(test)]
mod tests {
use super::*;
use tfhe::boolean::prelude::*;
fn to_bool_array(arr: [i32; 32]) -> [bool; 32] {
let mut bool_arr = [false; 32];
for i in 0..32 {
if arr[i] == 1 {
bool_arr[i] = true;
}
}
bool_arr
}
fn encrypt(bools: &[bool; 32], ck: &ClientKey) -> [Ciphertext; 32] {
array::from_fn(|i| ck.encrypt(bools[i]))
}
fn decrypt(bools: &[Ciphertext; 32], ck: &ClientKey) -> [bool; 32] {
array::from_fn(|i| ck.decrypt(&bools[i]))
}
#[test]
fn test_add_modulo_2_32() {
let (ck, sk) = gen_keys();
let a = encrypt(
&to_bool_array([
0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1,
1, 0, 0, 1,
]),
&ck,
);
let b = encrypt(
&to_bool_array([
0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,
1, 0, 1, 1,
]),
&ck,
);
let c = encrypt(
&to_bool_array([
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0,
1, 1, 0, 0,
]),
&ck,
);
let d = encrypt(
&to_bool_array([
0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1,
1, 0, 0, 0,
]),
&ck,
);
let e = encrypt(
&to_bool_array([
0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
1, 1, 0, 0,
]),
&ck,
);
let (sum, carry) = csa(&c, &d, &e, &sk);
let (sum, carry) = csa(&b, &sum, &carry, &sk);
let (sum, carry) = csa(&a, &sum, &carry, &sk);
let output = add(&sum, &carry, false, &sk);
let result = decrypt(&output, &ck);
let expected = to_bool_array([
0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0,
1, 0, 0,
]);
assert_eq!(result, expected);
}
#[test]
fn test_sigma0() {
let (ck, sk) = gen_keys();
let input = encrypt(
&to_bool_array([
0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0,
1, 1, 1, 1,
]),
&ck,
);
let output = sigma0(&input, &sk);
let result = decrypt(&output, &ck);
let expected = to_bool_array([
1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1,
0, 1, 1,
]);
assert_eq!(result, expected);
} //the other sigmas are implemented in the same way
#[test]
fn test_ch() {
let (ck, sk) = gen_keys();
let e = encrypt(
&to_bool_array([
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1,
1, 1, 1, 1,
]),
&ck,
);
let f = encrypt(
&to_bool_array([
1, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
1, 1, 0, 0,
]),
&ck,
);
let g = encrypt(
&to_bool_array([
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0,
1, 0, 1, 1,
]),
&ck,
);
let output = ch(&e, &f, &g, &sk);
let result = decrypt(&output, &ck);
let expected = to_bool_array([
0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1,
1, 0, 0,
]);
assert_eq!(result, expected);
}
#[test]
fn test_maj() {
let (ck, sk) = gen_keys();
let a = encrypt(
&to_bool_array([
0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0,
0, 1, 1, 1,
]),
&ck,
);
let b = encrypt(
&to_bool_array([
1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0,
0, 1, 0, 1,
]),
&ck,
);
let c = encrypt(
&to_bool_array([
0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1,
0, 0, 1, 0,
]),
&ck,
);
let output = maj(&a, &b, &c, &sk);
let result = decrypt(&output, &ck);
let expected = to_bool_array([
0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0,
1, 1, 1,
]);
assert_eq!(result, expected);
}
}

View File

@@ -0,0 +1,74 @@
mod boolean_ops;
mod padding;
mod sha256_function;
use clap::{Arg, ArgAction, Command};
use padding::pad_sha256_input;
use sha256_function::{bools_to_hex, sha256_fhe};
use std::io;
use tfhe::boolean::prelude::*;
fn main() {
let matches = Command::new("Homomorphic sha256")
.arg(
Arg::new("ladner_fischer")
.long("ladner-fischer")
.help("Use the Ladner Fischer parallel prefix algorithm for additions")
.action(ArgAction::SetTrue),
)
.get_matches();
// If set using the command line flag "--ladner-fischer" this algorithm will be used in
// additions
let ladner_fischer: bool = matches.get_flag("ladner_fischer");
// INTRODUCE INPUT FROM STDIN
let mut input = String::new();
println!("Write input to hash:");
io::stdin()
.read_line(&mut input)
.expect("Failed to read line");
input = input.trim_end_matches('\n').to_string();
println!("You entered: \"{}\"", input);
// CLIENT PADS DATA AND ENCRYPTS IT
let (ck, sk) = gen_keys();
let padded_input = pad_sha256_input(&input);
let encrypted_input = encrypt_bools(&padded_input, &ck);
// SERVER COMPUTES OVER THE ENCRYPTED PADDED DATA
println!("Computing the hash");
let encrypted_output = sha256_fhe(encrypted_input, ladner_fischer, &sk);
// CLIENT DECRYPTS THE OUTPUT
let output = decrypt_bools(&encrypted_output, &ck);
let outhex = bools_to_hex(output);
println!("{}", outhex);
}
fn encrypt_bools(bools: &Vec<bool>, ck: &ClientKey) -> Vec<Ciphertext> {
let mut ciphertext = vec![];
for bool in bools {
ciphertext.push(ck.encrypt(*bool));
}
ciphertext
}
fn decrypt_bools(ciphertext: &Vec<Ciphertext>, ck: &ClientKey) -> Vec<bool> {
let mut bools = vec![];
for cipher in ciphertext {
bools.push(ck.decrypt(cipher));
}
bools
}

View File

@@ -0,0 +1,70 @@
// This module contains the padding function, which is computed by the client over the plain text.
// The function returns the padded data as a vector of bools, for later encryption. Note that
// padding could also be performed by the server, by appending trivially encrypted bools. However,
// in our implementation, the exact length of the pre-image (hashed message) is not revealed.
// If input starts with "0x" and following characters are valid hexadecimal values, it's interpreted
// as hex, otherwise input is interpreted as text
pub fn pad_sha256_input(input: &str) -> Vec<bool> {
let bytes = if input.starts_with("0x") && is_valid_hex(&input[2..]) {
let no_prefix = &input[2..];
let hex_input = if no_prefix.len() % 2 == 0 {
// hex value can be converted to bytes
no_prefix.to_string()
} else {
format!("0{}", no_prefix) // pad hex value to ensure a correct conversion to bytes
};
hex_input
.as_bytes()
.chunks(2)
.map(|chunk| u8::from_str_radix(std::str::from_utf8(chunk).unwrap(), 16).unwrap())
.collect::<Vec<u8>>()
} else {
input.as_bytes().to_vec()
};
pad_sha256_data(&bytes)
}
fn is_valid_hex(hex: &str) -> bool {
hex.chars().all(|c| c.is_ascii_hexdigit())
}
fn pad_sha256_data(data: &[u8]) -> Vec<bool> {
let mut bits: Vec<bool> = data
.iter()
.flat_map(|byte| (0..8).rev().map(move |i| (byte >> i) & 1 == 1))
.collect();
// Append a single '1' bit
bits.push(true);
// Calculate the number of padding zeros required
let padding_zeros = (512 - ((bits.len() + 64) % 512)) % 512;
bits.extend(std::iter::repeat(false).take(padding_zeros));
// Append a 64-bit big-endian representation of the original message length
let data_len_bits = (data.len() as u64) * 8;
bits.extend((0..64).rev().map(|i| (data_len_bits >> i) & 1 == 1));
bits
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sha256_function::bools_to_hex;
#[test]
fn test_pad_sha256_input() {
let input = "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq";
let expected_output = "6162636462636465636465666465666765666768666768696768696a68696a6\
b696a6b6c6a6b6c6d6b6c6d6e6c6d6e6f6d6e6f706e6f70718000000000000000000000000000000000000000000\
000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001c0";
let result = pad_sha256_input(input);
let hex_result = bools_to_hex(result);
assert_eq!(hex_result, expected_output);
}
}

View File

@@ -0,0 +1,236 @@
// This module implements the main sha256 homomorphic function using parallel processing when
// possible and some helper functions
use crate::boolean_ops::{
add, ch, csa, maj, sigma0, sigma1, sigma_upper_case_0, sigma_upper_case_1, trivial_bools,
};
use std::array;
use tfhe::boolean::prelude::*;
pub fn sha256_fhe(
padded_input: Vec<Ciphertext>,
ladner_fischer: bool,
sk: &ServerKey,
) -> Vec<Ciphertext> {
assert_eq!(
padded_input.len() % 512,
0,
"padded input length is not a multiple of 512"
);
// Initialize hash values
let mut hash: [[Ciphertext; 32]; 8] = [
trivial_bools(&hex_to_bools(0x6a09e667), sk),
trivial_bools(&hex_to_bools(0xbb67ae85), sk),
trivial_bools(&hex_to_bools(0x3c6ef372), sk),
trivial_bools(&hex_to_bools(0xa54ff53a), sk),
trivial_bools(&hex_to_bools(0x510e527f), sk),
trivial_bools(&hex_to_bools(0x9b05688c), sk),
trivial_bools(&hex_to_bools(0x1f83d9ab), sk),
trivial_bools(&hex_to_bools(0x5be0cd19), sk),
];
let chunks = padded_input.chunks_exact(512);
for chunk in chunks {
// Compute the 64 words
let mut w = initialize_w(sk);
for i in 0..16 {
w[i].clone_from_slice(&chunk[i * 32..(i + 1) * 32]);
}
for i in (16..64).step_by(2) {
let u = i + 1;
let (word_i, word_u) = rayon::join(
|| {
let (s0, s1) = rayon::join(|| sigma0(&w[i - 15], sk), || sigma1(&w[i - 2], sk));
let (sum, carry) = csa(&s0, &w[i - 7], &w[i - 16], sk);
let (sum, carry) = csa(&s1, &sum, &carry, sk);
add(&sum, &carry, ladner_fischer, sk)
},
|| {
let (s0, s1) = rayon::join(|| sigma0(&w[u - 15], sk), || sigma1(&w[u - 2], sk));
let (sum, carry) = csa(&s0, &w[u - 7], &w[u - 16], sk);
let (sum, carry) = csa(&s1, &sum, &carry, sk);
add(&sum, &carry, ladner_fischer, sk)
},
);
w[i] = word_i;
w[u] = word_u;
}
let mut a = hash[0].clone();
let mut b = hash[1].clone();
let mut c = hash[2].clone();
let mut d = hash[3].clone();
let mut e = hash[4].clone();
let mut f = hash[5].clone();
let mut g = hash[6].clone();
let mut h = hash[7].clone();
// Compression loop
for i in 0..64 {
let (temp1, temp2) = rayon::join(
|| {
let ((sum, carry), s1) = rayon::join(
|| {
let ((sum, carry), ch) = rayon::join(
|| csa(&h, &w[i], &trivial_bools(&hex_to_bools(K[i]), sk), sk),
|| ch(&e, &f, &g, sk),
);
csa(&sum, &carry, &ch, sk)
},
|| sigma_upper_case_1(&e, sk),
);
let (sum, carry) = csa(&sum, &carry, &s1, sk);
add(&sum, &carry, ladner_fischer, sk)
},
|| {
add(
&sigma_upper_case_0(&a, sk),
&maj(&a, &b, &c, sk),
ladner_fischer,
sk,
)
},
);
let (temp_e, temp_a) = rayon::join(
|| add(&d, &temp1, ladner_fischer, sk),
|| add(&temp1, &temp2, ladner_fischer, sk),
);
h = g;
g = f;
f = e;
e = temp_e;
d = c;
c = b;
b = a;
a = temp_a;
}
hash[0] = add(&hash[0], &a, ladner_fischer, sk);
hash[1] = add(&hash[1], &b, ladner_fischer, sk);
hash[2] = add(&hash[2], &c, ladner_fischer, sk);
hash[3] = add(&hash[3], &d, ladner_fischer, sk);
hash[4] = add(&hash[4], &e, ladner_fischer, sk);
hash[5] = add(&hash[5], &f, ladner_fischer, sk);
hash[6] = add(&hash[6], &g, ladner_fischer, sk);
hash[7] = add(&hash[7], &h, ladner_fischer, sk);
}
// Concatenate the final hash values to produce a 256-bit hash
let mut output = vec![];
for item in &hash {
for j in item.iter().take(32) {
output.push(j.clone());
}
}
output
}
// Initialize the 64 words with trivial encryption
fn initialize_w(sk: &ServerKey) -> [[Ciphertext; 32]; 64] {
array::from_fn(|_| trivial_bools(&[false; 32], sk))
}
// To represent decrypted digest bools as hexadecimal String
pub fn bools_to_hex(bools: Vec<bool>) -> String {
let mut hex_string = String::new();
let mut byte = 0u8;
let mut counter = 0;
for bit in bools {
byte <<= 1;
if bit {
byte |= 1;
}
counter += 1;
if counter == 8 {
hex_string.push_str(&format!("{:02x}", byte));
byte = 0;
counter = 0;
}
}
// Handle any remaining bits in case the bools vector length is not a multiple of 8
if counter > 0 {
byte <<= 8 - counter;
hex_string.push_str(&format!("{:02x}", byte));
}
hex_string
}
// To represent constant values as bool arrays
fn hex_to_bools(hex_value: u32) -> [bool; 32] {
let mut bool_array = [false; 32];
let mut mask = 0x8000_0000;
for item in &mut bool_array {
*item = (hex_value & mask) != 0;
mask >>= 1;
}
bool_array
}
const K: [u32; 64] = [
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
];
#[cfg(test)]
mod tests {
use super::*;
fn to_bool_array(arr: [i32; 32]) -> [bool; 32] {
let mut bool_arr = [false; 32];
for i in 0..32 {
if arr[i] == 1 {
bool_arr[i] = true;
}
}
bool_arr
}
#[test]
fn test_bools_to_hex() {
let bools = to_bool_array([
1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 1, 0,
]);
let hex_bools = bools_to_hex(bools.to_vec());
assert_eq!(hex_bools, "90befffa");
}
#[test]
fn test_hex_to_bools() {
let hex = 0x428a2f98;
let result = hex_to_bools(hex);
let expected = to_bool_array([
0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1,
0, 0, 0,
]);
assert_eq!(result, expected);
}
}

View File

@@ -0,0 +1,555 @@
const test = require('node:test');
const assert = require('node:assert').strict;
const { performance } = require('perf_hooks');
const {
ShortintParametersName,
ShortintParameters,
TfheClientKey,
TfhePublicKey,
TfheCompressedPublicKey,
TfheCompactPublicKey,
TfheCompressedServerKey,
TfheConfigBuilder,
CompressedFheUint8,
FheUint8,
FheUint32,
CompactFheUint32,
CompactFheUint32List,
CompressedFheUint128,
FheUint128,
CompressedFheUint256,
CompactFheUint256,
CompactFheUint256List,
FheUint256
} = require("../pkg/tfhe.js");
const U256_MAX = BigInt("115792089237316195423570985008687907853269984665640564039457584007913129639935");
const U128_MAX = BigInt("340282366920938463463374607431768211455");
const U32_MAX = 4294967295;
// Here integers are not enabled
// but we try to use them, so an error should be returned
// as the underlying panic should have been trapped
test('hlapi_panic', (t) => {
let config = TfheConfigBuilder.all_disabled()
.build();
let clientKey = TfheClientKey.generate(config);
let clear = 73;
try {
let _ = FheUint8.encrypt_with_client_key(clear, clientKey);
assert(false);
} catch (e) {
assert(true);
}
});
test('hlapi_key_gen_big', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers()
.build();
let clientKey = TfheClientKey.generate(config);
let compressedServerKey = TfheCompressedServerKey.new(clientKey);
try {
let publicKey = TfhePublicKey.new(clientKey);
assert(false);
} catch (e) {
assert(true)
}
let serializedClientKey = clientKey.serialize();
let serializedCompressedServerKey = compressedServerKey.serialize();
});
test('hlapi_key_gen_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
let compressedServerKey = TfheCompressedServerKey.new(clientKey);
let publicKey = TfhePublicKey.new(clientKey);
let serializedClientKey = clientKey.serialize();
let serializedCompressedServerKey = compressedServerKey.serialize();
let serializedPublicKey = publicKey.serialize();
});
test('hlapi_client_key_encrypt_decrypt_uint8_big', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers()
.build();
let clientKey = TfheClientKey.generate(config);
let clear = 73;
let encrypted = FheUint8.encrypt_with_client_key(clear, clientKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, clear);
let serialized = encrypted.serialize();
let deserialized = FheUint8.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, clear);
});
test('hlapi_public_key_encrypt_decrypt_uint32_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
let publicKey = TfhePublicKey.new(clientKey);
let encrypted = FheUint32.encrypt_with_public_key(U32_MAX, publicKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U32_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint32.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U32_MAX);
});
test('hlapi_decompress_public_key_then_encrypt_decrypt_uint32_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
var startTime = performance.now()
let compressedPublicKey = TfheCompressedPublicKey.new(clientKey);
var endTime = performance.now()
let data = compressedPublicKey.serialize()
let publicKey = compressedPublicKey.decompress();
var startTime = performance.now()
let encrypted = FheUint8.encrypt_with_public_key(255, publicKey);
var endTime = performance.now()
let ser = encrypted.serialize();
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, 255);
let serialized = encrypted.serialize();
let deserialized = FheUint32.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, 255);
});
test('hlapi_compressed_public_client_uint8_big', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers()
.build();
let clientKey = TfheClientKey.generate(config);
let clear = 73;
let compressed_encrypted = CompressedFheUint8.encrypt_with_client_key(clear, clientKey);
let compressed_serialized = compressed_encrypted.serialize();
let compressed_deserialized = CompressedFheUint8.deserialize(compressed_serialized);
let decompressed = compressed_deserialized.decompress()
let decrypted = decompressed.decrypt(clientKey);
assert.deepStrictEqual(decrypted, clear);
});
test('hlapi_client_key_encrypt_decrypt_uint128_big', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers()
.build();
let clientKey = TfheClientKey.generate(config);
let encrypted = FheUint128.encrypt_with_client_key(U128_MAX, clientKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U128_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint128.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U128_MAX);
// Compressed
let compressed_encrypted = CompressedFheUint128.encrypt_with_client_key(U128_MAX, clientKey);
let compressed_serialized = compressed_encrypted.serialize();
let compressed_deserialized = CompressedFheUint128.deserialize(compressed_serialized);
let decompressed = compressed_deserialized.decompress()
decrypted = decompressed.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U128_MAX);
});
test('hlapi_client_key_encrypt_decrypt_uint128_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
let encrypted = FheUint128.encrypt_with_client_key(U128_MAX, clientKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U128_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint128.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U128_MAX);
// Compressed
let compressed_encrypted = CompressedFheUint128.encrypt_with_client_key(U128_MAX, clientKey);
let compressed_serialized = compressed_encrypted.serialize();
let compressed_deserialized = CompressedFheUint128.deserialize(compressed_serialized);
let decompressed = compressed_deserialized.decompress()
decrypted = decompressed.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U128_MAX);
});
test('hlapi_client_key_encrypt_decrypt_uint256_big', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers()
.build();
let clientKey = TfheClientKey.generate(config);
let encrypted = FheUint256.encrypt_with_client_key(U256_MAX, clientKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint256.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
// Compressed
let compressed_encrypted = CompressedFheUint256.encrypt_with_client_key(U256_MAX, clientKey);
let compressed_serialized = compressed_encrypted.serialize();
let compressed_deserialized = CompressedFheUint256.deserialize(compressed_serialized);
let decompressed = compressed_deserialized.decompress()
decrypted = decompressed.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
});
test('hlapi_client_key_encrypt_decrypt_uint256_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
let encrypted = FheUint256.encrypt_with_client_key(U256_MAX, clientKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint256.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
// Compressed
let compressed_encrypted = CompressedFheUint256.encrypt_with_client_key(U256_MAX, clientKey);
let compressed_serialized = compressed_encrypted.serialize();
let compressed_deserialized = CompressedFheUint256.deserialize(compressed_serialized);
let decompressed = compressed_deserialized.decompress()
decrypted = decompressed.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
});
test('hlapi_decompress_public_key_then_encrypt_decrypt_uint256_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
let compressedPublicKey = TfheCompressedPublicKey.new(clientKey);
let publicKey = compressedPublicKey.decompress();
let encrypted = FheUint256.encrypt_with_public_key(U256_MAX, publicKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint256.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
});
test('hlapi_public_key_encrypt_decrypt_uint256_small', (t) => {
let config = TfheConfigBuilder.all_disabled()
.enable_default_integers_small()
.build();
let clientKey = TfheClientKey.generate(config);
let publicKey = TfhePublicKey.new(clientKey);
let encrypted = FheUint256.encrypt_with_public_key(U256_MAX, publicKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint256.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
});
//////////////////////////////////////////////////////////////////////////////
/// 32 bits compact
//////////////////////////////////////////////////////////////////////////////
function hlapi_compact_public_key_encrypt_decrypt_uint32_single(config) {
let clientKey = TfheClientKey.generate(config);
let publicKey = TfheCompactPublicKey.new(clientKey);
let encrypted = FheUint32.encrypt_with_compact_public_key(U32_MAX, publicKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U32_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint32.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U32_MAX);
}
test('hlapi_compact_public_key_encrypt_decrypt_uint32_big_single', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint32_single(config);
});
test('hlapi_compact_public_key_encrypt_decrypt_uint32_small_single', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint32_single(config);
});
function hlapi_compact_public_key_encrypt_decrypt_uint32_single_compact(config) {
let clientKey = TfheClientKey.generate(config);
let publicKey = TfheCompactPublicKey.new(clientKey);
let compact_encrypted = CompactFheUint32.encrypt_with_compact_public_key(U32_MAX, publicKey);
let encrypted = compact_encrypted.expand();
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U32_MAX);
let serialized = compact_encrypted.serialize();
let deserialized = CompactFheUint32.deserialize(serialized);
let deserialized_decrypted = deserialized.expand().decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U32_MAX);
}
test('hlapi_compact_public_key_encrypt_decrypt_uint32_small_single_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint32_single_compact(config);
});
test('hlapi_compact_public_key_encrypt_decrypt_uint32_big_single_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint32_single_compact(config);
});
function hlapi_compact_public_key_encrypt_decrypt_uint32_list_compact(config) {
let clientKey = TfheClientKey.generate(config);
let publicKey = TfheCompactPublicKey.new(clientKey);
let values = [0, 1, 2394, U32_MAX];
let compact_list = CompactFheUint32List.encrypt_with_compact_public_key(values, publicKey);
{
let encrypted_list = compact_list.expand();
assert.deepStrictEqual(encrypted_list.length, values.length);
for (let i = 0; i < values.length; i++)
{
let decrypted = encrypted_list[i].decrypt(clientKey);
assert.deepStrictEqual(decrypted, values[i]);
}
}
let serialized_list = compact_list.serialize();
let deserialized_list = CompactFheUint32List.deserialize(serialized_list);
let encrypted_list = deserialized_list.expand();
assert.deepStrictEqual(encrypted_list.length, values.length);
for (let i = 0; i < values.length; i++)
{
let decrypted = encrypted_list[i].decrypt(clientKey);
assert.deepStrictEqual(decrypted, values[i]);
}
}
test('hlapi_compact_public_key_encrypt_decrypt_uint32_small_list_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint32_list_compact(config);
});
test('hlapi_compact_public_key_encrypt_decrypt_uint32_big_list_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint32_list_compact(config);
});
//////////////////////////////////////////////////////////////////////////////
/// 256 bits compact
//////////////////////////////////////////////////////////////////////////////
function hlapi_compact_public_key_encrypt_decrypt_uint256_single(config) {
let clientKey = TfheClientKey.generate(config);
let publicKey = TfheCompactPublicKey.new(clientKey);
let encrypted = FheUint256.encrypt_with_compact_public_key(U256_MAX, publicKey);
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
let serialized = encrypted.serialize();
let deserialized = FheUint256.deserialize(serialized);
let deserialized_decrypted = deserialized.decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
}
test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_single', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint256_single(config);
});
test('hlapi_compact_public_key_encrypt_decrypt_uint256_small_single', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint256_single(config);
});
function hlapi_compact_public_key_encrypt_decrypt_uint256_single_compact(config) {
let clientKey = TfheClientKey.generate(config);
let publicKey = TfheCompactPublicKey.new(clientKey);
let compact_encrypted = CompactFheUint256.encrypt_with_compact_public_key(U256_MAX, publicKey);
let encrypted = compact_encrypted.expand();
let decrypted = encrypted.decrypt(clientKey);
assert.deepStrictEqual(decrypted, U256_MAX);
let serialized = compact_encrypted.serialize();
let deserialized = CompactFheUint256.deserialize(serialized);
let deserialized_decrypted = deserialized.expand().decrypt(clientKey);
assert.deepStrictEqual(deserialized_decrypted, U256_MAX);
}
test('hlapi_compact_public_key_encrypt_decrypt_uint256_small_single_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint256_single_compact(config);
});
test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_single_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint256_single_compact(config);
});
function hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config) {
let clientKey = TfheClientKey.generate(config);
let publicKey = TfheCompactPublicKey.new(clientKey);
let values = [BigInt(0), BigInt(1), BigInt(2394), BigInt(2309840239), BigInt(U32_MAX), U256_MAX, U128_MAX];
let compact_list = CompactFheUint256List.encrypt_with_compact_public_key(values, publicKey);
{
let encrypted_list = compact_list.expand();
assert.deepStrictEqual(encrypted_list.length, values.length);
for (let i = 0; i < values.length; i++)
{
let decrypted = encrypted_list[i].decrypt(clientKey);
assert.deepStrictEqual(decrypted, values[i]);
}
}
let serialized_list = compact_list.serialize();
let deserialized_list = CompactFheUint256List.deserialize(serialized_list);
let encrypted_list = deserialized_list.expand();
assert.deepStrictEqual(encrypted_list.length, values.length);
for (let i = 0; i < values.length; i++)
{
let decrypted = encrypted_list[i].decrypt(clientKey);
assert.deepStrictEqual(decrypted, values[i]);
}
}
test('hlapi_compact_public_key_encrypt_decrypt_uint256_small_list_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_SMALL_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config);
});
test('hlapi_compact_public_key_encrypt_decrypt_uint256_big_list_compact', (t) => {
const block_params = new ShortintParameters(ShortintParametersName.PARAM_MESSAGE_2_CARRY_2_COMPACT_PK);
let config = TfheConfigBuilder.all_disabled()
.enable_custom_integers(block_params)
.build();
hlapi_compact_public_key_encrypt_decrypt_uint256_list_compact(config);
});

View File

@@ -94,7 +94,7 @@ impl ServerKey {
}
pub fn bootstrapping_key_size_bytes(&self) -> usize {
self.bootstrapping_key_size_elements() * std::mem::size_of::<concrete_fft::c64>()
std::mem::size_of_val(self.bootstrapping_key.as_view().data())
}
pub fn key_switching_key_size_elements(&self) -> usize {
@@ -102,7 +102,7 @@ impl ServerKey {
}
pub fn key_switching_key_size_bytes(&self) -> usize {
self.key_switching_key_size_elements() * std::mem::size_of::<u64>()
std::mem::size_of_val(self.key_switching_key.as_ref())
}
}

View File

@@ -5,7 +5,7 @@
use crate::boolean::ciphertext::{Ciphertext, CompressedCiphertext};
use crate::boolean::parameters::BooleanParameters;
use crate::boolean::{ClientKey, PublicKey, PLAINTEXT_FALSE, PLAINTEXT_TRUE};
use crate::boolean::{ClientKey, CompressedPublicKey, PublicKey, PLAINTEXT_FALSE, PLAINTEXT_TRUE};
use crate::core_crypto::algorithms::*;
use crate::core_crypto::entities::*;
use std::cell::RefCell;
@@ -144,6 +144,38 @@ impl BooleanEngine {
}
}
pub fn create_compressed_public_key(&mut self, client_key: &ClientKey) -> CompressedPublicKey {
let client_parameters = client_key.parameters;
// Formula is (n + 1) * log2(q) + 128
let zero_encryption_count = LwePublicKeyZeroEncryptionCount(
client_parameters.lwe_dimension.to_lwe_size().0 * LOG2_Q_32 + 128,
);
#[cfg(not(feature = "__wasm_api"))]
let compressed_lwe_public_key = par_allocate_and_generate_new_seeded_lwe_public_key(
&client_key.lwe_secret_key,
zero_encryption_count,
client_key.parameters.lwe_modular_std_dev,
CiphertextModulus::new_native(),
&mut self.bootstrapper.seeder,
);
#[cfg(feature = "__wasm_api")]
let compressed_lwe_public_key = allocate_and_generate_new_seeded_lwe_public_key(
&client_key.lwe_secret_key,
zero_encryption_count,
client_key.parameters.lwe_modular_std_dev,
CiphertextModulus::new_native(),
&mut self.bootstrapper.seeder,
);
CompressedPublicKey {
compressed_lwe_public_key,
parameters: client_parameters,
}
}
pub fn trivial_encrypt(&mut self, message: bool) -> Ciphertext {
Ciphertext::Trivial(message)
}
@@ -212,6 +244,32 @@ impl BooleanEngine {
Ciphertext::Encrypted(output)
}
pub fn encrypt_with_compressed_public_key(
&mut self,
message: bool,
compressed_pk: &CompressedPublicKey,
) -> Ciphertext {
let plain: Plaintext<u32> = if message {
Plaintext(PLAINTEXT_TRUE)
} else {
Plaintext(PLAINTEXT_FALSE)
};
let mut output = LweCiphertext::new(
0u32,
compressed_pk.parameters.lwe_dimension.to_lwe_size(),
CiphertextModulus::new_native(),
);
encrypt_lwe_ciphertext_with_seeded_public_key(
&compressed_pk.compressed_lwe_public_key,
&mut output,
plain,
&mut self.secret_generator,
);
Ciphertext::Encrypted(output)
}
pub fn decrypt(&mut self, ct: &Ciphertext, cks: &ClientKey) -> bool {
match ct {

View File

@@ -54,7 +54,7 @@
use crate::boolean::client_key::ClientKey;
use crate::boolean::parameters::DEFAULT_PARAMETERS;
use crate::boolean::public_key::PublicKey;
use crate::boolean::public_key::{CompressedPublicKey, PublicKey};
use crate::boolean::server_key::ServerKey;
#[cfg(test)]
use rand::Rng;

View File

@@ -7,5 +7,5 @@ pub use super::ciphertext::{Ciphertext, CompressedCiphertext};
pub use super::client_key::ClientKey;
pub use super::gen_keys;
pub use super::parameters::*;
pub use super::public_key::PublicKey;
pub use super::public_key::{CompressedPublicKey, PublicKey};
pub use super::server_key::{BinaryBooleanGates, ServerKey};

View File

@@ -0,0 +1,132 @@
use serde::{Deserialize, Serialize};
use crate::boolean::engine::{BooleanEngine, WithThreadLocalEngine};
use crate::boolean::prelude::{BooleanParameters, Ciphertext, ClientKey};
use crate::core_crypto::prelude::SeededLwePublicKeyOwned;
/// A structure containing a compressed public key.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct CompressedPublicKey {
pub(crate) compressed_lwe_public_key: SeededLwePublicKeyOwned<u32>,
pub parameters: BooleanParameters,
}
impl CompressedPublicKey {
/// Generates a new public key that is compressed
///
/// # Example
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let cpks = CompressedPublicKey::new(&cks);
/// # }
/// ```
///
/// Decompressing the key
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let cpks = CompressedPublicKey::new(&cks);
/// let pks = PublicKey::from(cpks);
/// # }
/// ```
pub fn new(client_key: &ClientKey) -> Self {
BooleanEngine::with_thread_local_mut(|engine| {
engine.create_compressed_public_key(client_key)
})
}
/// Encrypt a Boolean message using the compressed public key.
///
/// # Note
///
/// It is recommended to use the compressed
/// public key to save on storage / tranfert
/// and decompress it in the program before doing encryptions.
///
/// This is because encrypting using the compressed public key
/// will require to lazyly decompress parts of the key
/// for each encryption.
///
/// # Example
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let cpks = CompressedPublicKey::new(&cks);
///
/// // Encryption of one message:
/// let ct1 = cpks.encrypt(true);
/// let ct2 = cpks.encrypt(false);
/// let ct_res = sks.and(&ct1, &ct2);
///
/// // Decryption:
/// let dec = cks.decrypt(&ct_res);
/// assert_eq!(false, dec);
/// # }
/// ```
pub fn encrypt(&self, message: bool) -> Ciphertext {
BooleanEngine::with_thread_local_mut(|engine| {
engine.encrypt_with_compressed_public_key(message, self)
})
}
}
#[cfg(test)]
mod tests {
use crate::boolean::prelude::{
BinaryBooleanGates, BooleanParameters, ClientKey, CompressedPublicKey, ServerKey,
DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS,
};
use crate::boolean::random_boolean;
const NB_TEST: usize = 32;
#[test]
fn test_compressed_public_key_default_parameters() {
test_compressed_public_key(DEFAULT_PARAMETERS);
}
#[test]
fn test_compressed_public_key_tfhe_lib_parameters() {
test_compressed_public_key(TFHE_LIB_PARAMETERS);
}
fn test_compressed_public_key(parameters: BooleanParameters) {
let cks = ClientKey::new(&parameters);
let sks = ServerKey::new(&cks);
let cpks = CompressedPublicKey::new(&cks);
for _ in 0..NB_TEST {
let b1 = random_boolean();
let b2 = random_boolean();
let expected_result = !(b1 && b2);
let ct1 = cpks.encrypt(b1);
let ct2 = cpks.encrypt(b2);
let ct_res = sks.nand(&ct1, &ct2);
let dec_ct1 = cks.decrypt(&ct1);
let dec_ct2 = cks.decrypt(&ct2);
let dec_nand = cks.decrypt(&ct_res);
assert_eq!(dec_ct1, b1);
assert_eq!(dec_ct2, b2);
assert_eq!(dec_nand, expected_result);
}
}
}

View File

@@ -1,62 +1,5 @@
//! Module with the definition of the encryption PublicKey.
mod compressed;
mod standard;
use crate::boolean::ciphertext::Ciphertext;
use crate::boolean::client_key::ClientKey;
use crate::boolean::engine::{BooleanEngine, WithThreadLocalEngine};
use crate::boolean::parameters::BooleanParameters;
use crate::core_crypto::entities::*;
use serde::{Deserialize, Serialize};
/// A structure containing a public key.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PublicKey {
pub(crate) lwe_public_key: LwePublicKeyOwned<u32>,
pub(crate) parameters: BooleanParameters,
}
impl PublicKey {
/// Encrypt a Boolean message using the client key.
///
/// # Example
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let pks = PublicKey::new(&cks);
///
/// // Encryption of one message:
/// let ct1 = pks.encrypt(true);
/// let ct2 = pks.encrypt(false);
/// let ct_res = sks.and(&ct1, &ct2);
///
/// // Decryption:
/// let dec = cks.decrypt(&ct_res);
/// assert_eq!(false, dec);
/// # }
/// ```
pub fn encrypt(&self, message: bool) -> Ciphertext {
BooleanEngine::with_thread_local_mut(|engine| engine.encrypt_with_public_key(message, self))
}
/// Allocate and generate a client key.
///
/// # Example
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let pks = PublicKey::new(&cks);
/// # }
/// ```
pub fn new(client_key: &ClientKey) -> PublicKey {
BooleanEngine::with_thread_local_mut(|engine| engine.create_public_key(client_key))
}
}
pub use compressed::CompressedPublicKey;
pub use standard::PublicKey;

View File

@@ -0,0 +1,162 @@
//! Module with the definition of the encryption PublicKey.
use crate::boolean::ciphertext::Ciphertext;
use crate::boolean::client_key::ClientKey;
use crate::boolean::engine::{BooleanEngine, WithThreadLocalEngine};
use crate::boolean::parameters::BooleanParameters;
use crate::core_crypto::entities::*;
use serde::{Deserialize, Serialize};
use super::compressed::CompressedPublicKey;
/// A structure containing a public key.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PublicKey {
pub(crate) lwe_public_key: LwePublicKeyOwned<u32>,
pub(crate) parameters: BooleanParameters,
}
impl PublicKey {
/// Encrypt a Boolean message using the public key.
///
/// # Example
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let pks = PublicKey::new(&cks);
///
/// // Encryption of one message:
/// let ct1 = pks.encrypt(true);
/// let ct2 = pks.encrypt(false);
/// let ct_res = sks.and(&ct1, &ct2);
///
/// // Decryption:
/// let dec = cks.decrypt(&ct_res);
/// assert_eq!(false, dec);
/// # }
/// ```
pub fn encrypt(&self, message: bool) -> Ciphertext {
BooleanEngine::with_thread_local_mut(|engine| engine.encrypt_with_public_key(message, self))
}
/// Allocate and generate a client key.
///
/// # Example
///
/// ```rust
/// # fn main() {
/// use tfhe::boolean::prelude::*;
///
/// // Generate the client key and the server key:
/// let (cks, sks) = gen_keys();
///
/// let pks = PublicKey::new(&cks);
/// # }
/// ```
pub fn new(client_key: &ClientKey) -> PublicKey {
BooleanEngine::with_thread_local_mut(|engine| engine.create_public_key(client_key))
}
}
impl From<CompressedPublicKey> for PublicKey {
fn from(compressed_public_key: CompressedPublicKey) -> Self {
let parameters = compressed_public_key.parameters;
let decompressed_public_key = compressed_public_key
.compressed_lwe_public_key
.decompress_into_lwe_public_key();
Self {
lwe_public_key: decompressed_public_key,
parameters,
}
}
}
#[cfg(test)]
mod tests {
use crate::boolean::prelude::{
BinaryBooleanGates, BooleanParameters, ClientKey, CompressedPublicKey, ServerKey,
DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS,
};
use crate::boolean::random_boolean;
use super::PublicKey;
const NB_TEST: usize = 32;
#[test]
fn test_public_key_default_parameters() {
test_public_key(DEFAULT_PARAMETERS);
}
#[test]
fn test_public_key_tfhe_lib_parameters() {
test_public_key(TFHE_LIB_PARAMETERS);
}
fn test_public_key(parameters: BooleanParameters) {
let cks = ClientKey::new(&parameters);
let sks = ServerKey::new(&cks);
let pks = PublicKey::new(&cks);
for _ in 0..NB_TEST {
let b1 = random_boolean();
let b2 = random_boolean();
let expected_result = !(b1 && b2);
let ct1 = pks.encrypt(b1);
let ct2 = pks.encrypt(b2);
let ct_res = sks.nand(&ct1, &ct2);
let dec_ct1 = cks.decrypt(&ct1);
let dec_ct2 = cks.decrypt(&ct2);
let dec_nand = cks.decrypt(&ct_res);
assert_eq!(dec_ct1, b1);
assert_eq!(dec_ct2, b2);
assert_eq!(dec_nand, expected_result);
}
}
#[test]
fn test_decompressing_public_key_default_parameters() {
test_public_key(DEFAULT_PARAMETERS);
}
#[test]
fn test_decompressing_public_key_tfhe_lib_parameters() {
test_decompressing_public_key(TFHE_LIB_PARAMETERS);
}
fn test_decompressing_public_key(parameters: BooleanParameters) {
let cks = ClientKey::new(&parameters);
let sks = ServerKey::new(&cks);
let cpks = CompressedPublicKey::new(&cks);
let pks = PublicKey::from(cpks);
for _ in 0..NB_TEST {
let b1 = random_boolean();
let b2 = random_boolean();
let expected_result = !(b1 && b2);
let ct1 = pks.encrypt(b1);
let ct2 = pks.encrypt(b2);
let ct_res = sks.nand(&ct1, &ct2);
let dec_ct1 = cks.decrypt(&ct1);
let dec_ct2 = cks.decrypt(&ct2);
let dec_nand = cks.decrypt(&ct_res);
assert_eq!(dec_ct1, b1);
assert_eq!(dec_ct2, b2);
assert_eq!(dec_nand, expected_result);
}
}
}

View File

@@ -72,6 +72,20 @@ define_enable_default_fn!(integers);
#[cfg(feature = "integer")]
define_enable_default_fn!(integers @small);
#[no_mangle]
pub unsafe extern "C" fn config_builder_enable_custom_integers(
builder: *mut *mut ConfigBuilder,
shortint_block_parameters: crate::c_api::shortint::parameters::ShortintPBSParameters,
) -> ::std::os::raw::c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(builder).unwrap();
let inner = Box::from_raw(*builder)
.0
.enable_custom_integers(shortint_block_parameters.into(), None);
*builder = Box::into_raw(Box::new(ConfigBuilder(inner)));
})
}
/// Takes ownership of the builder
#[no_mangle]
pub unsafe extern "C" fn config_builder_build(

View File

@@ -1,10 +1,11 @@
use crate::c_api::high_level_api::keys::{ClientKey, PublicKey};
use crate::c_api::high_level_api::keys::{ClientKey, CompactPublicKey, PublicKey};
use crate::high_level_api::prelude::*;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, MulAssign,
Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
use crate::c_api::high_level_api::u128::U128;
use crate::c_api::high_level_api::u256::U256;
use crate::c_api::utils::*;
use std::os::raw::c_int;
@@ -67,6 +68,49 @@ macro_rules! create_integer_wrapper_type {
})
}
}
// The compact list version of the ciphertext type
::paste::paste! {
pub struct [<Compact $name List>]($crate::high_level_api::[<Compact $name List>]);
impl_destroy_on_type!([<Compact $name List>]);
impl_clone_on_type!([<Compact $name List>]);
impl_serialize_deserialize_on_type!([<Compact $name List>]);
#[no_mangle]
pub unsafe extern "C" fn [<compact_ $name:snake _list_len>](
sself: *const [<Compact $name List>],
result: *mut usize,
) -> ::std::os::raw::c_int {
$crate::c_api::utils::catch_panic(|| {
let list = $crate::c_api::utils::get_ref_checked(sself).unwrap();
*result = list.0.len();
})
}
#[no_mangle]
pub unsafe extern "C" fn [<compact_ $name:snake _list_expand>](
sself: *const [<Compact $name List>],
output: *mut *mut $name,
output_len: usize
) -> ::std::os::raw::c_int {
$crate::c_api::utils::catch_panic(|| {
check_ptr_is_non_null_and_aligned(output).unwrap();
let list = $crate::c_api::utils::get_ref_checked(sself).unwrap();
let expanded = list.0.expand();
let num_to_take = output_len.max(list.0.len());
let iter = expanded.into_iter().take(num_to_take).enumerate();
for (i, fhe_uint) in iter {
let ptr = output.wrapping_add(i);
*ptr = Box::into_raw(Box::new($name(fhe_uint)));
}
})
}
}
};
}
@@ -84,52 +128,65 @@ impl_decrypt_on_type!(FheUint8, u8);
impl_try_encrypt_trivial_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
impl_try_encrypt_with_client_key_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
impl_try_encrypt_with_public_key_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint8{crate::high_level_api::FheUint8}, u8);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint8{crate::high_level_api::CompressedFheUint8}, u8);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint8List{crate::high_level_api::CompactFheUint8List}, u8);
impl_decrypt_on_type!(FheUint10, u16);
impl_try_encrypt_trivial_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
impl_try_encrypt_with_client_key_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
impl_try_encrypt_with_public_key_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint10{crate::high_level_api::FheUint10}, u16);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint10{crate::high_level_api::CompressedFheUint10}, u16);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint10List{crate::high_level_api::CompactFheUint10List}, u16);
impl_decrypt_on_type!(FheUint12, u16);
impl_try_encrypt_trivial_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
impl_try_encrypt_with_client_key_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
impl_try_encrypt_with_public_key_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint12{crate::high_level_api::FheUint12}, u16);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint12{crate::high_level_api::CompressedFheUint12}, u16);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint12List{crate::high_level_api::CompactFheUint12List}, u16);
impl_decrypt_on_type!(FheUint14, u16);
impl_try_encrypt_trivial_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
impl_try_encrypt_with_client_key_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
impl_try_encrypt_with_public_key_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint14{crate::high_level_api::FheUint14}, u16);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint14{crate::high_level_api::CompressedFheUint14}, u16);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint14List{crate::high_level_api::CompactFheUint14List}, u16);
impl_decrypt_on_type!(FheUint16, u16);
impl_try_encrypt_trivial_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
impl_try_encrypt_with_client_key_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
impl_try_encrypt_with_public_key_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint16{crate::high_level_api::FheUint16}, u16);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint16{crate::high_level_api::CompressedFheUint16}, u16);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint16List{crate::high_level_api::CompactFheUint16List}, u16);
impl_decrypt_on_type!(FheUint32, u32);
impl_try_encrypt_trivial_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
impl_try_encrypt_with_client_key_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
impl_try_encrypt_with_public_key_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint32{crate::high_level_api::FheUint32}, u32);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint32{crate::high_level_api::CompressedFheUint32}, u32);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint32List{crate::high_level_api::CompactFheUint32List}, u32);
impl_decrypt_on_type!(FheUint64, u64);
impl_try_encrypt_trivial_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
impl_try_encrypt_with_client_key_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
impl_try_encrypt_with_public_key_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
impl_try_encrypt_with_compact_public_key_on_type!(FheUint64{crate::high_level_api::FheUint64}, u64);
impl_try_encrypt_with_client_key_on_type!(CompressedFheUint64{crate::high_level_api::CompressedFheUint64}, u64);
impl_try_encrypt_list_with_compact_public_key_on_type!(CompactFheUint64List{crate::high_level_api::CompactFheUint64List}, u64);
#[no_mangle]
pub unsafe extern "C" fn fhe_uint128_try_encrypt_trivial_u128(
low_word: u64,
high_word: u64,
value: U128,
result: *mut *mut FheUint128,
) -> c_int {
catch_panic(|| {
let value = ((high_word as u128) << 64u128) | low_word as u128;
let value = u128::from(value);
let inner = <crate::high_level_api::FheUint128>::try_encrypt_trivial(value).unwrap();
@@ -139,15 +196,14 @@ pub unsafe extern "C" fn fhe_uint128_try_encrypt_trivial_u128(
#[no_mangle]
pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_client_key_u128(
low_word: u64,
high_word: u64,
value: U128,
client_key: *const ClientKey,
result: *mut *mut FheUint128,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let value = ((high_word as u128) << 64u128) | low_word as u128;
let value = u128::from(value);
let inner = <crate::high_level_api::FheUint128>::try_encrypt(value, &client_key.0).unwrap();
@@ -157,15 +213,14 @@ pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_client_key_u128(
#[no_mangle]
pub unsafe extern "C" fn compressed_fhe_uint128_try_encrypt_with_client_key_u128(
low_word: u64,
high_word: u64,
value: U128,
client_key: *const ClientKey,
result: *mut *mut CompressedFheUint128,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let value = ((high_word as u128) << 64u128) | low_word as u128;
let value = u128::from(value);
let inner =
<crate::high_level_api::CompressedFheUint128>::try_encrypt(value, &client_key.0)
@@ -177,15 +232,14 @@ pub unsafe extern "C" fn compressed_fhe_uint128_try_encrypt_with_client_key_u128
#[no_mangle]
pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_public_key_u128(
low_word: u64,
high_word: u64,
value: U128,
public_key: *const PublicKey,
result: *mut *mut FheUint128,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
let value = ((high_word as u128) << 64u128) | low_word as u128;
let value = u128::from(value);
let inner = <crate::high_level_api::FheUint128>::try_encrypt(value, &public_key.0).unwrap();
@@ -193,12 +247,48 @@ pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_public_key_u128(
})
}
#[no_mangle]
pub unsafe extern "C" fn fhe_uint128_try_encrypt_with_compact_public_key_u128(
value: U128,
public_key: *const CompactPublicKey,
result: *mut *mut FheUint128,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
let value = u128::from(value);
let inner = <crate::high_level_api::FheUint128>::try_encrypt(value, &public_key.0).unwrap();
*result = Box::into_raw(Box::new(FheUint128(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn compact_fhe_uint256_list_try_encrypt_with_compact_public_key_u128(
input: *const U128,
input_len: usize,
public_key: *const CompactPublicKey,
result: *mut *mut CompactFheUint256List,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
let slc = ::std::slice::from_raw_parts(input, input_len);
let values = slc.iter().copied().map(u128::from).collect::<Vec<_>>();
let inner =
<crate::high_level_api::CompactFheUint256List>::try_encrypt(&values, &public_key.0)
.unwrap();
*result = Box::into_raw(Box::new(CompactFheUint256List(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn fhe_uint128_decrypt(
encrypted_value: *const FheUint128,
client_key: *const ClientKey,
low_word: *mut u64,
high_word: *mut u64,
result: *mut U128,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
@@ -206,18 +296,18 @@ pub unsafe extern "C" fn fhe_uint128_decrypt(
let inner: u128 = encrypted_value.0.decrypt(&client_key.0);
*low_word = (inner & (u64::MAX as u128)) as u64;
*high_word = (inner >> 64) as u64;
*result = U128::from(inner);
})
}
#[no_mangle]
pub unsafe extern "C" fn fhe_uint256_try_encrypt_trivial_u256(
value: *const U256,
value: U256,
result: *mut *mut FheUint256,
) -> c_int {
catch_panic(|| {
let inner = <crate::high_level_api::FheUint256>::try_encrypt_trivial((*value).0).unwrap();
let value = crate::integer::U256::from(value);
let inner = <crate::high_level_api::FheUint256>::try_encrypt_trivial(value).unwrap();
*result = Box::into_raw(Box::new(FheUint256(inner)));
})
@@ -225,15 +315,15 @@ pub unsafe extern "C" fn fhe_uint256_try_encrypt_trivial_u256(
#[no_mangle]
pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_client_key_u256(
value: *const U256,
value: U256,
client_key: *const ClientKey,
result: *mut *mut FheUint256,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let inner =
<crate::high_level_api::FheUint256>::try_encrypt((*value).0, &client_key.0).unwrap();
let value = crate::integer::U256::from(value);
let inner = <crate::high_level_api::FheUint256>::try_encrypt(value, &client_key.0).unwrap();
*result = Box::into_raw(Box::new(FheUint256(inner)));
})
@@ -241,15 +331,16 @@ pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_client_key_u256(
#[no_mangle]
pub unsafe extern "C" fn compressed_fhe_uint256_try_encrypt_with_client_key_u256(
value: *const U256,
value: U256,
client_key: *const ClientKey,
result: *mut *mut CompressedFheUint256,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let value = crate::integer::U256::from(value);
let inner =
<crate::high_level_api::CompressedFheUint256>::try_encrypt((*value).0, &client_key.0)
<crate::high_level_api::CompressedFheUint256>::try_encrypt(value, &client_key.0)
.unwrap();
*result = Box::into_raw(Box::new(CompressedFheUint256(inner)));
@@ -258,32 +349,72 @@ pub unsafe extern "C" fn compressed_fhe_uint256_try_encrypt_with_client_key_u256
#[no_mangle]
pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_public_key_u256(
value: *const U256,
value: U256,
public_key: *const PublicKey,
result: *mut *mut FheUint256,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
let inner =
<crate::high_level_api::FheUint256>::try_encrypt((*value).0, &public_key.0).unwrap();
let value = crate::integer::U256::from(value);
let inner = <crate::high_level_api::FheUint256>::try_encrypt(value, &public_key.0).unwrap();
*result = Box::into_raw(Box::new(FheUint256(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn fhe_uint256_try_encrypt_with_compact_public_key_u256(
value: U256,
public_key: *const CompactPublicKey,
result: *mut *mut FheUint256,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
let value = crate::integer::U256::from(value);
let inner = <crate::high_level_api::FheUint256>::try_encrypt(value, &public_key.0).unwrap();
*result = Box::into_raw(Box::new(FheUint256(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn compact_fhe_uint256_list_try_encrypt_with_compact_public_key_u256(
input: *const U256,
input_len: usize,
public_key: *const CompactPublicKey,
result: *mut *mut CompactFheUint256List,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
let slc = ::std::slice::from_raw_parts(input, input_len);
let values = slc
.iter()
.copied()
.map(crate::integer::U256::from)
.collect::<Vec<_>>();
let inner =
<crate::high_level_api::CompactFheUint256List>::try_encrypt(&values, &public_key.0)
.unwrap();
*result = Box::into_raw(Box::new(CompactFheUint256List(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn fhe_uint256_decrypt(
encrypted_value: *const FheUint256,
client_key: *const ClientKey,
result: *mut *mut U256,
result: *mut U256,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let encrypted_value = get_ref_checked(encrypted_value).unwrap();
let inner: crate::integer::U256 = encrypted_value.0.decrypt(&client_key.0);
*result = Box::into_raw(Box::new(U256(inner)));
*result = U256::from(inner);
})
}

View File

@@ -3,14 +3,20 @@ use std::os::raw::c_int;
pub struct ClientKey(pub(crate) crate::high_level_api::ClientKey);
pub struct PublicKey(pub(crate) crate::high_level_api::PublicKey);
pub struct CompactPublicKey(pub(crate) crate::high_level_api::CompactPublicKey);
pub struct CompressedCompactPublicKey(pub(crate) crate::high_level_api::CompressedCompactPublicKey);
pub struct ServerKey(pub(crate) crate::high_level_api::ServerKey);
impl_destroy_on_type!(ClientKey);
impl_destroy_on_type!(PublicKey);
impl_destroy_on_type!(CompactPublicKey);
impl_destroy_on_type!(CompressedCompactPublicKey);
impl_destroy_on_type!(ServerKey);
impl_serialize_deserialize_on_type!(ClientKey);
impl_serialize_deserialize_on_type!(PublicKey);
impl_serialize_deserialize_on_type!(CompactPublicKey);
impl_serialize_deserialize_on_type!(CompressedCompactPublicKey);
impl_serialize_deserialize_on_type!(ServerKey);
#[no_mangle]
@@ -69,3 +75,43 @@ pub unsafe extern "C" fn public_key_new(
*result_public_key = Box::into_raw(Box::new(PublicKey(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn compact_public_key_new(
client_key: *const ClientKey,
result_public_key: *mut *mut CompactPublicKey,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let inner = crate::high_level_api::CompactPublicKey::new(&client_key.0);
*result_public_key = Box::into_raw(Box::new(CompactPublicKey(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn compressed_compact_public_key_new(
client_key: *const ClientKey,
result_public_key: *mut *mut CompressedCompactPublicKey,
) -> c_int {
catch_panic(|| {
let client_key = get_ref_checked(client_key).unwrap();
let inner = crate::high_level_api::CompressedCompactPublicKey::new(&client_key.0);
*result_public_key = Box::into_raw(Box::new(CompressedCompactPublicKey(inner)));
})
}
#[no_mangle]
pub unsafe extern "C" fn compressed_compact_public_key_decompress(
public_key: *const CompressedCompactPublicKey,
result_public_key: *mut *mut CompactPublicKey,
) -> c_int {
catch_panic(|| {
let public_key = get_ref_checked(public_key).unwrap();
*result_public_key = Box::into_raw(Box::new(CompactPublicKey(
public_key.0.clone().decompress(),
)));
})
}

View File

@@ -7,4 +7,6 @@ pub mod config;
pub mod integers;
pub mod keys;
#[cfg(feature = "integer")]
pub mod u128;
#[cfg(feature = "integer")]
pub mod u256;

View File

@@ -0,0 +1,20 @@
#[repr(C)]
#[derive(Copy, Clone)]
pub struct U128 {
pub w0: u64,
pub w1: u64,
}
impl From<u128> for U128 {
fn from(value: u128) -> Self {
let w0 = (value & (u64::MAX as u128)) as u64;
let w1 = (value >> 64) as u64;
Self { w0, w1 }
}
}
impl From<U128> for u128 {
fn from(value: U128) -> Self {
((value.w1 as u128) << 64u128) | value.w0 as u128
}
}

View File

@@ -1,47 +1,30 @@
use crate::c_api::utils::*;
use std::os::raw::c_int;
pub struct U256(pub(in crate::c_api) crate::integer::U256);
impl_destroy_on_type!(U256);
/// w0 is the least significant, w4 is the most significant
#[no_mangle]
pub unsafe extern "C" fn u256_from_u64_words(
w0: u64,
w1: u64,
w2: u64,
w3: u64,
result: *mut *mut U256,
) -> c_int {
catch_panic(|| {
let inner = crate::integer::U256::from((w0, w1, w2, w3));
*result = Box::into_raw(Box::new(U256(inner)));
})
#[repr(C)]
#[derive(Copy, Clone)]
pub struct U256 {
pub w0: u64,
pub w1: u64,
pub w2: u64,
pub w3: u64,
}
/// w0 is the least significant, w4 is the most significant
#[no_mangle]
pub unsafe extern "C" fn u256_to_u64_words(
input: *const U256,
w0: *mut u64,
w1: *mut u64,
w2: *mut u64,
w3: *mut u64,
) -> c_int {
catch_panic(|| {
let input = get_ref_checked(input).unwrap();
impl From<crate::integer::U256> for U256 {
fn from(value: crate::integer::U256) -> Self {
Self {
w0: value.0[0],
w1: value.0[1],
w2: value.0[2],
w3: value.0[3],
}
}
}
check_ptr_is_non_null_and_aligned(w0).unwrap();
check_ptr_is_non_null_and_aligned(w1).unwrap();
check_ptr_is_non_null_and_aligned(w2).unwrap();
check_ptr_is_non_null_and_aligned(w3).unwrap();
*w0 = input.0 .0[0];
*w1 = input.0 .0[1];
*w2 = input.0 .0[2];
*w3 = input.0 .0[3];
})
impl From<U256> for crate::integer::U256 {
fn from(value: U256) -> Self {
Self([value.w0, value.w1, value.w2, value.w3])
}
}
/// Creates a U256 from little endian bytes
@@ -51,7 +34,7 @@ pub unsafe extern "C" fn u256_to_u64_words(
pub unsafe extern "C" fn u256_from_little_endian_bytes(
input: *const u8,
len: usize,
result: *mut *mut U256,
result: *mut U256,
) -> c_int {
catch_panic(|| {
let mut inner = crate::integer::U256::default();
@@ -59,7 +42,7 @@ pub unsafe extern "C" fn u256_from_little_endian_bytes(
let input = std::slice::from_raw_parts(input, len);
inner.copy_from_le_byte_slice(input);
*result = Box::into_raw(Box::new(U256(inner)));
*result = U256::from(inner)
})
}
@@ -70,7 +53,7 @@ pub unsafe extern "C" fn u256_from_little_endian_bytes(
pub unsafe extern "C" fn u256_from_big_endian_bytes(
input: *const u8,
len: usize,
result: *mut *mut U256,
result: *mut U256,
) -> c_int {
catch_panic(|| {
let mut inner = crate::integer::U256::default();
@@ -78,38 +61,32 @@ pub unsafe extern "C" fn u256_from_big_endian_bytes(
let input = std::slice::from_raw_parts(input, len);
inner.copy_from_be_byte_slice(input);
*result = Box::into_raw(Box::new(U256(inner)));
*result = U256::from(inner)
})
}
/// len must be 32
#[no_mangle]
pub unsafe extern "C" fn u256_little_endian_bytes(
input: *const U256,
input: U256,
result: *mut u8,
len: usize,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
let input = get_ref_checked(input).unwrap();
let bytes = std::slice::from_raw_parts_mut(result, len);
input.0.copy_to_le_byte_slice(bytes);
crate::integer::U256::from(input).copy_to_le_byte_slice(bytes);
})
}
/// len must be 32
#[no_mangle]
pub unsafe extern "C" fn u256_big_endian_bytes(
input: *const U256,
result: *mut u8,
len: usize,
) -> c_int {
pub unsafe extern "C" fn u256_big_endian_bytes(input: U256, result: *mut u8, len: usize) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
let input = get_ref_checked(input).unwrap();
let bytes = std::slice::from_raw_parts_mut(result, len);
input.0.copy_to_be_byte_slice(bytes);
crate::integer::U256::from(input).copy_to_be_byte_slice(bytes);
})
}

View File

@@ -59,6 +59,49 @@ macro_rules! impl_try_encrypt_with_public_key_on_type {
};
}
macro_rules! impl_try_encrypt_with_compact_public_key_on_type {
($wrapper_type:ty{$wrapped_type:ty}, $input_type:ty) => {
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$wrapper_type:snake _try_encrypt_with_compact_public_key_ $input_type:snake>](
value: $input_type,
public_key: *const $crate::c_api::high_level_api::keys::CompactPublicKey,
result: *mut *mut $wrapper_type,
) -> ::std::os::raw::c_int {
$crate::c_api::utils::catch_panic(|| {
let public_key = $crate::c_api::utils::get_ref_checked(public_key).unwrap();
let inner = <$wrapped_type>::try_encrypt(value, &public_key.0).unwrap();
*result = Box::into_raw(Box::new($wrapper_type(inner)));
})
}
}
};
}
macro_rules! impl_try_encrypt_list_with_compact_public_key_on_type {
($wrapper_type:ty{$wrapped_type:ty}, $input_type:ty) => {
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$wrapper_type:snake _try_encrypt_with_compact_public_key_ $input_type:snake>](
input: *const $input_type,
input_len: usize,
public_key: *const $crate::c_api::high_level_api::keys::CompactPublicKey,
result: *mut *mut $wrapper_type,
) -> ::std::os::raw::c_int {
$crate::c_api::utils::catch_panic(|| {
let public_key = $crate::c_api::utils::get_ref_checked(public_key).unwrap();
let slc = ::std::slice::from_raw_parts(input, input_len);
let inner = <$wrapped_type>::try_encrypt(slc, &public_key.0).unwrap();
*result = Box::into_raw(Box::new($wrapper_type(inner)));
})
}
}
};
}
macro_rules! impl_try_encrypt_trivial_on_type {
($wrapper_type:ty{$wrapped_type:ty}, $input_type:ty) => {
::paste::paste! {

View File

@@ -10,7 +10,7 @@ pub struct ShortintClientKey(pub(in crate::c_api) shortint::client_key::ClientKe
#[no_mangle]
pub unsafe extern "C" fn shortint_gen_client_key(
shortint_parameters: *const super::parameters::ShortintParameters,
shortint_parameters: super::parameters::ShortintPBSParameters,
result_client_key: *mut *mut ShortintClientKey,
) -> c_int {
catch_panic(|| {
@@ -20,9 +20,10 @@ pub unsafe extern "C" fn shortint_gen_client_key(
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result_client_key = std::ptr::null_mut();
let shortint_parameters = get_ref_checked(shortint_parameters).unwrap();
let shortint_parameters: crate::shortint::parameters::ClassicPBSParameters =
shortint_parameters.into();
let client_key = shortint::client_key::ClientKey::new(shortint_parameters.0.to_owned());
let client_key = shortint::client_key::ClientKey::new(shortint_parameters);
let heap_allocated_client_key = Box::new(ShortintClientKey(client_key));

View File

@@ -1,7 +1,6 @@
use crate::c_api::utils::*;
use std::os::raw::c_int;
use super::parameters::ShortintParameters;
use super::{
ShortintBivariatePBSLookupTable, ShortintCiphertext, ShortintClientKey,
ShortintCompressedCiphertext, ShortintCompressedPublicKey, ShortintCompressedServerKey,
@@ -57,17 +56,6 @@ pub unsafe extern "C" fn destroy_shortint_compressed_public_key(
})
}
#[no_mangle]
pub unsafe extern "C" fn destroy_shortint_parameters(
shortint_parameters: *mut ShortintParameters,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(shortint_parameters).unwrap();
drop(Box::from_raw(shortint_parameters));
})
}
#[no_mangle]
pub unsafe extern "C" fn destroy_shortint_ciphertext(
shortint_ciphertext: *mut ShortintCiphertext,

View File

@@ -19,7 +19,7 @@ pub use server_key::{ShortintCompressedServerKey, ShortintServerKey};
#[no_mangle]
pub unsafe extern "C" fn shortint_gen_keys_with_parameters(
shortint_parameters: *const parameters::ShortintParameters,
shortint_parameters: parameters::ShortintPBSParameters,
result_client_key: *mut *mut ShortintClientKey,
result_server_key: *mut *mut ShortintServerKey,
) -> c_int {
@@ -32,9 +32,10 @@ pub unsafe extern "C" fn shortint_gen_keys_with_parameters(
*result_client_key = std::ptr::null_mut();
*result_server_key = std::ptr::null_mut();
let shortint_parameters = get_ref_checked(shortint_parameters).unwrap();
let shortint_parameters: crate::shortint::parameters::ClassicPBSParameters =
shortint_parameters.into();
let client_key = shortint::client_key::ClientKey::new(shortint_parameters.0.to_owned());
let client_key = shortint::client_key::ClientKey::new(shortint_parameters);
let server_key = shortint::server_key::ServerKey::new(&client_key);
let heap_allocated_client_key = Box::new(ShortintClientKey(client_key));

View File

@@ -3,11 +3,14 @@ pub use crate::core_crypto::commons::dispersion::StandardDev;
pub use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
};
pub use crate::shortint::parameters::parameters_compact_pk::*;
pub use crate::shortint::parameters::*;
use std::os::raw::c_int;
use crate::shortint;
#[repr(C)]
#[derive(Copy, Clone)]
pub enum ShortintEncryptionKeyChoice {
ShortintEncryptionKeyChoiceBig,
ShortintEncryptionKeyChoiceSmall,
@@ -26,13 +29,185 @@ impl From<ShortintEncryptionKeyChoice> for crate::shortint::parameters::Encrypti
}
}
pub struct ShortintParameters(pub(in crate::c_api) shortint::parameters::Parameters);
#[repr(C)]
#[derive(Copy, Clone)]
pub struct ShortintPBSParameters {
pub lwe_dimension: usize,
pub glwe_dimension: usize,
pub polynomial_size: usize,
pub lwe_modular_std_dev: f64,
pub glwe_modular_std_dev: f64,
pub pbs_base_log: usize,
pub pbs_level: usize,
pub ks_base_log: usize,
pub ks_level: usize,
pub message_modulus: usize,
pub carry_modulus: usize,
pub modulus_power_of_2_exponent: usize,
pub encryption_key_choice: ShortintEncryptionKeyChoice,
}
impl From<ShortintPBSParameters> for crate::shortint::ClassicPBSParameters {
fn from(c_params: ShortintPBSParameters) -> Self {
Self {
lwe_dimension: LweDimension(c_params.lwe_dimension),
glwe_dimension: GlweDimension(c_params.glwe_dimension),
polynomial_size: PolynomialSize(c_params.polynomial_size),
lwe_modular_std_dev: StandardDev(c_params.lwe_modular_std_dev),
glwe_modular_std_dev: StandardDev(c_params.glwe_modular_std_dev),
pbs_base_log: DecompositionBaseLog(c_params.pbs_base_log),
pbs_level: DecompositionLevelCount(c_params.pbs_level),
ks_base_log: DecompositionBaseLog(c_params.ks_base_log),
ks_level: DecompositionLevelCount(c_params.ks_level),
message_modulus: crate::shortint::parameters::MessageModulus(c_params.message_modulus),
carry_modulus: crate::shortint::parameters::CarryModulus(c_params.carry_modulus),
ciphertext_modulus: crate::shortint::parameters::CiphertextModulus::try_new_power_of_2(
c_params.modulus_power_of_2_exponent,
)
.unwrap(),
encryption_key_choice: c_params.encryption_key_choice.into(),
}
}
}
impl From<crate::shortint::ClassicPBSParameters> for ShortintPBSParameters {
fn from(rust_params: crate::shortint::ClassicPBSParameters) -> Self {
Self::convert(rust_params)
}
}
impl ShortintEncryptionKeyChoice {
// From::from cannot be marked as const, so we have to have
// our own function
const fn convert(rust_choice: crate::shortint::EncryptionKeyChoice) -> Self {
match rust_choice {
crate::shortint::EncryptionKeyChoice::Big => Self::ShortintEncryptionKeyChoiceBig,
crate::shortint::EncryptionKeyChoice::Small => Self::ShortintEncryptionKeyChoiceSmall,
}
}
}
const fn convert_modulus(rust_modulus: crate::shortint::CiphertextModulus) -> usize {
if rust_modulus.is_native_modulus() {
64 // shortints are on 64 bits
} else {
assert!(rust_modulus.is_power_of_two());
let modulus = rust_modulus.get_custom_modulus();
let exponent = modulus.ilog2() as usize;
assert!(exponent <= 64);
exponent
}
}
impl ShortintPBSParameters {
const fn convert(rust_params: crate::shortint::ClassicPBSParameters) -> Self {
Self {
lwe_dimension: rust_params.lwe_dimension.0,
glwe_dimension: rust_params.glwe_dimension.0,
polynomial_size: rust_params.polynomial_size.0,
lwe_modular_std_dev: rust_params.lwe_modular_std_dev.0,
glwe_modular_std_dev: rust_params.glwe_modular_std_dev.0,
pbs_base_log: rust_params.pbs_base_log.0,
pbs_level: rust_params.pbs_level.0,
ks_base_log: rust_params.ks_base_log.0,
ks_level: rust_params.ks_level.0,
message_modulus: rust_params.message_modulus.0,
carry_modulus: rust_params.carry_modulus.0,
modulus_power_of_2_exponent: convert_modulus(rust_params.ciphertext_modulus),
encryption_key_choice: ShortintEncryptionKeyChoice::convert(
rust_params.encryption_key_choice,
),
}
}
}
macro_rules! expose_as_shortint_pbs_parameters(
(
$(
$param_name:ident
),*
$(,)?
) => {
::paste::paste!{
$(
#[no_mangle]
pub static [<SHORTINT_ $param_name>]: ShortintPBSParameters =
ShortintPBSParameters::convert($param_name);
)*
}
// Test that converting a param from its rust struct
// to the c struct and then to the rust struct again
// yields the same values as the original struct
//
// This is what will essentially happen in the real code
#[test]
fn test_shortint_pbs_parameters_roundtrip_c_rust() {
$(
// 1 scope for each parameters
{
let rust_params = crate::shortint::parameters::$param_name;
let c_params = ShortintPBSParameters::from(rust_params);
let rust_params_from_c = crate::shortint::parameters::ClassicPBSParameters::from(c_params);
assert_eq!(rust_params, rust_params_from_c);
}
)*
}
}
);
expose_as_shortint_pbs_parameters!(
PARAM_MESSAGE_1_CARRY_0,
PARAM_MESSAGE_1_CARRY_1,
PARAM_MESSAGE_2_CARRY_0,
PARAM_MESSAGE_1_CARRY_2,
PARAM_MESSAGE_2_CARRY_1,
PARAM_MESSAGE_3_CARRY_0,
PARAM_MESSAGE_1_CARRY_3,
PARAM_MESSAGE_2_CARRY_2,
PARAM_MESSAGE_3_CARRY_1,
PARAM_MESSAGE_4_CARRY_0,
PARAM_MESSAGE_1_CARRY_4,
PARAM_MESSAGE_2_CARRY_3,
PARAM_MESSAGE_3_CARRY_2,
PARAM_MESSAGE_4_CARRY_1,
PARAM_MESSAGE_5_CARRY_0,
PARAM_MESSAGE_1_CARRY_5,
PARAM_MESSAGE_2_CARRY_4,
PARAM_MESSAGE_3_CARRY_3,
PARAM_MESSAGE_4_CARRY_2,
PARAM_MESSAGE_5_CARRY_1,
PARAM_MESSAGE_6_CARRY_0,
PARAM_MESSAGE_1_CARRY_6,
PARAM_MESSAGE_2_CARRY_5,
PARAM_MESSAGE_3_CARRY_4,
PARAM_MESSAGE_4_CARRY_3,
PARAM_MESSAGE_5_CARRY_2,
PARAM_MESSAGE_6_CARRY_1,
PARAM_MESSAGE_7_CARRY_0,
PARAM_MESSAGE_1_CARRY_7,
PARAM_MESSAGE_2_CARRY_6,
PARAM_MESSAGE_3_CARRY_5,
PARAM_MESSAGE_4_CARRY_4,
PARAM_MESSAGE_5_CARRY_3,
PARAM_MESSAGE_6_CARRY_2,
PARAM_MESSAGE_7_CARRY_1,
PARAM_MESSAGE_8_CARRY_0,
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK,
// Small params
PARAM_SMALL_MESSAGE_1_CARRY_1,
PARAM_SMALL_MESSAGE_2_CARRY_2,
PARAM_SMALL_MESSAGE_3_CARRY_3,
PARAM_SMALL_MESSAGE_4_CARRY_4,
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_SMALL,
);
#[no_mangle]
pub unsafe extern "C" fn shortint_get_parameters(
message_bits: u32,
carry_bits: u32,
result: *mut *mut ShortintParameters,
result: *mut ShortintPBSParameters,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
@@ -76,12 +251,8 @@ pub unsafe extern "C" fn shortint_get_parameters(
_ => None,
};
match params {
Some(params) => {
let params = Box::new(ShortintParameters(params));
*result = Box::into_raw(params);
}
None => *result = std::ptr::null_mut(),
if let Some(params) = params {
*result = params.into();
}
})
}
@@ -90,7 +261,7 @@ pub unsafe extern "C" fn shortint_get_parameters(
pub unsafe extern "C" fn shortint_get_parameters_small(
message_bits: u32,
carry_bits: u32,
result: *mut *mut ShortintParameters,
result: *mut ShortintPBSParameters,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
@@ -102,71 +273,8 @@ pub unsafe extern "C" fn shortint_get_parameters_small(
_ => None,
};
match params {
Some(params) => {
let params = Box::new(ShortintParameters(params));
*result = Box::into_raw(params);
}
None => *result = std::ptr::null_mut(),
if let Some(params) = params {
*result = params.into();
}
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_create_parameters(
lwe_dimension: usize,
glwe_dimension: usize,
polynomial_size: usize,
lwe_modular_std_dev: f64,
glwe_modular_std_dev: f64,
pbs_base_log: usize,
pbs_level: usize,
ks_base_log: usize,
ks_level: usize,
pfks_level: usize,
pfks_base_log: usize,
pfks_modular_std_dev: f64,
cbs_level: usize,
cbs_base_log: usize,
message_modulus: usize,
carry_modulus: usize,
modulus_power_of_2_exponent: usize,
encryption_key_choice: ShortintEncryptionKeyChoice,
result: *mut *mut ShortintParameters,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
// First fill the result with a null ptr so that if we fail and the return code is not
// checked, then any access to the result pointer will segfault (mimics malloc on failure)
*result = std::ptr::null_mut();
let heap_allocated_parameters =
Box::new(ShortintParameters(shortint::parameters::Parameters {
lwe_dimension: LweDimension(lwe_dimension),
glwe_dimension: GlweDimension(glwe_dimension),
polynomial_size: PolynomialSize(polynomial_size),
lwe_modular_std_dev: StandardDev(lwe_modular_std_dev),
glwe_modular_std_dev: StandardDev(glwe_modular_std_dev),
pbs_base_log: DecompositionBaseLog(pbs_base_log),
pbs_level: DecompositionLevelCount(pbs_level),
ks_base_log: DecompositionBaseLog(ks_base_log),
ks_level: DecompositionLevelCount(ks_level),
pfks_level: DecompositionLevelCount(pfks_level),
pfks_base_log: DecompositionBaseLog(pfks_base_log),
pfks_modular_std_dev: StandardDev(pfks_modular_std_dev),
cbs_level: DecompositionLevelCount(cbs_level),
cbs_base_log: DecompositionBaseLog(cbs_base_log),
message_modulus: crate::shortint::parameters::MessageModulus(message_modulus),
carry_modulus: crate::shortint::parameters::CarryModulus(carry_modulus),
ciphertext_modulus:
crate::shortint::parameters::CiphertextModulus::try_new_power_of_2(
modulus_power_of_2_exponent,
)
.unwrap(),
encryption_key_choice: encryption_key_choice.into(),
}));
*result = Box::into_raw(heap_allocated_parameters);
})
}

View File

@@ -111,6 +111,8 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
let decomp_base_log = output.decomposition_base_log();
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
for (level_index, (mut level_matrix, mut generator)) in
output.iter_mut().zip(gen_iter).enumerate()
{
@@ -121,7 +123,7 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
.0
.wrapping_neg()
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = generator
@@ -249,6 +251,8 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
let decomp_base_log = output.decomposition_base_log();
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
@@ -258,7 +262,7 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
.0
.wrapping_neg()
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
// We iterate over the rows of the level matrix, the last row needs special
// treatment
@@ -367,6 +371,8 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
let decomp_base_log = output.decomposition_base_log();
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
for (level_index, (mut level_matrix, mut loop_generator)) in
output.iter_mut().zip(gen_iter).enumerate()
{
@@ -377,7 +383,7 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
.0
.wrapping_neg()
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = loop_generator
@@ -542,6 +548,8 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
let decomp_base_log = output.decomposition_base_log();
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
@@ -551,7 +559,7 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
.0
.wrapping_neg()
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = generator
@@ -832,13 +840,13 @@ where
let plaintext_ref = decrypted_plaintext_list.get(0);
let ciphertext_modulus = ggsw_ciphertext.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
// Glwe decryption maps to a smaller torus potentially, map back to the native torus
let rounded = decomposer.closest_representable(
(*plaintext_ref.0).wrapping_mul(
ggsw_ciphertext
.ciphertext_modulus()
.get_scaling_to_native_torus(),
),
(*plaintext_ref.0)
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
let decoded =
rounded.wrapping_div(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)));

View File

@@ -37,6 +37,8 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
let ciphertext_modulus = output_body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
generator.unsigned_torus_slice_wrapping_add_random_noise_custom_mod_assign(
output_body.as_mut(),
@@ -45,7 +47,7 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
);
if !ciphertext_modulus.is_native_modulus() {
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
slice_wrapping_scalar_mul_assign(output_body.as_mut(), torus_scaling);
}
@@ -250,6 +252,8 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
let ciphertext_modulus = output_body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
generator.fill_slice_with_random_noise_custom_mod(
output_body.as_mut(),
@@ -263,7 +267,7 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
);
if !ciphertext_modulus.is_native_modulus() {
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
slice_wrapping_scalar_mul_assign(output_body.as_mut(), torus_scaling);
}
@@ -572,6 +576,8 @@ pub fn decrypt_glwe_ciphertext<Scalar, KeyCont, InputCont, OutputCont>(
let ciphertext_modulus = input_glwe_ciphertext.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let (mask, body) = input_glwe_ciphertext.get_mask_and_body();
output_plaintext_list
.as_mut()
@@ -585,7 +591,7 @@ pub fn decrypt_glwe_ciphertext<Scalar, KeyCont, InputCont, OutputCont>(
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_div_assign(
output_plaintext_list.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
}
@@ -720,10 +726,12 @@ pub fn trivially_encrypt_glwe_ciphertext<Scalar, InputCont, OutputCont>(
let ciphertext_modulus = body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
body.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
}
@@ -801,6 +809,8 @@ where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
{
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let polynomial_size = PolynomialSize(encoded.plaintext_count().0);
let mut new_ct =
@@ -812,7 +822,7 @@ where
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
body.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}

View File

@@ -0,0 +1,110 @@
//! Module with primitives pertaining to [`LweCompactCiphertextList`] expansion.
use crate::core_crypto::algorithms::polynomial_algorithms::polynomial_wrapping_monic_monomial_mul_assign;
use crate::core_crypto::commons::parameters::MonomialDegree;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
/// Expand an [`LweCompactCiphertextList`] into an [`LweCiphertextList`].
///
/// Consider using [`par_expand_lwe_compact_ciphertext_list`] for better performance.
pub fn expand_lwe_compact_ciphertext_list<Scalar, InputCont, OutputCont>(
output_lwe_ciphertext_list: &mut LweCiphertextList<OutputCont>,
input_lwe_compact_ciphertext_list: &LweCompactCiphertextList<InputCont>,
) where
Scalar: UnsignedInteger,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
output_lwe_ciphertext_list.entity_count()
== input_lwe_compact_ciphertext_list.lwe_ciphertext_count().0
);
assert!(output_lwe_ciphertext_list.lwe_size() == input_lwe_compact_ciphertext_list.lwe_size());
let (input_mask_list, input_body_list) =
input_lwe_compact_ciphertext_list.get_mask_and_body_list();
let lwe_dimension = input_mask_list.lwe_dimension();
let max_ciphertext_per_bin = lwe_dimension.0;
for (input_mask, (mut output_ct_chunk, input_body_chunk)) in input_mask_list.iter().zip(
output_lwe_ciphertext_list
.chunks_mut(max_ciphertext_per_bin)
.zip(input_body_list.chunks(max_ciphertext_per_bin)),
) {
for (ct_idx, (mut out_ct, input_body)) in output_ct_chunk
.iter_mut()
.zip(input_body_chunk.iter())
.enumerate()
{
let (mut out_mask, out_body) = out_ct.get_mut_mask_and_body();
out_mask.as_mut().copy_from_slice(input_mask.as_ref());
let mut out_mask_as_polynomial = Polynomial::from_container(out_mask.as_mut());
// This the Psi_jl from the paper, it's equivalent to a multiplication in the X^N + 1
// ring for our choice of i == n
polynomial_wrapping_monic_monomial_mul_assign(
&mut out_mask_as_polynomial,
MonomialDegree(lwe_dimension.0 - (ct_idx + 1)),
);
*out_body.data = *input_body.data;
}
}
}
/// Parallel version of [`expand_lwe_compact_ciphertext_list`].
pub fn par_expand_lwe_compact_ciphertext_list<Scalar, InputCont, OutputCont>(
output_lwe_ciphertext_list: &mut LweCiphertextList<OutputCont>,
input_lwe_compact_ciphertext_list: &LweCompactCiphertextList<InputCont>,
) where
Scalar: UnsignedInteger + Send + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
output_lwe_ciphertext_list.entity_count()
== input_lwe_compact_ciphertext_list.lwe_ciphertext_count().0
);
assert!(output_lwe_ciphertext_list.lwe_size() == input_lwe_compact_ciphertext_list.lwe_size());
let (input_mask_list, input_body_list) =
input_lwe_compact_ciphertext_list.get_mask_and_body_list();
let lwe_dimension = input_mask_list.lwe_dimension();
let max_ciphertext_per_bin = lwe_dimension.0;
input_mask_list
.par_iter()
.zip(
output_lwe_ciphertext_list
.par_chunks_mut(max_ciphertext_per_bin)
.zip(input_body_list.par_chunks(max_ciphertext_per_bin)),
)
.for_each(|(input_mask, (mut output_ct_chunk, input_body_chunk))| {
output_ct_chunk
.par_iter_mut()
.zip(input_body_chunk.par_iter())
.enumerate()
.for_each(|(ct_idx, (mut out_ct, input_body))| {
let (mut out_mask, out_body) = out_ct.get_mut_mask_and_body();
out_mask.as_mut().copy_from_slice(input_mask.as_ref());
let mut out_mask_as_polynomial = Polynomial::from_container(out_mask.as_mut());
// This is the Psi_jl from the paper, it's equivalent to a multiplication in the
// X^N + 1 ring for our choice of i == n
polynomial_wrapping_monic_monomial_mul_assign(
&mut out_mask_as_polynomial,
MonomialDegree(lwe_dimension.0 - (ct_idx + 1)),
);
*out_body.data = *input_body.data;
});
});
}

View File

@@ -0,0 +1,149 @@
//! Module containing primitives pertaining to [`LWE compact public key
//! generation`](`LweCompactPublicKey`).
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
use crate::core_crypto::commons::dispersion::DispersionParameter;
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::ActivatedRandomGenerator;
use slice_algorithms::*;
/// Fill an [`LWE compact public key`](`LweCompactPublicKey`) with an actual public key constructed
/// from a private [`LWE secret key`](`LweSecretKey`).
pub fn generate_lwe_compact_public_key<Scalar, InputKeyCont, OutputKeyCont, Gen>(
lwe_secret_key: &LweSecretKey<InputKeyCont>,
output: &mut LweCompactPublicKey<OutputKeyCont>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert!(
output.ciphertext_modulus().is_native_modulus(),
"This operation only supports native moduli"
);
assert!(
lwe_secret_key.lwe_dimension() == output.lwe_dimension(),
"Mismatched LweDimension between input LweSecretKey {:?} \
and ouptut LweCompactPublicKey {:?}",
lwe_secret_key.lwe_dimension(),
output.lwe_dimension()
);
let (mut mask, mut body) = output.get_mut_mask_and_body();
generator.fill_slice_with_random_mask(mask.as_mut());
slice_semi_reverse_negacyclic_convolution(
body.as_mut(),
mask.as_ref(),
lwe_secret_key.as_ref(),
);
generator
.unsigned_torus_slice_wrapping_add_random_noise_assign(body.as_mut(), noise_parameters);
}
/// Allocate a new [`LWE compact public key`](`LweCompactPublicKey`) and fill it with an actual
/// public key constructed from a private [`LWE secret key`](`LweSecretKey`).
///
/// See [`encrypt_lwe_ciphertext_with_compact_public_key`] for usage.
pub fn allocate_and_generate_new_lwe_compact_public_key<Scalar, InputKeyCont, Gen>(
lwe_secret_key: &LweSecretKey<InputKeyCont>,
noise_parameters: impl DispersionParameter,
ciphertext_modulus: CiphertextModulus<Scalar>,
generator: &mut EncryptionRandomGenerator<Gen>,
) -> LweCompactPublicKeyOwned<Scalar>
where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
Gen: ByteRandomGenerator,
{
let mut pk = LweCompactPublicKeyOwned::new(
Scalar::ZERO,
lwe_secret_key.lwe_dimension(),
ciphertext_modulus,
);
generate_lwe_compact_public_key(lwe_secret_key, &mut pk, noise_parameters, generator);
pk
}
/// Fill a [`seeded LWE compact public key`](`LweCompactPublicKey`) with an actual public key
/// constructed from a private [`LWE secret key`](`LweSecretKey`).
pub fn generate_seeded_lwe_compact_public_key<Scalar, InputKeyCont, OutputKeyCont, NoiseSeeder>(
lwe_secret_key: &LweSecretKey<InputKeyCont>,
output: &mut SeededLweCompactPublicKey<OutputKeyCont>,
noise_parameters: impl DispersionParameter,
noise_seeder: &mut NoiseSeeder,
) where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: ContainerMut<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
assert!(
output.ciphertext_modulus().is_native_modulus(),
"This operation only supports native moduli"
);
assert!(
lwe_secret_key.lwe_dimension() == output.lwe_dimension(),
"Mismatched LweDimension between input LweSecretKey {:?} \
and ouptut LweCompactPublicKey {:?}",
lwe_secret_key.lwe_dimension(),
output.lwe_dimension()
);
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
output.compression_seed().seed,
noise_seeder,
);
let mut tmp_mask = vec![Scalar::ZERO; output.lwe_dimension().0];
generator.fill_slice_with_random_mask(tmp_mask.as_mut());
let mut body = output.get_mut_body();
slice_semi_reverse_negacyclic_convolution(
body.as_mut(),
tmp_mask.as_ref(),
lwe_secret_key.as_ref(),
);
generator
.unsigned_torus_slice_wrapping_add_random_noise_assign(body.as_mut(), noise_parameters);
}
/// Allocate a new [`seeded LWE compact public key`](`SeededLweCompactPublicKey`) and fill it with
/// an actual public key constructed from a private [`LWE secret key`](`LweSecretKey`).
pub fn allocate_and_generate_new_seeded_lwe_compact_public_key<Scalar, InputKeyCont, NoiseSeeder>(
lwe_secret_key: &LweSecretKey<InputKeyCont>,
noise_parameters: impl DispersionParameter,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweCompactPublicKeyOwned<Scalar>
where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
let mut pk = SeededLweCompactPublicKey::new(
Scalar::ZERO,
lwe_secret_key.lwe_dimension(),
noise_seeder.seed().into(),
ciphertext_modulus,
);
generate_seeded_lwe_compact_public_key(lwe_secret_key, &mut pk, noise_parameters, noise_seeder);
pk
}

View File

@@ -36,6 +36,8 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
let ciphertext_modulus = output_mask.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
// generate an error from the normal distribution described by std_dev
@@ -43,7 +45,7 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
*output_body.data = (*output_body.data).wrapping_add(encoded.0);
if !ciphertext_modulus.is_native_modulus() {
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
*output_body.data = (*output_body.data).wrapping_mul(torus_scaling);
}
@@ -296,9 +298,10 @@ pub fn trivially_encrypt_lwe_ciphertext<Scalar, OutputCont>(
*output_body.data = encoded.0;
let ciphertext_modulus = output_body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if !ciphertext_modulus.is_native_modulus() {
*output_body.data =
(*output_body.data).wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus());
*output_body.data = (*output_body.data)
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
}
@@ -374,9 +377,10 @@ where
*output_body.data = encoded.0;
let ciphertext_modulus = output_body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if !ciphertext_modulus.is_native_modulus() {
*output_body.data =
(*output_body.data).wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus());
*output_body.data = (*output_body.data)
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
new_ct
@@ -409,6 +413,8 @@ where
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let (mask, body) = lwe_ciphertext.get_mask_and_body();
if ciphertext_modulus.is_native_modulus() {
@@ -423,7 +429,7 @@ where
mask.as_ref(),
lwe_secret_key.as_ref(),
))
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus()),
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
)
}
}
@@ -777,6 +783,8 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
output.as_mut().fill(Scalar::ZERO);
let mut tmp_zero_encryption =
@@ -806,7 +814,7 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
*body.data = (*body.data).wrapping_add(
encoded
.0
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
}
}
@@ -912,6 +920,8 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let mut tmp_zero_encryption =
LweCiphertext::new(Scalar::ZERO, lwe_public_key.lwe_size(), ciphertext_modulus);
@@ -926,7 +936,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
mask.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
*body.data = *public_encryption_of_zero_body.data;
@@ -945,7 +955,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
*body.data = (*body.data).wrapping_add(
encoded
.0
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
}
}
@@ -1473,3 +1483,791 @@ where
seeded_ct
}
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`) using an
/// [`LWE compact public key`](`LweCompactPublicKey`). The ciphertext can be decrypted using the
/// [`LWE secret key`](`LweSecretKey`) that was used to generate the public key.
///
/// # Example
///
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweCiphertext creation
/// let lwe_dimension = LweDimension(2048);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
///
/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key(
/// &lwe_secret_key,
/// glwe_modular_std_dev,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let mut lwe = LweCiphertext::new(0u64, lwe_dimension.to_lwe_size(), ciphertext_modulus);
///
/// encrypt_lwe_ciphertext_with_compact_public_key(
/// &lwe_compact_public_key,
/// &mut lwe,
/// plaintext,
/// glwe_modular_std_dev,
/// glwe_modular_std_dev,
/// &mut secret_generator,
/// &mut encryption_generator,
/// );
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&lwe_secret_key, &lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn encrypt_lwe_ciphertext_with_compact_public_key<
Scalar,
KeyCont,
OutputCont,
SecretGen,
EncryptionGen,
>(
lwe_compact_public_key: &LweCompactPublicKey<KeyCont>,
output: &mut LweCiphertext<OutputCont>,
encoded: Plaintext<Scalar>,
mask_noise_parameters: impl DispersionParameter,
body_noise_parameters: impl DispersionParameter,
secret_generator: &mut SecretRandomGenerator<SecretGen>,
encryption_generator: &mut EncryptionRandomGenerator<EncryptionGen>,
) where
Scalar: UnsignedTorus,
KeyCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
SecretGen: ByteRandomGenerator,
EncryptionGen: ByteRandomGenerator,
{
assert!(
output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(),
"Mismatch between LweDimension of output cipertext and input public key. \
Got {:?} in output, and {:?} in public key.",
output.lwe_size().to_lwe_dimension(),
lwe_compact_public_key.lwe_dimension()
);
assert!(
lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(),
"Mismatch between CiphertextModulus of output cipertext and input public key. \
Got {:?} in output, and {:?} in public key.",
output.ciphertext_modulus(),
lwe_compact_public_key.ciphertext_modulus()
);
assert!(
output.ciphertext_modulus().is_native_modulus(),
"This operation only supports native moduli"
);
let mut binary_random_vector = vec![Scalar::ZERO; lwe_compact_public_key.lwe_dimension().0];
secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector);
let (mut ct_mask, ct_body) = output.get_mut_mask_and_body();
let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body();
slice_semi_reverse_negacyclic_convolution(
ct_mask.as_mut(),
pk_mask.as_ref(),
&binary_random_vector,
);
// Noise from Chi_1 for the mask part of the encryption
encryption_generator.unsigned_torus_slice_wrapping_add_random_noise_assign(
ct_mask.as_mut(),
mask_noise_parameters,
);
*ct_body.data = slice_wrapping_dot_product(pk_body.as_ref(), &binary_random_vector);
// Noise from Chi_2 for the body part of the encryption
*ct_body.data =
(*ct_body.data).wrapping_add(encryption_generator.random_noise(body_noise_parameters));
*ct_body.data = (*ct_body.data).wrapping_add(encoded.0);
}
/// Encrypt an input plaintext list in an output [`LWE compact ciphertext
/// list`](`LweCompactCiphertextList`) using an [`LWE compact public key`](`LweCompactPublicKey`).
/// The expanded ciphertext list can be decrypted using the [`LWE secret key`](`LweSecretKey`) that
/// was used to generate the public key.
///
/// # Example
///
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweCiphertext creation
/// let lwe_dimension = LweDimension(2048);
/// let lwe_ciphertext_count = LweCiphertextCount(lwe_dimension.0 * 4);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
///
/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key(
/// &lwe_secret_key,
/// glwe_modular_std_dev,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut input_plaintext_list = PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0));
/// input_plaintext_list
/// .iter_mut()
/// .enumerate()
/// .for_each(|(idx, x)| {
/// *x.0 = (idx as u64 % 16) << 60;
/// });
///
/// // Create a new LweCompactCiphertextList
/// let mut output_compact_ct_list = LweCompactCiphertextList::new(
/// 0u64,
/// lwe_dimension.to_lwe_size(),
/// lwe_ciphertext_count,
/// ciphertext_modulus,
/// );
///
/// encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
/// &lwe_compact_public_key,
/// &mut output_compact_ct_list,
/// &input_plaintext_list,
/// glwe_modular_std_dev,
/// glwe_modular_std_dev,
/// &mut secret_generator,
/// &mut encryption_generator,
/// );
///
/// let mut output_plaintext_list = input_plaintext_list.clone();
/// output_plaintext_list.as_mut().fill(0u64);
///
/// let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list();
///
/// decrypt_lwe_ciphertext_list(
/// &lwe_secret_key,
/// &lwe_ciphertext_list,
/// &mut output_plaintext_list,
/// );
///
/// let signed_decomposer =
/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// // Round the plaintexts
/// output_plaintext_list
/// .iter_mut()
/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
///
/// // Check we recovered the original messages
/// assert_eq!(input_plaintext_list, output_plaintext_list);
/// ```
pub fn encrypt_lwe_compact_ciphertext_list_with_compact_public_key<
Scalar,
KeyCont,
InputCont,
OutputCont,
SecretGen,
EncryptionGen,
>(
lwe_compact_public_key: &LweCompactPublicKey<KeyCont>,
output: &mut LweCompactCiphertextList<OutputCont>,
encoded: &PlaintextList<InputCont>,
mask_noise_parameters: impl DispersionParameter,
body_noise_parameters: impl DispersionParameter,
secret_generator: &mut SecretRandomGenerator<SecretGen>,
encryption_generator: &mut EncryptionRandomGenerator<EncryptionGen>,
) where
Scalar: UnsignedTorus,
KeyCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
SecretGen: ByteRandomGenerator,
EncryptionGen: ByteRandomGenerator,
{
assert!(
output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(),
"Mismatch between LweDimension of output cipertext and input public key. \
Got {:?} in output, and {:?} in public key.",
output.lwe_size().to_lwe_dimension(),
lwe_compact_public_key.lwe_dimension()
);
assert!(
lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(),
"Mismatch between CiphertextModulus of output cipertext and input public key. \
Got {:?} in output, and {:?} in public key.",
output.ciphertext_modulus(),
lwe_compact_public_key.ciphertext_modulus()
);
assert!(
output.lwe_ciphertext_count().0 == encoded.plaintext_count().0,
"Mismatch between LweCiphertextCount of output cipertext and \
PlaintextCount of input list. Got {:?} in output, and {:?} in input plaintext list.",
output.lwe_ciphertext_count(),
encoded.plaintext_count()
);
assert!(
output.ciphertext_modulus().is_native_modulus(),
"This operation only supports native moduli"
);
let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list();
let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body();
let lwe_mask_count = output_mask_list.lwe_mask_count();
let lwe_dimension = output_mask_list.lwe_dimension();
let mut binary_random_vector = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()];
secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector);
let max_ciphertext_per_bin = lwe_dimension.0;
let gen_iter = encryption_generator
.fork_lwe_compact_ciphertext_list_to_bin::<Scalar>(lwe_mask_count, lwe_dimension)
.expect("Failed to split generator into lwe compact ciphertext bins");
// Loop over the ciphertext "bins"
output_mask_list
.iter_mut()
.zip(
output_body_list
.chunks_mut(max_ciphertext_per_bin)
.zip(encoded.chunks(max_ciphertext_per_bin))
.zip(binary_random_vector.chunks(max_ciphertext_per_bin))
.zip(gen_iter),
)
.for_each(
|(
mut output_mask,
(
((mut output_body_chunk, input_plaintext_chunk), binary_random_slice),
mut loop_generator,
),
)| {
// output_body_chunk may not be able to fit the full convolution result so we create
// a temp buffer to compute the full convolution
let mut pk_body_convolved = vec![Scalar::ZERO; lwe_dimension.0];
slice_semi_reverse_negacyclic_convolution(
output_mask.as_mut(),
pk_mask.as_ref(),
binary_random_slice,
);
// Fill the temp buffer with b convolved with r
slice_semi_reverse_negacyclic_convolution(
pk_body_convolved.as_mut_slice(),
pk_body.as_ref(),
binary_random_slice,
);
// Noise from Chi_1 for the mask part of the encryption
loop_generator.unsigned_torus_slice_wrapping_add_random_noise_assign(
output_mask.as_mut(),
mask_noise_parameters,
);
// Fill the body chunk afterwards manually as it most likely will be smaller than
// the full convolution result. b convolved r + Delta * m + e2
// taking noise from Chi_2 for the body part of the encryption
output_body_chunk
.iter_mut()
.zip(pk_body_convolved.iter().zip(input_plaintext_chunk.iter()))
.for_each(|(dst, (&src, plaintext))| {
*dst.data = src
.wrapping_add(loop_generator.random_noise(body_noise_parameters))
.wrapping_add(*plaintext.0)
});
},
);
}
/// Parallel variant of [`encrypt_lwe_compact_ciphertext_list_with_compact_public_key`]. Encrypt an
/// input plaintext list in an output [`LWE compact ciphertext list`](`LweCompactCiphertextList`)
/// using an [`LWE compact public key`](`LweCompactPublicKey`). The expanded ciphertext list can be
/// decrypted using the [`LWE secret key`](`LweSecretKey`) that was used to generate the public key.
///
/// # Example
///
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweCiphertext creation
/// let lwe_dimension = LweDimension(2048);
/// let lwe_ciphertext_count = LweCiphertextCount(lwe_dimension.0 * 4);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
///
/// let lwe_compact_public_key = allocate_and_generate_new_lwe_compact_public_key(
/// &lwe_secret_key,
/// glwe_modular_std_dev,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut input_plaintext_list = PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0));
/// input_plaintext_list
/// .iter_mut()
/// .enumerate()
/// .for_each(|(idx, x)| {
/// *x.0 = (idx as u64 % 16) << 60;
/// });
///
/// // Create a new LweCompactCiphertextList
/// let mut output_compact_ct_list = LweCompactCiphertextList::new(
/// 0u64,
/// lwe_dimension.to_lwe_size(),
/// lwe_ciphertext_count,
/// ciphertext_modulus,
/// );
///
/// par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
/// &lwe_compact_public_key,
/// &mut output_compact_ct_list,
/// &input_plaintext_list,
/// glwe_modular_std_dev,
/// glwe_modular_std_dev,
/// &mut secret_generator,
/// &mut encryption_generator,
/// );
///
/// let mut output_plaintext_list = input_plaintext_list.clone();
/// output_plaintext_list.as_mut().fill(0u64);
///
/// let lwe_ciphertext_list = output_compact_ct_list.par_expand_into_lwe_ciphertext_list();
///
/// decrypt_lwe_ciphertext_list(
/// &lwe_secret_key,
/// &lwe_ciphertext_list,
/// &mut output_plaintext_list,
/// );
///
/// let signed_decomposer =
/// SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// // Round the plaintexts
/// output_plaintext_list
/// .iter_mut()
/// .for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
///
/// // Check we recovered the original messages
/// assert_eq!(input_plaintext_list, output_plaintext_list);
/// ```
pub fn par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key<
Scalar,
KeyCont,
InputCont,
OutputCont,
SecretGen,
EncryptionGen,
>(
lwe_compact_public_key: &LweCompactPublicKey<KeyCont>,
output: &mut LweCompactCiphertextList<OutputCont>,
encoded: &PlaintextList<InputCont>,
mask_noise_parameters: impl DispersionParameter + Sync,
body_noise_parameters: impl DispersionParameter + Sync,
secret_generator: &mut SecretRandomGenerator<SecretGen>,
encryption_generator: &mut EncryptionRandomGenerator<EncryptionGen>,
) where
Scalar: UnsignedTorus + Sync + Send,
KeyCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
SecretGen: ByteRandomGenerator,
EncryptionGen: ParallelByteRandomGenerator,
{
assert!(
output.lwe_size().to_lwe_dimension() == lwe_compact_public_key.lwe_dimension(),
"Mismatch between LweDimension of output cipertext and input public key. \
Got {:?} in output, and {:?} in public key.",
output.lwe_size().to_lwe_dimension(),
lwe_compact_public_key.lwe_dimension()
);
assert!(
lwe_compact_public_key.ciphertext_modulus() == output.ciphertext_modulus(),
"Mismatch between CiphertextModulus of output cipertext and input public key. \
Got {:?} in output, and {:?} in public key.",
output.ciphertext_modulus(),
lwe_compact_public_key.ciphertext_modulus()
);
assert!(
output.lwe_ciphertext_count().0 == encoded.plaintext_count().0,
"Mismatch between LweCiphertextCount of output cipertext and \
PlaintextCount of input list. Got {:?} in output, and {:?} in input plaintext list.",
output.lwe_ciphertext_count(),
encoded.plaintext_count()
);
assert!(
output.ciphertext_modulus().is_native_modulus(),
"This operation only supports native moduli"
);
let (mut output_mask_list, mut output_body_list) = output.get_mut_mask_and_body_list();
let (pk_mask, pk_body) = lwe_compact_public_key.get_mask_and_body();
let lwe_mask_count = output_mask_list.lwe_mask_count();
let lwe_dimension = output_mask_list.lwe_dimension();
let mut binary_random_vector = vec![Scalar::ZERO; output_mask_list.lwe_mask_list_size()];
secret_generator.fill_slice_with_random_uniform_binary(&mut binary_random_vector);
let max_ciphertext_per_bin = lwe_dimension.0;
let gen_iter = encryption_generator
.par_fork_lwe_compact_ciphertext_list_to_bin::<Scalar>(lwe_mask_count, lwe_dimension)
.expect("Failed to split generator into lwe compact ciphertext bins");
// Loop over the ciphertext "bins"
output_mask_list
.par_iter_mut()
.zip(
output_body_list
.par_chunks_mut(max_ciphertext_per_bin)
.zip(encoded.par_chunks(max_ciphertext_per_bin))
.zip(binary_random_vector.par_chunks(max_ciphertext_per_bin))
.zip(gen_iter),
)
.for_each(
|(
mut output_mask,
(
((mut output_body_chunk, input_plaintext_chunk), binary_random_slice),
mut loop_generator,
),
)| {
// output_body_chunk may not be able to fit the full convolution result so we create
// a temp buffer to compute the full convolution
let mut pk_body_convolved = vec![Scalar::ZERO; lwe_dimension.0];
rayon::join(
|| {
slice_semi_reverse_negacyclic_convolution(
output_mask.as_mut(),
pk_mask.as_ref(),
binary_random_slice,
);
},
|| {
// Fill the temp buffer with b convolved with r
slice_semi_reverse_negacyclic_convolution(
pk_body_convolved.as_mut_slice(),
pk_body.as_ref(),
binary_random_slice,
);
},
);
// Noise from Chi_1 for the mask part of the encryption
loop_generator.unsigned_torus_slice_wrapping_add_random_noise_assign(
output_mask.as_mut(),
mask_noise_parameters,
);
// Fill the body chunk afterwards manually as it most likely will be smaller than
// the full convolution result. b convolved r + Delta * m + e2
// taking noise from Chi_2 for the body part of the encryption
output_body_chunk
.iter_mut()
.zip(pk_body_convolved.iter().zip(input_plaintext_chunk.iter()))
.for_each(|(dst, (&src, plaintext))| {
*dst.data = src
.wrapping_add(loop_generator.random_noise(body_noise_parameters))
.wrapping_add(*plaintext.0)
});
},
);
}
#[cfg(test)]
mod test {
use crate::core_crypto::commons::test_tools;
use crate::core_crypto::prelude::*;
use crate::core_crypto::commons::generators::{
DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator,
};
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
#[test]
fn test_compact_public_key_encryption() {
use rand::Rng;
let lwe_dimension = LweDimension(2048);
let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
let ciphertext_modulus = CiphertextModulus::new_native();
let mut secret_random_generator = test_tools::new_secret_random_generator();
let mut encryption_random_generator = test_tools::new_encryption_random_generator();
let mut thread_rng = rand::thread_rng();
for _ in 0..10_000 {
let lwe_sk =
LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator);
let mut compact_lwe_pk =
LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus);
generate_lwe_compact_public_key(
&lwe_sk,
&mut compact_lwe_pk,
glwe_modular_std_dev,
&mut encryption_random_generator,
);
let msg: u64 = thread_rng.gen();
let msg = msg % 16;
let plaintext = Plaintext(msg << 60);
let mut output_ct = LweCiphertext::new(
0u64,
lwe_dimension.to_lwe_size(),
CiphertextModulus::new_native(),
);
encrypt_lwe_ciphertext_with_compact_public_key(
&compact_lwe_pk,
&mut output_ct,
plaintext,
glwe_modular_std_dev,
glwe_modular_std_dev,
&mut secret_random_generator,
&mut encryption_random_generator,
);
let decrypted_plaintext = decrypt_lwe_ciphertext(&lwe_sk, &output_ct);
let signed_decomposer =
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
let cleartext = signed_decomposer.closest_representable(decrypted_plaintext.0) >> 60;
assert_eq!(cleartext, msg);
}
}
#[test]
fn test_par_compact_lwe_list_public_key_encryption_equivalence() {
use rand::Rng;
let lwe_dimension = LweDimension(2048);
let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
let ciphertext_modulus = CiphertextModulus::new_native();
let mut thread_rng = rand::thread_rng();
for _ in 0..100 {
// We'll encrypt between 1 and 4 * lwe_dimension ciphertexts
let ct_count: usize = thread_rng.gen();
let ct_count = ct_count % (lwe_dimension.0 * 4) + 1;
let lwe_ciphertext_count = LweCiphertextCount(ct_count);
println!("{lwe_dimension:?} {ct_count:?}");
let seed = test_tools::random_seed();
let mut input_plaintext_list =
PlaintextList::new(0u64, PlaintextCount(lwe_ciphertext_count.0));
input_plaintext_list.iter_mut().for_each(|x| {
let msg: u64 = thread_rng.gen();
*x.0 = (msg % 16) << 60;
});
let par_lwe_ct_list = {
let mut deterministic_seeder =
DeterministicSeeder::<ActivatedRandomGenerator>::new(seed);
let mut secret_random_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(
deterministic_seeder.seed(),
);
let mut encryption_random_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
deterministic_seeder.seed(),
&mut deterministic_seeder,
);
let lwe_sk =
LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator);
let mut compact_lwe_pk =
LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus);
generate_lwe_compact_public_key(
&lwe_sk,
&mut compact_lwe_pk,
glwe_modular_std_dev,
&mut encryption_random_generator,
);
let mut output_compact_ct_list = LweCompactCiphertextList::new(
0u64,
lwe_dimension.to_lwe_size(),
lwe_ciphertext_count,
ciphertext_modulus,
);
par_encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
&compact_lwe_pk,
&mut output_compact_ct_list,
&input_plaintext_list,
glwe_modular_std_dev,
glwe_modular_std_dev,
&mut secret_random_generator,
&mut encryption_random_generator,
);
let mut output_plaintext_list = input_plaintext_list.clone();
output_plaintext_list.as_mut().fill(0u64);
let lwe_ciphertext_list =
output_compact_ct_list.par_expand_into_lwe_ciphertext_list();
decrypt_lwe_ciphertext_list(
&lwe_sk,
&lwe_ciphertext_list,
&mut output_plaintext_list,
);
let signed_decomposer =
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
output_plaintext_list
.iter_mut()
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
assert_eq!(input_plaintext_list, output_plaintext_list);
lwe_ciphertext_list
};
let ser_lwe_ct_list = {
let mut deterministic_seeder =
DeterministicSeeder::<ActivatedRandomGenerator>::new(seed);
let mut secret_random_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(
deterministic_seeder.seed(),
);
let mut encryption_random_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
deterministic_seeder.seed(),
&mut deterministic_seeder,
);
let lwe_sk =
LweSecretKey::generate_new_binary(lwe_dimension, &mut secret_random_generator);
let mut compact_lwe_pk =
LweCompactPublicKey::new(0u64, lwe_dimension, ciphertext_modulus);
generate_lwe_compact_public_key(
&lwe_sk,
&mut compact_lwe_pk,
glwe_modular_std_dev,
&mut encryption_random_generator,
);
let mut output_compact_ct_list = LweCompactCiphertextList::new(
0u64,
lwe_dimension.to_lwe_size(),
lwe_ciphertext_count,
ciphertext_modulus,
);
encrypt_lwe_compact_ciphertext_list_with_compact_public_key(
&compact_lwe_pk,
&mut output_compact_ct_list,
&input_plaintext_list,
glwe_modular_std_dev,
glwe_modular_std_dev,
&mut secret_random_generator,
&mut encryption_random_generator,
);
let mut output_plaintext_list = input_plaintext_list.clone();
output_plaintext_list.as_mut().fill(0u64);
let lwe_ciphertext_list = output_compact_ct_list.expand_into_lwe_ciphertext_list();
decrypt_lwe_ciphertext_list(
&lwe_sk,
&lwe_ciphertext_list,
&mut output_plaintext_list,
);
let signed_decomposer =
SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
output_plaintext_list
.iter_mut()
.for_each(|x| *x.0 = signed_decomposer.closest_representable(*x.0));
assert_eq!(input_plaintext_list, output_plaintext_list);
lwe_ciphertext_list
};
assert_eq!(ser_lwe_ct_list, par_lwe_ct_list);
}
}
}

View File

@@ -93,6 +93,7 @@ pub fn generate_lwe_keyswitch_key<Scalar, InputKeyCont, OutputKeyCont, KSKeyCont
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
@@ -115,7 +116,7 @@ pub fn generate_lwe_keyswitch_key<Scalar, InputKeyCont, OutputKeyCont, KSKeyCont
// of mapping that back to the native torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
encrypt_lwe_ciphertext_list(
@@ -255,6 +256,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
@@ -282,7 +284,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
// of mapping that back to the native torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
encrypt_seeded_lwe_ciphertext_list_with_existing_generator(

View File

@@ -238,12 +238,13 @@ pub fn lwe_ciphertext_plaintext_add_assign<Scalar, InCont>(
{
let body = lhs.get_mut_body();
let ciphertext_modulus = body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if ciphertext_modulus.is_native_modulus() {
*body.data = (*body.data).wrapping_add(rhs.0);
} else {
*body.data = (*body.data).wrapping_add(
rhs.0
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
}
}
@@ -313,12 +314,13 @@ pub fn lwe_ciphertext_plaintext_sub_assign<Scalar, InCont>(
{
let body = lhs.get_mut_body();
let ciphertext_modulus = body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if ciphertext_modulus.is_native_modulus() {
*body.data = (*body.data).wrapping_sub(rhs.0);
} else {
*body.data = (*body.data).wrapping_sub(
rhs.0
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
}
}

View File

@@ -4,6 +4,7 @@
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::dispersion::DispersionParameter;
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
@@ -450,3 +451,324 @@ where
bsk
}
/// Fill a [`seeded LWE bootstrap key`](`SeededLweMultiBitBootstrapKey`) with an actual seeded
/// bootstrapping key constructed from an input key [`LWE secret key`](`LweSecretKey`) and an output
/// key [`GLWE secret key`](`GlweSecretKey`)
///
/// Consider using [`par_generate_seeded_lwe_multi_bit_bootstrap_key`] for better key generation
/// times.
#[allow(clippy::too_many_arguments)]
pub fn generate_seeded_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
OutputCont,
NoiseSeeder,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
output: &mut SeededLweMultiBitBootstrapKey<OutputCont>,
noise_parameters: impl DispersionParameter,
noise_seeder: &mut NoiseSeeder,
) where
Scalar: UnsignedTorus + CastFrom<usize>,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
assert!(
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
input_lwe_secret_key.lwe_dimension(),
output.input_lwe_dimension()
);
assert!(
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output.glwe_size()
);
assert!(
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
output_glwe_secret_key.polynomial_size(),
output.polynomial_size()
);
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
output.compression_seed().seed,
noise_seeder,
);
assert!(
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
input_lwe_secret_key.lwe_dimension(),
output.input_lwe_dimension()
);
assert!(
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output.glwe_size()
);
assert!(
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
output_glwe_secret_key.polynomial_size(),
output.polynomial_size()
);
let gen_iter = generator
.fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
output.input_lwe_dimension(),
output.decomposition_level_count(),
output.glwe_size(),
output.polynomial_size(),
output.grouping_factor(),
)
.unwrap();
let grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.zip(
input_lwe_secret_key
.as_ref()
.chunks_exact(grouping_factor.0),
)
.zip(gen_iter)
{
let gen_iter = loop_generator.fork_n(ggsw_per_multi_bit_element.0).unwrap();
for ((bit_inversion_idx, mut ggsw), mut inner_loop_generator) in
ggsw_group.iter_mut().enumerate().zip(gen_iter)
{
// Use the index of the ggsw as a way to know which bit to invert
let key_bits_plaintext = combine_key_bits(bit_inversion_idx, input_key_elements);
encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator(
output_glwe_secret_key,
&mut ggsw,
Plaintext(key_bits_plaintext),
noise_parameters,
&mut inner_loop_generator,
);
}
}
}
/// Allocate a new [`seeded LWE bootstrap key`](`SeededLweMultiBitBootstrapKey`) and fill it with an
/// actual seeded bootstrapping key constructed from an input key [`LWE secret key`](`LweSecretKey`)
/// and an output key [`GLWE secret key`](`GlweSecretKey`)
///
/// Consider using [`par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key`] for better
/// key generation times.
#[allow(clippy::too_many_arguments)]
pub fn allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
NoiseSeeder,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_parameters: impl DispersionParameter,
grouping_factor: LweBskGroupingFactor,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
where
Scalar: UnsignedTorus + CastFrom<usize>,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
let mut bsk = SeededLweMultiBitBootstrapKeyOwned::new(
Scalar::ZERO,
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output_glwe_secret_key.polynomial_size(),
decomp_base_log,
decomp_level_count,
input_lwe_secret_key.lwe_dimension(),
grouping_factor,
noise_seeder.seed().into(),
ciphertext_modulus,
);
generate_seeded_lwe_multi_bit_bootstrap_key(
input_lwe_secret_key,
output_glwe_secret_key,
&mut bsk,
noise_parameters,
noise_seeder,
);
bsk
}
/// Parallel variant of [`generate_seeded_lwe_multi_bit_bootstrap_key`], it is recommended to use
/// this function for better key generation times as LWE bootstrapping keys can be quite large.
#[allow(clippy::too_many_arguments)]
pub fn par_generate_seeded_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
OutputCont,
NoiseSeeder,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
output: &mut SeededLweMultiBitBootstrapKey<OutputCont>,
noise_parameters: impl DispersionParameter + Sync,
noise_seeder: &mut NoiseSeeder,
) where
Scalar: UnsignedTorus + CastFrom<usize> + Sync + Send,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
assert!(
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
input_lwe_secret_key.lwe_dimension(),
output.input_lwe_dimension()
);
assert!(
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output.glwe_size()
);
assert!(
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
output_glwe_secret_key.polynomial_size(),
output.polynomial_size()
);
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
output.compression_seed().seed,
noise_seeder,
);
let gen_iter = generator
.par_fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
output.input_lwe_dimension(),
output.decomposition_level_count(),
output.glwe_size(),
output.polynomial_size(),
output.grouping_factor(),
)
.unwrap();
let grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
output
.par_iter_mut()
.chunks(ggsw_per_multi_bit_element.0)
.zip(
input_lwe_secret_key
.as_ref()
.par_chunks_exact(grouping_factor.0),
)
.zip(gen_iter)
.for_each(
|((mut ggsw_group, input_key_elements), mut loop_generator)| {
let gen_iter = loop_generator
.par_fork_n(ggsw_per_multi_bit_element.0)
.unwrap();
ggsw_group
.par_iter_mut()
.enumerate()
.zip(gen_iter)
.for_each(|((bit_inversion_idx, ggsw), mut inner_loop_generator)| {
// Use the index of the ggsw as a way to know which bit to invert
let key_bits_plaintext =
combine_key_bits(bit_inversion_idx, input_key_elements);
par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator(
output_glwe_secret_key,
ggsw,
Plaintext(key_bits_plaintext),
noise_parameters,
&mut inner_loop_generator,
);
});
},
);
}
/// Parallel variant of [`allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key`], it is
/// recommended to use this function for better key generation times as LWE bootstrapping keys can
/// be quite large.
#[allow(clippy::too_many_arguments)]
pub fn par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
NoiseSeeder,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_parameters: impl DispersionParameter + Sync,
grouping_factor: LweBskGroupingFactor,
ciphertext_modulus: CiphertextModulus<Scalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweMultiBitBootstrapKeyOwned<Scalar>
where
Scalar: UnsignedTorus + CastFrom<usize> + Sync + Send,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar> + Sync,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
let mut bsk = SeededLweMultiBitBootstrapKeyOwned::new(
Scalar::ZERO,
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output_glwe_secret_key.polynomial_size(),
decomp_base_log,
decomp_level_count,
input_lwe_secret_key.lwe_dimension(),
grouping_factor,
noise_seeder.seed().into(),
ciphertext_modulus,
);
par_generate_seeded_lwe_multi_bit_bootstrap_key(
input_lwe_secret_key,
output_glwe_secret_key,
&mut bsk,
noise_parameters,
noise_seeder,
);
bsk
}

View File

@@ -1,5 +1,6 @@
use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext;
use crate::core_crypto::algorithms::polynomial_algorithms::*;
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::parameters::*;
@@ -9,11 +10,89 @@ use crate::core_crypto::fft_impl::common::pbs_modulus_switch;
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
add_external_product_assign, add_external_product_assign_scratch, update_with_fmadd,
};
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
use concrete_fft::c64;
use std::sync::{mpsc, Condvar, Mutex};
use std::thread;
pub fn prepare_multi_bit_ggsw_mem_optimized<
Scalar,
GgswBufferCont,
GgswGroupCont,
PolyCont,
FourierPolyCont,
>(
fourier_ggsw_buffer: &mut FourierGgswCiphertext<GgswBufferCont>,
ggsw_group: &[FourierGgswCiphertext<GgswGroupCont>],
lwe_mask_elements: &[Scalar],
a_monomial: &mut Polynomial<PolyCont>,
fourier_a_monomial: &mut FourierPolynomial<FourierPolyCont>,
fft: FftView<'_>,
buffers: &mut ComputationBuffers,
) where
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
GgswBufferCont: ContainerMut<Element = c64>,
GgswGroupCont: Container<Element = c64>,
PolyCont: ContainerMut<Element = Scalar>,
FourierPolyCont: ContainerMut<Element = c64>,
{
let mut ggsw_group_iter = ggsw_group.iter();
// Keygen guarantees the first term is a constant term of the polynomial, no
// polynomial multiplication required
let ggsw_a_none = ggsw_group_iter.next().unwrap();
fourier_ggsw_buffer
.as_mut_view()
.data()
.copy_from_slice(ggsw_a_none.as_view().data());
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
let polynomial_size = a_monomial.polynomial_size();
for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() {
// We already processed the first ggsw, advance the index by 1
let ggsw_idx = ggsw_idx + 1;
// Select the proper mask elements to build the monomial degree depending on
// the order the GGSW were generated in, using the bits from mask_idx and
// ggsw_idx as selector bits
let mut monomial_degree = Scalar::ZERO;
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
let selection_bit: Scalar = Scalar::cast_from((ggsw_idx >> mask_position) & 1);
monomial_degree =
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}
let switched_degree = pbs_modulus_switch(
monomial_degree,
polynomial_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
a_monomial.as_mut()[0] = Scalar::ONE;
a_monomial.as_mut()[1..].fill(Scalar::ZERO);
polynomial_wrapping_monic_monomial_mul_assign(a_monomial, MonomialDegree(switched_degree));
fft.forward_as_integer(
fourier_a_monomial.as_mut_view(),
a_monomial.as_view(),
buffers.stack(),
);
update_with_fmadd(
multi_bit_fourier_ggsw,
fourier_ggsw.as_view().data(),
fourier_a_monomial.as_view().data,
false,
polynomial_size.to_fourier_polynomial_size().0,
);
}
}
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
/// key`](`LweMultiBitBootstrapKey`) in the fourier domain.
@@ -270,6 +349,15 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
accumulator.ciphertext_modulus(),
);
assert!(
thread_count.0 != 0,
"Got thread_count == 0, this is not supported"
);
assert!(accumulator
.ciphertext_modulus()
.is_compatible_with_native_modulus());
let (lwe_mask, lwe_body) = input.get_mask_and_body();
// No way to chunk the result of ggsw_iter at the moment
@@ -338,10 +426,7 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
let mut unit_polynomial =
Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
unit_polynomial.as_mut()[0] = Scalar::ONE;
let mut a_monomial = unit_polynomial.clone();
let mut a_monomial = Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
let mut fourier_a_monomial = FourierPolynomial::new(multi_bit_bsk.polynomial_size());
let work_queue = &work_queue;
@@ -367,64 +452,15 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
let mut ggsw_group_iter = ggsw_group.iter();
// Keygen guarantees the first term is a constant term of the polynomial, no
// polynomial multiplication required
let ggsw_a_none = ggsw_group_iter.next().unwrap();
fourier_ggsw_buffer
.as_mut_view()
.data()
.copy_from_slice(ggsw_a_none.as_view().data());
let multi_bit_fourier_ggsw = fourier_ggsw_buffer.as_mut_view().data();
for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() {
// We already processed the first ggsw, advance the index by 1
let ggsw_idx = ggsw_idx + 1;
// Select the proper mask elements to build the monomial degree depending on
// the order the GGSW were generated in, using the bits from mask_idx and
// ggsw_idx as selector bits
let mut monomial_degree = Scalar::ZERO;
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
let selection_bit: Scalar =
Scalar::cast_from((ggsw_idx >> mask_position) & 1);
monomial_degree =
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}
let switched_degree = pbs_modulus_switch(
monomial_degree,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
a_monomial
.as_mut()
.copy_from_slice(unit_polynomial.as_ref());
polynomial_wrapping_monic_monomial_mul_assign(
&mut a_monomial,
MonomialDegree(switched_degree),
);
fft.forward_as_integer(
fourier_a_monomial.as_mut_view(),
a_monomial.as_view(),
buffers.stack(),
);
update_with_fmadd(
multi_bit_fourier_ggsw,
fourier_ggsw.as_view().data(),
fourier_a_monomial.as_view().data,
false,
lut_poly_size.to_fourier_polynomial_size().0,
);
}
prepare_multi_bit_ggsw_mem_optimized(
&mut fourier_ggsw_buffer,
ggsw_group,
lwe_mask_elements,
&mut a_monomial,
&mut fourier_a_monomial,
fft,
&mut buffers,
);
// Drop the lock before we wake other threads
drop(fourier_ggsw_buffer);
@@ -801,6 +837,11 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
accumulator.ciphertext_modulus(),
);
assert!(
thread_count.0 != 0,
"Got thread_count == 0, this is not supported"
);
let mut local_accumulator = GlweCiphertext::new(
Scalar::ZERO,
accumulator.glwe_size(),
@@ -815,3 +856,430 @@ pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}
pub fn std_prepare_multi_bit_ggsw<Scalar, GgswBufferCont, TmpGgswBufferCont, GgswGroupCont>(
multi_bit_ggsw: &mut GgswCiphertext<GgswBufferCont>,
tmp_ggsw_buffer: &mut GgswCiphertext<TmpGgswBufferCont>,
ggsw_group: &GgswCiphertextList<GgswGroupCont>,
lwe_mask_elements: &[Scalar],
) where
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize>,
GgswBufferCont: ContainerMut<Element = Scalar>,
TmpGgswBufferCont: ContainerMut<Element = Scalar>,
GgswGroupCont: Container<Element = Scalar>,
{
let mut ggsw_group_iter = ggsw_group.iter();
// Keygen guarantees the first term is a constant term of the polynomial, no
// polynomial multiplication required
let ggsw_a_none = ggsw_group_iter.next().unwrap();
multi_bit_ggsw
.as_mut()
.copy_from_slice(ggsw_a_none.as_ref());
let polynomial_size = multi_bit_ggsw.polynomial_size();
for (ggsw_idx, std_ggsw) in ggsw_group_iter.enumerate() {
// We already processed the first ggsw, advance the index by 1
let ggsw_idx = ggsw_idx + 1;
// Select the proper mask elements to build the monomial degree depending on
// the order the GGSW were generated in, using the bits from mask_idx and
// ggsw_idx as selector bits
let mut monomial_degree = Scalar::ZERO;
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
let selection_bit: Scalar = Scalar::cast_from((ggsw_idx >> mask_position) & 1);
monomial_degree =
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}
let switched_degree = pbs_modulus_switch(
monomial_degree,
polynomial_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
tmp_ggsw_buffer
.as_mut_polynomial_list()
.iter_mut()
.zip(std_ggsw.as_polynomial_list().iter())
.for_each(|(mut tmp_polynomial, input_polynomial)| {
polynomial_wrapping_monic_monomial_mul(
&mut tmp_polynomial,
&input_polynomial,
MonomialDegree(switched_degree),
);
});
slice_wrapping_add_assign(multi_bit_ggsw.as_mut(), tmp_ggsw_buffer.as_ref());
}
}
pub fn std_multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
accumulator: &mut GlweCiphertext<OutputCont>,
multi_bit_bsk: &LweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync + Send,
InputCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
KeyCont: Container<Element = Scalar> + Sync,
{
assert_eq!(
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
"Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \
FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
);
assert_eq!(
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
"Mimatched GlweSize. Accumulator GlweSize {:?}. \
FourierLweMultiBitBootstrapKey GlweSize {:?}.",
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
);
assert_eq!(
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
"Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \
FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
);
assert_eq!(
accumulator.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
"Mimatched CiphertextModulus. Accumulator CiphertextModulus {:?}. \
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
accumulator.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
);
assert_eq!(
input.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
"Mimatched CiphertextModulus. LweCiphertext CiphertextModulus {:?}. \
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
input.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
);
let (lwe_mask, lwe_body) = input.get_mask_and_body();
// No way to chunk the result of ggsw_iter at the moment
// let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();
let mut work_queue = Vec::with_capacity(multi_bit_bsk.multi_bit_input_lwe_dimension().0);
let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
for (lwe_mask_elements, ggsw_group) in lwe_mask
.as_ref()
.chunks_exact(grouping_factor.0)
.zip(multi_bit_bsk.chunks_exact(ggsw_per_multi_bit_element.0))
{
work_queue.push((lwe_mask_elements, ggsw_group));
}
assert!(work_queue.len() == lwe_mask.lwe_dimension().0 / grouping_factor.0);
let work_queue = Mutex::new(work_queue);
// Each producer thread works in a dedicated slot of the buffer
let thread_buffers: usize = thread_count.0;
let lut_poly_size = accumulator.polynomial_size();
let monomial_degree = pbs_modulus_switch(
*lwe_body.data,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
// Modulus switching
accumulator
.as_mut_polynomial_list()
.iter_mut()
.for_each(|mut poly| {
polynomial_wrapping_monic_monomial_div_assign(
&mut poly,
MonomialDegree(monomial_degree),
)
});
let fourier_multi_bit_ggsw_buffers = (0..thread_buffers)
.map(|_| {
(
Mutex::new(false),
Condvar::new(),
Mutex::new(FourierGgswCiphertext::new(
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
multi_bit_bsk.decomposition_base_log(),
multi_bit_bsk.decomposition_level_count(),
)),
)
})
.collect::<Vec<_>>();
let (tx, rx) = mpsc::channel::<usize>();
let fft = Fft::new(multi_bit_bsk.polynomial_size());
let fft = fft.as_view();
thread::scope(|s| {
let produce_multi_bit_fourier_ggsw = |thread_id: usize, tx: mpsc::Sender<usize>| {
let mut buffers = ComputationBuffers::new();
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
let mut std_ggsw_buffer = GgswCiphertext::new(
Scalar::ZERO,
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
multi_bit_bsk.decomposition_base_log(),
multi_bit_bsk.decomposition_level_count(),
multi_bit_bsk.ciphertext_modulus(),
);
let mut tmp_ggsw_buffer = GgswCiphertext::new(
Scalar::ZERO,
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
multi_bit_bsk.decomposition_base_log(),
multi_bit_bsk.decomposition_level_count(),
multi_bit_bsk.ciphertext_modulus(),
);
let work_queue = &work_queue;
let dest_idx = thread_id;
let (ready_for_consumer_lock, condvar, fourier_ggsw_buffer) =
&fourier_multi_bit_ggsw_buffers[dest_idx];
loop {
let maybe_work = {
let mut queue_lock = work_queue.lock().unwrap();
queue_lock.pop()
};
let Some((lwe_mask_elements, ggsw_group)) = maybe_work else {break};
let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();
// Wait while the buffer is not ready for processing and wait on the condvar
// to get notified when we can start processing again
while *ready_for_consumer {
ready_for_consumer = condvar.wait(ready_for_consumer).unwrap();
}
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
std_prepare_multi_bit_ggsw(
&mut std_ggsw_buffer,
&mut tmp_ggsw_buffer,
&ggsw_group,
lwe_mask_elements,
);
fourier_ggsw_buffer.as_mut_view().fill_with_forward_fourier(
std_ggsw_buffer.as_view(),
fft,
buffers.stack(),
);
// Drop the lock before we wake other threads
drop(fourier_ggsw_buffer);
*ready_for_consumer = true;
tx.send(dest_idx).unwrap();
// Wake threads waiting on the condvar
condvar.notify_all();
}
};
let threads: Vec<_> = (0..thread_count.0)
.map(|id| {
let tx = tx.clone();
s.spawn(move || produce_multi_bit_fourier_ggsw(id, tx))
})
.collect();
// We initialize ct0 for the successive external products
let ct0 = accumulator;
let mut ct1 = GlweCiphertext::new(
Scalar::ZERO,
ct0.glwe_size(),
ct0.polynomial_size(),
ct0.ciphertext_modulus(),
);
let ct1 = &mut ct1;
let mut buffers = ComputationBuffers::new();
buffers.resize(
add_external_product_assign_scratch::<Scalar>(
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let mut src_idx = 1usize;
for _ in 0..multi_bit_bsk.multi_bit_input_lwe_dimension().0 {
src_idx ^= 1;
let idx = rx.recv().unwrap();
let (ready_lock, condvar, multi_bit_fourier_ggsw) =
&fourier_multi_bit_ggsw_buffers[idx];
let (src_ct, mut dst_ct) = if src_idx == 0 {
(ct0.as_view(), ct1.as_mut_view())
} else {
(ct1.as_view(), ct0.as_mut_view())
};
dst_ct.as_mut().fill(Scalar::ZERO);
let mut ready = ready_lock.lock().unwrap();
assert!(*ready);
let multi_bit_fourier_ggsw = multi_bit_fourier_ggsw.lock().unwrap();
add_external_product_assign(
dst_ct,
multi_bit_fourier_ggsw.as_view(),
src_ct,
fft,
buffers.stack(),
);
drop(multi_bit_fourier_ggsw);
*ready = false;
// Wake a single producer thread sleeping on the condvar (only one will get to work
// anyways)
condvar.notify_one();
}
if src_idx == 0 {
ct0.as_mut().copy_from_slice(ct1.as_ref());
}
let ciphertext_modulus = ct0.ciphertext_modulus();
if !ciphertext_modulus.is_native_modulus() {
// When we convert back from the fourier domain, integer values will contain up to 53
// MSBs with information. In our representation of power of 2 moduli < native modulus we
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
// round while keeping the data in the MSBs
let signed_decomposer = SignedDecomposer::new(
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
DecompositionLevelCount(1),
);
ct0.as_mut()
.iter_mut()
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
}
threads.into_iter().for_each(|t| t.join().unwrap());
});
}
pub fn std_multi_bit_programmable_bootstrap_lwe_ciphertext<
Scalar,
InputCont,
OutputCont,
AccCont,
KeyCont,
>(
input: &LweCiphertext<InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
multi_bit_bsk: &LweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync + Send,
InputCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
AccCont: Container<Element = Scalar>,
KeyCont: Container<Element = Scalar> + Sync,
{
assert_eq!(
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
"Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \
FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
);
assert_eq!(
output.lwe_size().to_lwe_dimension(),
multi_bit_bsk.output_lwe_dimension(),
"Mimatched output LweDimension. LweCiphertext output LweDimension {:?}. \
FourierLweMultiBitBootstrapKey output LweDimension {:?}.",
output.lwe_size().to_lwe_dimension(),
multi_bit_bsk.output_lwe_dimension(),
);
assert_eq!(
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
"Mimatched GlweSize. Accumulator GlweSize {:?}. \
FourierLweMultiBitBootstrapKey GlweSize {:?}.",
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
);
assert_eq!(
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
"Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \
FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
);
assert_eq!(
input.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
"Mimatched CiphertextModulus. LweCiphertext CiphertextModulus {:?}. \
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
input.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
);
assert_eq!(
accumulator.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
"Mimatched CiphertextModulus. Accumulator CiphertextModulus {:?}. \
LweMultiBitBootstrapKey CiphertextModulus {:?}.",
accumulator.ciphertext_modulus(),
multi_bit_bsk.ciphertext_modulus(),
);
let mut local_accumulator = GlweCiphertext::new(
Scalar::ZERO,
accumulator.glwe_size(),
accumulator.polynomial_size(),
accumulator.ciphertext_modulus(),
);
local_accumulator
.as_mut()
.copy_from_slice(accumulator.as_ref());
std_multi_bit_blind_rotate_assign(input, &mut local_accumulator, multi_bit_bsk, thread_count);
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}

View File

@@ -349,8 +349,6 @@ pub fn add_external_product_assign<Scalar, OutputGlweCont, InputGlweCont, GgswCo
///
/// # Example
///
/// # Example
///
/// ```
/// use tfhe::core_crypto::prelude::*;
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
@@ -487,6 +485,8 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
InputGlweCont: Container<Element = Scalar>,
{
assert_eq!(out.ciphertext_modulus(), glwe.ciphertext_modulus());
let ciphertext_modulus = out.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
impl_add_external_product_assign(
out.as_mut_view(),
@@ -496,7 +496,6 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
stack,
);
let ciphertext_modulus = out.ciphertext_modulus();
if !ciphertext_modulus.is_native_modulus() {
// When we convert back from the fourier domain, integer values will contain up to 53
// MSBs with information. In our representation of power of 2 moduli < native modulus we
@@ -776,6 +775,8 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
GgswCont: Container<Element = c64>,
{
assert_eq!(ct0.ciphertext_modulus(), ct1.ciphertext_modulus());
let ciphertext_modulus = ct0.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
cmux(
ct0.as_mut_view(),
@@ -785,7 +786,6 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
stack,
);
let ciphertext_modulus = ct0.ciphertext_modulus();
if !ciphertext_modulus.is_native_modulus() {
// When we convert back from the fourier domain, integer values will contain up to 53
// MSBs with information. In our representation of power of 2 moduli < native modulus we

View File

@@ -0,0 +1,86 @@
//! Miscellaneous algorithms.
use crate::core_crypto::prelude::*;
#[inline]
pub fn divide_round_to_u128<Scalar>(numerator: Scalar, denominator: Scalar) -> u128
where
Scalar: UnsignedInteger,
{
let numerator_128: u128 = numerator.cast_into();
let half_denominator: u128 = (denominator / Scalar::TWO).cast_into();
let denominator_128: u128 = denominator.cast_into();
// That's the rounding
(numerator_128 + half_denominator) / denominator_128
}
#[inline]
pub fn divide_round_to_u128_custom_mod<Scalar>(
numerator: Scalar,
denominator: Scalar,
modulus: u128,
) -> u128
where
Scalar: UnsignedInteger,
{
let numerator_128: u128 = numerator.cast_into();
let half_denominator: u128 = (denominator / Scalar::TWO).cast_into();
let denominator_128: u128 = denominator.cast_into();
// That's the rounding
((numerator_128 + half_denominator) % modulus) / denominator_128
}
pub fn odd_modular_inverse_pow_2<Scalar>(odd_value_to_invert: Scalar, log2_modulo: usize) -> Scalar
where
Scalar: UnsignedInteger,
{
let t = log2_modulo.ilog2() + if log2_modulo.is_power_of_two() { 0 } else { 1 };
let mut y = Scalar::ONE;
let e = odd_value_to_invert;
for i in 1..=t {
// 1 << (1 << i) == 2 ^ {2 ^ i}
let curr_mod = Scalar::ONE.shl(1 << i);
// y = y * (2 - y * e) mod 2 ^ {2 ^ i}
// Here using wrapping ops is ok as the modulus used is a power of 2, as long as 2 ^ {2 ^ i}
// is smaller than Scalar::BITS, we are good to go, the discarded values would not have been
// Used anyways, and 2 ^ {2 ^ i} is compatible with a native modulus
y = (y.wrapping_mul(Scalar::TWO.wrapping_sub(y.wrapping_mul(e)))).wrapping_rem(curr_mod);
}
y.wrapping_rem(Scalar::ONE.shl(log2_modulo))
}
#[test]
fn test_divide_round() {
use rand::Rng;
let mut rng = rand::thread_rng();
const NB_TESTS: usize = 1_000_000_000;
const SCALING: f64 = u64::MAX as f64;
for _ in 0..NB_TESTS {
let num: f64 = rng.gen();
let mut denom = 0.0f64;
while denom == 0.0f64 {
denom = rng.gen();
}
let num = (num * SCALING).round();
let denom = (denom * SCALING).round();
let rounded = (num / denom).round();
let expected_rounded_u64: u64 = rounded as u64;
let num_u64: u64 = num as u64;
let denom_u64: u64 = denom as u64;
// sanity check
assert_eq!(num, num_u64 as f64);
assert_eq!(denom, denom_u64 as f64);
let rounded_u128 = divide_round_to_u128(num_u64, denom_u64);
assert_eq!(expected_rounded_u64, rounded_u128 as u64);
}
}

View File

@@ -9,6 +9,8 @@ pub mod glwe_sample_extraction;
pub mod glwe_secret_key_generation;
pub mod lwe_bootstrap_key_conversion;
pub mod lwe_bootstrap_key_generation;
pub mod lwe_compact_ciphertext_list_expansion;
pub mod lwe_compact_public_key_generation;
pub mod lwe_encryption;
pub mod lwe_keyswitch;
pub mod lwe_keyswitch_key_generation;
@@ -22,6 +24,7 @@ pub mod lwe_programmable_bootstrapping;
pub mod lwe_public_key_generation;
pub mod lwe_secret_key_generation;
pub mod lwe_wopbs;
pub mod misc;
pub mod polynomial_algorithms;
pub mod seeded_ggsw_ciphertext_decompression;
pub mod seeded_ggsw_ciphertext_list_decompression;
@@ -30,7 +33,9 @@ pub mod seeded_glwe_ciphertext_list_decompression;
pub mod seeded_lwe_bootstrap_key_decompression;
pub mod seeded_lwe_ciphertext_decompression;
pub mod seeded_lwe_ciphertext_list_decompression;
pub mod seeded_lwe_compact_public_key_decompression;
pub mod seeded_lwe_keyswitch_key_decompression;
pub mod seeded_lwe_multi_bit_bootstrap_key_decompression;
pub mod seeded_lwe_public_key_decompression;
pub mod slice_algorithms;
@@ -46,6 +51,8 @@ pub use glwe_sample_extraction::*;
pub use glwe_secret_key_generation::*;
pub use lwe_bootstrap_key_conversion::*;
pub use lwe_bootstrap_key_generation::*;
pub use lwe_compact_ciphertext_list_expansion::*;
pub use lwe_compact_public_key_generation::*;
pub use lwe_encryption::*;
pub use lwe_keyswitch::*;
pub use lwe_keyswitch_key_generation::*;
@@ -66,5 +73,7 @@ pub use seeded_glwe_ciphertext_list_decompression::*;
pub use seeded_lwe_bootstrap_key_decompression::*;
pub use seeded_lwe_ciphertext_decompression::*;
pub use seeded_lwe_ciphertext_list_decompression::*;
pub use seeded_lwe_compact_public_key_decompression::*;
pub use seeded_lwe_keyswitch_key_decompression::*;
pub use seeded_lwe_multi_bit_bootstrap_key_decompression::*;
pub use seeded_lwe_public_key_decompression::*;

View File

@@ -259,6 +259,124 @@ pub fn polynomial_wrapping_monic_monomial_mul_assign<Scalar, OutputCont>(
.for_each(|a| *a = a.wrapping_neg());
}
/// Divides (mod $(X^{N}+1)$), the input polynomial with a monic monomial of a given degree i.e.
/// $X^{degree}$.
///
/// # Note
///
/// Computations wrap around (similar to computing modulo $2^{n\_{bits}}$) when exceeding the
/// unsigned integer capacity.
///
/// # Examples
///
/// ```
/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*;
/// use tfhe::core_crypto::commons::parameters::*;
/// use tfhe::core_crypto::entities::*;
/// let input = Polynomial::from_container(vec![1u8, 2, 3]);
/// let mut output = Polynomial::from_container(vec![0, 0, 0]);
/// polynomial_wrapping_monic_monomial_div(&mut output, &input, MonomialDegree(2));
/// assert_eq!(output.as_ref(), &[3, 255, 254]);
/// ```
pub fn polynomial_wrapping_monic_monomial_div<Scalar, OutputCont, InputCont>(
output: &mut Polynomial<OutputCont>,
input: &Polynomial<InputCont>,
monomial_degree: MonomialDegree,
) where
Scalar: UnsignedInteger,
OutputCont: ContainerMut<Element = Scalar>,
InputCont: Container<Element = Scalar>,
{
assert!(
output.polynomial_size() == input.polynomial_size(),
"Output polynomial size {:?} is not the same as input polynomial size {:?}.",
output.polynomial_size(),
input.polynomial_size(),
);
let polynomial_len = output.container_len();
let remaining_degree = monomial_degree.0 % output.as_ref().container_len();
let src_slice = &input[remaining_degree..];
let src_slice_len = src_slice.len();
let dst_slice = &mut output[..src_slice_len];
dst_slice.copy_from_slice(src_slice);
for (dst, &src) in output[polynomial_len - remaining_degree..]
.iter_mut()
.zip(input[..remaining_degree].iter())
{
*dst = src.wrapping_neg();
}
let full_cycles_count = monomial_degree.0 / polynomial_len;
if full_cycles_count % 2 != 0 {
output
.as_mut()
.iter_mut()
.for_each(|a| *a = a.wrapping_neg());
}
}
/// Multiply (mod $(X^{N}+1)$), the input polynomial with a monic monomial of a given degree i.e.
/// $X^{degree}$.
///
/// # Note
///
/// Computations wrap around (similar to computing modulo $2^{n\_{bits}}$) when exceeding the
/// unsigned integer capacity.
///
/// # Examples
///
/// ```
/// use tfhe::core_crypto::algorithms::polynomial_algorithms::*;
/// use tfhe::core_crypto::commons::parameters::*;
/// use tfhe::core_crypto::entities::*;
/// let input = Polynomial::from_container(vec![1u8, 2, 3]);
/// let mut output = Polynomial::from_container(vec![0, 0, 0]);
/// polynomial_wrapping_monic_monomial_mul(&mut output, &input, MonomialDegree(2));
/// assert_eq!(output.as_ref(), &[254, 253, 1]);
/// ```
pub fn polynomial_wrapping_monic_monomial_mul<Scalar, OutputCont, InputCont>(
output: &mut Polynomial<OutputCont>,
input: &Polynomial<InputCont>,
monomial_degree: MonomialDegree,
) where
Scalar: UnsignedInteger,
OutputCont: ContainerMut<Element = Scalar>,
InputCont: Container<Element = Scalar>,
{
assert!(
output.polynomial_size() == input.polynomial_size(),
"Output polynomial size {:?} is not the same as input polynomial size {:?}.",
output.polynomial_size(),
input.polynomial_size(),
);
let polynomial_len = output.container_len();
let remaining_degree = monomial_degree.0 % output.as_ref().container_len();
for (dst, &src) in output[..remaining_degree]
.iter_mut()
.zip(input[polynomial_len - remaining_degree..].iter())
{
*dst = src.wrapping_neg();
}
let dst_slice = &mut output[remaining_degree..];
let dst_slice_len = dst_slice.len();
let src_slice = &input[..dst_slice_len];
dst_slice.copy_from_slice(src_slice);
let full_cycles_count = monomial_degree.0 / polynomial_len;
if full_cycles_count % 2 != 0 {
output
.as_mut()
.iter_mut()
.for_each(|a| *a = a.wrapping_neg());
}
}
/// Subtract the sum of the element-wise product between two lists of polynomials, to the output
/// polynomial.
///

View File

@@ -34,13 +34,14 @@ pub fn decompress_seeded_glwe_ciphertext_with_existing_generator<
let (mut output_mask, mut output_body) = output_glwe.get_mut_mask_and_body();
let ciphertext_modulus = output_mask.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
// generate a uniformly random mask
generator.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
output_body

View File

@@ -32,6 +32,7 @@ pub fn decompress_seeded_glwe_ciphertext_list_with_existing_generator<
);
let ciphertext_modulus = output_list.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
for (mut glwe_out, body_in) in output_list.iter_mut().zip(input_seeded_list.iter()) {
let (mut output_mask, mut output_body) = glwe_out.get_mut_mask_and_body();
@@ -42,7 +43,7 @@ pub fn decompress_seeded_glwe_ciphertext_list_with_existing_generator<
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
output_body.as_mut().copy_from_slice(body_in.as_ref());

View File

@@ -26,6 +26,7 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
);
let ciphertext_modulus = output_lwe.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let (mut output_mask, output_body) = output_lwe.get_mut_mask_and_body();
// generate a uniformly random mask
@@ -33,7 +34,7 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
*output_body.data = *input_seeded_lwe.get_body().data;

View File

@@ -5,6 +5,8 @@ use crate::core_crypto::commons::math::random::RandomGenerator;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
/// Convenience function to share the core logic of the decompression algorithm for
/// [`SeededLweCiphertextList`] between all functions needing it.
pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
@@ -32,6 +34,7 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
);
let ciphertext_modulus = output_list.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
for (mut lwe_out, body_in) in output_list.iter_mut().zip(input_seeded_list.iter()) {
let (mut output_mask, output_body) = lwe_out.get_mut_mask_and_body();
@@ -42,13 +45,96 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_scaling_to_native_torus(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
*output_body.data = *body_in.data;
}
}
/// Convenience function to share the core logic of the decompression algorithm for
/// [`SeededLweCiphertextList`] between all functions needing it.
pub fn par_decompress_seeded_lwe_ciphertext_list_with_existing_generator<
Scalar,
InputCont,
OutputCont,
Gen,
>(
output_list: &mut LweCiphertextList<OutputCont>,
input_seeded_list: &SeededLweCiphertextList<InputCont>,
generator: &mut RandomGenerator<Gen>,
) where
Scalar: UnsignedTorus + Send + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ParallelByteRandomGenerator,
{
assert_eq!(
output_list.ciphertext_modulus(),
input_seeded_list.ciphertext_modulus(),
"Mismatched CiphertextModulus \
between input SeededLweCiphertextList ({:?}) and output LweCiphertextList ({:?})",
input_seeded_list.ciphertext_modulus(),
output_list.ciphertext_modulus(),
);
let ciphertext_modulus = output_list.ciphertext_modulus();
let bytes_per_child = (Scalar::BITS / 8) * (output_list.lwe_size().to_lwe_dimension().0);
let generators = generator
.par_try_fork(output_list.lwe_ciphertext_count().0, bytes_per_child)
.expect("Failed to fork generator");
output_list
.par_iter_mut()
.zip(input_seeded_list.par_iter())
.zip(generators)
.for_each(|((mut lwe_out, body_in), mut generator)| {
let (mut output_mask, output_body) = lwe_out.get_mut_mask_and_body();
// generate a uniformly random mask
generator.fill_slice_with_random_uniform_custom_mod(
output_mask.as_mut(),
ciphertext_modulus,
);
if !ciphertext_modulus.is_native_modulus() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
);
}
*output_body.data = *body_in.data;
});
}
/// Decompress a [`SeededLweCiphertextList`], without consuming it, into a standard
/// [`LweCiphertextList`].
pub fn par_decompress_seeded_lwe_ciphertext_list<Scalar, InputCont, OutputCont, Gen>(
output_list: &mut LweCiphertextList<OutputCont>,
input_seeded_list: &SeededLweCiphertextList<InputCont>,
) where
Scalar: UnsignedTorus + Send + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ParallelByteRandomGenerator,
{
assert_eq!(
output_list.ciphertext_modulus(),
input_seeded_list.ciphertext_modulus(),
"Mismatched CiphertextModulus \
between input SeededLweCiphertextList ({:?}) and output LweCiphertextList ({:?})",
input_seeded_list.ciphertext_modulus(),
output_list.ciphertext_modulus(),
);
let mut generator = RandomGenerator::<Gen>::new(input_seeded_list.compression_seed().seed);
par_decompress_seeded_lwe_ciphertext_list_with_existing_generator::<_, _, _, Gen>(
output_list,
input_seeded_list,
&mut generator,
)
}
/// Decompress a [`SeededLweCiphertextList`], without consuming it, into a standard
/// [`LweCiphertextList`].
pub fn decompress_seeded_lwe_ciphertext_list<Scalar, InputCont, OutputCont, Gen>(

View File

@@ -0,0 +1,49 @@
//! Module with primitives pertaining to [`SeededLweCompactPublicKey`] decompression.
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::math::random::RandomGenerator;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
/// Convenience function to share the core logic of the decompression algorithm for
/// [`SeededLweCompactPublicKey`] between all functions needing it.
pub fn decompress_seeded_lwe_compact_public_key_with_existing_generator<
Scalar,
InputCont,
OutputCont,
Gen,
>(
output_cpk: &mut LweCompactPublicKey<OutputCont>,
input_seeded_cpk: &SeededLweCompactPublicKey<InputCont>,
generator: &mut RandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
decompress_seeded_glwe_ciphertext_with_existing_generator(
&mut output_cpk.as_mut_glwe_ciphertext(),
&input_seeded_cpk.as_seeded_glwe_ciphertext(),
generator,
);
}
/// Decompress a [`SeededLweCompactPublicKey`], without consuming it, into a standard
/// [`LweCompactPublicKey`].
pub fn decompress_seeded_lwe_compact_public_key<Scalar, InputCont, OutputCont, Gen>(
output_cpk: &mut LweCompactPublicKey<OutputCont>,
input_seeded_cpk: &SeededLweCompactPublicKey<InputCont>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
let mut generator = RandomGenerator::<Gen>::new(input_seeded_cpk.compression_seed().seed);
decompress_seeded_lwe_compact_public_key_with_existing_generator::<_, _, _, Gen>(
output_cpk,
input_seeded_cpk,
&mut generator,
)
}

View File

@@ -0,0 +1,63 @@
//! Module with primitives pertaining to [`SeededLweMultiBitBootstrapKey`] decompression.
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::math::random::RandomGenerator;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
/// Convenience function to share the core logic of the decompression algorithm for
/// [`SeededLweMultiBitBootstrapKey`] between all functions needing it.
pub fn decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator<
Scalar,
InputCont,
OutputCont,
Gen,
>(
output_bsk: &mut LweMultiBitBootstrapKey<OutputCont>,
input_bsk: &SeededLweMultiBitBootstrapKey<InputCont>,
generator: &mut RandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert_eq!(
output_bsk.ciphertext_modulus(),
input_bsk.ciphertext_modulus(),
"Mismatched CiphertextModulus \
between input SeededLweMultiBitBootstrapKey ({:?}) and output LweMultiBitBootstrapKey ({:?})",
input_bsk.ciphertext_modulus(),
output_bsk.ciphertext_modulus(),
);
decompress_seeded_ggsw_ciphertext_list_with_existing_generator(output_bsk, input_bsk, generator)
}
/// Decompress a [`SeededLweMultiBitBootstrapKey`], without consuming it, into a standard
/// [`LweMultiBitBootstrapKey`].
pub fn decompress_seeded_lwe_multi_bit_bootstrap_key<Scalar, InputCont, OutputCont, Gen>(
output_bsk: &mut LweMultiBitBootstrapKey<OutputCont>,
input_bsk: &SeededLweMultiBitBootstrapKey<InputCont>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert_eq!(
output_bsk.ciphertext_modulus(),
input_bsk.ciphertext_modulus(),
"Mismatched CiphertextModulus \
between input SeededLweMultiBitBootstrapKey ({:?}) and output LweMultiBitBootstrapKey ({:?})",
input_bsk.ciphertext_modulus(),
output_bsk.ciphertext_modulus(),
);
let mut generator = RandomGenerator::<Gen>::new(input_bsk.compression_seed().seed);
decompress_seeded_lwe_multi_bit_bootstrap_key_with_existing_generator::<_, _, _, Gen>(
output_bsk,
input_bsk,
&mut generator,
)
}

View File

@@ -32,3 +32,31 @@ pub fn decompress_seeded_lwe_public_key<Scalar, InputCont, OutputCont, Gen>(
&mut generator,
);
}
/// Decompress a [`SeededLwePublicKey`], without consuming it, into a standard
/// [`LwePublicKey`] using mutliple threads.
pub fn par_decompress_seeded_lwe_public_key<Scalar, InputCont, OutputCont, Gen>(
output_pk: &mut LwePublicKey<OutputCont>,
input_pk: &SeededLwePublicKey<InputCont>,
) where
Scalar: UnsignedTorus + Send + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ParallelByteRandomGenerator,
{
assert_eq!(
output_pk.ciphertext_modulus(),
input_pk.ciphertext_modulus(),
"Mismatched CiphertextModulus \
between input SeededLwePublicKey ({:?}) and output LwePublicKey ({:?})",
output_pk.ciphertext_modulus(),
input_pk.ciphertext_modulus(),
);
let mut generator = RandomGenerator::<Gen>::new(input_pk.compression_seed().seed);
par_decompress_seeded_lwe_ciphertext_list_with_existing_generator::<_, _, _, Gen>(
output_pk,
input_pk,
&mut generator,
);
}

View File

@@ -1,6 +1,8 @@
//! Module providing algorithms to perform computations on raw slices.
use crate::core_crypto::algorithms::polynomial_algorithms::polynomial_wrapping_add_mul_assign;
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::entities::Polynomial;
/// Compute a dot product between two slices containing unsigned integers.
///
@@ -309,3 +311,54 @@ where
lhs.iter_mut()
.for_each(|lhs| *lhs = (*lhs).wrapping_div(rhs));
}
/// Primitive for compact LWE public key
///
/// Here $i$ from section 3 of <https://eprint.iacr.org/2023/603> is taken equal to $n$.
/// ```
/// use tfhe::core_crypto::algorithms::slice_algorithms::*;
/// let lhs = vec![1u8, 2u8, 3u8];
/// let rhs = vec![4u8, 5u8, 6u8];
/// let mut output = vec![0u8; 3];
/// slice_semi_reverse_negacyclic_convolution(&mut output, &lhs, &rhs);
/// assert_eq!(&output, &[(-17i8) as u8, 5, 32]);
/// ```
pub fn slice_semi_reverse_negacyclic_convolution<Scalar>(
output: &mut [Scalar],
lhs: &[Scalar],
rhs: &[Scalar],
) where
Scalar: UnsignedInteger,
{
assert!(
lhs.len() == rhs.len(),
"lhs (len: {}) and rhs (len: {}) must have the same length",
lhs.len(),
rhs.len()
);
assert!(
output.len() == lhs.len(),
"output (len: {}) and lhs (len: {}) must have the same length",
output.len(),
lhs.len()
);
// Apply phi_1 to the rhs term
let mut phi_1_rhs: Vec<_> = rhs.to_vec();
phi_1_rhs.reverse();
let phi_1_rhs_as_polynomial = Polynomial::from_container(phi_1_rhs.as_slice());
// Clear output as we'll add the multiplication result
output.fill(Scalar::ZERO);
let mut output_as_polynomial = Polynomial::from_container(output);
let lhs_as_polynomial = Polynomial::from_container(lhs);
// Apply the classic negacyclic convolution via polynomial mul in the X^N + 1 ring, with the
// phi_1 rhs it is equivalent to the operator we need
polynomial_wrapping_add_mul_assign(
&mut output_as_polynomial,
&lhs_as_polynomial,
&phi_1_rhs_as_polynomial,
);
}

View File

@@ -0,0 +1,83 @@
use super::*;
use crate::core_crypto::commons::generators::{
DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator,
};
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
fn test_seeded_lwe_cpk_gen_equivalence<Scalar: UnsignedTorus>(
ciphertext_modulus: CiphertextModulus<Scalar>,
) {
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
// computations
// Define parameters for LweCompactPublicKey creation
let lwe_dimension = LweDimension(1024);
let lwe_modular_std_dev = StandardDev(0.00000004990272175010415);
// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mask_seed = seeder.seed();
let deterministic_seeder_seed = seeder.seed();
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
const NB_TEST: usize = 10;
for _ in 0..NB_TEST {
// Create the LweSecretKey
let input_lwe_secret_key =
allocate_and_generate_new_binary_lwe_secret_key(lwe_dimension, &mut secret_generator);
let mut cpk = LweCompactPublicKey::new(Scalar::ZERO, lwe_dimension, ciphertext_modulus);
let mut deterministic_seeder =
DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed);
let mut encryption_generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
mask_seed,
&mut deterministic_seeder,
);
generate_lwe_compact_public_key(
&input_lwe_secret_key,
&mut cpk,
lwe_modular_std_dev,
&mut encryption_generator,
);
assert!(check_content_respects_mod(&cpk, ciphertext_modulus));
let mut seeded_cpk = SeededLweCompactPublicKey::new(
Scalar::ZERO,
lwe_dimension,
mask_seed.into(),
ciphertext_modulus,
);
let mut deterministic_seeder =
DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed);
generate_seeded_lwe_compact_public_key(
&input_lwe_secret_key,
&mut seeded_cpk,
lwe_modular_std_dev,
&mut deterministic_seeder,
);
assert!(check_content_respects_mod(&seeded_cpk, ciphertext_modulus));
let decompressed_cpk = seeded_cpk.decompress_into_lwe_compact_public_key();
assert_eq!(cpk, decompressed_cpk);
}
}
#[test]
fn test_seeded_lwe_cpk_gen_equivalence_u32_native_mod() {
test_seeded_lwe_cpk_gen_equivalence::<u32>(CiphertextModulus::new_native())
}
#[test]
fn test_seeded_lwe_cpk_gen_equivalence_u64_naive_mod() {
test_seeded_lwe_cpk_gen_equivalence::<u64>(CiphertextModulus::new_native())
}

View File

@@ -827,3 +827,67 @@ fn test_u128_encryption() {
}
}
}
fn lwe_compact_public_encrypt_decrypt_custom_mod<Scalar: UnsignedTorus>(
params: TestParams<Scalar>,
) {
let lwe_dimension = LweDimension(params.polynomial_size.0);
let glwe_modular_std_dev = params.glwe_modular_std_dev;
let ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let mut rsc = TestResources::new();
const NB_TESTS: usize = 10;
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
let mut msg = msg_modulus;
let delta: Scalar = encoding_with_padding / msg_modulus;
while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);
for _ in 0..NB_TESTS {
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut rsc.secret_random_generator,
);
let pk = allocate_and_generate_new_lwe_compact_public_key(
&lwe_sk,
glwe_modular_std_dev,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
let mut ct = LweCiphertext::new(
Scalar::ZERO,
lwe_dimension.to_lwe_size(),
ciphertext_modulus,
);
let plaintext = Plaintext(msg * delta);
encrypt_lwe_ciphertext_with_compact_public_key(
&pk,
&mut ct,
plaintext,
glwe_modular_std_dev,
glwe_modular_std_dev,
&mut rsc.secret_random_generator,
&mut rsc.encryption_random_generator,
);
assert!(check_content_respects_mod(&ct, ciphertext_modulus));
let decrypted = decrypt_lwe_ciphertext(&lwe_sk, &ct);
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
assert_eq!(msg, decoded);
}
}
}
create_parametrized_test!(lwe_compact_public_encrypt_decrypt_custom_mod {
TEST_PARAMS_4_BITS_NATIVE_U64
});

View File

@@ -0,0 +1,176 @@
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::dispersion::StandardDev;
use crate::core_crypto::commons::generators::{DeterministicSeeder, EncryptionRandomGenerator};
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, Seed};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension,
LweBskGroupingFactor, LweDimension, PolynomialSize,
};
use crate::core_crypto::commons::test_tools::new_secret_random_generator;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::CastFrom;
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence<
T: UnsignedTorus + CastFrom<usize> + Sync + Send,
>(
ciphertext_modulus: CiphertextModulus<T>,
) {
for _ in 0..10 {
let mut lwe_dim =
LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
let glwe_dim =
GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
let poly_size =
PolynomialSize(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
let level = DecompositionLevelCount(
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
);
let base_log = DecompositionBaseLog(
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
);
let grouping_factor = LweBskGroupingFactor(
crate::core_crypto::commons::test_tools::random_usize_between(2..4),
);
let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
let deterministic_seeder_seed =
Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
while lwe_dim.0 % grouping_factor.0 != 0 {
lwe_dim = LweDimension(lwe_dim.0 + 1);
}
let mut secret_generator = new_secret_random_generator();
let lwe_sk =
allocate_and_generate_new_binary_lwe_secret_key(lwe_dim, &mut secret_generator);
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dim,
poly_size,
&mut secret_generator,
);
let mut parallel_multi_bit_bsk = LweMultiBitBootstrapKeyOwned::new(
T::ZERO,
glwe_dim.to_glwe_size(),
poly_size,
base_log,
level,
lwe_dim,
grouping_factor,
ciphertext_modulus,
);
let mut encryption_generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
mask_seed,
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
);
par_generate_lwe_multi_bit_bootstrap_key(
&lwe_sk,
&glwe_sk,
&mut parallel_multi_bit_bsk,
StandardDev::from_standard_dev(10.),
&mut encryption_generator,
);
let mut sequential_multi_bit_bsk = LweMultiBitBootstrapKeyOwned::new(
T::ZERO,
glwe_dim.to_glwe_size(),
poly_size,
base_log,
level,
lwe_dim,
grouping_factor,
ciphertext_modulus,
);
let mut encryption_generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
mask_seed,
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
);
generate_lwe_multi_bit_bootstrap_key(
&lwe_sk,
&glwe_sk,
&mut sequential_multi_bit_bsk,
StandardDev::from_standard_dev(10.),
&mut encryption_generator,
);
assert_eq!(parallel_multi_bit_bsk, sequential_multi_bit_bsk);
let mut sequential_seeded_multi_bit_bsk = SeededLweMultiBitBootstrapKey::new(
T::ZERO,
glwe_dim.to_glwe_size(),
poly_size,
base_log,
level,
lwe_dim,
grouping_factor,
mask_seed.into(),
ciphertext_modulus,
);
generate_seeded_lwe_multi_bit_bootstrap_key(
&lwe_sk,
&glwe_sk,
&mut sequential_seeded_multi_bit_bsk,
StandardDev::from_standard_dev(10.),
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
);
let mut parallel_seeded_multi_bit_bsk = SeededLweMultiBitBootstrapKey::new(
T::ZERO,
glwe_dim.to_glwe_size(),
poly_size,
base_log,
level,
lwe_dim,
grouping_factor,
mask_seed.into(),
ciphertext_modulus,
);
par_generate_seeded_lwe_multi_bit_bootstrap_key(
&lwe_sk,
&glwe_sk,
&mut parallel_seeded_multi_bit_bsk,
StandardDev::from_standard_dev(10.),
&mut DeterministicSeeder::<ActivatedRandomGenerator>::new(deterministic_seeder_seed),
);
assert_eq!(
sequential_seeded_multi_bit_bsk,
parallel_seeded_multi_bit_bsk
);
let decompressed_multi_bit_bsk =
sequential_seeded_multi_bit_bsk.decompress_into_lwe_multi_bit_bootstrap_key();
assert_eq!(decompressed_multi_bit_bsk, sequential_multi_bit_bsk);
}
}
#[test]
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u32_native_mod() {
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u32>(CiphertextModulus::new_native());
}
#[test]
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u32_custom_mod() {
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u32>(
CiphertextModulus::try_new_power_of_2(31).unwrap(),
);
}
#[test]
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u64_native_mod() {
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u64>(CiphertextModulus::new_native());
}
#[test]
fn test_parallel_and_seeded_multi_bit_bsk_gen_equivalence_u64_custom_mod() {
test_parallel_and_seeded_multi_bit_bsk_gen_equivalence::<u64>(
CiphertextModulus::try_new_power_of_2(63).unwrap(),
);
}

View File

@@ -146,6 +146,125 @@ fn lwe_encrypt_multi_bit_pbs_decrypt_custom_mod<
}
}
fn lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod<
Scalar: UnsignedTorus + Sync + Send + CastFrom<usize> + CastInto<usize>,
>(
params: MultiBitParams<Scalar>,
) {
let input_lwe_dimension = params.input_lwe_dimension;
let lwe_modular_std_dev = params.lwe_modular_std_dev;
let glwe_modular_std_dev = params.glwe_modular_std_dev;
let ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let msg_modulus = Scalar::ONE.shl(message_modulus_log.0);
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
let glwe_dimension = params.glwe_dimension;
let polynomial_size = params.polynomial_size;
let decomp_base_log = params.decomp_base_log;
let decomp_level_count = params.decomp_level_count;
let grouping_factor = params.grouping_factor;
let thread_count = params.thread_count;
let mut rsc = TestResources::new();
let f = |x: Scalar| {
x.wrapping_mul(Scalar::TWO)
.wrapping_sub(Scalar::ONE)
.wrapping_rem(msg_modulus)
};
let delta: Scalar = encoding_with_padding / msg_modulus;
let mut msg = msg_modulus;
const NB_TESTS: usize = 10;
let accumulator = generate_accumulator(
polynomial_size,
glwe_dimension.to_glwe_size(),
msg_modulus.cast_into(),
ciphertext_modulus,
delta,
f,
);
assert!(check_content_respects_mod(&accumulator, ciphertext_modulus));
// Keygen is a bit slow on this one so we keep it out of the testing loop
// Create the LweSecretKey
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
input_lwe_dimension,
&mut rsc.secret_random_generator,
);
let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut rsc.secret_random_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.clone().into_lwe_secret_key();
let mut bsk = LweMultiBitBootstrapKey::new(
Scalar::ZERO,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_base_log,
decomp_level_count,
input_lwe_dimension,
grouping_factor,
ciphertext_modulus,
);
par_generate_lwe_multi_bit_bootstrap_key(
&input_lwe_secret_key,
&output_glwe_secret_key,
&mut bsk,
glwe_modular_std_dev,
&mut rsc.encryption_random_generator,
);
assert!(check_content_respects_mod(&*bsk, ciphertext_modulus));
while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);
for _ in 0..NB_TESTS {
let plaintext = Plaintext(msg * delta);
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
plaintext,
lwe_modular_std_dev,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);
assert!(check_content_respects_mod(
&lwe_ciphertext_in,
ciphertext_modulus
));
let mut out_pbs_ct = LweCiphertext::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
ciphertext_modulus,
);
std_multi_bit_programmable_bootstrap_lwe_ciphertext(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator,
&bsk,
thread_count,
);
assert!(check_content_respects_mod(&out_pbs_ct, ciphertext_modulus));
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
assert_eq!(decoded, f(msg));
}
}
}
#[test]
pub fn test_lwe_encrypt_multi_bit_pbs_decrypt_factor_2_thread_5_native_mod() {
lwe_encrypt_multi_bit_pbs_decrypt_custom_mod::<u64>(
@@ -229,3 +348,87 @@ pub fn test_lwe_encrypt_multi_bit_pbs_decrypt_factor_3_thread_12_custom_mod() {
},
);
}
#[test]
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_2_thread_5_native_mod() {
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
// correct computations
MultiBitParams {
input_lwe_dimension: LweDimension(786),
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
decomp_base_log: DecompositionBaseLog(23),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
message_modulus_log: CiphertextModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(2),
thread_count: ThreadCount(5),
},
);
}
#[test]
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_3_thread_12_native_mod() {
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
// correct computations
MultiBitParams {
input_lwe_dimension: LweDimension(786),
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
decomp_base_log: DecompositionBaseLog(22),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
message_modulus_log: CiphertextModulusLog(4),
ciphertext_modulus: CiphertextModulus::new_native(),
grouping_factor: LweBskGroupingFactor(3),
thread_count: ThreadCount(12),
},
);
}
#[test]
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_2_thread_5_custom_mod() {
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
// correct computations
MultiBitParams {
input_lwe_dimension: LweDimension(786),
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
decomp_base_log: DecompositionBaseLog(23),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
message_modulus_log: CiphertextModulusLog(3),
ciphertext_modulus: CiphertextModulus::try_new_power_of_2(63).unwrap(),
grouping_factor: LweBskGroupingFactor(2),
thread_count: ThreadCount(5),
},
);
}
#[test]
pub fn test_lwe_encrypt_std_multi_bit_pbs_decrypt_factor_3_thread_12_custom_mod() {
lwe_encrypt_std_multi_bit_pbs_decrypt_custom_mod::<u64>(
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield
// correct computations
MultiBitParams {
input_lwe_dimension: LweDimension(786),
lwe_modular_std_dev: StandardDev(0.0000040164874030685975),
decomp_base_log: DecompositionBaseLog(22),
decomp_level_count: DecompositionLevelCount(1),
glwe_dimension: GlweDimension(2),
polynomial_size: PolynomialSize(1024),
glwe_modular_std_dev: StandardDev(0.0000000000000003152931493498455),
message_modulus_log: CiphertextModulusLog(3),
ciphertext_modulus: CiphertextModulus::try_new_power_of_2(63).unwrap(),
grouping_factor: LweBskGroupingFactor(3),
thread_count: ThreadCount(12),
},
);
}

View File

@@ -4,10 +4,12 @@ use paste::paste;
mod ggsw_encryption;
mod glwe_encryption;
mod lwe_bootstrap_key_generation;
mod lwe_compact_public_key_generation;
mod lwe_encryption;
mod lwe_keyswitch;
mod lwe_keyswitch_key_generation;
mod lwe_linear_algebra;
mod lwe_multi_bit_bootstrap_key_generation;
mod lwe_multi_bit_programmable_bootstrapping;
mod lwe_programmable_bootstrapping;
@@ -137,7 +139,7 @@ pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]
if !modulus.is_native_modulus() {
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
// we want the bits under the mask to be 0
let power_2_diff_mask = modulus.get_scaling_to_native_torus() - Scalar::ONE;
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return input
.as_ref()
.iter()
@@ -153,7 +155,7 @@ pub fn check_scalar_respects_mod<Scalar: UnsignedInteger>(
modulus: CiphertextModulus<Scalar>,
) -> bool {
if !modulus.is_native_modulus() {
let power_2_diff_mask = modulus.get_scaling_to_native_torus() - Scalar::ONE;
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return (input & power_2_diff_mask) == Scalar::ZERO;
}

View File

@@ -121,6 +121,26 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
}
}
pub const fn try_new(modulus: u128) -> Result<Self, &'static str> {
if Scalar::BITS < 128 && modulus > (1 << Scalar::BITS) {
Err("Modulus is bigger than the maximum value of the associated Scalar type")
} else {
let res = match modulus {
0 => CiphertextModulus::new_native(),
modulus => {
let Some(non_zero_modulus) = NonZeroU128::new(modulus) else {
panic!("Got zero modulus for CiphertextModulusInner::Custom variant",)
};
CiphertextModulus {
inner: CiphertextModulusInner::Custom(non_zero_modulus),
_scalar: PhantomData,
}
}
};
Ok(res.canonicalize())
}
}
pub const fn canonicalize(self) -> Self {
match self.inner {
CiphertextModulusInner::Native => self,
@@ -151,10 +171,14 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
res.canonicalize()
}
pub fn get_scaling_to_native_torus(&self) -> Scalar {
pub fn get_power_of_two_scaling_to_native_torus(&self) -> Scalar {
match self.inner {
CiphertextModulusInner::Native => Scalar::ONE,
CiphertextModulusInner::Custom(modulus) => {
assert!(
modulus.is_power_of_two(),
"Cannot get scaling for non power of two modulus {modulus:}"
);
Scalar::ONE.wrapping_shl(Scalar::BITS as u32 - modulus.ilog2())
}
}

View File

@@ -9,7 +9,8 @@ use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{CastInto, UnsignedInteger};
use crate::core_crypto::commons::parameters::{
CiphertextModulus, DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, GlweDimension,
GlweSize, LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweSize, PolynomialSize,
GlweSize, LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweMaskCount, LweSize,
PolynomialSize,
};
use concrete_csprng::generators::ForkError;
use rayon::prelude::*;
@@ -169,6 +170,16 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
self.try_fork(lwe_size.0, mask_bytes, noise_bytes)
}
pub(crate) fn fork_lwe_compact_ciphertext_list_to_bin<T: UnsignedInteger>(
&mut self,
lwe_mask_count: LweMaskCount,
lwe_dimension: LweDimension,
) -> Result<impl Iterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
let mask_bytes = mask_bytes_per_lwe_compact_ciphertext_bin::<T>(lwe_dimension);
let noise_bytes = noise_bytes_per_lwe_compact_ciphertext_bin(lwe_dimension);
self.try_fork(lwe_mask_count.0, mask_bytes, noise_bytes)
}
// Forks both generators into an iterator
fn try_fork(
&mut self,
@@ -431,6 +442,16 @@ impl<G: ParallelByteRandomGenerator> EncryptionRandomGenerator<G> {
self.par_try_fork(lwe_size.0, mask_bytes, noise_bytes)
}
pub(crate) fn par_fork_lwe_compact_ciphertext_list_to_bin<T: UnsignedInteger>(
&mut self,
lwe_mask_count: LweMaskCount,
lwe_dimension: LweDimension,
) -> Result<impl IndexedParallelIterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
let mask_bytes = mask_bytes_per_lwe_compact_ciphertext_bin::<T>(lwe_dimension);
let noise_bytes = noise_bytes_per_lwe_compact_ciphertext_bin(lwe_dimension);
self.par_try_fork(lwe_mask_count.0, mask_bytes, noise_bytes)
}
// Forks both generators into a parallel iterator.
fn par_try_fork(
&mut self,
@@ -504,6 +525,12 @@ fn mask_bytes_per_pfpksk<T: UnsignedInteger>(
lwe_size.0 * mask_bytes_per_pfpksk_chunk::<T>(level, glwe_size, poly_size)
}
fn mask_bytes_per_lwe_compact_ciphertext_bin<T: UnsignedInteger>(
lwe_dimension: LweDimension,
) -> usize {
lwe_dimension.0 * mask_bytes_per_coef::<T>()
}
fn noise_bytes_per_coef() -> usize {
// We use f64 to sample the noise for every precision, and we need 4/pi inputs to generate
// such an output (here we take 32 to keep a safety margin).
@@ -553,6 +580,10 @@ fn noise_bytes_per_pfpksk(
lwe_size.0 * noise_bytes_per_pfpksk_chunk(level, poly_size)
}
fn noise_bytes_per_lwe_compact_ciphertext_bin(lwe_dimension: LweDimension) -> usize {
lwe_dimension.0 * noise_bytes_per_coef()
}
#[cfg(test)]
mod test {
use crate::core_crypto::algorithms::*;
@@ -562,7 +593,7 @@ mod test {
PolynomialSize,
};
use crate::core_crypto::commons::test_tools::{
new_encryption_random_generator, new_secret_random_generator,
new_encryption_random_generator, new_secret_random_generator, normality_test_f64,
};
use crate::core_crypto::commons::traits::UnsignedTorus;
@@ -728,6 +759,119 @@ mod test {
noise_gen_slice_native::<u128>();
}
fn test_normal_random_encryption_native<Scalar: UnsignedTorus>() {
const RUNS: usize = 10000;
const SAMPLES_PER_RUN: usize = 1000;
let mut rng = new_encryption_random_generator();
let failures: f64 = (0..RUNS)
.map(|_| {
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
rng.fill_slice_with_random_noise(&mut samples, StandardDev(f64::powi(2., -20)));
let samples: Vec<f64> = samples
.iter()
.copied()
.map(|x| {
let torus = x.into_torus();
// The upper half of the torus corresponds to the negative domain when
// mapping unsigned integer back to float (MSB or
// sign bit is set)
if torus > 0.5 {
torus - 1.0
} else {
torus
}
})
.collect();
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
// If we are normal return 0, it's not a failure
0.0
} else {
1.0
}
})
.sum::<f64>();
let failure_rate = failures / (RUNS as f64);
println!("failure_rate: {failure_rate}");
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
assert!(failure_rate <= 0.065);
}
#[test]
fn test_normal_random_encryption_native_u32() {
test_normal_random_encryption_native::<u32>();
}
#[test]
fn test_normal_random_encryption_native_u64() {
test_normal_random_encryption_native::<u64>();
}
#[test]
fn test_normal_random_encryption_native_u128() {
test_normal_random_encryption_native::<u128>();
}
fn test_normal_random_encryption_add_assign_native<Scalar: UnsignedTorus>() {
const RUNS: usize = 10000;
const SAMPLES_PER_RUN: usize = 1000;
let mut rng = new_encryption_random_generator();
let failures: f64 = (0..RUNS)
.map(|_| {
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
rng.unsigned_torus_slice_wrapping_add_random_noise_assign(
&mut samples,
StandardDev(f64::powi(2., -20)),
);
let samples: Vec<f64> = samples
.iter()
.copied()
.map(|x| {
let torus = x.into_torus();
// The upper half of the torus corresponds to the negative domain when
// mapping unsigned integer back to float (MSB or
// sign bit is set)
if torus > 0.5 {
torus - 1.0
} else {
torus
}
})
.collect();
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
// If we are normal return 0, it's not a failure
0.0
} else {
1.0
}
})
.sum::<f64>();
let failure_rate = failures / (RUNS as f64);
println!("failure_rate: {failure_rate}");
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
assert!(failure_rate <= 0.065);
}
#[test]
fn test_normal_random_encryption_add_assign_native_u32() {
test_normal_random_encryption_add_assign_native::<u32>();
}
#[test]
fn test_normal_random_encryption_add_assign_native_u64() {
test_normal_random_encryption_add_assign_native::<u64>();
}
#[test]
fn test_normal_random_encryption_add_assign_native_u128() {
test_normal_random_encryption_add_assign_native::<u128>();
}
fn noise_gen_slice_custom_mod<Scalar: UnsignedTorus>(
ciphertext_modulus: CiphertextModulus<Scalar>,
) {
@@ -783,6 +927,170 @@ mod test {
noise_gen_slice_custom_mod::<u128>(CiphertextModulus::new_native());
}
fn test_normal_random_encryption_custom_mod<Scalar: UnsignedTorus>(
ciphertext_modulus: CiphertextModulus<Scalar>,
) {
const RUNS: usize = 10000;
const SAMPLES_PER_RUN: usize = 1000;
let mut rng = new_encryption_random_generator();
let failures: f64 = (0..RUNS)
.map(|_| {
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
rng.fill_slice_with_random_noise_custom_mod(
&mut samples,
StandardDev(f64::powi(2., -20)),
ciphertext_modulus,
);
let samples: Vec<f64> = samples
.iter()
.copied()
.map(|x| {
let torus = x.into_torus();
// The upper half of the torus corresponds to the negative domain when
// mapping unsigned integer back to float (MSB or
// sign bit is set)
if torus > 0.5 {
torus - 1.0
} else {
torus
}
})
.collect();
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
// If we are normal return 0, it's not a failure
0.0
} else {
1.0
}
})
.sum::<f64>();
let failure_rate = failures / (RUNS as f64);
println!("failure_rate: {failure_rate}");
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
assert!(failure_rate <= 0.065);
}
#[test]
fn test_normal_random_encryption_custom_mod_u32() {
test_normal_random_encryption_custom_mod::<u32>(
CiphertextModulus::try_new_power_of_2(31).unwrap(),
);
}
#[test]
fn test_normal_random_encryption_custom_mod_u64() {
test_normal_random_encryption_custom_mod::<u64>(
CiphertextModulus::try_new_power_of_2(63).unwrap(),
);
}
#[test]
fn test_normal_random_encryption_custom_mod_u128() {
test_normal_random_encryption_custom_mod::<u128>(
CiphertextModulus::try_new_power_of_2(127).unwrap(),
);
}
#[test]
fn test_normal_random_encryption_native_custom_mod_u32() {
test_normal_random_encryption_custom_mod::<u32>(CiphertextModulus::new_native());
}
#[test]
fn test_normal_random_encryption_native_custom_mod_u64() {
test_normal_random_encryption_custom_mod::<u64>(CiphertextModulus::new_native());
}
#[test]
fn test_normal_random_encryption_native_custom_mod_u128() {
test_normal_random_encryption_custom_mod::<u128>(CiphertextModulus::new_native());
}
fn test_normal_random_encryption_add_assign_custom_mod<Scalar: UnsignedTorus>(
ciphertext_modulus: CiphertextModulus<Scalar>,
) {
const RUNS: usize = 10000;
const SAMPLES_PER_RUN: usize = 1000;
let mut rng = new_encryption_random_generator();
let failures: f64 = (0..RUNS)
.map(|_| {
let mut samples = vec![Scalar::ZERO; SAMPLES_PER_RUN];
rng.unsigned_torus_slice_wrapping_add_random_noise_custom_mod_assign(
&mut samples,
StandardDev(f64::powi(2., -20)),
ciphertext_modulus,
);
let samples: Vec<f64> = samples
.iter()
.copied()
.map(|x| {
let torus = x.into_torus();
// The upper half of the torus corresponds to the negative domain when
// mapping unsigned integer back to float (MSB or
// sign bit is set)
if torus > 0.5 {
torus - 1.0
} else {
torus
}
})
.collect();
if normality_test_f64(&samples, 0.05).null_hypothesis_is_valid {
// If we are normal return 0, it's not a failure
0.0
} else {
1.0
}
})
.sum::<f64>();
let failure_rate = failures / (RUNS as f64);
println!("failure_rate: {failure_rate}");
// The expected failure rate even on proper gaussian is 5%, so we take a small safety margin
assert!(failure_rate <= 0.065);
}
#[test]
fn test_normal_random_encryption_add_assign_custom_mod_u32() {
test_normal_random_encryption_add_assign_custom_mod::<u32>(
CiphertextModulus::try_new_power_of_2(31).unwrap(),
);
}
#[test]
fn test_normal_random_encryption_add_assign_custom_mod_u64() {
test_normal_random_encryption_add_assign_custom_mod::<u64>(
CiphertextModulus::try_new_power_of_2(63).unwrap(),
);
}
#[test]
fn test_normal_random_encryption_add_assign_custom_mod_u128() {
test_normal_random_encryption_add_assign_custom_mod::<u128>(
CiphertextModulus::try_new_power_of_2(127).unwrap(),
);
}
#[test]
fn test_normal_random_encryption_add_assign_native_custom_mod_u32() {
test_normal_random_encryption_add_assign_custom_mod::<u32>(CiphertextModulus::new_native());
}
#[test]
fn test_normal_random_encryption_add_assign_native_custom_mod_u64() {
test_normal_random_encryption_add_assign_custom_mod::<u64>(CiphertextModulus::new_native());
}
#[test]
fn test_normal_random_encryption_add_assign_native_custom_mod_u128() {
test_normal_random_encryption_add_assign_custom_mod::<u128>(CiphertextModulus::new_native());
}
fn mask_gen_slice_native<Scalar: UnsignedTorus>() {
let mut gen = new_encryption_random_generator();

View File

@@ -1,6 +1,10 @@
use crate::core_crypto::commons::math::decomposition::SignedDecompositionIter;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
use crate::core_crypto::commons::math::decomposition::{
SignedDecompositionIter, SignedDecompositionIterNonNative,
};
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
use crate::core_crypto::prelude::misc::divide_round_to_u128_custom_mod;
use std::marker::PhantomData;
/// A structure which allows to decompose unsigned integers into a set of smaller terms.
@@ -174,3 +178,215 @@ where
}
}
}
/// A structure which allows to decompose unsigned integers into a set of smaller terms for moduli
/// which are non power of 2.
///
/// See the [module level](super) documentation for a description of the signed decomposition.
#[derive(Debug)]
pub struct SignedDecomposerNonNative<Scalar>
where
Scalar: UnsignedInteger,
{
pub(crate) base_log: usize,
pub(crate) level_count: usize,
pub(crate) ciphertext_modulus: CiphertextModulus<Scalar>,
}
impl<Scalar> SignedDecomposerNonNative<Scalar>
where
Scalar: UnsignedInteger,
{
/// Create a new decomposer.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3));
/// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4));
/// ```
pub fn new(
base_log: DecompositionBaseLog,
level_count: DecompositionLevelCount,
ciphertext_modulus: CiphertextModulus<Scalar>,
) -> SignedDecomposerNonNative<Scalar> {
debug_assert!(
Scalar::BITS > base_log.0 * level_count.0,
"Decomposed bits exceeds the size of the integer to be decomposed"
);
SignedDecomposerNonNative {
base_log: base_log.0,
level_count: level_count.0,
ciphertext_modulus,
}
}
/// Return the logarithm in base two of the base of this decomposer.
///
/// If the decomposer uses a base $B=2^b$, this returns $b$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4));
/// ```
pub fn base_log(&self) -> DecompositionBaseLog {
DecompositionBaseLog(self.base_log)
}
/// Return the number of levels of this decomposer.
///
/// If the decomposer uses $l$ levels, this returns $l$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3));
/// ```
pub fn level_count(&self) -> DecompositionLevelCount {
DecompositionLevelCount(self.level_count)
}
/// Return the ciphertext modulus of this decomposer.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// assert_eq!(
/// decomposer.ciphertext_modulus(),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap()
/// );
/// ```
pub fn ciphertext_modulus(&self) -> CiphertextModulus<Scalar> {
self.ciphertext_modulus
}
/// Return the closet value representable by the decomposition.
///
/// For some input integer `k`, decomposition base `B`, decomposition level count `l` and given
/// ciphertext modulus `q` the performed operation is the following:
///
/// $$
/// \lfloor \frac{k\cdot q}{B^{l}} \rceil \cdot B^{l}
/// $$
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// let closest = decomposer.closest_representable(16982820785129133100u64);
/// assert_eq!(closest, 16983074190859960320u64);
/// ```
#[inline]
pub fn closest_representable(&self, input: Scalar) -> Scalar {
let ciphertext_modulus = self.ciphertext_modulus.get_custom_modulus();
// Floored approach
// B^l
let base_to_level_count = 1 << (self.base_log * self.level_count);
// sr = floor(q/(B^l))
let smallest_representable = ciphertext_modulus / base_to_level_count;
let input_128: u128 = input.cast_into();
// rounded = round(input/sr)
let rounded =
divide_round_to_u128_custom_mod(input_128, smallest_representable, ciphertext_modulus);
// rounded * sr
let closest_representable = rounded * smallest_representable;
Scalar::cast_from(closest_representable)
}
/// Generate an iterator over the terms of the decomposition of the input.
///
/// # Warning
///
/// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::numeric::UnsignedInteger;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
///
/// // These two values allow to take each arm of the half basis check below
/// for value in [1u64 << 63, 16982820785129133100u64] {
/// for term in decomposer.decompose(value) {
/// assert!(1 <= term.level().0);
/// assert!(term.level().0 <= 3);
/// let term = term.value();
/// let abs_term = if term < decomposer.ciphertext_modulus().get_custom_modulus() as u64 / 2
/// {
/// term
/// } else {
/// decomposer.ciphertext_modulus().get_custom_modulus() as u64 - term
/// };
/// println!("abs_term: {abs_term}");
/// let half_basis = 2u64.pow(4) / 2u64;
/// println!("half_basis: {half_basis}");
/// assert!(abs_term <= half_basis);
/// }
/// assert_eq!(decomposer.decompose(1).count(), 3);
/// }
/// ```
pub fn decompose(&self, input: Scalar) -> SignedDecompositionIterNonNative<Scalar> {
// Note that there would be no sense of making the decomposition on an input which was
// not rounded to the closest representable first. We then perform it before decomposing.
SignedDecompositionIterNonNative::new(
self.closest_representable(input),
DecompositionBaseLog(self.base_log),
DecompositionLevelCount(self.level_count),
self.ciphertext_modulus,
)
}
}

View File

@@ -1,4 +1,7 @@
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, DecompositionTerm};
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
use crate::core_crypto::commons::math::decomposition::{
DecompositionLevel, DecompositionTerm, DecompositionTermNonNative,
};
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
@@ -122,3 +125,158 @@ fn decompose_one_level<S: UnsignedInteger>(base_log: usize, state: &mut S, mod_b
*state += carry;
res.wrapping_sub(carry << base_log)
}
/// An iterator that yields the terms of the signed decomposition of an integer.
///
/// # Warning
///
/// This iterator yields the decomposition in reverse order. That means that the highest level
/// will be yielded first.
#[derive(Clone, Debug)]
pub struct SignedDecompositionIterNonNative<T>
where
T: UnsignedInteger,
{
// The base log of the decomposition
base_log: usize,
// The number of levels of the decomposition
level_count: usize,
// The internal state of the decomposition
state: T,
// The current level
current_level: usize,
// A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form:
// ...0001111
mod_b_mask: T,
// Ciphertext modulus
ciphertext_modulus: CiphertextModulus<T>,
// A flag which store whether the iterator is a fresh one (for the recompose method)
fresh: bool,
}
impl<T> SignedDecompositionIterNonNative<T>
where
T: UnsignedInteger,
{
pub(crate) fn new(
input: T,
base_log: DecompositionBaseLog,
level: DecompositionLevelCount,
ciphertext_modulus: CiphertextModulus<T>,
) -> SignedDecompositionIterNonNative<T> {
let base_to_the_level = 1 << (base_log.0 * level.0);
let smallest_representable = ciphertext_modulus.get_custom_modulus() / base_to_the_level;
let input_128: u128 = input.cast_into();
let state = T::cast_from(input_128 / smallest_representable);
SignedDecompositionIterNonNative {
base_log: base_log.0,
level_count: level.0,
state,
current_level: level.0,
mod_b_mask: (T::ONE << base_log.0) - T::ONE,
ciphertext_modulus,
fresh: true,
}
}
pub(crate) fn is_fresh(&self) -> bool {
self.fresh
}
/// Return the logarithm in base two of the base of this decomposition.
///
/// If the decomposition uses a base $B=2^b$, this returns $b$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// let val = 9_223_372_036_854_775_808u64;
/// let decomp = decomposer.decompose(val);
/// assert_eq!(decomp.base_log(), DecompositionBaseLog(4));
/// ```
pub fn base_log(&self) -> DecompositionBaseLog {
DecompositionBaseLog(self.base_log)
}
/// Return the number of levels of this decomposition.
///
/// If the decomposition uses $l$ levels, this returns $l$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
/// );
/// let val = 9_223_372_036_854_775_808u64;
/// let decomp = decomposer.decompose(val);
/// assert_eq!(decomp.level_count(), DecompositionLevelCount(3));
/// ```
pub fn level_count(&self) -> DecompositionLevelCount {
DecompositionLevelCount(self.level_count)
}
}
impl<T> Iterator for SignedDecompositionIterNonNative<T>
where
T: UnsignedInteger,
{
type Item = DecompositionTermNonNative<T>;
fn next(&mut self) -> Option<Self::Item> {
// The iterator is not fresh anymore
self.fresh = false;
// We check if the decomposition is over
if self.current_level == 0 {
return None;
}
// We decompose the current level
let output = decompose_one_level_non_native(
self.base_log,
&mut self.state,
self.mod_b_mask,
T::cast_from(self.ciphertext_modulus.get_custom_modulus()),
);
self.current_level -= 1;
// We return the output for this level
Some(DecompositionTermNonNative::new(
DecompositionLevel(self.current_level + 1),
DecompositionBaseLog(self.base_log),
output,
self.ciphertext_modulus,
))
}
}
fn decompose_one_level_non_native<S: UnsignedInteger>(
base_log: usize,
state: &mut S,
mod_b_mask: S,
ciphertext_modulus: S,
) -> S {
let res = *state & mod_b_mask;
*state >>= base_log;
let mut carry = (res.wrapping_sub(S::ONE) | *state) & res;
carry >>= base_log - 1;
*state += carry;
res.wrapping_add(ciphertext_modulus)
.wrapping_sub(carry << base_log)
.wrapping_rem(ciphertext_modulus)
}

View File

@@ -1,3 +1,4 @@
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
use crate::core_crypto::commons::parameters::DecompositionBaseLog;
@@ -91,3 +92,117 @@ where
DecompositionLevel(self.level)
}
}
/// A member of the decomposition.
///
/// If we decompose a value $\theta$ as a sum $\sum\_{i=1}^l\tilde{\theta}\_i\frac{q}{B^i}$, this
/// represents a $\tilde{\theta}\_i$.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct DecompositionTermNonNative<T>
where
T: UnsignedInteger,
{
level: usize,
base_log: usize,
value: T,
ciphertext_modulus: CiphertextModulus<T>,
}
impl<T> DecompositionTermNonNative<T>
where
T: UnsignedInteger,
{
// Creates a new decomposition term.
pub(crate) fn new(
level: DecompositionLevel,
base_log: DecompositionBaseLog,
value: T,
ciphertext_modulus: CiphertextModulus<T>,
) -> DecompositionTermNonNative<T> {
DecompositionTermNonNative {
level: level.0,
base_log: base_log.0,
value,
ciphertext_modulus,
}
}
/// Turn this term into a summand.
///
/// If our member represents one $\tilde{\theta}\_i$ of the decomposition, this method returns
/// $\tilde{\theta}\_i\frac{q}{B^i}$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new(1 << 32).unwrap(),
/// );
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
/// assert_eq!(output.to_recomposition_summand(), 1048576);
/// ```
pub fn to_recomposition_summand(&self) -> T {
// Floored approach
// * floor(q / B^j)
let base_to_the_level = 1 << (self.base_log * self.level);
let digit_radix = self.ciphertext_modulus.get_custom_modulus() / base_to_the_level;
let value_u128: u128 = self.value.cast_into();
let summand_u128 = value_u128 * digit_radix;
T::cast_from(summand_u128)
}
/// Return the value of the term.
///
/// If our member represents one $\tilde{\theta}\_i$, this returns its actual value.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new(1 << 32).unwrap(),
/// );
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
/// assert_eq!(output.value(), 1);
/// ```
pub fn value(&self) -> T {
self.value
}
/// Return the level of the term.
///
/// If our member represents one $\tilde{\theta}\_i$, this returns the value of $i$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::{
/// DecompositionLevel, SignedDecomposerNonNative,
/// };
/// use tfhe::core_crypto::commons::parameters::{
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
/// };
/// let decomposer = SignedDecomposerNonNative::new(
/// DecompositionBaseLog(4),
/// DecompositionLevelCount(3),
/// CiphertextModulus::try_new(1 << 32).unwrap(),
/// );
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
/// assert_eq!(output.level(), DecompositionLevel(3));
/// ```
pub fn level(&self) -> DecompositionLevel {
DecompositionLevel(self.level)
}
}

View File

@@ -1,9 +1,13 @@
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
use crate::core_crypto::commons::math::decomposition::{
SignedDecomposer, SignedDecomposerNonNative,
};
use crate::core_crypto::commons::math::random::{RandomGenerable, Uniform};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::{Numeric, SignedInteger, UnsignedInteger};
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
use crate::core_crypto::commons::test_tools::{any_uint, any_usize, random_usize_between};
use crate::core_crypto::commons::traits::CastInto;
use std::fmt::Debug;
// Return a random decomposition valid for the size of the T type.
@@ -69,7 +73,7 @@ fn test_round_to_closest_representable<T: UnsignedTorus>() {
let bit: usize = log_b * level_max;
let val = val << (bits - bit);
let delta = delta >> (bits - (bits - bit - 1));
let delta = delta >> (bit + 1);
let decomposer = SignedDecomposer::new(
DecompositionBaseLog(log_b),
@@ -117,3 +121,171 @@ fn test_round_to_closest_twice_u32() {
fn test_round_to_closest_twice_u64() {
test_round_to_closest_twice::<u64>();
}
// Return a random decomposition valid for the size of the T type.
fn random_decomp_non_native<T: UnsignedInteger>(
ciphertext_modulus: CiphertextModulus<T>,
) -> SignedDecomposerNonNative<T> {
let mut base_log;
let mut level_count;
loop {
base_log = random_usize_between(2..T::BITS);
level_count = random_usize_between(2..T::BITS);
if base_log * level_count < T::BITS {
break;
}
}
SignedDecomposerNonNative::new(
DecompositionBaseLog(base_log),
DecompositionLevelCount(level_count),
ciphertext_modulus,
)
}
fn test_round_to_closest_representable_non_native<T: UnsignedTorus>(
ciphertext_modulus: CiphertextModulus<T>,
) {
// Manage limit cases
{
let log_b = any_usize();
let level_max = any_usize();
let bits = T::BITS;
let log_b = (log_b % ((bits / 4) - 1)) + 1;
let level_count = (level_max % 4) + 1;
let rep_bits: usize = log_b * level_count;
let base_to_the_level_u128 = 1u128 << rep_bits;
let smallest_representable_u128 =
ciphertext_modulus.get_custom_modulus() / base_to_the_level_u128;
let sub_smallest_representable_u128 = smallest_representable_u128 / 2;
// Compute an epsilon that should not change the result of a closest representable
let epsilon_u128 = any_uint::<u128>() % sub_smallest_representable_u128;
// Around 0
let val = T::ZERO;
let decomposer = SignedDecomposerNonNative::new(
DecompositionBaseLog(log_b),
DecompositionLevelCount(level_count),
ciphertext_modulus,
);
let val_u128: u128 = val.cast_into();
let val_plus_epsilon: T = val_u128
.wrapping_add(epsilon_u128)
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
.cast_into();
let closest = decomposer.closest_representable(val_plus_epsilon);
assert_eq!(
val, closest,
"\n val_plus_epsilon: {val_plus_epsilon:064b}, \n \
expected_closest: {val:064b}, \n \
closest: {closest:064b}\n \
decomp_base_log: {}, decomp_level_count: {}",
decomposer.base_log, decomposer.level_count
);
let val_minus_epsilon: T = val_u128
.wrapping_add(ciphertext_modulus.get_custom_modulus())
.wrapping_sub(epsilon_u128)
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
.cast_into();
let closest = decomposer.closest_representable(val_minus_epsilon);
assert_eq!(
val, closest,
"\n val_minus_epsilon: {val_minus_epsilon:064b}, \n \
expected_closest: {val:064b}, \n \
closest: {closest:064b}\n \
decomp_base_log: {}, decomp_level_count: {}",
decomposer.base_log, decomposer.level_count
);
}
for _ in 0..1000 {
let log_b = any_usize();
let level_max = any_usize();
let bits = T::BITS;
let log_b = (log_b % ((bits / 4) - 1)) + 1;
let level_count = (level_max % 4) + 1;
let rep_bits: usize = log_b * level_count;
let base_to_the_level_u128 = 1u128 << rep_bits;
let base_to_the_level = T::ONE << rep_bits;
let smallest_representable_u128 =
ciphertext_modulus.get_custom_modulus() / base_to_the_level_u128;
let smallest_representable: T = smallest_representable_u128.cast_into();
let sub_smallest_representable_u128 = smallest_representable_u128 / 2;
// Compute an epsilon that should not change the result of a closest representable
let epsilon_u128 = any_uint::<u128>() % sub_smallest_representable_u128;
let multiple_of_smallest_representable = any_uint::<T>() % base_to_the_level;
let val = multiple_of_smallest_representable * smallest_representable;
let decomposer = SignedDecomposerNonNative::new(
DecompositionBaseLog(log_b),
DecompositionLevelCount(level_count),
ciphertext_modulus,
);
let val_u128: u128 = val.cast_into();
let val_plus_epsilon: T = val_u128
.wrapping_add(epsilon_u128)
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
.cast_into();
let closest = decomposer.closest_representable(val_plus_epsilon);
assert_eq!(
val, closest,
"\n val_plus_epsilon: {val_plus_epsilon:064b}, \n \
expected_closest: {val:064b}, \n \
closest: {closest:064b}\n \
decomp_base_log: {}, decomp_level_count: {}",
decomposer.base_log, decomposer.level_count
);
let val_minus_epsilon: T = val_u128
.wrapping_add(ciphertext_modulus.get_custom_modulus())
.wrapping_sub(epsilon_u128)
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
.cast_into();
let closest = decomposer.closest_representable(val_minus_epsilon);
assert_eq!(
val, closest,
"\n val_minus_epsilon: {val_minus_epsilon:064b}, \n \
expected_closest: {val:064b}, \n \
closest: {closest:064b}\n \
decomp_base_log: {}, decomp_level_count: {}",
decomposer.base_log, decomposer.level_count
);
}
}
#[test]
fn test_round_to_closest_representable_non_native_u64() {
test_round_to_closest_representable_non_native::<u64>(
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
);
}
fn test_round_to_closest_twice_non_native<T: UnsignedTorus + Debug>(
ciphertext_modulus: CiphertextModulus<T>,
) {
for _ in 0..1000 {
let decomp = random_decomp_non_native(ciphertext_modulus);
let input: T = any_uint();
let rounded_once = decomp.closest_representable(input);
let rounded_twice = decomp.closest_representable(rounded_once);
assert_eq!(rounded_once, rounded_twice);
}
}
#[test]
fn test_round_to_closest_twice_non_native_u64() {
test_round_to_closest_twice_non_native::<u64>(
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
);
}

Some files were not shown because too many files have changed in this diff Show More