mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-11 07:38:08 -05:00
Compare commits
27 Commits
tm/update-
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7043246c17 | ||
|
|
51735fb8ed | ||
|
|
23a348c9ae | ||
|
|
61b616b784 | ||
|
|
df48e176f3 | ||
|
|
dd2345df6b | ||
|
|
933800ea6f | ||
|
|
3e4cee3a75 | ||
|
|
00ea9b8e07 | ||
|
|
23ce85f6a2 | ||
|
|
126a95e929 | ||
|
|
23fffb1443 | ||
|
|
6d58a54266 | ||
|
|
9b8d5f5a43 | ||
|
|
696f964ecf | ||
|
|
a5323d1edf | ||
|
|
2d500d0de6 | ||
|
|
b1657876fb | ||
|
|
d2a570bdd6 | ||
|
|
122ef489fd | ||
|
|
ed84387bba | ||
|
|
1f4ba33a50 | ||
|
|
e645ee3397 | ||
|
|
569abd9a3b | ||
|
|
917bb5e1ef | ||
|
|
509aadcad2 | ||
|
|
e20aea90df |
@@ -2,6 +2,8 @@
|
||||
ignore = [
|
||||
# Ignoring unmaintained 'paste' advisory as it is a widely used, low-risk build dependency.
|
||||
"RUSTSEC-2024-0436",
|
||||
# Ignoring unmaintained 'bincode' crate. Getting rid of it would be too complex on the short term.
|
||||
"RUSTSEC-2025-0141",
|
||||
]
|
||||
|
||||
[output]
|
||||
|
||||
2
.github/actions/gpu_setup/action.yml
vendored
2
.github/actions/gpu_setup/action.yml
vendored
@@ -23,6 +23,8 @@ runs:
|
||||
echo "${CMAKE_SCRIPT_SHA} cmake-${CMAKE_VERSION}-linux-x86_64.sh" > checksum
|
||||
sha256sum -c checksum
|
||||
sudo bash cmake-"${CMAKE_VERSION}"-linux-x86_64.sh --skip-license --prefix=/usr/ --exclude-subdir
|
||||
sudo apt-get clean
|
||||
sudo rm -rf /var/lib/apt/lists/*
|
||||
sudo apt update
|
||||
sudo apt remove -y unattended-upgrades
|
||||
sudo apt install -y cmake-format libclang-dev
|
||||
|
||||
@@ -80,7 +80,7 @@ jobs:
|
||||
|
||||
- name: Retrieve data from cache
|
||||
id: retrieve-data-cache
|
||||
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/restore@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
with:
|
||||
path: |
|
||||
utils/tfhe-backward-compat-data/**/*.cbor
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
- name: Store data in cache
|
||||
if: steps.retrieve-data-cache.outputs.cache-hit != 'true'
|
||||
continue-on-error: true
|
||||
uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/save@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
with:
|
||||
path: |
|
||||
utils/tfhe-backward-compat-data/**/*.cbor
|
||||
|
||||
4
.github/workflows/aws_tfhe_fast_tests.yml
vendored
4
.github/workflows/aws_tfhe_fast_tests.yml
vendored
@@ -219,7 +219,7 @@ jobs:
|
||||
|
||||
- name: Node cache restoration
|
||||
id: node-cache
|
||||
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/restore@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
with:
|
||||
path: |
|
||||
~/.nvm
|
||||
@@ -232,7 +232,7 @@ jobs:
|
||||
make install_node
|
||||
|
||||
- name: Node cache save
|
||||
uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/save@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
if: steps.node-cache.outputs.cache-hit != 'true'
|
||||
with:
|
||||
path: |
|
||||
|
||||
4
.github/workflows/aws_tfhe_wasm_tests.yml
vendored
4
.github/workflows/aws_tfhe_wasm_tests.yml
vendored
@@ -80,7 +80,7 @@ jobs:
|
||||
|
||||
- name: Node cache restoration
|
||||
id: node-cache
|
||||
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/restore@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
with:
|
||||
path: |
|
||||
~/.nvm
|
||||
@@ -93,7 +93,7 @@ jobs:
|
||||
make install_node
|
||||
|
||||
- name: Node cache save
|
||||
uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/save@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
if: steps.node-cache.outputs.cache-hit != 'true'
|
||||
with:
|
||||
path: |
|
||||
|
||||
@@ -195,7 +195,7 @@ jobs:
|
||||
uses: foundry-rs/foundry-toolchain@8b0419c685ef46cb79ec93fbdc131174afceb730
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4.3.0
|
||||
uses: actions/cache@9255dc7a253b0ccc959486e2bca901246202afeb # v5.0.1
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
|
||||
10
.github/workflows/benchmark_wasm_client.yml
vendored
10
.github/workflows/benchmark_wasm_client.yml
vendored
@@ -119,7 +119,7 @@ jobs:
|
||||
|
||||
- name: Node cache restoration
|
||||
id: node-cache
|
||||
uses: actions/cache/restore@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/restore@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
with:
|
||||
path: |
|
||||
~/.nvm
|
||||
@@ -132,7 +132,7 @@ jobs:
|
||||
make install_node
|
||||
|
||||
- name: Node cache save
|
||||
uses: actions/cache/save@0057852bfaa89a56745cba8c7296529d2fc39830 #v4.3.0
|
||||
uses: actions/cache/save@9255dc7a253b0ccc959486e2bca901246202afeb #v5.0.1
|
||||
if: steps.node-cache.outputs.cache-hit != 'true'
|
||||
with:
|
||||
path: |
|
||||
@@ -153,6 +153,12 @@ jobs:
|
||||
env:
|
||||
BROWSER: ${{ matrix.browser }}
|
||||
|
||||
- name: Run benchmarks (unsafe coop)
|
||||
run: |
|
||||
make bench_web_js_api_unsafe_coop_"${BROWSER}"_ci
|
||||
env:
|
||||
BROWSER: ${{ matrix.browser }}
|
||||
|
||||
- name: Parse results
|
||||
run: |
|
||||
make parse_wasm_benchmarks
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,6 +10,7 @@ target/
|
||||
**/*.rmeta
|
||||
**/Cargo.lock
|
||||
**/*.bin
|
||||
**/.DS_Store
|
||||
|
||||
# Some of our bench outputs
|
||||
/tfhe/benchmarks_parameters
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
/tfhe/src/core_crypto/gpu @agnesLeroy
|
||||
/tfhe/src/core_crypto/hpu @zama-ai/hardware
|
||||
|
||||
/tfhe/src/shortint/ @mayeul-zama
|
||||
/tfhe/src/shortint/ @mayeul-zama @nsarlin-zama
|
||||
|
||||
/tfhe/src/integer/ @tmontaigu
|
||||
/tfhe/src/integer/gpu @agnesLeroy
|
||||
@@ -19,8 +19,12 @@
|
||||
|
||||
/tfhe/src/high_level_api/ @tmontaigu
|
||||
|
||||
/tfhe-zk-pok/ @nsarlin-zama
|
||||
|
||||
/tfhe-benchmark/ @soonum
|
||||
|
||||
/utils/ @nsarlin-zama
|
||||
|
||||
/Makefile @IceTDrinker @soonum
|
||||
|
||||
/mockups/tfhe-hpu-mockup @zama-ai/hardware
|
||||
|
||||
@@ -36,6 +36,7 @@ rayon = "1.11"
|
||||
serde = { version = "1.0", default-features = false }
|
||||
wasm-bindgen = "0.2.101"
|
||||
getrandom = "0.2.8"
|
||||
# The project maintainers consider that this is the last version of the 1.3 branch, any newer version should not be trusted
|
||||
bincode = "=1.3.3"
|
||||
|
||||
[profile.bench]
|
||||
|
||||
38
Makefile
38
Makefile
@@ -1300,13 +1300,14 @@ run_web_js_api_parallel: build_web_js_api_parallel setup_venv
|
||||
--browser-path $(browser_path) \
|
||||
--driver-path $(driver_path) \
|
||||
--browser-kind $(browser_kind) \
|
||||
--server-cmd "npm run server" \
|
||||
--server-cmd $(server_cmd) \
|
||||
--server-workdir "$(WEB_SERVER_DIR)" \
|
||||
--id-pattern $(filter)
|
||||
|
||||
test_web_js_api_parallel_chrome: browser_path = "$(WEB_RUNNER_DIR)/chrome/chrome-linux64/chrome"
|
||||
test_web_js_api_parallel_chrome: driver_path = "$(WEB_RUNNER_DIR)/chrome/chromedriver-linux64/chromedriver"
|
||||
test_web_js_api_parallel_chrome: browser_kind = chrome
|
||||
test_web_js_api_parallel_chrome: server_cmd = "npm run server:multithreaded"
|
||||
test_web_js_api_parallel_chrome: filter = Test
|
||||
|
||||
.PHONY: test_web_js_api_parallel_chrome # Run tests for the web wasm api on Chrome
|
||||
@@ -1322,6 +1323,7 @@ test_web_js_api_parallel_chrome_ci: setup_venv
|
||||
test_web_js_api_parallel_firefox: browser_path = "$(WEB_RUNNER_DIR)/firefox/firefox/firefox"
|
||||
test_web_js_api_parallel_firefox: driver_path = "$(WEB_RUNNER_DIR)/firefox/geckodriver"
|
||||
test_web_js_api_parallel_firefox: browser_kind = firefox
|
||||
test_web_js_api_parallel_firefox: server_cmd = "npm run server:multithreaded"
|
||||
test_web_js_api_parallel_firefox: filter = Test
|
||||
|
||||
.PHONY: test_web_js_api_parallel_firefox # Run tests for the web wasm api on Firefox
|
||||
@@ -1571,6 +1573,7 @@ bench_pbs128_gpu: install_rs_check_toolchain
|
||||
bench_web_js_api_parallel_chrome: browser_path = "$(WEB_RUNNER_DIR)/chrome/chrome-linux64/chrome"
|
||||
bench_web_js_api_parallel_chrome: driver_path = "$(WEB_RUNNER_DIR)/chrome/chromedriver-linux64/chromedriver"
|
||||
bench_web_js_api_parallel_chrome: browser_kind = chrome
|
||||
bench_web_js_api_parallel_chrome: server_cmd = "npm run server:multithreaded"
|
||||
bench_web_js_api_parallel_chrome: filter = Bench
|
||||
|
||||
.PHONY: bench_web_js_api_parallel_chrome # Run benchmarks for the web wasm api
|
||||
@@ -1586,6 +1589,7 @@ bench_web_js_api_parallel_chrome_ci: setup_venv
|
||||
bench_web_js_api_parallel_firefox: browser_path = "$(WEB_RUNNER_DIR)/firefox/firefox/firefox"
|
||||
bench_web_js_api_parallel_firefox: driver_path = "$(WEB_RUNNER_DIR)/firefox/geckodriver"
|
||||
bench_web_js_api_parallel_firefox: browser_kind = firefox
|
||||
bench_web_js_api_parallel_firefox: server_cmd = "npm run server:multithreaded"
|
||||
bench_web_js_api_parallel_firefox: filter = Bench
|
||||
|
||||
.PHONY: bench_web_js_api_parallel_firefox # Run benchmarks for the web wasm api
|
||||
@@ -1598,6 +1602,38 @@ bench_web_js_api_parallel_firefox_ci: setup_venv
|
||||
nvm use $(NODE_VERSION) && \
|
||||
$(MAKE) bench_web_js_api_parallel_firefox
|
||||
|
||||
bench_web_js_api_unsafe_coop_chrome: browser_path = "$(WEB_RUNNER_DIR)/chrome/chrome-linux64/chrome"
|
||||
bench_web_js_api_unsafe_coop_chrome: driver_path = "$(WEB_RUNNER_DIR)/chrome/chromedriver-linux64/chromedriver"
|
||||
bench_web_js_api_unsafe_coop_chrome: browser_kind = chrome
|
||||
bench_web_js_api_unsafe_coop_chrome: server_cmd = "npm run server:unsafe-coop"
|
||||
bench_web_js_api_unsafe_coop_chrome: filter = ZeroKnowledgeBench # Only bench zk with unsafe coop
|
||||
|
||||
.PHONY: bench_web_js_api_unsafe_coop_chrome # Run benchmarks for the web wasm api without cross-origin isolation
|
||||
bench_web_js_api_unsafe_coop_chrome: run_web_js_api_parallel
|
||||
|
||||
.PHONY: bench_web_js_api_unsafe_coop_chrome_ci # Run benchmarks for the web wasm api without cross-origin isolation
|
||||
bench_web_js_api_unsafe_coop_chrome_ci: setup_venv
|
||||
source ~/.nvm/nvm.sh && \
|
||||
nvm install $(NODE_VERSION) && \
|
||||
nvm use $(NODE_VERSION) && \
|
||||
$(MAKE) bench_web_js_api_unsafe_coop_chrome
|
||||
|
||||
bench_web_js_api_unsafe_coop_firefox: browser_path = "$(WEB_RUNNER_DIR)/firefox/firefox/firefox"
|
||||
bench_web_js_api_unsafe_coop_firefox: driver_path = "$(WEB_RUNNER_DIR)/firefox/geckodriver"
|
||||
bench_web_js_api_unsafe_coop_firefox: browser_kind = firefox
|
||||
bench_web_js_api_unsafe_coop_firefox: server_cmd = "npm run server:unsafe-coop"
|
||||
bench_web_js_api_unsafe_coop_firefox: filter = ZeroKnowledgeBench # Only bench zk with unsafe coop
|
||||
|
||||
.PHONY: bench_web_js_api_unsafe_coop_firefox # Run benchmarks for the web wasm api without cross-origin isolation
|
||||
bench_web_js_api_unsafe_coop_firefox: run_web_js_api_parallel
|
||||
|
||||
.PHONY: bench_web_js_api_unsafe_coop_firefox_ci # Run benchmarks for the web wasm api without cross-origin isolation
|
||||
bench_web_js_api_unsafe_coop_firefox_ci: setup_venv
|
||||
source ~/.nvm/nvm.sh && \
|
||||
nvm install $(NODE_VERSION) && \
|
||||
nvm use $(NODE_VERSION) && \
|
||||
$(MAKE) bench_web_js_api_unsafe_coop_firefox
|
||||
|
||||
.PHONY: bench_hlapi # Run benchmarks for integer operations
|
||||
bench_hlapi: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" __TFHE_RS_BENCH_BIT_SIZES_SET=$(BIT_SIZES_SET) \
|
||||
|
||||
@@ -1,24 +1,32 @@
|
||||
08f31a47c29cc4d72ad32c0b5411fa20b3deef5b84558dd2fb892d3cdf90528a data/toy_params/glwe_after_id_br_karatsuba.cbor
|
||||
29b6e3e7d27700004b70dca24d225816500490e2d6ee49b9af05837fd421896b data/valid_params_128/lwe_after_spec_pbs.cbor
|
||||
2c70d1d78cc3760733850a353ace2b9c4705e840141b75841739e90e51247e18 data/valid_params_128/small_lwe_secret_key.cbor
|
||||
2fb4bb45c259b8383da10fc8f9459c40a6972c49b1696eb107f0a75640724be5 data/toy_params/lwe_after_id_pbs_karatsuba.cbor
|
||||
36c9080b636475fcacca503ce041bbfeee800fd3e1890dee559ea18defff9fe8 data/toy_params/glwe_after_id_br.cbor
|
||||
377761beeb4216cf5aa2624a8b64b8259f5a75c32d28e850be8bced3a0cdd6f5 data/toy_params/ksk.cbor
|
||||
59dba26d457f96478eda130cab5301fce86f23c6a8807de42f2a1e78c4985ca7 data/valid_params_128/lwe_ks.cbor
|
||||
5d80dd93fefae4f4f89484dfcd65bbe99cc32e7e3b0a90c33dd0d77516c0a023 data/valid_params_128/glwe_after_id_br_karatsuba.cbor
|
||||
656f0009c7834c5bcb61621e222047516054b9bc5d0593d474ab8f1c086b67a6 data/valid_params_128/lwe_after_id_pbs.cbor
|
||||
699580ca92b9c2f9e1f57fb1e312c9e8cb29714f7acdef9d2ba05f798546751f data/toy_params/lwe_sum.cbor
|
||||
6e54ab41056984595b077baff70236d934308cf5c0c33b4482fbfb129b3756c6 data/valid_params_128/glwe_after_id_br.cbor
|
||||
70f5e5728822de05b49071efb5ec28551b0f5cc87aa709a455d8e7f04b9c96ee data/toy_params/lwe_after_id_pbs.cbor
|
||||
76a5c52cab7fec1dc167da676c6cd39479cda6b2bb9f4e0573cb7d99c2692faa data/valid_params_128/lwe_after_id_pbs_karatsuba.cbor
|
||||
7cc6803f5fbc3d5a1bf597f2b979ce17eecd3d6baca12183dea21022a7b65c52 data/toy_params/bsk.cbor
|
||||
7f3c40a134623b44779a556212477fea26eaed22450f3b6faeb8721d63699972 data/valid_params_128/lwe_sum.cbor
|
||||
837b3bd3245d4d0534ed255fdef896fb4fa6998a258a14543dfdadd0bfc9b6dd data/toy_params/lwe_prod.cbor
|
||||
9ece8ca9c1436258b94e8c5e629b8722f9b18fdd415dd5209b6167a9dde8491c data/toy_params/glwe_after_spec_br_karatsuba.cbor
|
||||
aa44aea29efd6d9e4d35a21a625d9cba155672e3f7ed3eddee1e211e62ad146b data/valid_params_128/lwe_ms.cbor
|
||||
b7a037b9eaa88d6385167579b93e26a0cb6976d9b8967416fd1173e113bda199 data/valid_params_128/large_lwe_secret_key.cbor
|
||||
b7b8e3586128887bd682120f3e3a43156139bce5e3fe0b03284f8753a864d647 data/toy_params/lwe_after_spec_pbs_karatsuba.cbor
|
||||
bd00a8ae7494e400de5753029552ee1647efe7e17409b863a26a13b081099b8c data/toy_params/lwe_after_spec_pbs.cbor
|
||||
c6df98676de04fe54b5ffc2eb30a82ebb706c9d7d5a4e0ed509700fec88761f7 data/toy_params/lwe_ms.cbor
|
||||
c7d5a864d5616a7d8ad50bbf40416e41e6c9b60c546dc14d4aa8fc40a418baa7 data/toy_params/large_lwe_secret_key.cbor
|
||||
c806533b325b1009db38be2f9bef5f3b2fad6b77b4c71f2855ccc9d3b4162e98 data/valid_params_128/lwe_b.cbor
|
||||
c9eb75bd2993639348a679cf48c06e3c38d1a513f48e5b0ce0047cea8cff6bbc data/toy_params/lwe_a.cbor
|
||||
d3391969acf26dc69de0927ba279139d8d79999944069addc8ff469ad6c5ae2d data/valid_params_128/lwe_after_spec_pbs_karatsuba.cbor
|
||||
d6da5baef0e787f6be56e218d8354e26904652602db964844156fdff08350ce6 data/toy_params/lwe_ks.cbor
|
||||
e591ab9af1b6a0aede273f9a3abb65a4c387feb5fa06a6959e9314058ca0f7e5 data/valid_params_128/ksk.cbor
|
||||
e59b002df3a9b01ad321ec51cf076fa35131ab9dbef141d1c54b717d61426c92 data/valid_params_128/glwe_after_spec_br_karatsuba.cbor
|
||||
e628354c81508a2d888016e8282df363dd12f1e19190b6475d4eb9d7ab8ae007 data/valid_params_128/glwe_after_spec_br.cbor
|
||||
e69d2d2c064fc8c0460b39191ca65338146990349954f5ec5ebd01d93610e7eb data/valid_params_128/lwe_a.cbor
|
||||
e76c24b2a0c9a842ad13dda35473c2514f9e7d20983b5ea0759c4521a91626d9 data/valid_params_128/lwe_prod.cbor
|
||||
|
||||
@@ -39,6 +39,9 @@ The following values are generated:
|
||||
| `glwe_after_spec_br` | The glwe returned by the application of the spec blind rotation on the mod switched ciphertexts. | `GlweCiphertext<Vec<u64>>` | rot spec LUT |
|
||||
| `lwe_after_spec_pbs` | The lwe returned by the application of the sample extract operation on the output of the spec blind rotation | `LweCiphertext<Vec<u64>>` | `spec(A)` |
|
||||
|
||||
Ciphertexts with the `_karatsuba` suffix are generated using the Karatsuba polynomial multiplication algorithm in the blind rotation, while default ciphertexts are generated using an FFT multiplication.
|
||||
This makes it easier to reproduce bit exact results.
|
||||
|
||||
### Encodings
|
||||
#### Non native encoding
|
||||
Warning: TFHE-rs uses a specific encoding for non native (ie: u32, u64) power of two ciphertext modulus. This encoding puts the encoded value in the high bits of the native integer.
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:08f31a47c29cc4d72ad32c0b5411fa20b3deef5b84558dd2fb892d3cdf90528a
|
||||
size 4679
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9ece8ca9c1436258b94e8c5e629b8722f9b18fdd415dd5209b6167a9dde8491c
|
||||
size 4679
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2fb4bb45c259b8383da10fc8f9459c40a6972c49b1696eb107f0a75640724be5
|
||||
size 2365
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b7b8e3586128887bd682120f3e3a43156139bce5e3fe0b03284f8753a864d647
|
||||
size 2365
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5d80dd93fefae4f4f89484dfcd65bbe99cc32e7e3b0a90c33dd0d77516c0a023
|
||||
size 36935
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e59b002df3a9b01ad321ec51cf076fa35131ab9dbef141d1c54b717d61426c92
|
||||
size 36935
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:76a5c52cab7fec1dc167da676c6cd39479cda6b2bb9f4e0573cb7d99c2692faa
|
||||
size 18493
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d3391969acf26dc69de0927ba279139d8d79999944069addc8ff469ad6c5ae2d
|
||||
size 18493
|
||||
@@ -265,6 +265,7 @@ fn generate_test_vectors<P: AsRef<Path>>(
|
||||
|
||||
let mut id_lut = encoding.encode_lut(glwe_dimension, polynomial_size, ID_LUT);
|
||||
assert_data_not_zero(&id_lut);
|
||||
let mut id_lut_karatsuba = id_lut.clone();
|
||||
|
||||
blind_rotate_assign(&modswitched, &mut id_lut, &fourier_bsk);
|
||||
assert_data_not_zero(&id_lut);
|
||||
@@ -287,8 +288,32 @@ fn generate_test_vectors<P: AsRef<Path>>(
|
||||
assert_data_not_zero(&lwe_pbs_id);
|
||||
store_data(path, &lwe_pbs_id, "lwe_after_id_pbs");
|
||||
|
||||
blind_rotate_karatsuba_assign(&modswitched, &mut id_lut_karatsuba, &bsk);
|
||||
store_data(path, &id_lut_karatsuba, "glwe_after_id_br_karatsuba");
|
||||
|
||||
let mut lwe_pbs_karatsuba_id = LweCiphertext::new(
|
||||
0u64,
|
||||
glwe_dimension
|
||||
.to_equivalent_lwe_dimension(polynomial_size)
|
||||
.to_lwe_size(),
|
||||
encoding.ciphertext_modulus,
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(
|
||||
&id_lut_karatsuba,
|
||||
&mut lwe_pbs_karatsuba_id,
|
||||
MonomialDegree(0),
|
||||
);
|
||||
|
||||
let decrypted_pbs_id = decrypt_lwe_ciphertext(&large_lwe_secret_key, &lwe_pbs_karatsuba_id);
|
||||
let res = encoding.decode(decrypted_pbs_id);
|
||||
|
||||
assert_eq!(res, MSG_A);
|
||||
store_data(path, &lwe_pbs_karatsuba_id, "lwe_after_id_pbs_karatsuba");
|
||||
|
||||
let mut spec_lut = encoding.encode_lut(glwe_dimension, polynomial_size, SPEC_LUT);
|
||||
assert_data_not_zero(&spec_lut);
|
||||
let mut spec_lut_karatsuba = spec_lut.clone();
|
||||
|
||||
blind_rotate_assign(&modswitched, &mut spec_lut, &fourier_bsk);
|
||||
assert_data_not_zero(&spec_lut);
|
||||
@@ -310,6 +335,33 @@ fn generate_test_vectors<P: AsRef<Path>>(
|
||||
assert_eq!(res, SPEC_LUT(MSG_A));
|
||||
assert_data_not_zero(&lwe_pbs_spec);
|
||||
store_data(path, &lwe_pbs_spec, "lwe_after_spec_pbs");
|
||||
|
||||
blind_rotate_karatsuba_assign(&modswitched, &mut spec_lut_karatsuba, &bsk);
|
||||
store_data(path, &spec_lut_karatsuba, "glwe_after_spec_br_karatsuba");
|
||||
|
||||
let mut lwe_pbs_karatsuba_spec = LweCiphertext::new(
|
||||
0u64,
|
||||
glwe_dimension
|
||||
.to_equivalent_lwe_dimension(polynomial_size)
|
||||
.to_lwe_size(),
|
||||
encoding.ciphertext_modulus,
|
||||
);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(
|
||||
&spec_lut_karatsuba,
|
||||
&mut lwe_pbs_karatsuba_spec,
|
||||
MonomialDegree(0),
|
||||
);
|
||||
|
||||
let decrypted_pbs_spec = decrypt_lwe_ciphertext(&large_lwe_secret_key, &lwe_pbs_karatsuba_spec);
|
||||
let res = encoding.decode(decrypted_pbs_spec);
|
||||
|
||||
assert_eq!(res, SPEC_LUT(MSG_A));
|
||||
store_data(
|
||||
path,
|
||||
&lwe_pbs_karatsuba_spec,
|
||||
"lwe_after_spec_pbs_karatsuba",
|
||||
);
|
||||
}
|
||||
|
||||
fn rm_dir_except_readme<P: AsRef<Path>>(dir: P) {
|
||||
|
||||
@@ -40,7 +40,7 @@ rand = "0.8.5"
|
||||
regex = "1.10.4"
|
||||
bitflags = { version = "2.5.0", features = ["serde"] }
|
||||
itertools = "0.11.0"
|
||||
lru = "0.12.3"
|
||||
lru = "0.16.3"
|
||||
bitfield-struct = "0.10.0"
|
||||
crossbeam = { version = "0.8.4", features = ["crossbeam-queue"] }
|
||||
rayon = { workspace = true }
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:35cc06547a23b862ab9829351d74d944e60ea9dad3ecf593d15f0ce8445d145e
|
||||
size 81710610
|
||||
oid sha256:934c8131c12010dc837f6a2af5111b83f8f5d42f10485e9b3b971edb24c467f8
|
||||
size 82201876
|
||||
|
||||
@@ -113,6 +113,7 @@ pub fn iop_add_simd(prog: &mut Program) {
|
||||
prog,
|
||||
crate::asm::iop::SIMD_N,
|
||||
fw_impl::llt::iop_add_ripple_rtl,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -227,14 +228,23 @@ pub fn iop_muls(prog: &mut Program) {
|
||||
pub fn iop_erc_20(prog: &mut Program) {
|
||||
// Add Comment header
|
||||
prog.push_comment("ERC_20 (new_from, new_to) <- (from, to, amount)".to_string());
|
||||
iop_erc_20_rtl(prog, 0).add_to_prog(prog);
|
||||
// TODO: Make sweep of kogge_blk_w
|
||||
// All these little parameters would be very handy to write an
|
||||
// exploration/compilation program which would try to minimize latency by
|
||||
// playing with these.
|
||||
iop_erc_20_rtl(prog, 0, Some(10)).add_to_prog(prog);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_erc_20_simd(prog: &mut Program) {
|
||||
// Add Comment header
|
||||
prog.push_comment("ERC_20_SIMD (new_from, new_to) <- (from, to, amount)".to_string());
|
||||
simd(prog, crate::asm::iop::SIMD_N, fw_impl::llt::iop_erc_20_rtl);
|
||||
simd(
|
||||
prog,
|
||||
crate::asm::iop::SIMD_N,
|
||||
fw_impl::llt::iop_erc_20_rtl,
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
@@ -381,7 +391,7 @@ pub fn iop_rotate_scalar_left(prog: &mut Program) {
|
||||
/// (dst_from[0], dst_to[0], ..., dst_from[N-1], dst_to[N-1])
|
||||
/// Where N is the batch size
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_erc_20_rtl(prog: &mut Program, batch_index: u8) -> Rtl {
|
||||
pub fn iop_erc_20_rtl(prog: &mut Program, batch_index: u8, kogge_blk_w: Option<usize>) -> Rtl {
|
||||
// Allocate metavariables:
|
||||
// Dest -> Operand
|
||||
let dst_from = prog.iop_template_var(OperandKind::Dst, 2 * batch_index);
|
||||
@@ -392,13 +402,6 @@ pub fn iop_erc_20_rtl(prog: &mut Program, batch_index: u8) -> Rtl {
|
||||
// Src Amount -> Operand
|
||||
let src_amount = prog.iop_template_var(OperandKind::Src, 3 * batch_index + 2);
|
||||
|
||||
// TODO: Make this a parameter or sweep this
|
||||
// All these little parameters would be very handy to write an
|
||||
// exploration/compilation program which would try to minimize latency by
|
||||
// playing with these.
|
||||
let kogge_blk_w = 10;
|
||||
let ripple = true;
|
||||
|
||||
{
|
||||
let props = prog.params();
|
||||
let tfhe_params: asm::DigitParameters = props.clone().into();
|
||||
@@ -429,19 +432,21 @@ pub fn iop_erc_20_rtl(prog: &mut Program, batch_index: u8) -> Rtl {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if ripple {
|
||||
if let Some(blk_w) = kogge_blk_w {
|
||||
kogge::add(prog, dst_to, src_to, src_amount.clone(), None, blk_w)
|
||||
+ kogge::sub(prog, dst_from, src_from, src_amount, blk_w)
|
||||
} else {
|
||||
// Default to ripple carry
|
||||
kogge::ripple_add(dst_to, src_to, src_amount.clone(), None)
|
||||
+ kogge::ripple_sub(prog, dst_from, src_from, src_amount)
|
||||
} else {
|
||||
kogge::add(prog, dst_to, src_to, src_amount.clone(), None, kogge_blk_w)
|
||||
+ kogge::sub(prog, dst_from, src_from, src_amount, kogge_blk_w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A SIMD implementation of add for maximum throughput
|
||||
/// NB: No use of kogge_blk_w here, impl force to use ripple carry
|
||||
#[instrument(level = "trace", skip(prog))]
|
||||
pub fn iop_add_ripple_rtl(prog: &mut Program, i: u8) -> Rtl {
|
||||
pub fn iop_add_ripple_rtl(prog: &mut Program, i: u8, _kogge_blk_w: Option<usize>) -> Rtl {
|
||||
// Allocate metavariables:
|
||||
let dst = prog.iop_template_var(OperandKind::Dst, i);
|
||||
let src_a = prog.iop_template_var(OperandKind::Src, 2 * i);
|
||||
@@ -899,13 +904,13 @@ fn bw_inv(prog: &mut Program, b: Vec<VarCell>) -> Vec<VarCell> {
|
||||
/// Maybe this should go into a SIMD firmware implementation... At some point we
|
||||
/// would need a mechanism to choose between implementations on the fly to make
|
||||
/// real good use of all of this.
|
||||
fn simd<F>(prog: &mut Program, batch_size: usize, rtl_closure: F)
|
||||
fn simd<F>(prog: &mut Program, batch_size: usize, rtl_closure: F, kogge_blk_w: Option<usize>)
|
||||
where
|
||||
F: Fn(&mut Program, u8) -> Rtl,
|
||||
F: Fn(&mut Program, u8, Option<usize>) -> Rtl,
|
||||
{
|
||||
(0..batch_size)
|
||||
.map(|i| i as u8)
|
||||
.map(|i| rtl_closure(prog, i))
|
||||
.map(|i| rtl_closure(prog, i, kogge_blk_w))
|
||||
.sum::<Rtl>()
|
||||
.add_to_prog(prog);
|
||||
}
|
||||
|
||||
@@ -160,9 +160,9 @@ impl ProgramInner {
|
||||
.filter(|(_, var)| var.is_none())
|
||||
.map(|(rid, _)| *rid)
|
||||
.collect::<Vec<_>>();
|
||||
demote_order
|
||||
.into_iter()
|
||||
.for_each(|rid| self.regs.demote(&rid));
|
||||
demote_order.into_iter().for_each(|rid| {
|
||||
self.regs.demote(&rid);
|
||||
});
|
||||
}
|
||||
|
||||
/// Release register entry
|
||||
@@ -179,7 +179,7 @@ impl ProgramInner {
|
||||
|
||||
/// Notify register access to update LRU state
|
||||
pub(crate) fn reg_access(&mut self, rid: asm::RegId) {
|
||||
self.regs.promote(&rid)
|
||||
self.regs.promote(&rid);
|
||||
}
|
||||
|
||||
/// Retrieved least-recent-used heap entry
|
||||
@@ -220,9 +220,9 @@ impl ProgramInner {
|
||||
.filter(|(_mid, var)| var.is_none())
|
||||
.map(|(mid, _)| *mid)
|
||||
.collect::<Vec<_>>();
|
||||
demote_order
|
||||
.into_iter()
|
||||
.for_each(|mid| self.heap.demote(&mid));
|
||||
demote_order.into_iter().for_each(|mid| {
|
||||
self.heap.demote(&mid);
|
||||
});
|
||||
}
|
||||
_ => { /*Only release Heap slot*/ }
|
||||
}
|
||||
@@ -231,7 +231,9 @@ impl ProgramInner {
|
||||
/// Notify heap access to update LRU state
|
||||
pub(crate) fn heap_access(&mut self, mid: asm::MemId) {
|
||||
match mid {
|
||||
asm::MemId::Heap { .. } => self.heap.promote(&mid),
|
||||
asm::MemId::Heap { .. } => {
|
||||
self.heap.promote(&mid);
|
||||
}
|
||||
_ => { /* Do Nothing slot do not below to heap*/ }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -367,6 +367,8 @@ def dump_benchmark_results(results, browser_kind):
|
||||
"""
|
||||
Dump as JSON benchmark results into a file.
|
||||
If `results` is an empty dict then this function is a no-op.
|
||||
If the file already exists, new results are merged with existing ones,
|
||||
overwriting keys that already exist.
|
||||
|
||||
:param results: benchmark results as :class:`dict`
|
||||
:param browser_kind: browser as :class:`BrowserKind`
|
||||
@@ -376,7 +378,15 @@ def dump_benchmark_results(results, browser_kind):
|
||||
key.replace("mean", "_".join((browser_kind.name, "mean"))): val
|
||||
for key, val in results.items()
|
||||
}
|
||||
pathlib.Path("tfhe-benchmark/wasm_benchmark_results.json").write_text(json.dumps(results))
|
||||
results_path = pathlib.Path("tfhe-benchmark/wasm_benchmark_results.json")
|
||||
existing_results = {}
|
||||
if results_path.exists():
|
||||
try:
|
||||
existing_results = json.loads(results_path.read_text())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
existing_results.update(results)
|
||||
results_path.write_text(json.dumps(existing_results))
|
||||
|
||||
|
||||
def start_web_server(
|
||||
|
||||
1
tfhe-benchmark/.gitignore
vendored
Normal file
1
tfhe-benchmark/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
benchmarks_parameters/*
|
||||
@@ -2,7 +2,9 @@ use benchmark::utilities::{
|
||||
hlapi_throughput_num_ops, write_to_json, BenchmarkType, BitSizesSet, EnvConfig, OperatorType,
|
||||
};
|
||||
use criterion::{black_box, Criterion, Throughput};
|
||||
use oprf::oprf_any_range2;
|
||||
use rand::prelude::*;
|
||||
use rayon::prelude::*;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::*;
|
||||
use tfhe::core_crypto::prelude::Numeric;
|
||||
@@ -11,34 +13,42 @@ use tfhe::keycache::NamedParam;
|
||||
use tfhe::named::Named;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::{
|
||||
ClientKey, CompressedServerKey, FheIntegerType, FheUint10, FheUint12, FheUint128, FheUint14,
|
||||
FheUint16, FheUint2, FheUint32, FheUint4, FheUint6, FheUint64, FheUint8, FheUintId, IntegerId,
|
||||
KVStore,
|
||||
ClientKey, CompressedServerKey, FheIntegerType, FheUint, FheUint10, FheUint12, FheUint128,
|
||||
FheUint14, FheUint16, FheUint2, FheUint32, FheUint4, FheUint6, FheUint64, FheUint8, FheUintId,
|
||||
IntegerId, KVStore,
|
||||
};
|
||||
|
||||
use rayon::prelude::*;
|
||||
mod oprf;
|
||||
|
||||
fn bench_fhe_type<FheType>(
|
||||
trait BenchWait {
|
||||
fn wait_bench(&self);
|
||||
}
|
||||
|
||||
impl<Id: FheUintId> BenchWait for FheUint<Id> {
|
||||
fn wait_bench(&self) {
|
||||
self.wait()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T1: FheWait, T2> BenchWait for (T1, T2) {
|
||||
fn wait_bench(&self) {
|
||||
self.0.wait()
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_fhe_type_op<FheType, F, R>(
|
||||
c: &mut Criterion,
|
||||
client_key: &ClientKey,
|
||||
type_name: &str,
|
||||
bit_size: usize,
|
||||
display_name: &str,
|
||||
func_name: &str,
|
||||
func: F,
|
||||
) where
|
||||
F: Fn(&FheType, &FheType) -> R,
|
||||
R: BenchWait,
|
||||
FheType: FheEncrypt<u128, ClientKey>,
|
||||
FheType: FheWait,
|
||||
for<'a> &'a FheType: Add<&'a FheType, Output = FheType>
|
||||
+ Sub<&'a FheType, Output = FheType>
|
||||
+ Mul<&'a FheType, Output = FheType>
|
||||
+ BitAnd<&'a FheType, Output = FheType>
|
||||
+ BitOr<&'a FheType, Output = FheType>
|
||||
+ BitXor<&'a FheType, Output = FheType>
|
||||
+ Shl<&'a FheType, Output = FheType>
|
||||
+ Shr<&'a FheType, Output = FheType>
|
||||
+ RotateLeft<&'a FheType, Output = FheType>
|
||||
+ RotateRight<&'a FheType, Output = FheType>
|
||||
+ OverflowingAdd<&'a FheType, Output = FheType>
|
||||
+ OverflowingSub<&'a FheType, Output = FheType>,
|
||||
for<'a> FheType: FheMin<&'a FheType, Output = FheType> + FheMax<&'a FheType, Output = FheType>,
|
||||
{
|
||||
let mut bench_group = c.benchmark_group(type_name);
|
||||
let mut bench_prefix = "hlapi".to_string();
|
||||
@@ -71,170 +81,90 @@ fn bench_fhe_type<FheType>(
|
||||
let lhs = FheType::encrypt(rng.gen(), client_key);
|
||||
let rhs = FheType::encrypt(rng.gen(), client_key);
|
||||
|
||||
let mut bench_id;
|
||||
let bench_id = format!("{bench_prefix}::{func_name}::{param_name}::{type_name}");
|
||||
|
||||
bench_id = format!("{bench_prefix}::add::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs + &rhs;
|
||||
res.wait();
|
||||
let res = func(&lhs, &rhs);
|
||||
res.wait_bench();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "add");
|
||||
|
||||
bench_id = format!("{bench_prefix}::overflowing_add::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let (res, flag) = lhs.overflowing_add(&rhs);
|
||||
res.wait();
|
||||
black_box((res, flag))
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "overflowing_add");
|
||||
|
||||
bench_id = format!("{bench_prefix}::overflowing_sub::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let (res, flag) = lhs.overflowing_sub(&rhs);
|
||||
res.wait();
|
||||
black_box((res, flag))
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "overflowing_sub");
|
||||
|
||||
bench_id = format!("{bench_prefix}::sub::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs - &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "sub");
|
||||
|
||||
bench_id = format!("{bench_prefix}::mul::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs * &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "mul");
|
||||
|
||||
bench_id = format!("{bench_prefix}::bitand::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs & &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "bitand");
|
||||
|
||||
bench_id = format!("{bench_prefix}::bitor::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs | &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "bitor");
|
||||
|
||||
bench_id = format!("{bench_prefix}::bitxor::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs ^ &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "bitxor");
|
||||
|
||||
bench_id = format!("{bench_prefix}::left_shift::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs << &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "left_shift");
|
||||
|
||||
bench_id = format!("{bench_prefix}::right_shift::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = &lhs >> &rhs;
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "right_shift");
|
||||
|
||||
bench_id = format!("{bench_prefix}::left_rotate::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = (&lhs).rotate_left(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "left_rotate");
|
||||
|
||||
bench_id = format!("{bench_prefix}::right_rotate::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = (&lhs).rotate_right(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "right_rotate");
|
||||
|
||||
bench_id = format!("{bench_prefix}::min::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = lhs.min(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "min");
|
||||
|
||||
bench_id = format!("{bench_prefix}::max::{param_name}::{type_name}");
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
let res = lhs.max(&rhs);
|
||||
res.wait();
|
||||
black_box(res)
|
||||
})
|
||||
});
|
||||
write_record(bench_id, "max");
|
||||
write_record(bench_id, display_name);
|
||||
}
|
||||
|
||||
macro_rules! bench_type {
|
||||
($fhe_type:ident) => {
|
||||
macro_rules! bench_type_op (
|
||||
(type_name: $fhe_type:ident, display_name: $display_name:literal, operation: $op:ident) => {
|
||||
::paste::paste! {
|
||||
fn [<bench_ $fhe_type:snake>](c: &mut Criterion, cks: &ClientKey) {
|
||||
bench_fhe_type::<$fhe_type>(c, cks, stringify!($fhe_type), $fhe_type::num_bits());
|
||||
fn [<bench_ $fhe_type:snake _ $op>](c: &mut Criterion, cks: &ClientKey) {
|
||||
bench_fhe_type_op::<$fhe_type, _, _>(
|
||||
c,
|
||||
cks,
|
||||
stringify!($fhe_type),
|
||||
$fhe_type::num_bits(),
|
||||
$display_name,
|
||||
stringify!($op),
|
||||
|lhs, rhs| lhs.$op(rhs)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
);
|
||||
|
||||
macro_rules! generate_typed_benches {
|
||||
($fhe_type:ident) => {
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "add", operation: add);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "overflowing_add", operation: overflowing_add);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "sub", operation: sub);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "overflowing_sub", operation: overflowing_sub);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "mul", operation: mul);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "bitand", operation: bitand);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "bitor", operation: bitor);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "bitxor", operation: bitxor);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "left_shift", operation: shl);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "right_shift", operation: shr);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "left_rotate", operation: rotate_left);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "right_rotate", operation: rotate_right);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "min", operation: min);
|
||||
bench_type_op!(type_name: $fhe_type, display_name: "max", operation: max);
|
||||
};
|
||||
}
|
||||
|
||||
bench_type!(FheUint2);
|
||||
bench_type!(FheUint4);
|
||||
bench_type!(FheUint6);
|
||||
bench_type!(FheUint8);
|
||||
bench_type!(FheUint10);
|
||||
bench_type!(FheUint12);
|
||||
bench_type!(FheUint14);
|
||||
bench_type!(FheUint16);
|
||||
bench_type!(FheUint32);
|
||||
bench_type!(FheUint64);
|
||||
bench_type!(FheUint128);
|
||||
// Generate benches for all FheUint types
|
||||
generate_typed_benches!(FheUint2);
|
||||
generate_typed_benches!(FheUint4);
|
||||
generate_typed_benches!(FheUint6);
|
||||
generate_typed_benches!(FheUint8);
|
||||
generate_typed_benches!(FheUint10);
|
||||
generate_typed_benches!(FheUint12);
|
||||
generate_typed_benches!(FheUint14);
|
||||
generate_typed_benches!(FheUint16);
|
||||
generate_typed_benches!(FheUint32);
|
||||
generate_typed_benches!(FheUint64);
|
||||
generate_typed_benches!(FheUint128);
|
||||
|
||||
macro_rules! run_benches {
|
||||
($c:expr, $cks:expr, $($fhe_type:ident),+ $(,)?) => {
|
||||
$(
|
||||
::paste::paste! {
|
||||
[<bench_ $fhe_type:snake _add>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _overflowing_add>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _sub>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _overflowing_sub>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _mul>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _bitand>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _bitor>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _bitxor>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _shl>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _shr>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _rotate_left>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _rotate_right>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _min>]($c, $cks);
|
||||
[<bench_ $fhe_type:snake _max>]($c, $cks);
|
||||
}
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
trait TypeDisplay {
|
||||
fn fmt(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -444,7 +374,7 @@ fn main() {
|
||||
|
||||
match env_config.bit_sizes_set {
|
||||
BitSizesSet::Fast => {
|
||||
bench_fhe_uint64(&mut c, &cks);
|
||||
run_benches!(&mut c, &cks, FheUint64);
|
||||
|
||||
// KVStore Benches
|
||||
if benched_device == tfhe::Device::Cpu {
|
||||
@@ -452,17 +382,11 @@ fn main() {
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
bench_fhe_uint2(&mut c, &cks);
|
||||
bench_fhe_uint4(&mut c, &cks);
|
||||
bench_fhe_uint6(&mut c, &cks);
|
||||
bench_fhe_uint8(&mut c, &cks);
|
||||
bench_fhe_uint10(&mut c, &cks);
|
||||
bench_fhe_uint12(&mut c, &cks);
|
||||
bench_fhe_uint14(&mut c, &cks);
|
||||
bench_fhe_uint16(&mut c, &cks);
|
||||
bench_fhe_uint32(&mut c, &cks);
|
||||
bench_fhe_uint64(&mut c, &cks);
|
||||
bench_fhe_uint128(&mut c, &cks);
|
||||
// Call all benchmarks for all types
|
||||
run_benches!(
|
||||
&mut c, &cks, FheUint2, FheUint4, FheUint6, FheUint8, FheUint10, FheUint12,
|
||||
FheUint14, FheUint16, FheUint32, FheUint64, FheUint128
|
||||
);
|
||||
|
||||
// KVStore Benches
|
||||
if benched_device == tfhe::Device::Cpu {
|
||||
@@ -481,5 +405,7 @@ fn main() {
|
||||
}
|
||||
}
|
||||
|
||||
oprf_any_range2();
|
||||
|
||||
c.final_summary();
|
||||
}
|
||||
|
||||
@@ -29,12 +29,21 @@ pub fn transfer_whitepaper<FheType>(
|
||||
) -> (FheType, FheType)
|
||||
where
|
||||
FheType: Add<Output = FheType> + for<'a> FheOrd<&'a FheType> + FheTrivialEncrypt<u64>,
|
||||
FheBool: IfThenElse<FheType>,
|
||||
FheBool: IfThenZero<FheType> + IfThenElse<FheType>,
|
||||
for<'a> &'a FheType: Add<Output = FheType> + Sub<Output = FheType>,
|
||||
{
|
||||
let has_enough_funds = (from_amount).ge(amount);
|
||||
let zero_amount = FheType::encrypt_trivial(0u64);
|
||||
let amount_to_transfer = has_enough_funds.select(amount, &zero_amount);
|
||||
let amount_to_transfer = {
|
||||
#[cfg(not(feature = "hpu"))]
|
||||
{
|
||||
let zero_amount = FheType::encrypt_trivial(0u64);
|
||||
has_enough_funds.select(amount, &zero_amount)
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
{
|
||||
has_enough_funds.if_then_zero(amount)
|
||||
}
|
||||
};
|
||||
|
||||
let new_to_amount = to_amount + &amount_to_transfer;
|
||||
let new_from_amount = from_amount - &amount_to_transfer;
|
||||
@@ -51,12 +60,21 @@ pub fn par_transfer_whitepaper<FheType>(
|
||||
where
|
||||
FheType:
|
||||
Add<Output = FheType> + for<'a> FheOrd<&'a FheType> + Send + Sync + FheTrivialEncrypt<u64>,
|
||||
FheBool: IfThenElse<FheType>,
|
||||
FheBool: IfThenZero<FheType> + IfThenElse<FheType>,
|
||||
for<'a> &'a FheType: Add<Output = FheType> + Sub<Output = FheType>,
|
||||
{
|
||||
let has_enough_funds = (from_amount).ge(amount);
|
||||
let zero_amount = FheType::encrypt_trivial(0u64);
|
||||
let amount_to_transfer = has_enough_funds.select(amount, &zero_amount);
|
||||
let amount_to_transfer = {
|
||||
#[cfg(feature = "gpu")]
|
||||
{
|
||||
let zero_amount = FheType::encrypt_trivial(0u64);
|
||||
has_enough_funds.select(amount, &zero_amount)
|
||||
}
|
||||
#[cfg(not(feature = "gpu"))]
|
||||
{
|
||||
has_enough_funds.if_then_zero(amount)
|
||||
}
|
||||
};
|
||||
|
||||
let (new_to_amount, new_from_amount) = rayon::join(
|
||||
|| to_amount + &amount_to_transfer,
|
||||
|
||||
44
tfhe-benchmark/benches/high_level_api/oprf.rs
Normal file
44
tfhe-benchmark/benches/high_level_api/oprf.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
use criterion::{black_box, criterion_group, Criterion};
|
||||
use std::num::NonZeroU64;
|
||||
use tfhe::{set_server_key, ClientKey, ConfigBuilder, FheUint64, RangeForRandom, Seed, ServerKey};
|
||||
|
||||
pub fn oprf_any_range(c: &mut Criterion) {
|
||||
let bench_name = "hlapi::oprf_any_range";
|
||||
|
||||
let mut bench_group = c.benchmark_group(bench_name);
|
||||
bench_group
|
||||
.sample_size(15)
|
||||
.measurement_time(std::time::Duration::from_secs(30));
|
||||
|
||||
let param = BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
let config = ConfigBuilder::with_custom_parameters(param).build();
|
||||
let cks = ClientKey::generate(config);
|
||||
let sks = ServerKey::new(&cks);
|
||||
|
||||
rayon::broadcast(|_| set_server_key(sks.clone()));
|
||||
set_server_key(sks);
|
||||
|
||||
for excluded_upper_bound in [3, 52] {
|
||||
let range = RangeForRandom::new_from_excluded_upper_bound(
|
||||
NonZeroU64::new(excluded_upper_bound).unwrap(),
|
||||
);
|
||||
|
||||
let bench_id_oprf = format!("{bench_name}::bound_{excluded_upper_bound}");
|
||||
|
||||
bench_group.bench_function(&bench_id_oprf, |b| {
|
||||
b.iter(|| {
|
||||
_ = black_box(FheUint64::generate_oblivious_pseudo_random_custom_range(
|
||||
Seed(0),
|
||||
&range,
|
||||
None,
|
||||
));
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
bench_group.finish()
|
||||
}
|
||||
|
||||
criterion_group!(oprf_any_range2, oprf_any_range);
|
||||
@@ -630,7 +630,7 @@ mod integer_params {
|
||||
#[cfg(feature = "hpu")]
|
||||
let params = vec![BENCH_HPU_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128.into()];
|
||||
#[cfg(not(feature = "hpu"))]
|
||||
let params = vec![BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into()];
|
||||
let params = vec![BENCH_PARAM_MESSAGE_2_CARRY_2_KS32_PBS.into()];
|
||||
|
||||
let params_and_bit_sizes = iproduct!(params, env_config.bit_sizes());
|
||||
Self {
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
#[cfg(any(feature = "shortint", feature = "integer"))]
|
||||
pub mod shortint_params_aliases {
|
||||
use tfhe::shortint::parameters::current_params::*;
|
||||
#[cfg(feature = "hpu")]
|
||||
use tfhe::shortint::parameters::KeySwitch32PBSParameters;
|
||||
use tfhe::shortint::parameters::{
|
||||
ClassicPBSParameters, CompactPublicKeyEncryptionParameters, CompressionParameters,
|
||||
MultiBitPBSParameters, NoiseSquashingCompressionParameters, NoiseSquashingParameters,
|
||||
ShortintKeySwitchingParameters,
|
||||
KeySwitch32PBSParameters, MultiBitPBSParameters, NoiseSquashingCompressionParameters,
|
||||
NoiseSquashingParameters, ShortintKeySwitchingParameters,
|
||||
};
|
||||
|
||||
// KS PBS Gaussian
|
||||
@@ -42,6 +40,8 @@ pub mod shortint_params_aliases {
|
||||
V1_5_PARAM_MESSAGE_4_CARRY_4_KS_PBS_TUNIFORM_2M128;
|
||||
pub const BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS: ClassicPBSParameters =
|
||||
V1_5_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
pub const BENCH_PARAM_MESSAGE_2_CARRY_2_KS32_PBS: KeySwitch32PBSParameters =
|
||||
V1_5_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
|
||||
|
||||
pub const BENCH_ALL_CLASSIC_PBS_PARAMETERS: [(&ClassicPBSParameters, &str); 141] =
|
||||
VEC_ALL_CLASSIC_PBS_PARAMETERS;
|
||||
|
||||
@@ -27,6 +27,7 @@ rand_distr = "0.4.3"
|
||||
criterion = "0.5.1"
|
||||
doc-comment = "0.3.3"
|
||||
serde_json = "1.0.94"
|
||||
num-bigint = "0.4.6"
|
||||
# clap has to be pinned as its minimum supported rust version
|
||||
# changes often between minor releases, which breaks our CI
|
||||
clap = { version = "=4.5.30", features = ["derive"] }
|
||||
|
||||
@@ -2,14 +2,30 @@
|
||||
|
||||
This document explains the mechanism and steps to generate an oblivious encrypted random value using only server keys.
|
||||
|
||||
The goal is to give to the server the possibility to generate a random value, which will be obtained in an encrypted format and will remain unknown to the server. The implementation is based on [this article](https://eprint.iacr.org/2024/665).
|
||||
The goal is to give to the server the possibility to generate a random value, which will be obtained in an encrypted format and will remain unknown to the server.
|
||||
|
||||
This is possible through two methods on `FheUint` and `FheInt`:
|
||||
The main method for this is `FheUint::generate_oblivious_pseudo_random_custom_range` which returns an integer in the given range.
|
||||
Currently the range can only be in the form `[0, excluded_upper_bound[` with any `excluded_upper_bound` in `[1, 2^64[`
|
||||
It follows a distribution close to the uniform.
|
||||
|
||||
This function guarantees the norm-1 distance (defined as ∆(P,Q) := 1/2 Sum[ω∈Ω] |P(ω) − Q(ω)|)
|
||||
between the actual distribution and the target uniform distribution will be below the `max_distance` argument (which must be in ]0, 1[).
|
||||
The higher the distance, the more dissimilar the actual distribution is from the target uniform distribution.
|
||||
|
||||
The default value for `max_distance` is `2^-128` if `None` is provided.
|
||||
|
||||
Higher values allow better performance but must be considered carefully in the context of their target application as it may have serious unintended consequences.
|
||||
|
||||
If the range is a power of 2, the distribution is uniform (for any `max_distance`) and the cost is smaller.
|
||||
|
||||
|
||||
For powers of 2 specifically there are two methods on `FheUint` and `FheInt` (based on [this article](https://eprint.iacr.org/2024/665)):
|
||||
- `generate_oblivious_pseudo_random` which return an integer taken uniformly in the full integer range (`[0; 2^N[` for a `FheUintN` and `[-2^(N-1); 2^(N-1)[` for a `FheIntN`).
|
||||
- `generate_oblivious_pseudo_random_bounded` which return an integer taken uniformly in `[0; 2^random_bits_count[`. For a `FheUintN`, we must have `random_bits_count <= N`. For a `FheIntN`, we must have `random_bits_count <= N - 1`.
|
||||
|
||||
Both methods functions take a seed `Seed` as input, which could be any `u128` value.
|
||||
They both rely on the use of the usual server key.
|
||||
|
||||
These method functions take a seed `Seed` as input, which could be any `u128` value.
|
||||
They rely on the use of the usual server key.
|
||||
The output is reproducible, i.e., the function is deterministic from the inputs: assuming the same hardware, seed and server key, this function outputs the same random encrypted value.
|
||||
|
||||
|
||||
@@ -18,7 +34,8 @@ Here is an example of the usage:
|
||||
|
||||
```rust
|
||||
use tfhe::prelude::FheDecrypt;
|
||||
use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8, FheInt8, Seed};
|
||||
use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8, FheInt8, RangeForRandom, Seed};
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
pub fn main() {
|
||||
let config = ConfigBuilder::default().build();
|
||||
@@ -26,23 +43,30 @@ pub fn main() {
|
||||
|
||||
set_server_key(server_key);
|
||||
|
||||
let random_bits_count = 3;
|
||||
|
||||
let ct_res = FheUint8::generate_oblivious_pseudo_random(Seed(0));
|
||||
let excluded_upper_bound = NonZeroU64::new(3).unwrap();
|
||||
let range = RangeForRandom::new_from_excluded_upper_bound(excluded_upper_bound);
|
||||
|
||||
// in [0, excluded_upper_bound[ = {0, 1, 2}
|
||||
let ct_res = FheUint8::generate_oblivious_pseudo_random_custom_range(Seed(0), &range, None);
|
||||
let dec_result: u8 = ct_res.decrypt(&client_key);
|
||||
|
||||
let ct_res = FheUint8::generate_oblivious_pseudo_random_bounded(Seed(0), random_bits_count);
|
||||
let random_bits_count = 3;
|
||||
|
||||
// in [0, 2^8[
|
||||
let ct_res = FheUint8::generate_oblivious_pseudo_random(Seed(0));
|
||||
let dec_result: u8 = ct_res.decrypt(&client_key);
|
||||
|
||||
// in [0, 2^random_bits_count[ = [0, 8[
|
||||
let ct_res = FheUint8::generate_oblivious_pseudo_random_bounded(Seed(0), random_bits_count);
|
||||
let dec_result: u8 = ct_res.decrypt(&client_key);
|
||||
assert!(dec_result < (1 << random_bits_count));
|
||||
|
||||
// in [-2^7, 2^7[
|
||||
let ct_res = FheInt8::generate_oblivious_pseudo_random(Seed(0));
|
||||
|
||||
let dec_result: i8 = ct_res.decrypt(&client_key);
|
||||
|
||||
// in [0, 2^random_bits_count[ = [0, 8[
|
||||
let ct_res = FheInt8::generate_oblivious_pseudo_random_bounded(Seed(0), random_bits_count);
|
||||
|
||||
let dec_result: i8 = ct_res.decrypt(&client_key);
|
||||
assert!(dec_result < (1 << random_bits_count));
|
||||
}
|
||||
|
||||
@@ -0,0 +1,415 @@
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use dyn_stack::{PodStack, StackReq};
|
||||
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::commons::utils::izip_eq;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::ggsw::collect_next_term;
|
||||
use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecompositionLendingIter;
|
||||
use crate::core_crypto::prelude::polynomial_algorithms::*;
|
||||
use crate::core_crypto::prelude::{
|
||||
extract_lwe_sample_from_glwe_ciphertext, lwe_ciphertext_modulus_switch, ComputationBuffers,
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweSize, ModulusSwitchedLweCiphertext,
|
||||
MonomialDegree, PolynomialSize, SignedDecomposer,
|
||||
};
|
||||
|
||||
pub fn programmable_bootstrap_karatsuba_lwe_ciphertext_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> StackReq {
|
||||
StackReq::all_of(&[
|
||||
// local accumulator
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
// blind rotation
|
||||
blind_rotate_karatsuba_assign_scratch::<Scalar>(glwe_size, polynomial_size),
|
||||
])
|
||||
}
|
||||
|
||||
/// Return the required memory for [`blind_rotate_karatsuba_assign`].
|
||||
pub fn blind_rotate_karatsuba_assign_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> StackReq {
|
||||
StackReq::any_of(&[
|
||||
// tmp_poly allocation
|
||||
StackReq::new_aligned::<Scalar>(polynomial_size.0, CACHELINE_ALIGN),
|
||||
StackReq::all_of(&[
|
||||
// ct1 allocation
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
// external product
|
||||
karatsuba_add_external_product_assign_scratch::<Scalar>(glwe_size, polynomial_size),
|
||||
]),
|
||||
])
|
||||
}
|
||||
|
||||
/// Return the required memory for [`karatsuba_add_external_product_assign`].
|
||||
pub fn karatsuba_add_external_product_assign_scratch<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
) -> StackReq {
|
||||
StackReq::all_of(&[
|
||||
// Output buffer
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
// decomposition
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
// decomposition term
|
||||
StackReq::new_aligned::<Scalar>(glwe_size.0 * polynomial_size.0, CACHELINE_ALIGN),
|
||||
])
|
||||
}
|
||||
|
||||
/// Perform a programmable bootstrap given an input [`LWE ciphertext`](`LweCiphertext`), a
|
||||
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
|
||||
/// key`](`LweBootstrapKey`) using the karatsuba polynomial multiplication. The result is written in
|
||||
/// the provided output [`LWE ciphertext`](`LweCiphertext`).
|
||||
///
|
||||
/// If you want to manage the computation memory manually you can use
|
||||
/// [`programmable_bootstrap_karatsuba_lwe_ciphertext_mem_optimized`].
|
||||
///
|
||||
/// # Warning
|
||||
/// For a more efficient implementation of the programmable bootstrap, see
|
||||
/// [`programmable_bootstrap_lwe_ciphertext`](super::programmable_bootstrap_lwe_ciphertext)
|
||||
pub fn programmable_bootstrap_karatsuba_lwe_ciphertext<InputCont, OutputCont, AccCont, KeyCont>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
bsk: &LweBootstrapKey<KeyCont>,
|
||||
) where
|
||||
InputCont: Container<Element = u64>,
|
||||
OutputCont: ContainerMut<Element = u64>,
|
||||
AccCont: Container<Element = u64>,
|
||||
KeyCont: Container<Element = u64>,
|
||||
{
|
||||
assert!(
|
||||
input.ciphertext_modulus().is_power_of_two(),
|
||||
"This operation requires the input to have a power of two modulus."
|
||||
);
|
||||
assert_eq!(
|
||||
output.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
buffers.resize(
|
||||
programmable_bootstrap_karatsuba_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
)
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
programmable_bootstrap_karatsuba_lwe_ciphertext_mem_optimized(
|
||||
input,
|
||||
output,
|
||||
accumulator,
|
||||
bsk,
|
||||
buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Perform a programmable bootstrap given an input [`LWE ciphertext`](`LweCiphertext`), a
|
||||
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
|
||||
/// key`](`LweBootstrapKey`) using the karatsuba polynomial multiplication. The result is written in
|
||||
/// the provided output [`LWE ciphertext`](`LweCiphertext`).
|
||||
///
|
||||
/// # Warning
|
||||
/// For a more efficient implementation of the programmable bootstrap, see
|
||||
/// [`programmable_bootstrap_lwe_ciphertext_mem_optimized`](super::programmable_bootstrap_lwe_ciphertext_mem_optimized)
|
||||
pub fn programmable_bootstrap_karatsuba_lwe_ciphertext_mem_optimized<
|
||||
InputCont,
|
||||
OutputCont,
|
||||
AccCont,
|
||||
KeyCont,
|
||||
>(
|
||||
input: &LweCiphertext<InputCont>,
|
||||
output: &mut LweCiphertext<OutputCont>,
|
||||
accumulator: &GlweCiphertext<AccCont>,
|
||||
bsk: &LweBootstrapKey<KeyCont>,
|
||||
stack: &mut PodStack,
|
||||
) where
|
||||
InputCont: Container<Element = u64>,
|
||||
OutputCont: ContainerMut<Element = u64>,
|
||||
AccCont: Container<Element = u64>,
|
||||
KeyCont: Container<Element = u64>,
|
||||
{
|
||||
assert_eq!(
|
||||
output.ciphertext_modulus(),
|
||||
accumulator.ciphertext_modulus()
|
||||
);
|
||||
assert_eq!(accumulator.ciphertext_modulus(), bsk.ciphertext_modulus());
|
||||
|
||||
let (local_accumulator_data, stack) =
|
||||
stack.collect_aligned(CACHELINE_ALIGN, accumulator.as_ref().iter().copied());
|
||||
let mut local_accumulator = GlweCiphertextMutView::from_container(
|
||||
&mut *local_accumulator_data,
|
||||
accumulator.polynomial_size(),
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let log_modulus = accumulator
|
||||
.polynomial_size()
|
||||
.to_blind_rotation_input_modulus_log();
|
||||
|
||||
let msed = lwe_ciphertext_modulus_switch(input.as_view(), log_modulus);
|
||||
|
||||
blind_rotate_karatsuba_assign_mem_optimized(&msed, &mut local_accumulator, bsk, stack);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
|
||||
}
|
||||
|
||||
/// Perform a blind rotation given an input [`modulus switched LWE
|
||||
/// ciphertext`](`ModulusSwitchedLweCiphertext`), modifying a look-up table passed as a [`GLWE
|
||||
/// ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap key`](`LweBootstrapKey`) using the
|
||||
/// karatsuba polynomial multiplication.
|
||||
///
|
||||
/// If you want to manage the computation memory manually you can use
|
||||
/// [`blind_rotate_karatsuba_assign_mem_optimized`].
|
||||
///
|
||||
/// # Warning
|
||||
/// For a more efficient implementation of the blind rotation, see
|
||||
/// [`blind_rotate_assign`](super::blind_rotate_assign)
|
||||
pub fn blind_rotate_karatsuba_assign<OutputScalar, OutputCont, KeyCont>(
|
||||
msed_input: &impl ModulusSwitchedLweCiphertext<usize>,
|
||||
lut: &mut GlweCiphertext<OutputCont>,
|
||||
bsk: &LweBootstrapKey<KeyCont>,
|
||||
) where
|
||||
OutputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = OutputScalar>,
|
||||
GlweCiphertext<OutputCont>: PartialEq<GlweCiphertext<OutputCont>>,
|
||||
{
|
||||
let mut buffers = ComputationBuffers::new();
|
||||
|
||||
buffers.resize(
|
||||
blind_rotate_karatsuba_assign_scratch::<u64>(bsk.glwe_size(), bsk.polynomial_size())
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
|
||||
blind_rotate_karatsuba_assign_mem_optimized(msed_input, lut, bsk, buffers.stack())
|
||||
}
|
||||
|
||||
/// Perform a blind rotation given an input [`modulus switched LWE
|
||||
/// ciphertext`](`ModulusSwitchedLweCiphertext`), modifying a look-up table passed as a [`GLWE
|
||||
/// ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap key`](`LweBootstrapKey`) using the
|
||||
/// karatsuba polynomial multiplication.
|
||||
///
|
||||
/// # Warning
|
||||
/// For a more efficient implementation of the blind rotation, see
|
||||
/// [`blind_rotate_assign`](super::blind_rotate_assign)
|
||||
pub fn blind_rotate_karatsuba_assign_mem_optimized<OutputScalar, OutputCont, KeyCont>(
|
||||
msed_input: &impl ModulusSwitchedLweCiphertext<usize>,
|
||||
lut: &mut GlweCiphertext<OutputCont>,
|
||||
bsk: &LweBootstrapKey<KeyCont>,
|
||||
stack: &mut PodStack,
|
||||
) where
|
||||
OutputScalar: UnsignedTorus + CastInto<usize>,
|
||||
OutputCont: ContainerMut<Element = OutputScalar>,
|
||||
KeyCont: Container<Element = OutputScalar>,
|
||||
GlweCiphertext<OutputCont>: PartialEq<GlweCiphertext<OutputCont>>,
|
||||
{
|
||||
assert!(lut.ciphertext_modulus().is_power_of_two());
|
||||
|
||||
assert_eq!(
|
||||
bsk.input_lwe_dimension(),
|
||||
msed_input.lwe_dimension(),
|
||||
"Mismatched input LweDimension. \
|
||||
LweBootstrapKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
|
||||
bsk.input_lwe_dimension(),
|
||||
msed_input.lwe_dimension(),
|
||||
);
|
||||
assert_eq!(
|
||||
bsk.glwe_size(),
|
||||
lut.glwe_size(),
|
||||
"Mismatched GlweSize. \
|
||||
LweBootstrapKey GlweSize: {:?}, lut GlweSize {:?}.",
|
||||
bsk.glwe_size(),
|
||||
lut.glwe_size(),
|
||||
);
|
||||
assert_eq!(
|
||||
lut.polynomial_size(),
|
||||
bsk.polynomial_size(),
|
||||
"Mismatched PolynomialSize. \
|
||||
LweBootstrapKey PolynomialSize: {:?}, lut PolynomialSize {:?}.",
|
||||
bsk.polynomial_size(),
|
||||
lut.polynomial_size(),
|
||||
);
|
||||
|
||||
let msed_lwe_mask = msed_input.mask();
|
||||
|
||||
let msed_lwe_body = msed_input.body();
|
||||
|
||||
let monomial_degree = MonomialDegree(msed_lwe_body.cast_into());
|
||||
|
||||
let lut_poly_size = lut.polynomial_size();
|
||||
let ciphertext_modulus = lut.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
lut.as_mut_polynomial_list()
|
||||
.iter_mut()
|
||||
.for_each(|mut poly| {
|
||||
let (tmp_poly, _) = stack.make_aligned_raw(poly.as_ref().len(), CACHELINE_ALIGN);
|
||||
|
||||
let mut tmp_poly = Polynomial::from_container(&mut *tmp_poly);
|
||||
tmp_poly.as_mut().copy_from_slice(poly.as_ref());
|
||||
polynomial_wrapping_monic_monomial_div(&mut poly, &tmp_poly, monomial_degree);
|
||||
});
|
||||
|
||||
// We initialize the ct_0 used for the successive cmuxes
|
||||
let ct0 = lut;
|
||||
let (ct1, stack) = stack.make_aligned_raw(ct0.as_ref().len(), CACHELINE_ALIGN);
|
||||
let mut ct1 =
|
||||
GlweCiphertextMutView::from_container(&mut *ct1, lut_poly_size, ciphertext_modulus);
|
||||
|
||||
for (lwe_mask_element, bootstrap_key_ggsw) in izip_eq!(msed_lwe_mask, bsk.iter()) {
|
||||
if lwe_mask_element != 0 {
|
||||
let monomial_degree = MonomialDegree(lwe_mask_element);
|
||||
|
||||
// we effectively inline the body of cmux here, merging the initial subtraction
|
||||
// operation with the monic polynomial multiplication, then performing the
|
||||
// external product manually
|
||||
|
||||
// We rotate ct_1 and subtract ct_0 (first step of cmux) by performing
|
||||
// ct_1 <- (ct_0 * X^a_i) - ct_0
|
||||
for (mut ct1_poly, ct0_poly) in izip_eq!(
|
||||
ct1.as_mut_polynomial_list().iter_mut(),
|
||||
ct0.as_polynomial_list().iter(),
|
||||
) {
|
||||
polynomial_wrapping_monic_monomial_mul_and_subtract(
|
||||
&mut ct1_poly,
|
||||
&ct0_poly,
|
||||
monomial_degree,
|
||||
);
|
||||
}
|
||||
|
||||
// second step of cmux:
|
||||
// ct_0 <- ct_0 + ct1 * s_i
|
||||
// with ct_0 + ct1s_i = ct_0 + ((ct_0 * X^a_i) - ct_0)s_i
|
||||
// = ct_0 if s_i= 0
|
||||
// ct_0 * X^a_i otherwise
|
||||
// = ct_0 * X^(a_i * s_i)
|
||||
//
|
||||
// as_mut_view is required to keep borrow rules consistent
|
||||
karatsuba_add_external_product_assign(
|
||||
ct0.as_mut_view(),
|
||||
bootstrap_key_ggsw,
|
||||
ct1.as_view(),
|
||||
stack,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
ct0.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the external product of `ggsw` and `glwe`, and adds the result to `out`.
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub fn karatsuba_add_external_product_assign<Scalar>(
|
||||
mut out: GlweCiphertextMutView<'_, Scalar>,
|
||||
ggsw: GgswCiphertextView<Scalar>,
|
||||
glwe: GlweCiphertextView<Scalar>,
|
||||
stack: &mut PodStack,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
{
|
||||
// we check that the polynomial sizes match
|
||||
debug_assert_eq!(ggsw.polynomial_size(), glwe.polynomial_size());
|
||||
debug_assert_eq!(ggsw.polynomial_size(), out.polynomial_size());
|
||||
// we check that the glwe sizes match
|
||||
debug_assert_eq!(ggsw.glwe_size(), glwe.glwe_size());
|
||||
debug_assert_eq!(ggsw.glwe_size(), out.glwe_size());
|
||||
|
||||
let align = CACHELINE_ALIGN;
|
||||
let poly_size = ggsw.polynomial_size().0;
|
||||
|
||||
// we round the input mask and body
|
||||
let decomposer = SignedDecomposer::<Scalar>::new(
|
||||
ggsw.decomposition_base_log(),
|
||||
ggsw.decomposition_level_count(),
|
||||
);
|
||||
|
||||
let (output_buffer, substack0) =
|
||||
stack.make_aligned_raw::<Scalar>(poly_size * ggsw.glwe_size().0, align);
|
||||
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
|
||||
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
|
||||
// it has been fully initialized for the first time.
|
||||
let output_buffer = &mut *output_buffer;
|
||||
let mut is_output_uninit = true;
|
||||
|
||||
let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new(
|
||||
glwe.as_ref()
|
||||
.iter()
|
||||
.map(|s| decomposer.init_decomposer_state(*s)),
|
||||
DecompositionBaseLog(decomposer.base_log),
|
||||
DecompositionLevelCount(decomposer.level_count),
|
||||
substack0,
|
||||
);
|
||||
|
||||
// We loop through the levels
|
||||
for ggsw_decomp_matrix in ggsw.iter() {
|
||||
// We retrieve the decomposition of this level.
|
||||
let (_glwe_level, glwe_decomp_term, _substack2) =
|
||||
collect_next_term(&mut decomposition, substack1, align);
|
||||
let glwe_decomp_term = GlweCiphertextView::from_container(
|
||||
&*glwe_decomp_term,
|
||||
ggsw.polynomial_size(),
|
||||
out.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
// For each level we have to add the result of the vector-matrix product between the
|
||||
// decomposition of the glwe, and the ggsw level matrix to the output. To do so, we
|
||||
// iteratively add to the output, the product between every line of the matrix, and
|
||||
// the corresponding (scalar) polynomial in the glwe decomposition:
|
||||
//
|
||||
// ggsw_mat ggsw_mat
|
||||
// glwe_dec | - - - - | < glwe_dec | - - - - |
|
||||
// | - - - | x | - - - - | | - - - | x | - - - - | <
|
||||
// ^ | - - - - | ^ | - - - - |
|
||||
//
|
||||
// t = 1 t = 2 ...
|
||||
|
||||
for (ggsw_row, glwe_poly) in izip_eq!(
|
||||
ggsw_decomp_matrix.as_glwe_list().iter(),
|
||||
glwe_decomp_term.as_polynomial_list().iter()
|
||||
) {
|
||||
let row_as_poly_list = ggsw_row.as_polynomial_list();
|
||||
if is_output_uninit {
|
||||
for (mut output_poly, row_poly) in output_buffer
|
||||
.chunks_exact_mut(poly_size)
|
||||
.map(Polynomial::from_container)
|
||||
.zip(row_as_poly_list.iter())
|
||||
{
|
||||
polynomial_wrapping_mul(&mut output_poly, &row_poly, &glwe_poly);
|
||||
}
|
||||
} else {
|
||||
for (mut output_poly, row_poly) in output_buffer
|
||||
.chunks_exact_mut(poly_size)
|
||||
.map(Polynomial::from_container)
|
||||
.zip(row_as_poly_list.iter())
|
||||
{
|
||||
polynomial_wrapping_add_mul_assign(&mut output_poly, &row_poly, &glwe_poly);
|
||||
}
|
||||
}
|
||||
|
||||
is_output_uninit = false;
|
||||
}
|
||||
}
|
||||
|
||||
// We iterate over the polynomials in the output.
|
||||
if !is_output_uninit {
|
||||
izip_eq!(
|
||||
out.as_mut_polynomial_list().iter_mut(),
|
||||
output_buffer
|
||||
.into_chunks(poly_size)
|
||||
.map(Polynomial::from_container),
|
||||
)
|
||||
.for_each(|(mut out, res)| polynomial_wrapping_add_assign(&mut out, &res));
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
pub mod fft128_pbs;
|
||||
pub mod fft64_pbs;
|
||||
pub mod karatsuba_pbs;
|
||||
pub mod ntt64_bnf_pbs;
|
||||
pub mod ntt64_pbs;
|
||||
|
||||
pub use fft128_pbs::*;
|
||||
pub use fft64_pbs::*;
|
||||
pub use karatsuba_pbs::*;
|
||||
pub use ntt64_bnf_pbs::*;
|
||||
pub use ntt64_pbs::*;
|
||||
|
||||
|
||||
@@ -1161,3 +1161,91 @@ fn lwe_encrypt_pbs_ntt64_bnf_decrypt(params: ClassicTestParams<u64>) {
|
||||
create_parameterized_test!(lwe_encrypt_pbs_ntt64_bnf_decrypt {
|
||||
TEST_PARAMS_3_BITS_SOLINAS_U64
|
||||
});
|
||||
|
||||
fn lwe_encrypt_pbs_karatsuba_decrypt_custom_mod(params: ClassicTestParams<u64>) {
|
||||
let lwe_noise_distribution = params.lwe_noise_distribution;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let msg_modulus = 1 << (message_modulus_log.0);
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
let glwe_dimension = params.glwe_dimension;
|
||||
let polynomial_size = params.polynomial_size;
|
||||
|
||||
let mut rsc = TestResources::new();
|
||||
|
||||
let f = |x: u64| x;
|
||||
|
||||
let delta: u64 = encoding_with_padding / msg_modulus;
|
||||
let mut msg = msg_modulus;
|
||||
|
||||
let accumulator = generate_programmable_bootstrap_glwe_lut(
|
||||
polynomial_size,
|
||||
glwe_dimension.to_glwe_size(),
|
||||
msg_modulus.cast_into(),
|
||||
ciphertext_modulus,
|
||||
delta,
|
||||
f,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&accumulator,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
while msg != 0 {
|
||||
msg = msg.wrapping_sub(1);
|
||||
|
||||
let mut keys_gen = |params| generate_keys(params, &mut rsc);
|
||||
let keys = gen_keys_or_get_from_cache_if_enabled(params, &mut keys_gen);
|
||||
let (input_lwe_secret_key, output_lwe_secret_key, bsk) =
|
||||
(keys.small_lwe_sk, keys.big_lwe_sk, keys.bsk);
|
||||
|
||||
for _ in 0..NB_TESTS {
|
||||
let plaintext = Plaintext(msg * delta);
|
||||
|
||||
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&input_lwe_secret_key,
|
||||
plaintext,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&lwe_ciphertext_in,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let mut out_pbs_ct = LweCiphertext::new(
|
||||
0,
|
||||
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
programmable_bootstrap_karatsuba_lwe_ciphertext(
|
||||
&lwe_ciphertext_in,
|
||||
&mut out_pbs_ct,
|
||||
&accumulator,
|
||||
&bsk,
|
||||
);
|
||||
|
||||
assert!(check_encrypted_content_respects_mod(
|
||||
&out_pbs_ct,
|
||||
ciphertext_modulus
|
||||
));
|
||||
|
||||
let decrypted = decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
|
||||
|
||||
let decoded = round_decode(decrypted.0, delta) % msg_modulus;
|
||||
|
||||
assert_eq!(decoded, f(msg));
|
||||
}
|
||||
|
||||
// In coverage, we break after one while loop iteration, changing message values does not
|
||||
// yield higher coverage
|
||||
#[cfg(tarpaulin)]
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
create_parameterized_test!(lwe_encrypt_pbs_karatsuba_decrypt_custom_mod);
|
||||
|
||||
@@ -540,10 +540,12 @@ pub fn sup_diff(cumulative_bins: &[u64], theoretical_cdf: &[f64]) -> f64 {
|
||||
.iter()
|
||||
.copied()
|
||||
.zip_eq(theoretical_cdf.iter().copied())
|
||||
.map(|(x, theoretical_cdf)| {
|
||||
.enumerate()
|
||||
.map(|(i, (x, theoretical_cdf))| {
|
||||
let empirical_cdf = x as f64 / number_of_samples as f64;
|
||||
|
||||
if theoretical_cdf == 1.0 {
|
||||
if i == cumulative_bins.len() - 1 {
|
||||
assert_eq!(theoretical_cdf, 1.0);
|
||||
assert_eq!(empirical_cdf, 1.0);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId};
|
||||
use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
|
||||
use crate::high_level_api::traits::{
|
||||
FheEq, Flip, IfThenElse, ReRandomize, ScalarIfThenElse, Tagged,
|
||||
FheEq, Flip, IfThenElse, IfThenZero, ReRandomize, ScalarIfThenElse, Tagged,
|
||||
};
|
||||
use crate::high_level_api::{global_state, CompactPublicKey};
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
@@ -552,6 +552,66 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> IfThenZero<FheUint<Id>> for FheBool
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
/// Conditional selection.
|
||||
///
|
||||
/// The output value returned depends on the value of `self`.
|
||||
///
|
||||
/// - if `self` is true, the output will have the value of `ct_then`
|
||||
/// - if `self` is false, the output will be an encryption of 0
|
||||
fn if_then_zero(&self, ct_then: &FheUint<Id>) -> FheUint<Id> {
|
||||
global_state::with_internal_keys(|sks| match sks {
|
||||
InternalServerKey::Cpu(cpu_sks) => {
|
||||
let ct_condition = self;
|
||||
let mut ct_out = ct_then.ciphertext.on_cpu().clone();
|
||||
cpu_sks.pbs_key().zero_out_if_condition_is_false(
|
||||
&mut ct_out,
|
||||
&ct_condition.ciphertext.on_cpu().0,
|
||||
);
|
||||
FheUint::new(
|
||||
ct_out,
|
||||
cpu_sks.tag.clone(),
|
||||
ReRandomizationMetadata::default(),
|
||||
)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_) => {
|
||||
panic!("Cuda does not support if_then_zero")
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(device) => {
|
||||
let hpu_then = ct_then.ciphertext.on_hpu(device);
|
||||
let hpu_cond = self.ciphertext.on_hpu(device);
|
||||
|
||||
let (opcode, proto) = {
|
||||
let asm_iop = &hpu_asm::iop::IOP_IF_THEN_ZERO;
|
||||
(
|
||||
asm_iop.opcode(),
|
||||
&asm_iop.format().expect("Unspecified IOP format").proto,
|
||||
)
|
||||
};
|
||||
// These clones are cheap are they are just Arc
|
||||
let hpu_result = HpuRadixCiphertext::exec(
|
||||
proto,
|
||||
opcode,
|
||||
&[hpu_then.clone(), hpu_cond.clone()],
|
||||
&[],
|
||||
)
|
||||
.pop()
|
||||
.unwrap();
|
||||
FheUint::new(
|
||||
hpu_result,
|
||||
device.tag.clone(),
|
||||
ReRandomizationMetadata::default(),
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id: FheIntId> IfThenElse<FheInt<Id>> for FheBool {
|
||||
/// Conditional selection.
|
||||
///
|
||||
|
||||
@@ -4,7 +4,9 @@ use crate::high_level_api::keys::InternalServerKey;
|
||||
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::shortint::MessageModulus;
|
||||
use crate::{FheInt, Seed};
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
impl<Id: FheUintId> FheUint<Id> {
|
||||
/// Generates an encrypted unsigned integer
|
||||
@@ -92,7 +94,7 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
}
|
||||
})
|
||||
}
|
||||
/// Generates an encrypted `num_block` blocks unsigned integer
|
||||
/// Generates an encrypted unsigned integer
|
||||
/// taken uniformly in `[0, 2^random_bits_count[` using the given seed.
|
||||
/// The encrypted value is oblivious to the server.
|
||||
/// It can be useful to make server random generation deterministic.
|
||||
@@ -150,6 +152,103 @@ impl<Id: FheUintId> FheUint<Id> {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Generates an encrypted unsigned integer
|
||||
/// taken almost uniformly in the given range using the given seed.
|
||||
/// Currently the range can only be in the form `[0, excluded_upper_bound[`
|
||||
/// with any `excluded_upper_bound` in `[1, 2^64[`.
|
||||
///
|
||||
/// The encrypted value is oblivious to the server.
|
||||
/// It can be useful to make server random generation deterministic.
|
||||
///
|
||||
/// This function guarantees the the norm-1 distance
|
||||
/// (defined as ∆(P,Q) := 1/2 Sum[ω∈Ω] |P(ω) − Q(ω)|)
|
||||
/// between the actual distribution and the target uniform distribution
|
||||
/// will be below the `max_distance` argument (which must be in ]0, 1[).
|
||||
/// The higher the distance, the more dissimilar the actual distribution is
|
||||
/// from the target uniform distribution.
|
||||
///
|
||||
/// The default value for `max_distance` is `2^-128` if `None` is provided.
|
||||
///
|
||||
/// Higher values allow better performance but must be considered carefully in the context of
|
||||
/// their target application as it may have serious unintended consequences.
|
||||
///
|
||||
/// If the range is a power of 2, the distribution is uniform (for any `max_distance`) and
|
||||
/// the cost is smaller.
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::num::NonZeroU64;
|
||||
/// use tfhe::prelude::FheDecrypt;
|
||||
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8, RangeForRandom, Seed};
|
||||
///
|
||||
/// let config = ConfigBuilder::default().build();
|
||||
/// let (client_key, server_key) = generate_keys(config);
|
||||
///
|
||||
/// set_server_key(server_key);
|
||||
///
|
||||
/// let excluded_upper_bound = NonZeroU64::new(3).unwrap();
|
||||
///
|
||||
/// let range = RangeForRandom::new_from_excluded_upper_bound(excluded_upper_bound);
|
||||
///
|
||||
/// let ct_res = FheUint8::generate_oblivious_pseudo_random_custom_range(Seed(0), &range, None);
|
||||
///
|
||||
/// let dec_result: u16 = ct_res.decrypt(&client_key);
|
||||
/// assert!(dec_result < excluded_upper_bound.get() as u16);
|
||||
/// ```
|
||||
pub fn generate_oblivious_pseudo_random_custom_range(
|
||||
seed: Seed,
|
||||
range: &RangeForRandom,
|
||||
max_distance: Option<f64>,
|
||||
) -> Self {
|
||||
let excluded_upper_bound = range.excluded_upper_bound;
|
||||
|
||||
if excluded_upper_bound.is_power_of_two() {
|
||||
let random_bits_count = excluded_upper_bound.ilog2() as u64;
|
||||
|
||||
Self::generate_oblivious_pseudo_random_bounded(seed, random_bits_count)
|
||||
} else {
|
||||
let max_distance = max_distance.unwrap_or_else(|| 2_f64.powi(-128));
|
||||
|
||||
assert!(
|
||||
0_f64 < max_distance && max_distance < 1_f64,
|
||||
"max_distance (={max_distance}) should be in ]0, 1["
|
||||
);
|
||||
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(key) => {
|
||||
let message_modulus = key.message_modulus();
|
||||
|
||||
let num_input_random_bits = num_input_random_bits_for_max_distance(
|
||||
excluded_upper_bound,
|
||||
max_distance,
|
||||
message_modulus,
|
||||
);
|
||||
|
||||
let num_blocks_output = Id::num_blocks(key.message_modulus()) as u64;
|
||||
|
||||
let ct = key
|
||||
.pbs_key()
|
||||
.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
seed,
|
||||
num_input_random_bits,
|
||||
excluded_upper_bound,
|
||||
num_blocks_output,
|
||||
);
|
||||
|
||||
Self::new(ct, key.tag.clone(), ReRandomizationMetadata::default())
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_cuda_key) => {
|
||||
panic!("Gpu does not support this operation yet.")
|
||||
}
|
||||
#[cfg(feature = "hpu")]
|
||||
InternalServerKey::Hpu(_device) => {
|
||||
panic!("Hpu does not support this operation yet.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
/// Returns the amount of memory required to execute generate_oblivious_pseudo_random_bounded
|
||||
///
|
||||
@@ -273,7 +372,7 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
}
|
||||
})
|
||||
}
|
||||
/// Generates an encrypted `num_block` blocks signed integer
|
||||
/// Generates an encrypted signed integer
|
||||
/// taken uniformly in `[0, 2^random_bits_count[` using the given seed.
|
||||
/// The encrypted value is oblivious to the server.
|
||||
/// It can be useful to make server random generation deterministic.
|
||||
@@ -367,10 +466,350 @@ impl<Id: FheIntId> FheInt<Id> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RangeForRandom {
|
||||
excluded_upper_bound: NonZeroU64,
|
||||
}
|
||||
|
||||
impl RangeForRandom {
|
||||
pub fn new_from_excluded_upper_bound(excluded_upper_bound: NonZeroU64) -> Self {
|
||||
Self {
|
||||
excluded_upper_bound,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn num_input_random_bits_for_max_distance(
|
||||
excluded_upper_bound: NonZeroU64,
|
||||
max_distance: f64,
|
||||
message_modulus: MessageModulus,
|
||||
) -> u64 {
|
||||
assert!(message_modulus.0.is_power_of_two());
|
||||
let log_message_modulus = message_modulus.0.ilog2() as u64;
|
||||
|
||||
let mut random_block_count = 1;
|
||||
|
||||
let random_block_count = loop {
|
||||
let random_bit_count = random_block_count * log_message_modulus;
|
||||
|
||||
let distance = distance(excluded_upper_bound.get(), random_bit_count);
|
||||
|
||||
if distance < max_distance {
|
||||
break random_block_count;
|
||||
}
|
||||
|
||||
random_block_count += 1;
|
||||
};
|
||||
|
||||
random_block_count * log_message_modulus
|
||||
}
|
||||
|
||||
fn distance(excluded_upper_bound: u64, random_bit_count: u64) -> f64 {
|
||||
let remainder = mod_pow_2(random_bit_count, excluded_upper_bound);
|
||||
|
||||
remainder as f64 * (excluded_upper_bound - remainder) as f64
|
||||
/ (2_f64.powi(random_bit_count as i32) * excluded_upper_bound as f64)
|
||||
}
|
||||
|
||||
// Computes 2^exponent % modulus
|
||||
fn mod_pow_2(exponent: u64, modulus: u64) -> u64 {
|
||||
assert_ne!(modulus, 0);
|
||||
|
||||
if modulus == 1 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut result: u128 = 1;
|
||||
let mut base: u128 = 2; // We are calculating 2^i
|
||||
|
||||
// We cast exponent to u128 to match the loop, though u64 is fine
|
||||
let mut exp = exponent;
|
||||
let mod_val = modulus as u128;
|
||||
|
||||
while exp > 0 {
|
||||
// If exponent is odd, multiply result with base
|
||||
if exp % 2 == 1 {
|
||||
result = (result * base) % mod_val;
|
||||
}
|
||||
|
||||
// Square the base
|
||||
base = (base * base) % mod_val;
|
||||
|
||||
// Divide exponent by 2
|
||||
exp /= 2;
|
||||
}
|
||||
|
||||
result as u64
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use super::*;
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::test_oprf::{
|
||||
oprf_density_function, p_value_upper_bound_oprf_almost_uniformity_from_values,
|
||||
probability_density_function_from_density,
|
||||
};
|
||||
use crate::prelude::FheDecrypt;
|
||||
use crate::shortint::oprf::test::test_uniformity;
|
||||
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
|
||||
use crate::{generate_keys, set_server_key, ClientKey, ConfigBuilder, FheUint8, Seed};
|
||||
use num_bigint::BigUint;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
// Helper: The "Oracle" implementation using BigInt
|
||||
// This is slow but mathematically guaranteed to be correct.
|
||||
fn oracle_mod_pow_2(exponent: u64, modulus: u64) -> u64 {
|
||||
assert_ne!(modulus, 0);
|
||||
|
||||
if modulus == 1 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let base = BigUint::from(2u32);
|
||||
let exp = BigUint::from(exponent);
|
||||
let modu = BigUint::from(modulus);
|
||||
|
||||
let res = base.modpow(&exp, &modu);
|
||||
res.iter_u64_digits().next().unwrap_or(0)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_cases() {
|
||||
// 2^0 % 10 = 1
|
||||
assert_eq!(mod_pow_2(0, 10), 1, "Failed exponent 0");
|
||||
|
||||
// 2^10 % 1 = 0
|
||||
assert_eq!(mod_pow_2(10, 1), 0, "Failed modulus 1");
|
||||
|
||||
// 2^1 % 10 = 2
|
||||
assert_eq!(mod_pow_2(1, 10), 2, "Failed exponent 1");
|
||||
|
||||
// 2^3 % 5 = 8 % 5 = 3
|
||||
assert_eq!(mod_pow_2(3, 5), 3, "Failed small calc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boundaries_and_overflow() {
|
||||
assert_eq!(mod_pow_2(2, u64::MAX), 4);
|
||||
|
||||
assert_eq!(mod_pow_2(u64::MAX, 3), 2);
|
||||
|
||||
assert_eq!(mod_pow_2(5, 32), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_against_oracle() {
|
||||
let mut rng = thread_rng();
|
||||
for _ in 0..1_000_000 {
|
||||
let exp: u64 = rng.gen();
|
||||
let mod_val: u64 = rng.gen();
|
||||
|
||||
let mod_val = if mod_val == 0 { 1 } else { mod_val };
|
||||
|
||||
let expected = oracle_mod_pow_2(exp, mod_val);
|
||||
let actual = mod_pow_2(exp, mod_val);
|
||||
|
||||
assert_eq!(
|
||||
actual, expected,
|
||||
"Mismatch! 2^{exp} % {mod_val} => Ours: {actual}, Oracle: {expected}",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_distance_with_uniform() {
|
||||
for excluded_upper_bound in 1..20 {
|
||||
for num_input_random_bits in 0..20 {
|
||||
let density = oprf_density_function(excluded_upper_bound, num_input_random_bits);
|
||||
|
||||
let theoretical_pdf = probability_density_function_from_density(&density);
|
||||
|
||||
let p_uniform = 1. / excluded_upper_bound as f64;
|
||||
|
||||
let actual_distance: f64 = 1. / 2.
|
||||
* theoretical_pdf
|
||||
.iter()
|
||||
.map(|p| (*p - p_uniform).abs())
|
||||
.sum::<f64>();
|
||||
|
||||
let theoretical_distance = distance(excluded_upper_bound, num_input_random_bits);
|
||||
|
||||
assert!(
|
||||
(theoretical_distance - actual_distance).abs()
|
||||
<= theoretical_distance / 1_000_000.,
|
||||
"{theoretical_distance} != {actual_distance}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uniformity_scalar_mul_shift() {
|
||||
let max_distance = 2_f64.powi(-20);
|
||||
|
||||
let message_modulus = MessageModulus(4);
|
||||
|
||||
let excluded_upper_bound = 3;
|
||||
|
||||
let num_input_random_bits = num_input_random_bits_for_max_distance(
|
||||
NonZeroU64::new(excluded_upper_bound).unwrap(),
|
||||
max_distance,
|
||||
message_modulus,
|
||||
);
|
||||
|
||||
let sample_count: usize = 10_000_000;
|
||||
|
||||
let p_value_limit: f64 = 0.001;
|
||||
|
||||
// The distribution is not exactly uniform
|
||||
// This check ensures than with the given low max_distance,
|
||||
// the distribution is indistinguishable from the uniform with at the given sample count
|
||||
test_uniformity(sample_count, p_value_limit, excluded_upper_bound, |_seed| {
|
||||
oprf_clear_equivalent(excluded_upper_bound, num_input_random_bits)
|
||||
});
|
||||
}
|
||||
|
||||
fn oprf_clear_equivalent(excluded_upper_bound: u64, num_input_random_bits: u64) -> u64 {
|
||||
let random_input_upper_bound = 1 << num_input_random_bits;
|
||||
|
||||
let random_input = thread_rng().gen_range(0..random_input_upper_bound);
|
||||
|
||||
(random_input * excluded_upper_bound) >> num_input_random_bits
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_uniformity_generate_oblivious_pseudo_random_custom_range() {
|
||||
let base_sample_count: usize = 10_000;
|
||||
|
||||
let p_value_limit: f64 = 0.001;
|
||||
|
||||
let params = PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
|
||||
let config = ConfigBuilder::with_custom_parameters(params).build();
|
||||
|
||||
let (cks, sks) = generate_keys(config);
|
||||
rayon::broadcast(|_| set_server_key(sks.clone()));
|
||||
|
||||
let message_modulus = params.message_modulus;
|
||||
|
||||
// [0.7, 0.1] for `max_distance` chosen to have `num_input_random_bits` be [2, 4]
|
||||
// for any of the listed `excluded_upper_bound`
|
||||
for (expected_num_input_random_bits, max_distance, excluded_upper_bounds) in
|
||||
[(2, 0.7, [3, 5, 6, 7]), (4, 0.1, [3, 5, 6, 7])]
|
||||
{
|
||||
for excluded_upper_bound in excluded_upper_bounds {
|
||||
let sample_count = base_sample_count * excluded_upper_bound as usize;
|
||||
|
||||
let excluded_upper_bound = NonZeroU64::new(excluded_upper_bound).unwrap();
|
||||
|
||||
let num_input_random_bits = num_input_random_bits_for_max_distance(
|
||||
excluded_upper_bound,
|
||||
max_distance,
|
||||
message_modulus,
|
||||
);
|
||||
|
||||
assert_eq!(num_input_random_bits, expected_num_input_random_bits);
|
||||
|
||||
test_uniformity_generate_oblivious_pseudo_random_custom_range2(
|
||||
sample_count,
|
||||
p_value_limit,
|
||||
message_modulus,
|
||||
&cks,
|
||||
excluded_upper_bound,
|
||||
max_distance,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn test_uniformity_generate_oblivious_pseudo_random_custom_range2(
|
||||
sample_count: usize,
|
||||
p_value_limit: f64,
|
||||
message_modulus: MessageModulus,
|
||||
cks: &ClientKey,
|
||||
excluded_upper_bound: NonZeroU64,
|
||||
max_distance: f64,
|
||||
) {
|
||||
let num_input_random_bits = num_input_random_bits_for_max_distance(
|
||||
excluded_upper_bound,
|
||||
max_distance,
|
||||
message_modulus,
|
||||
);
|
||||
|
||||
let range = RangeForRandom::new_from_excluded_upper_bound(excluded_upper_bound);
|
||||
|
||||
let real_values: Vec<u64> = (0..sample_count)
|
||||
.into_par_iter()
|
||||
.map(|_| {
|
||||
let img = FheUint8::generate_oblivious_pseudo_random_custom_range(
|
||||
Seed(rand::thread_rng().gen::<u128>()),
|
||||
&range,
|
||||
Some(max_distance),
|
||||
);
|
||||
|
||||
img.decrypt(cks)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let excluded_upper_bound = excluded_upper_bound.get();
|
||||
|
||||
let uniform_values: Vec<u64> = (0..sample_count)
|
||||
.into_par_iter()
|
||||
.map(|_| thread_rng().gen_range(0..excluded_upper_bound))
|
||||
.collect();
|
||||
|
||||
let clear_oprf_value_lower_num_input_random_bits = (0..sample_count)
|
||||
.into_par_iter()
|
||||
.map(|_| oprf_clear_equivalent(excluded_upper_bound, num_input_random_bits - 1))
|
||||
.collect();
|
||||
|
||||
let clear_oprf_value_same_num_input_random_bits = (0..sample_count)
|
||||
.into_par_iter()
|
||||
.map(|_| oprf_clear_equivalent(excluded_upper_bound, num_input_random_bits))
|
||||
.collect();
|
||||
|
||||
let clear_oprf_value_higher_num_input_random_bits = (0..sample_count)
|
||||
.into_par_iter()
|
||||
.map(|_| oprf_clear_equivalent(excluded_upper_bound, num_input_random_bits + 1))
|
||||
.collect();
|
||||
|
||||
for (values, should_have_low_p_value) in [
|
||||
(&real_values, false),
|
||||
// to test that the same distribution passes
|
||||
(&clear_oprf_value_same_num_input_random_bits, false),
|
||||
// to test that other distribution don't pass
|
||||
// (makes sure the test is statistically powerful)
|
||||
(&uniform_values, true),
|
||||
(&clear_oprf_value_lower_num_input_random_bits, true),
|
||||
(&clear_oprf_value_higher_num_input_random_bits, true),
|
||||
] {
|
||||
let p_value_upper_bound = p_value_upper_bound_oprf_almost_uniformity_from_values(
|
||||
values,
|
||||
num_input_random_bits,
|
||||
excluded_upper_bound,
|
||||
);
|
||||
|
||||
println!("p_value_upper_bound: {p_value_upper_bound}");
|
||||
|
||||
if should_have_low_p_value {
|
||||
assert!(
|
||||
p_value_upper_bound < p_value_limit,
|
||||
"p_value_upper_bound (={p_value_upper_bound}) expected to be smaller than {p_value_limit}"
|
||||
);
|
||||
} else {
|
||||
assert!(
|
||||
p_value_limit < p_value_upper_bound ,
|
||||
"p_value_upper_bound (={p_value_upper_bound}) expected to be bigger than {p_value_limit}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "gpu")]
|
||||
#[allow(unused_imports)]
|
||||
mod test {
|
||||
mod test_gpu {
|
||||
use crate::prelude::*;
|
||||
use crate::{
|
||||
generate_keys, set_server_key, ConfigBuilder, FheInt128, FheUint32, FheUint64, GpuIndex,
|
||||
|
||||
@@ -386,6 +386,12 @@ fn test_if_then_else() {
|
||||
super::test_case_if_then_else(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_if_then_zero() {
|
||||
let client_key = setup_default_cpu();
|
||||
super::test_case_if_then_zero(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flip() {
|
||||
let client_key = setup_default_cpu();
|
||||
|
||||
@@ -89,6 +89,12 @@ fn test_case_if_then_else_hpu() {
|
||||
super::test_case_if_then_else(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_if_then_zero_hpu() {
|
||||
let client_key = setup_default_hpu();
|
||||
super::test_case_if_then_zero(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_case_flip_hpu() {
|
||||
let client_key = setup_default_hpu();
|
||||
|
||||
@@ -568,6 +568,28 @@ fn test_case_if_then_else(client_key: &ClientKey) {
|
||||
);
|
||||
}
|
||||
|
||||
fn test_case_if_then_zero(client_key: &ClientKey) {
|
||||
let clear_a = 42u8;
|
||||
let clear_b = 128u8;
|
||||
|
||||
let a = FheUint8::encrypt(clear_a, client_key);
|
||||
let b = FheUint8::encrypt(clear_b, client_key);
|
||||
|
||||
let result = a.le(&b).if_then_zero(&a);
|
||||
let decrypted_result: u8 = result.decrypt(client_key);
|
||||
assert_eq!(
|
||||
decrypted_result,
|
||||
if clear_a <= clear_b { clear_a } else { 0 }
|
||||
);
|
||||
|
||||
let result = a.ge(&b).if_then_zero(&a);
|
||||
let decrypted_result: u8 = result.decrypt(client_key);
|
||||
assert_eq!(
|
||||
decrypted_result,
|
||||
if clear_a >= clear_b { clear_a } else { 0 }
|
||||
);
|
||||
}
|
||||
|
||||
fn test_case_flip(client_key: &ClientKey) {
|
||||
let clear_a = rand::random::<u32>();
|
||||
let clear_b = rand::random::<u32>();
|
||||
|
||||
@@ -48,6 +48,7 @@ macro_rules! export_concrete_array_types {
|
||||
}
|
||||
|
||||
pub use crate::core_crypto::commons::math::random::{Seed, XofSeed};
|
||||
pub use crate::high_level_api::integers::oprf::RangeForRandom;
|
||||
pub use crate::integer::server_key::MatchValues;
|
||||
use crate::{error, Error, Versionize};
|
||||
use backward_compatibility::compressed_ciphertext_list::SquashedNoiseCiphertextStateVersions;
|
||||
|
||||
@@ -9,8 +9,9 @@
|
||||
pub use crate::high_level_api::traits::{
|
||||
BitSlice, CiphertextList, DivRem, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
|
||||
FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, FheWait, Flip, IfThenElse,
|
||||
OverflowingAdd, OverflowingMul, OverflowingNeg, OverflowingSub, ReRandomize, RotateLeft,
|
||||
RotateLeftAssign, RotateRight, RotateRightAssign, ScalarIfThenElse, SquashNoise, Tagged,
|
||||
IfThenZero, OverflowingAdd, OverflowingMul, OverflowingNeg, OverflowingSub, ReRandomize,
|
||||
RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, ScalarIfThenElse, SquashNoise,
|
||||
Tagged,
|
||||
};
|
||||
#[cfg(feature = "hpu")]
|
||||
pub use crate::high_level_api::traits::{FheHpu, HpuHandle};
|
||||
|
||||
@@ -149,6 +149,10 @@ pub trait IfThenElse<Ciphertext> {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait IfThenZero<Ciphertext> {
|
||||
fn if_then_zero(&self, ct_then: &Ciphertext) -> Ciphertext;
|
||||
}
|
||||
|
||||
pub trait ScalarIfThenElse<Lhs, Rhs> {
|
||||
type Output;
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ use super::{RadixCiphertext, ServerKey, SignedRadixCiphertext};
|
||||
use crate::core_crypto::commons::generators::DeterministicSeeder;
|
||||
use crate::core_crypto::prelude::DefaultRandomGenerator;
|
||||
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
|
||||
use std::num::NonZeroU64;
|
||||
|
||||
pub use tfhe_csprng::seeders::{Seed, Seeder};
|
||||
|
||||
@@ -163,6 +164,7 @@ impl ServerKey {
|
||||
/// as `num_input_random_bits`
|
||||
///
|
||||
/// ```rust
|
||||
/// use std::num::NonZeroU64;
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128;
|
||||
/// use tfhe::Seed;
|
||||
@@ -173,7 +175,7 @@ impl ServerKey {
|
||||
/// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128, size);
|
||||
///
|
||||
/// let num_input_random_bits = 5;
|
||||
/// let excluded_upper_bound = 3;
|
||||
/// let excluded_upper_bound = NonZeroU64::new(3).unwrap();
|
||||
/// let num_blocks_output = 8;
|
||||
///
|
||||
/// let ct_res = sks.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
@@ -186,15 +188,17 @@ impl ServerKey {
|
||||
/// // Decrypt:
|
||||
/// let dec_result: u64 = cks.decrypt(&ct_res);
|
||||
///
|
||||
/// assert!(dec_result < excluded_upper_bound);
|
||||
/// assert!(dec_result < excluded_upper_bound.get());
|
||||
/// ```
|
||||
pub fn par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
&self,
|
||||
seed: Seed,
|
||||
num_input_random_bits: u64,
|
||||
excluded_upper_bound: u64,
|
||||
excluded_upper_bound: NonZeroU64,
|
||||
num_blocks_output: u64,
|
||||
) -> RadixCiphertext {
|
||||
let excluded_upper_bound = excluded_upper_bound.get();
|
||||
|
||||
assert!(self.message_modulus().0.is_power_of_two());
|
||||
let message_bits_count = self.message_modulus().0.ilog2() as u64;
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientK
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::{ClientKey, CompressedServerKey, MatchValues, Seed, Tag};
|
||||
use std::cmp::{max, min};
|
||||
use std::num::NonZeroU64;
|
||||
use std::sync::Arc;
|
||||
|
||||
create_parameterized_test!(random_op_sequence {
|
||||
@@ -498,7 +499,18 @@ where
|
||||
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_integer_bounded,
|
||||
);
|
||||
let oprf_custom_range_executor = OpSequenceCpuFunctionExecutor::new(
|
||||
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
|
||||
&|sk: &ServerKey,
|
||||
seed: Seed,
|
||||
num_input_random_bits: u64,
|
||||
excluded_upper_bound: u64,
|
||||
num_blocks_output: u64| {
|
||||
sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
seed,
|
||||
num_input_random_bits,
|
||||
NonZeroU64::new(excluded_upper_bound).unwrap_or(NonZeroU64::new(1).unwrap()),
|
||||
num_blocks_output,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
let mut oprf_ops: Vec<(OprfExecutor, String)> = vec![(
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}
|
||||
use crate::shortint::parameters::*;
|
||||
use statrs::distribution::ContinuousCDF;
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZeroU64;
|
||||
use std::sync::Arc;
|
||||
use tfhe_csprng::seeders::Seed;
|
||||
|
||||
@@ -36,9 +37,19 @@ fn oprf_any_range_unsigned<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(
|
||||
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
|
||||
);
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&|sk: &ServerKey,
|
||||
seed: Seed,
|
||||
num_input_random_bits: u64,
|
||||
excluded_upper_bound: u64,
|
||||
num_blocks_output: u64| {
|
||||
sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
seed,
|
||||
num_input_random_bits,
|
||||
NonZeroU64::new(excluded_upper_bound).unwrap(),
|
||||
num_blocks_output,
|
||||
)
|
||||
});
|
||||
oprf_any_range_test(param, executor);
|
||||
}
|
||||
|
||||
@@ -46,9 +57,19 @@ fn oprf_almost_uniformity_unsigned<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(
|
||||
&ServerKey::par_generate_oblivious_pseudo_random_unsigned_custom_range,
|
||||
);
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&|sk: &ServerKey,
|
||||
seed: Seed,
|
||||
num_input_random_bits: u64,
|
||||
excluded_upper_bound: u64,
|
||||
num_blocks_output: u64| {
|
||||
sk.par_generate_oblivious_pseudo_random_unsigned_custom_range(
|
||||
seed,
|
||||
num_input_random_bits,
|
||||
NonZeroU64::new(excluded_upper_bound).unwrap(),
|
||||
num_blocks_output,
|
||||
)
|
||||
});
|
||||
oprf_almost_uniformity_test(param, executor);
|
||||
}
|
||||
|
||||
@@ -89,7 +110,7 @@ where
|
||||
);
|
||||
}
|
||||
|
||||
pub fn oprf_uniformity_test<P, E>(param: P, mut executor: E)
|
||||
pub(crate) fn oprf_uniformity_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(Seed, u64, u64), RadixCiphertext>,
|
||||
@@ -113,7 +134,7 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
pub fn oprf_any_range_test<P, E>(param: P, mut executor: E)
|
||||
pub(crate) fn oprf_any_range_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(Seed, u64, u64, u64), RadixCiphertext>,
|
||||
@@ -149,7 +170,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn oprf_almost_uniformity_test<P, E>(param: P, mut executor: E)
|
||||
pub(crate) fn oprf_almost_uniformity_test<P, E>(param: P, mut executor: E)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
E: for<'a> FunctionExecutor<(Seed, u64, u64, u64), RadixCiphertext>,
|
||||
@@ -165,40 +186,70 @@ where
|
||||
let num_input_random_bits: u64 = 4;
|
||||
let num_blocks_output = 64;
|
||||
let excluded_upper_bound = 10;
|
||||
let random_input_upper_bound = 1 << num_input_random_bits;
|
||||
|
||||
let mut density = vec![0_usize; excluded_upper_bound as usize];
|
||||
for i in 0..random_input_upper_bound {
|
||||
let index = ((i * excluded_upper_bound) as f64 / random_input_upper_bound as f64) as usize;
|
||||
density[index] += 1;
|
||||
}
|
||||
|
||||
let theoretical_pdf: Vec<f64> = density
|
||||
.iter()
|
||||
.map(|count| *count as f64 / random_input_upper_bound as f64)
|
||||
.collect();
|
||||
|
||||
let values: Vec<u64> = (0..sample_count)
|
||||
.map(|seed| {
|
||||
let img = executor.execute((
|
||||
Seed(seed as u128),
|
||||
num_input_random_bits,
|
||||
excluded_upper_bound as u64,
|
||||
excluded_upper_bound,
|
||||
num_blocks_output,
|
||||
));
|
||||
cks.decrypt(&img)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let p_value_upper_bound = p_value_upper_bound_oprf_almost_uniformity_from_values(
|
||||
&values,
|
||||
num_input_random_bits,
|
||||
excluded_upper_bound,
|
||||
);
|
||||
|
||||
assert!(p_value_limit < p_value_upper_bound);
|
||||
}
|
||||
|
||||
pub(crate) fn p_value_upper_bound_oprf_almost_uniformity_from_values(
|
||||
values: &[u64],
|
||||
num_input_random_bits: u64,
|
||||
excluded_upper_bound: u64,
|
||||
) -> f64 {
|
||||
let density = oprf_density_function(excluded_upper_bound, num_input_random_bits);
|
||||
|
||||
let theoretical_pdf = probability_density_function_from_density(&density);
|
||||
|
||||
let mut bins = vec![0_u64; excluded_upper_bound as usize];
|
||||
for value in values {
|
||||
for value in values.iter().copied() {
|
||||
bins[value as usize] += 1;
|
||||
}
|
||||
|
||||
let cumulative_bins = cumulate(&bins);
|
||||
let theoretical_cdf = cumulate(&theoretical_pdf);
|
||||
let sup_diff = sup_diff(&cumulative_bins, &theoretical_cdf);
|
||||
let p_value_upper_bound = dkw_alpha_from_epsilon(sample_count as f64, sup_diff);
|
||||
|
||||
assert!(p_value_limit < p_value_upper_bound);
|
||||
dkw_alpha_from_epsilon(values.len() as f64, sup_diff)
|
||||
}
|
||||
|
||||
pub(crate) fn oprf_density_function(
|
||||
excluded_upper_bound: u64,
|
||||
num_input_random_bits: u64,
|
||||
) -> Vec<usize> {
|
||||
let random_input_upper_bound = 1 << num_input_random_bits;
|
||||
|
||||
let mut density = vec![0_usize; excluded_upper_bound as usize];
|
||||
|
||||
for i in 0..random_input_upper_bound {
|
||||
let output = ((i * excluded_upper_bound) >> num_input_random_bits) as usize;
|
||||
|
||||
density[output] += 1;
|
||||
}
|
||||
density
|
||||
}
|
||||
|
||||
pub(crate) fn probability_density_function_from_density(density: &[usize]) -> Vec<f64> {
|
||||
let total_count: usize = density.iter().copied().sum();
|
||||
|
||||
density
|
||||
.iter()
|
||||
.map(|count| *count as f64 / total_count as f64)
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -475,8 +475,12 @@ pub(crate) mod test {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_uniformity<F>(sample_count: usize, p_value_limit: f64, distinct_values: u64, f: F)
|
||||
where
|
||||
pub(crate) fn test_uniformity<F>(
|
||||
sample_count: usize,
|
||||
p_value_limit: f64,
|
||||
distinct_values: u64,
|
||||
f: F,
|
||||
) where
|
||||
F: Sync + Fn(usize) -> u64,
|
||||
{
|
||||
let p_value = uniformity_p_value(f, sample_count, distinct_values);
|
||||
@@ -487,7 +491,7 @@ pub(crate) mod test {
|
||||
);
|
||||
}
|
||||
|
||||
fn uniformity_p_value<F>(f: F, sample_count: usize, distinct_values: u64) -> f64
|
||||
pub(crate) fn uniformity_p_value<F>(f: F, sample_count: usize, distinct_values: u64) -> f64
|
||||
where
|
||||
F: Sync + Fn(usize) -> u64,
|
||||
{
|
||||
@@ -495,8 +499,11 @@ pub(crate) mod test {
|
||||
|
||||
let mut values_count = HashMap::new();
|
||||
|
||||
for i in &values {
|
||||
assert!(*i < distinct_values, "i {} dv{}", *i, distinct_values);
|
||||
for i in values.iter().copied() {
|
||||
assert!(
|
||||
i < distinct_values,
|
||||
"i (={i}) is supposed to be smaller than distinct_values (={distinct_values})",
|
||||
);
|
||||
|
||||
*values_count.entry(i).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 252 KiB After Width: | Height: | Size: 3.2 KiB |
@@ -12,14 +12,22 @@ function setButtonsDisabledState(buttonIds, state) {
|
||||
|
||||
async function setup() {
|
||||
let supportsThreads = await threads();
|
||||
if (!supportsThreads) {
|
||||
console.error("This browser does not support threads");
|
||||
return;
|
||||
// This variable is set to true if we are using the `serve.multithreaded.json` config
|
||||
if (crossOriginIsolated) {
|
||||
if (supportsThreads) {
|
||||
console.info("Running in multithreaded mode");
|
||||
} else {
|
||||
console.error("This browser does not support threads");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
console.warn("Running in unsafe coop mode");
|
||||
}
|
||||
|
||||
const worker = new Worker(new URL("worker.js", import.meta.url), {
|
||||
type: "module",
|
||||
});
|
||||
|
||||
const demos = await Comlink.wrap(worker).demos;
|
||||
|
||||
const demoNames = [
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"build": "cp -r ../../tfhe/pkg ./ && webpack build ./index.js --mode production -o dist --output-filename index.js && cp index.html dist/ && cp favicon.ico dist/",
|
||||
"server": "serve --config ../serve.json dist/",
|
||||
"server": "npm run server:multithreaded",
|
||||
"server:multithreaded": "serve --config ../serve.multithreaded.json dist/",
|
||||
"server:unsafe-coop": "serve --config ../serve.unsafe-coop.json dist/",
|
||||
"format": "prettier . --write",
|
||||
"check-format": "prettier . --check"
|
||||
},
|
||||
|
||||
11
tfhe/web_wasm_parallel_tests/serve.unsafe-coop.json
Normal file
11
tfhe/web_wasm_parallel_tests/serve.unsafe-coop.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"headers": [
|
||||
{
|
||||
"source": "**/*.@(js|html)",
|
||||
"headers": [
|
||||
{ "key": "Cross-Origin-Embedder-Policy", "value": "unsafe-none" },
|
||||
{ "key": "Cross-Origin-Opener-Policy", "value": "unsafe-none" }
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import * as Comlink from "comlink";
|
||||
import { threads } from "wasm-feature-detect";
|
||||
import init, {
|
||||
initThreadPool,
|
||||
init_panic_hook,
|
||||
@@ -726,8 +727,15 @@ async function compactPublicKeyZeroKnowledgeBench() {
|
||||
serialized_size = list.safe_serialize(BigInt(10000000)).length;
|
||||
}
|
||||
const mean = timing / bench_loops;
|
||||
|
||||
let base_bench_str = "compact_fhe_uint_proven_encryption_";
|
||||
let supportsThreads = await threads();
|
||||
if (!supportsThreads) {
|
||||
base_bench_str += "unsafe_coop_";
|
||||
}
|
||||
|
||||
const common_bench_str =
|
||||
"compact_fhe_uint_proven_encryption_" +
|
||||
base_bench_str +
|
||||
params.zk_scheme +
|
||||
"_" +
|
||||
bits_to_encrypt +
|
||||
@@ -753,7 +761,10 @@ async function compactPublicKeyZeroKnowledgeBench() {
|
||||
|
||||
async function main() {
|
||||
await init();
|
||||
await initThreadPool(navigator.hardwareConcurrency);
|
||||
let supportsThreads = await threads();
|
||||
if (supportsThreads) {
|
||||
await initThreadPool(navigator.hardwareConcurrency);
|
||||
}
|
||||
await init_panic_hook();
|
||||
|
||||
return Comlink.proxy({
|
||||
|
||||
Reference in New Issue
Block a user