mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-14 00:58:13 -05:00
Compare commits
17 Commits
create-pul
...
am/wip/fft
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87c8c655d5 | ||
|
|
d9723d4ebf | ||
|
|
8055604745 | ||
|
|
a5f2b642ae | ||
|
|
bc510269d7 | ||
|
|
27481fdfcd | ||
|
|
0877bd0353 | ||
|
|
995de8ec82 | ||
|
|
990f32fb99 | ||
|
|
2607e12b13 | ||
|
|
f93cb4ccb4 | ||
|
|
fd9f033709 | ||
|
|
a680570966 | ||
|
|
d7a77359bc | ||
|
|
12b91905e2 | ||
|
|
04b2409ba9 | ||
|
|
ade1d63df0 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -31,3 +31,7 @@ web-test-runner/
|
||||
|
||||
# Dir used for backward compatibility test data
|
||||
tfhe/tfhe-backward-compat-data/
|
||||
|
||||
# Sampling tool stuff
|
||||
/venv/
|
||||
**/*.algo_sample_acquistion
|
||||
|
||||
@@ -9,13 +9,9 @@ members = [
|
||||
"backends/tfhe-cuda-backend",
|
||||
"utils/tfhe-versionable",
|
||||
"utils/tfhe-versionable-derive",
|
||||
"tfhe-rs-cost-model"
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"tfhe/backward_compatibility_tests",
|
||||
"utils/cargo-tfhe-lints-inner",
|
||||
"utils/cargo-tfhe-lints"
|
||||
]
|
||||
exclude = ["tfhe/backward_compatibility_tests"]
|
||||
|
||||
[profile.bench]
|
||||
lto = "fat"
|
||||
|
||||
34
Makefile
34
Makefile
@@ -401,7 +401,7 @@ clippy_versionable: install_rs_check_toolchain
|
||||
.PHONY: clippy_all # Run all clippy targets
|
||||
clippy_all: clippy_rustdoc clippy clippy_boolean clippy_shortint clippy_integer clippy_all_targets \
|
||||
clippy_c_api clippy_js_wasm_api clippy_tasks clippy_core clippy_concrete_csprng clippy_zk_pok clippy_trivium \
|
||||
clippy_versionable
|
||||
clippy_versionable clippy_noise_measurement
|
||||
|
||||
.PHONY: clippy_fast # Run main clippy targets
|
||||
clippy_fast: clippy_rustdoc clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_tasks \
|
||||
@@ -1244,6 +1244,38 @@ sha256_bool: install_rs_check_toolchain
|
||||
--example sha256_bool \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean
|
||||
|
||||
.PHONY: external_product_noise_measurement # Run scripts to run noise measurement for external_product
|
||||
external_product_noise_measurement: setup_venv_noise_measurement install_rs_check_toolchain
|
||||
source venv/bin/activate && \
|
||||
cd tfhe-rs-cost-model/src/ && \
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain $(CARGO_RS_CHECK_TOOLCHAIN) \
|
||||
--chunks "$$(nproc)" -- \
|
||||
--algorithm multi-bit-ext-prod \
|
||||
--multi-bit-grouping-factor 2
|
||||
|
||||
|
||||
.PHONY: external_product_noise_measurement_classic # Run scripts to run noise measurement for external_product
|
||||
external_product_noise_measurement_classic: setup_venv_noise_measurement install_rs_check_toolchain
|
||||
source venv/bin/activate && \
|
||||
cd tfhe-rs-cost-model/src/ && \
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain $(CARGO_RS_CHECK_TOOLCHAIN) \
|
||||
--chunks "$$(nproc)" -- \
|
||||
--algorithm ext-prod
|
||||
|
||||
.PHONY: clippy_noise_measurement # Run clippy lints on noise measurement tool
|
||||
clippy_noise_measurement: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
|
||||
-p tfhe-rs-cost-model -- --no-deps -D warnings
|
||||
|
||||
.PHONY: setup_venv_noise_measurement
|
||||
setup_venv_noise_measurement:
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate && \
|
||||
pip install -U pip wheel setuptools && \
|
||||
pip install -r tfhe-rs-cost-model/src/requirements.txt
|
||||
|
||||
.PHONY: pcc # pcc stands for pre commit checks (except GPU)
|
||||
pcc: no_tfhe_typo no_dbg_log check_fmt check_typos lint_doc check_md_docs_are_tested check_intra_md_links \
|
||||
clippy_all tfhe_lints check_compile_tests
|
||||
|
||||
57
exps.sh
Executable file
57
exps.sh
Executable file
@@ -0,0 +1,57 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
toolchain=$(cat toolchain.txt)
|
||||
source venv/bin/activate
|
||||
cd tfhe-rs-cost-model/src/
|
||||
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain "${toolchain}" \
|
||||
--chunks "$(nproc)" \
|
||||
--dir ext_prod_no_fft -- \
|
||||
--algorithm ext-prod \
|
||||
--sample-size 100
|
||||
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain "${toolchain}" \
|
||||
--chunks "$(nproc)" \
|
||||
--dir multi_bit_gf_2_ext_prod_no_fft -- \
|
||||
--algorithm multi-bit-ext-prod \
|
||||
--multi-bit-grouping-factor 2 \
|
||||
--sample-size 100
|
||||
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain "${toolchain}" \
|
||||
--chunks "$(nproc)" \
|
||||
--dir multi_bit_gf3_ext_prod_no_fft -- \
|
||||
--algorithm multi-bit-ext-prod \
|
||||
--multi-bit-grouping-factor 3 \
|
||||
--sample-size 100
|
||||
|
||||
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain "${toolchain}" \
|
||||
--chunks "$(nproc)" \
|
||||
--dir ext_prod_fft -- \
|
||||
--algorithm ext-prod \
|
||||
--sample-size 100 \
|
||||
--use-fft
|
||||
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain "${toolchain}" \
|
||||
--chunks "$(nproc)" \
|
||||
--dir multi_bit_gf_2_ext_prod_fft -- \
|
||||
--algorithm multi-bit-ext-prod \
|
||||
--multi-bit-grouping-factor 2 \
|
||||
--sample-size 100 \
|
||||
--use-fft
|
||||
|
||||
python3 external_product_correction.py \
|
||||
--rust-toolchain "${toolchain}" \
|
||||
--chunks "$(nproc)" \
|
||||
--dir multi_bit_gf3_ext_prod_fft -- \
|
||||
--algorithm multi-bit-ext-prod \
|
||||
--multi-bit-grouping-factor 3 \
|
||||
--sample-size 100 \
|
||||
--use-fft
|
||||
26
tfhe-rs-cost-model/Cargo.toml
Normal file
26
tfhe-rs-cost-model/Cargo.toml
Normal file
@@ -0,0 +1,26 @@
|
||||
[package]
|
||||
name = "tfhe-rs-cost-model"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
aligned-vec = { version = "0.5", features = ["serde"] }
|
||||
clap = { version = "3.1", features = ["derive"] }
|
||||
itertools = "0.8.0"
|
||||
indicatif = "0.16.2"
|
||||
rand = "0.6.5"
|
||||
rand_chacha = "0.1.1"
|
||||
rayon = "1.9.0"
|
||||
|
||||
[target.'cfg(target_arch = "x86_64")'.dependencies.tfhe]
|
||||
path = "../tfhe"
|
||||
features = ["x86_64-unix"]
|
||||
|
||||
[target.'cfg(target_arch = "aarch64")'.dependencies.tfhe]
|
||||
path = "../tfhe"
|
||||
features = ["aarch64-unix"]
|
||||
|
||||
[features]
|
||||
nightly-avx512 = ["tfhe/nightly-avx512"]
|
||||
47
tfhe-rs-cost-model/README.md
Normal file
47
tfhe-rs-cost-model/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
|
||||
# Noise Sampling & Assurance Tool
|
||||
|
||||
<<<<<<< HEAD
|
||||
Before a `Makefile` is done (**TODO**?), we run the tool (only analysis of previously gathered samples) in `./src` via
|
||||
=======
|
||||
Before a `Makefile` is done (TODO), we run the tool (analysis mode only) in `./src` via
|
||||
>>>>>>> d62eed0c (readme)
|
||||
```bash
|
||||
./bin/python3 external_product_correction.py --chunks 192 --rust-toolchain nightly-2024-08-19 --analysis-only --dir multi-bit-sampling/gf2/ -- --algorithm multi-bit-ext-prod --multi-bit-grouping-factor 2
|
||||
```
|
||||
where Python has its local environment and additional lib's installed locally, some of the following commands may help:
|
||||
```bash
|
||||
python3 -m venv .
|
||||
./bin/pip install scipy
|
||||
./bin/pip install scikit-learn
|
||||
```
|
||||
Also, the current Rust toolchain can be found in `/toolchain.txt`
|
||||
|
||||
|
||||
## "Advanced"
|
||||
|
||||
The command that is called can be called directly as
|
||||
```bash
|
||||
$ RUSTFLAGS="-C target-cpu=native" cargo run --release -- --help
|
||||
```
|
||||
which writes down the list of parameters that can also be given to the analyzing tool after `--`.
|
||||
|
||||
|
||||
## How It Works
|
||||
|
||||
???
|
||||
|
||||
- all is orchestrated by `external_product_correction.py`
|
||||
- Rust code is compiled & executed ... this generates vector(s) of errors
|
||||
- samples are analyzed and curves are fitted
|
||||
|
||||
|
||||
## Nice-To-Have
|
||||
|
||||
- `Makefile`? part of CI workflow? test report?
|
||||
- for now, improve output: meaning of printed values, ...
|
||||
- rework as an assurance tool for all op's (not only for external product)
|
||||
- make a macro that generates these tests?
|
||||
- put this macro "near" each tested operation (i.e., greatly simplify adding new op's)
|
||||
- use noise formulas extracted from the latest optimizer (was there a PR on that?)
|
||||
|
||||
3
tfhe-rs-cost-model/src/analyze-cmd.sh
Normal file
3
tfhe-rs-cost-model/src/analyze-cmd.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
./bin/python3 external_product_correction.py --chunks 192 --rust-toolchain nightly-2024-08-19 --analysis-only --dir multi-bit-sampling/gf2/ -- --algorithm multi-bit-ext-prod --multi-bit-grouping-factor 2 > log-real-to-pred-2.dat
|
||||
562
tfhe-rs-cost-model/src/external_product_correction.py
Normal file
562
tfhe-rs-cost-model/src/external_product_correction.py
Normal file
@@ -0,0 +1,562 @@
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import csv
|
||||
import dataclasses
|
||||
import datetime
|
||||
import json
|
||||
import pathlib
|
||||
import subprocess
|
||||
import functools
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
np.set_printoptions(threshold=np.inf)
|
||||
np.set_printoptions(linewidth=np.inf)
|
||||
|
||||
from scipy.optimize import curve_fit
|
||||
from sklearn.ensemble import IsolationForest
|
||||
|
||||
# Command used to run Rust program responsible to perform sampling on external product.
|
||||
BASE_COMMAND = 'RUSTFLAGS="-C target-cpu=native" cargo {} {} --release --features=nightly-avx512'
|
||||
# Leave toolchain empty at first
|
||||
BUILD_COMMAND = BASE_COMMAND.format("{}", "build")
|
||||
RUN_COMMAND = BASE_COMMAND.format("{}", "run") + " -- --tot {} --id {} {}"
|
||||
|
||||
SECS_PER_HOUR = 3600
|
||||
SECS_PER_MINUTES = 60
|
||||
|
||||
parser = argparse.ArgumentParser(description="Compute coefficient correction for external product")
|
||||
parser.add_argument(
|
||||
"--chunks",
|
||||
type=int,
|
||||
help="Total number of chunks the parameter grid is divided into."
|
||||
"Each chunk is run in a sub-process, to speed up processing make sure to"
|
||||
" have at least this number of CPU cores to allocate for this task",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rust-toolchain",
|
||||
type=str,
|
||||
help="The rust toolchain to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
"-o",
|
||||
type=str,
|
||||
dest="output_filename",
|
||||
default="correction_coefficients.json",
|
||||
help="Output file containing correction coefficients, formatted as JSON"
|
||||
" (default: correction_coefficients.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--analysis-only",
|
||||
"-A",
|
||||
action="store_true",
|
||||
dest="analysis_only",
|
||||
help="If this flag is set, no sampling will be done, it will only try to"
|
||||
" analyze existing results",
|
||||
)
|
||||
parser.add_argument("--dir", type=str, default=".", help="Dir where acquisition files are stored.")
|
||||
parser.add_argument(
|
||||
"--worst-case-analysis",
|
||||
"-W",
|
||||
dest="worst_case_analysis",
|
||||
action="store_true",
|
||||
help="Perform a 1000 analysis pruning different outliers, "
|
||||
"selecting the wort-case parameter for the fft noise fitting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"sampling_args",
|
||||
nargs=argparse.REMAINDER,
|
||||
help="Arguments directly passed to sampling program, to get an exhaustive list"
|
||||
" of options run command: `cargo run -- --help`",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(init=False)
|
||||
class SamplingLine:
|
||||
"""
|
||||
Extract output variance parameter from a sampling result string.
|
||||
|
||||
:param line: :class:`str` formatted as ``polynomial_size, glwe_dimension,
|
||||
decomposition_level_count, decomposition_base_log, input_variance, output_variance,
|
||||
predicted_variance``
|
||||
"""
|
||||
|
||||
parameters: list
|
||||
input_variance: float
|
||||
output_variance_exp: float
|
||||
output_variance_th: float
|
||||
|
||||
def __init__(self, line: dict):
|
||||
self.input_variance = float(line["input_variance"])
|
||||
self.output_variance_exp = float(line["output_variance"])
|
||||
self.single_ggsw_variance = float(line["single_ggsw_variance"])
|
||||
self.output_variance_th = float(line["predicted_variance"])
|
||||
self.parameters = [
|
||||
float(line["polynomial_size"]),
|
||||
float(line["glwe_dimension"]),
|
||||
float(line["decomposition_level_count"]),
|
||||
float(line["decomposition_base_log"]),
|
||||
]
|
||||
# polynomial_size, glwe_dimension, decomposition_level_count, decomposition_base_log
|
||||
ggsw_value = int(line["ggsw_encrypted_value"])
|
||||
if ggsw_value != 1:
|
||||
raise ValueError(f"GGSW value is not 1, it's: {ggsw_value}")
|
||||
|
||||
|
||||
def concatenate_result_files(dir_):
|
||||
"""
|
||||
Concatenate result files into a single one.
|
||||
|
||||
:param pattern: filename pattern as :class:`str`
|
||||
:return: concatenated filename as :class:`Path`
|
||||
"""
|
||||
dir_path = Path(dir_)
|
||||
results_filepath = dir_path / "concatenated_sampling_results"
|
||||
files = sorted(Path(dir_).glob("*.algo_sample_acquistion"))
|
||||
if results_filepath.exists():
|
||||
results_filepath.unlink()
|
||||
|
||||
first_file = files[0]
|
||||
with results_filepath.open("w", encoding="utf-8") as results:
|
||||
content = first_file.read_text()
|
||||
(header, sep, _content) = content.partition("\n")
|
||||
new_hader = (header + sep).replace(" ", "")
|
||||
results.write(new_hader)
|
||||
|
||||
with results_filepath.open("a", encoding="utf-8") as results:
|
||||
for file in files:
|
||||
content = file.read_text()
|
||||
(_header, _sep, content) = content.partition("\n")
|
||||
results.write(content.replace(" ", ""))
|
||||
|
||||
return results_filepath
|
||||
|
||||
|
||||
def extract_from_acquisitions(filename):
|
||||
"""
|
||||
Retrieve and parse data from sampling results.
|
||||
|
||||
:param filename: sampling results filename as :class:`Path`
|
||||
:return: :class:`tuple` of :class:`numpy.array`
|
||||
"""
|
||||
parameters = []
|
||||
exp_output_variance = []
|
||||
th_output_variance = []
|
||||
single_ggsw_variance = []
|
||||
input_variance = []
|
||||
|
||||
with filename.open() as csvfile:
|
||||
csv_reader = csv.DictReader(csvfile, delimiter=",")
|
||||
|
||||
for line in csv_reader:
|
||||
try:
|
||||
sampled_line = SamplingLine(line)
|
||||
except Exception as err:
|
||||
# If an exception occurs when parsing a result line, we simply discard this one.
|
||||
print(f"Exception while parsing line (error: {err}, line: {line})")
|
||||
continue
|
||||
|
||||
exp_output_var = sampled_line.output_variance_exp
|
||||
th_output_var = sampled_line.output_variance_th
|
||||
single_ggsw_var = sampled_line.single_ggsw_variance
|
||||
input_var = sampled_line.input_variance
|
||||
params = sampled_line.parameters
|
||||
|
||||
if exp_output_var < 0.083:
|
||||
params.append(th_output_var)
|
||||
parameters.append(params)
|
||||
exp_output_variance.append(exp_output_var)
|
||||
th_output_variance.append(th_output_var)
|
||||
single_ggsw_variance.append(single_ggsw_var)
|
||||
input_variance.append(input_var)
|
||||
|
||||
num_samples = len(parameters)
|
||||
|
||||
print(f"There is {num_samples} samples ...")
|
||||
|
||||
return (
|
||||
(
|
||||
np.array(parameters),
|
||||
np.array(exp_output_variance),
|
||||
np.array(th_output_variance),
|
||||
np.array(single_ggsw_variance),
|
||||
np.array(input_variance),
|
||||
)
|
||||
if num_samples != 0
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def get_input(filename):
|
||||
"""
|
||||
:param filename: result filename as :class:`Path`
|
||||
:return: :class:`tuple` of X and Y values
|
||||
"""
|
||||
acquisition_samples = extract_from_acquisitions(filename)
|
||||
if acquisition_samples is None:
|
||||
return None
|
||||
|
||||
(
|
||||
parameters,
|
||||
exp_output_variance,
|
||||
_th_output_variance,
|
||||
_single_ggsw_variance,
|
||||
input_variance,
|
||||
) = acquisition_samples
|
||||
y_values = np.maximum(0.0, (exp_output_variance - input_variance)) #TODO return a NaN if exp_output_variance <= input_variance ??
|
||||
x_values = parameters
|
||||
return x_values, y_values
|
||||
|
||||
|
||||
def get_input_without_outlier(filename, bits):
|
||||
inputs = get_input(filename)
|
||||
if inputs is None:
|
||||
return None
|
||||
return remove_outlier(bits, *inputs)
|
||||
|
||||
|
||||
def remove_outlier(bits, x_values, y_values):
|
||||
"""
|
||||
Remove outliers from a dataset using an isolation forest algorithm.
|
||||
|
||||
:param x_values: values for the first dimension as :class:`list`
|
||||
:param y_values: values for the second dimension as :class:`list`
|
||||
:return: cleaned dataset as :class:`tuple` which element storing values a dimension in a
|
||||
:class:`list`
|
||||
"""
|
||||
# identify outliers in the training dataset
|
||||
iso = IsolationForest(contamination=0.1) # Contamination value obtained by experience
|
||||
yhat = iso.fit_predict(x_values)
|
||||
|
||||
# select all rows that are not outliers
|
||||
mask = yhat != -1
|
||||
previous_size = len(x_values)
|
||||
# ~ x_values, y_values = x_values[mask, :], y_values[mask]
|
||||
new_size = len(x_values)
|
||||
print(f"Removing {previous_size - new_size} outliers ...")
|
||||
x_values = x_values.astype(np.float64)
|
||||
# Scale the values from variance to modular variance after the filtering was done to avoid
|
||||
# overflowing the isolation forest from sklearn
|
||||
x_values[:, -1] = x_values[:, -1] * np.float64(2 ** (bits * 2))
|
||||
y_values = y_values.astype(np.float64) * np.float64(2 ** (bits * 2))
|
||||
return x_values, y_values
|
||||
|
||||
|
||||
def fft_noise(x, a, b, c, log2_q):
|
||||
"""
|
||||
Noise formula for FFTW.
|
||||
"""
|
||||
# 53 bits of mantissa kept at most
|
||||
bits_lost_per_conversion = max(0, log2_q - 53)
|
||||
bit_lost_roundtrip = 2 * bits_lost_per_conversion
|
||||
|
||||
N = x[:, 0]
|
||||
k = x[:, 1]
|
||||
level = x[:, 2]
|
||||
logbase = x[:, 3]
|
||||
theoretical_var = x[:, -1]
|
||||
# ~ print(x[:,0])
|
||||
# ~ print(x[:,1])
|
||||
# ~ print(x[:,-1])
|
||||
# ~ print("----")
|
||||
return k * (k + 1) * level * ( # tanh: * 2.**(1.-(1.+a/(2.**logbase))*.5*(np.tanh(b*(level-(N/50.)/logbase))+1.))
|
||||
# ~ a * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * N**1.584962501
|
||||
# ~ + b * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * N**2
|
||||
a * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * N**2 # in theory, not present
|
||||
+ b * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * (N**2)*np.log2(N)
|
||||
+ c * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * (N**2)*(np.log2(N)**2)
|
||||
) + theoretical_var
|
||||
|
||||
|
||||
def fft_noise_128(x, a, b, c, log2_q):
|
||||
"""
|
||||
Noise formula for f128 fft
|
||||
"""
|
||||
# 106 bits of mantissa kept at most
|
||||
bits_lost_per_conversion = max(0, log2_q - 106)
|
||||
bit_lost_roundtrip = 2 * bits_lost_per_conversion
|
||||
|
||||
N = x[:, 0]
|
||||
k = x[:, 1]
|
||||
level = x[:, 2]
|
||||
logbase = x[:, 3]
|
||||
theoretical_var = x[:, -1]
|
||||
# we lose 2 * 11 bits of mantissa per conversion 22 * 2 = 44
|
||||
return k * (k + 1) * level * ( # tanh: * 2.**(1.-(1.+a/(2.**logbase))*.5*(np.tanh(b*(level-(N/50.)/logbase))+1.))
|
||||
# ~ a * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * N**1.584962501
|
||||
# ~ + b * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * N**2
|
||||
a * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * N**2
|
||||
+ b * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * (N**2)*np.log2(N)
|
||||
+ c * 2**bit_lost_roundtrip * 2.0 ** (2 * logbase) * (N**2)*(np.log2(N)**2)
|
||||
) + theoretical_var
|
||||
|
||||
def log_fft_noise_fun(x, a, b, c, fft_noise_fun):
|
||||
return np.log2(fft_noise_fun(x, a, b, c))
|
||||
|
||||
|
||||
def train(x_values, y_values, fft_noise_fun):
|
||||
weights, _ = curve_fit(
|
||||
#TODO try changing the formula according to paper
|
||||
lambda x, a, b, c: log_fft_noise_fun(x, a, b, c, fft_noise_fun), x_values, np.log2(y_values)
|
||||
)
|
||||
return weights
|
||||
|
||||
|
||||
def get_weights(filename, fft_noise_fun, bits):
|
||||
"""
|
||||
Get weights from sampling results.
|
||||
|
||||
:param filename: results filename as :class:`Path`
|
||||
:return: :class:`dict` of weights formatted as ``{"a": <float>, "b": <float>, "c": <float>}``
|
||||
"""
|
||||
inputs_without_outlier = get_input_without_outlier(filename, bits)
|
||||
if inputs_without_outlier is None:
|
||||
return None
|
||||
x_values, y_values = inputs_without_outlier
|
||||
weights = train(x_values, y_values, fft_noise_fun)
|
||||
test(x_values, y_values, weights, fft_noise_fun)
|
||||
return {"a": weights[0], "b": weights[1], "c": weights[2]}
|
||||
|
||||
|
||||
def write_to_file(filename, obj):
|
||||
"""
|
||||
Write the given ``obj``ect into a file formatted as JSON.
|
||||
|
||||
:param filename: filename to write into as :class:`str`
|
||||
:param obj: object to write as JSON
|
||||
"""
|
||||
filepath = Path(filename)
|
||||
try:
|
||||
with filepath.open("w", encoding="utf-8") as f:
|
||||
json.dump(obj, f)
|
||||
except Exception as err:
|
||||
print(f"Exception occurred while writing to {filename}: {err}")
|
||||
else:
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
def build_sampler(rust_toolchain) -> bool:
|
||||
"""
|
||||
Build sampling Rust program as a subprocess.
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
print("Building sampling program")
|
||||
|
||||
build_command = BUILD_COMMAND.format(rust_toolchain)
|
||||
|
||||
process = subprocess.run(build_command, shell=True, capture_output=True, check=False)
|
||||
|
||||
elapsed_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
stderr = process.stderr.decode()
|
||||
stderr_formatted = f"STDERR: {stderr}" if stderr else ""
|
||||
print(
|
||||
f"Building failed after {elapsed_time} seconds\n"
|
||||
f"STDOUT: {process.stdout.decode()}\n"
|
||||
f"{stderr_formatted}"
|
||||
)
|
||||
|
||||
if process.returncode == 0:
|
||||
print(f"Building done in {elapsed_time} seconds")
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def run_sampling_chunk(rust_toolchain, total_chunks, identity, input_args) -> bool:
|
||||
"""
|
||||
Run an external product sampling on a chunk of data as a subprocess.
|
||||
|
||||
:param total_chunks: number of chunks the parameter is divided into
|
||||
:param identity: chunk identifier as :class:`int`
|
||||
:param input_args: arguments passed to sampling program
|
||||
"""
|
||||
cmd = RUN_COMMAND.format(rust_toolchain, total_chunks, identity, input_args)
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
print(f"External product sampling chunk #{identity} starting")
|
||||
|
||||
process = subprocess.run(cmd, shell=True, capture_output=True, check=False)
|
||||
|
||||
elapsed_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
hours = int(elapsed_time // SECS_PER_HOUR)
|
||||
minutes = int((elapsed_time % SECS_PER_HOUR) // SECS_PER_MINUTES)
|
||||
seconds = int(elapsed_time % SECS_PER_HOUR % SECS_PER_MINUTES)
|
||||
|
||||
if process.returncode == 0:
|
||||
print(
|
||||
f"External product sampling chunk #{identity} successfully done in"
|
||||
f" {hours}:{minutes}:{seconds}"
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
stderr = process.stderr.decode()
|
||||
stderr_formatted = f"STDERR: {stderr}" if stderr else ""
|
||||
print(
|
||||
f"External product sampling chunk #{identity} failed after"
|
||||
f" {hours}:{minutes}:{seconds}\n"
|
||||
f"STDOUT: {process.stdout.decode()}\n"
|
||||
f"{stderr_formatted}"
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def log_var(variance):
|
||||
if variance <= 0:
|
||||
return np.nan
|
||||
return np.log2(variance) # was: np.ceil(0.5 * np.log2(variance)) ??
|
||||
|
||||
|
||||
def test(x_values, y_values, weights, fft_noise_fun):
|
||||
for nu in range(8,15):
|
||||
big_N = 2.**nu
|
||||
mse = 0.0
|
||||
mse_without_correction = 0.0
|
||||
count = 0
|
||||
for index in range(len(x_values)):
|
||||
params = np.array([x_values[index, :]])
|
||||
real_out = y_values[index]
|
||||
pred_out = max(fft_noise_fun(params, *list(weights))[0], 0.000001) #TODO make sure this const is OK
|
||||
if params[0,0] == big_N:
|
||||
mse += (log_var(real_out) - log_var(pred_out)) ** 2
|
||||
print(f"{log_var(real_out) - log_var(pred_out)}, {params[0][0]}, {params[0][1]}, {params[0][2]}, {params[0][3]}, {params[0][4]}, {real_out}, {pred_out}") # log_var(real_out) - log_var(pred_out) == log_var(real_out/pred_out)
|
||||
# print(
|
||||
# f"th: {log_var(params[0, -1])}, pred_fft: {log_var(pred_out)}, "
|
||||
# f"real: {log_var(real_out)}"
|
||||
# )
|
||||
mse_without_correction += (log_var(real_out) - log_var(params[0, -1])) ** 2
|
||||
count += 1
|
||||
# print(log_var(params[0, -1]))
|
||||
# mse_without_correction += (log_var(real_out) ) ** 2
|
||||
|
||||
# print()
|
||||
count = max(count, 1)
|
||||
|
||||
mse /= count # len(x_values)
|
||||
mse = .5 * mse ** .5
|
||||
mse_without_correction /= count # len(x_values)
|
||||
mse_without_correction = .5 * mse_without_correction ** .5
|
||||
# ~ print(f"½ √mse (N = {big_N}): {mse} .. {2 ** (2*mse)}") # \nMSE without correction: {mse_without_correction}
|
||||
|
||||
mse = 0.0
|
||||
mse_without_correction = 0.0
|
||||
count = 0
|
||||
for index in range(len(x_values)):
|
||||
params = np.array([x_values[index, :]])
|
||||
real_out = y_values[index]
|
||||
pred_out = max(fft_noise_fun(params, *list(weights))[0], 0.000001)
|
||||
mse += (log_var(real_out) - log_var(pred_out)) ** 2
|
||||
# print(
|
||||
# f"th: {log_var(params[0, -1])}, pred_fft: {log_var(pred_out)}, "
|
||||
# f"real: {log_var(real_out)}"
|
||||
# )
|
||||
mse_without_correction += (log_var(real_out) - log_var(params[0, -1])) ** 2
|
||||
count += 1
|
||||
# print(log_var(params[0, -1]))
|
||||
# mse_without_correction += (log_var(real_out) ) ** 2
|
||||
|
||||
count = max(count, 1)
|
||||
|
||||
mse /= count # len(x_values)
|
||||
mse = .5 * mse ** .5
|
||||
mse_without_correction /= count # len(x_values)
|
||||
mse_without_correction = .5 * mse_without_correction ** .5
|
||||
print(f"½ √mse (all N): {mse} .. {2 ** (2*mse)} \n½ √MSE without correction: {mse_without_correction} .. {2 ** (2*mse_without_correction)}")
|
||||
return mse, mse_without_correction
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
rust_toolchain = args.rust_toolchain
|
||||
if rust_toolchain[0] != "+":
|
||||
rust_toolchain = f"+{rust_toolchain}"
|
||||
|
||||
sampling_args = list(filter(lambda x: x != "--", args.sampling_args))
|
||||
|
||||
bits = 64
|
||||
fft_noise_fun = fft_noise
|
||||
if any(arg in ["ext-prod-u128-split", "ext-prod-u128"] for arg in sampling_args):
|
||||
fft_noise_fun = fft_noise_128
|
||||
bits = 128
|
||||
|
||||
for idx, flag_or_value in enumerate(sampling_args):
|
||||
if flag_or_value in ["-q", "--modulus-log2"]:
|
||||
bits = int(sampling_args[idx + 1])
|
||||
break
|
||||
|
||||
sampling_args.extend(["--dir", args.dir])
|
||||
|
||||
fft_noise_fun = functools.partial(fft_noise_fun, log2_q=bits)
|
||||
dest_dir = Path(args.dir).resolve()
|
||||
|
||||
if not args.analysis_only:
|
||||
# if dest_dir.exists() and dest_dir.glob(args.output_filename):
|
||||
# user_input = input(
|
||||
# f"Warning directory {str(dest_dir)} already exists, "
|
||||
# "proceed and overwrite existing data? [y/N]\n"
|
||||
# )
|
||||
# if user_input.lower() != "y":
|
||||
# print("Aborting.")
|
||||
# exit(1)
|
||||
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not build_sampler(rust_toolchain):
|
||||
print("Error while building sampler. Exiting")
|
||||
exit(1)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.chunks) as executor:
|
||||
futures = []
|
||||
for n in range(args.chunks):
|
||||
futures.append(
|
||||
executor.submit(
|
||||
run_sampling_chunk,
|
||||
rust_toolchain,
|
||||
args.chunks,
|
||||
n,
|
||||
" ".join(sampling_args),
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for all sampling chunks to be completed.
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
execution_ok = True
|
||||
|
||||
for future in futures:
|
||||
execution_ok = execution_ok and future.result()
|
||||
|
||||
if not execution_ok:
|
||||
print("Error while running samplings processes. Check logs.")
|
||||
exit(1)
|
||||
|
||||
result_file = concatenate_result_files(args.dir)
|
||||
output_file = dest_dir / args.output_filename
|
||||
|
||||
if args.worst_case_analysis:
|
||||
weights = get_weights(result_file, fft_noise_fun, bits)
|
||||
if weights is None:
|
||||
print("Empty weights after outlier removal, exiting")
|
||||
return
|
||||
max_a = weights["a"]
|
||||
max_b = weights["b"]
|
||||
max_c = weights["c"]
|
||||
for _ in range(1000):
|
||||
weights = get_weights(result_file, fft_noise_fun, bits)
|
||||
max_a = max(max_a, weights["a"])
|
||||
max_b = max(max_b, weights["b"])
|
||||
max_c = max(max_c, weights["c"])
|
||||
write_to_file(output_file, {"a": max_a, "b": max_b, "c": max_c})
|
||||
else:
|
||||
weights = get_weights(result_file, fft_noise_fun, bits)
|
||||
if weights is None:
|
||||
print("Empty weights after outlier removal, exiting")
|
||||
return
|
||||
write_to_file(output_file, weights)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
21
tfhe-rs-cost-model/src/histo-plot.sh
Executable file
21
tfhe-rs-cost-model/src/histo-plot.sh
Executable file
@@ -0,0 +1,21 @@
|
||||
#!/usr/bin/env gnuplot
|
||||
|
||||
GF = 2
|
||||
SFX = "-tanh" # -tanh
|
||||
DATAFILE = "log-real-to-pred-".GF.SFX.".dat"
|
||||
|
||||
set term pngcairo size 1200,900 linewidth 2
|
||||
set out "histogram-".GF.SFX.".png"
|
||||
|
||||
set style fill solid 0.5 # fill style
|
||||
set xrange [-4:4]
|
||||
set yrange [0:600]
|
||||
|
||||
min=-3. # min value
|
||||
max= 3. # max value
|
||||
n = 200
|
||||
width=(max-min)/n # interval width
|
||||
set boxwidth width*0.8
|
||||
hist(x,width)=width*floor(x/width)+width/2.0
|
||||
|
||||
plot DATAFILE u (hist($1,width)):(1.0) smooth freq w boxes lc rgb "green" notitle
|
||||
BIN
tfhe-rs-cost-model/src/histogram-2.png
Normal file
BIN
tfhe-rs-cost-model/src/histogram-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.3 KiB |
BIN
tfhe-rs-cost-model/src/histogram-3.png
Normal file
BIN
tfhe-rs-cost-model/src/histogram-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.3 KiB |
BIN
tfhe-rs-cost-model/src/histogram-4.png
Normal file
BIN
tfhe-rs-cost-model/src/histogram-4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.4 KiB |
BIN
tfhe-rs-cost-model/src/histograms.gif
Normal file
BIN
tfhe-rs-cost-model/src/histograms.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
706
tfhe-rs-cost-model/src/ks_pbs_timing.rs
Normal file
706
tfhe-rs-cost-model/src/ks_pbs_timing.rs
Normal file
@@ -0,0 +1,706 @@
|
||||
use super::*;
|
||||
|
||||
use itertools::Itertools;
|
||||
use rand::prelude::*;
|
||||
use rayon::prelude::*;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tfhe::core_crypto::commons::noise_formulas::lwe_keyswitch::keyswitch_additive_variance_132_bits_security_gaussian;
|
||||
use tfhe::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap::pbs_variance_132_bits_security_gaussian;
|
||||
use tfhe::core_crypto::commons::noise_formulas::secure_noise::{
|
||||
minimal_glwe_variance_for_132_bits_security_gaussian,
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian,
|
||||
};
|
||||
|
||||
// pub const SECURITY_LEVEL: u64 = 132;
|
||||
// Variance of uniform distribution over [0; 1)
|
||||
pub const UNIFORM_NOISE_VARIANCE: f64 = 1. / 12.;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct Params {
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
ks_level: DecompositionLevelCount,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
struct ParamsHash {
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log_smaller_than_5: bool,
|
||||
}
|
||||
|
||||
impl std::hash::Hash for ParamsHash {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.lwe_dimension.0.hash(state);
|
||||
self.glwe_dimension.0.hash(state);
|
||||
self.polynomial_size.0.hash(state);
|
||||
self.pbs_level.0.hash(state);
|
||||
self.ks_level.0.hash(state);
|
||||
self.ks_base_log_smaller_than_5.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Params> for ParamsHash {
|
||||
fn from(value: Params) -> Self {
|
||||
Self {
|
||||
lwe_dimension: value.lwe_dimension,
|
||||
glwe_dimension: value.glwe_dimension,
|
||||
polynomial_size: value.polynomial_size,
|
||||
pbs_level: value.pbs_level,
|
||||
ks_level: value.ks_level,
|
||||
ks_base_log_smaller_than_5: value.ks_base_log.0 <= 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NoiseVariances {
|
||||
lwe_noise_variance: Variance,
|
||||
glwe_noise_variance: Variance,
|
||||
estimated_pbs_noise_variance: Variance,
|
||||
estimated_ks_noise_variance: Variance,
|
||||
br_to_ms_noise_variance: Variance,
|
||||
}
|
||||
|
||||
impl NoiseVariances {
|
||||
fn all_noises_are_not_uniformly_random(&self) -> bool {
|
||||
self.lwe_noise_variance.0 < UNIFORM_NOISE_VARIANCE
|
||||
&& self.glwe_noise_variance.0 < UNIFORM_NOISE_VARIANCE
|
||||
&& self.estimated_ks_noise_variance.0 < UNIFORM_NOISE_VARIANCE
|
||||
&& self.estimated_pbs_noise_variance.0 < UNIFORM_NOISE_VARIANCE
|
||||
&& self.br_to_ms_noise_variance.0 < UNIFORM_NOISE_VARIANCE
|
||||
}
|
||||
}
|
||||
|
||||
// TODO
|
||||
// This needs to be updated with the research optimizer
|
||||
// This was taken from concrete CPU temporarily
|
||||
pub fn estimate_modulus_switching_noise_with_binary_key(
|
||||
internal_ks_output_lwe_dimension: LweDimension,
|
||||
glwe_polynomial_size: PolynomialSize,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
let ciphertext_modulus_log = modulus.log2() as u32;
|
||||
|
||||
fn modular_variance_variance_ratio(ciphertext_modulus_log: u32) -> f64 {
|
||||
2_f64.powi(2 * ciphertext_modulus_log as i32)
|
||||
}
|
||||
|
||||
fn modular_variance_to_variance(modular_variance: f64, ciphertext_modulus_log: u32) -> f64 {
|
||||
modular_variance / modular_variance_variance_ratio(ciphertext_modulus_log)
|
||||
}
|
||||
|
||||
let nb_msb = glwe_polynomial_size.0.ilog2() + 1;
|
||||
|
||||
let w = 2_f64.powi(nb_msb as i32);
|
||||
let n = internal_ks_output_lwe_dimension.0 as f64;
|
||||
|
||||
Variance(
|
||||
(1. / 12. + n / 24.) / (w * w)
|
||||
+ modular_variance_to_variance(-1. / 12. + n / 48., ciphertext_modulus_log),
|
||||
)
|
||||
}
|
||||
|
||||
fn lwe_glwe_noise_ap_estimate(
|
||||
Params {
|
||||
lwe_dimension,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
pbs_base_log,
|
||||
pbs_level,
|
||||
ks_base_log,
|
||||
ks_level,
|
||||
}: Params,
|
||||
ciphertext_modulus_log: u32,
|
||||
) -> NoiseVariances {
|
||||
let modulus = 2.0f64.powi(ciphertext_modulus_log as i32);
|
||||
let lwe_noise_variance =
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian(lwe_dimension, modulus);
|
||||
|
||||
let glwe_noise_variance = minimal_glwe_variance_for_132_bits_security_gaussian(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
modulus,
|
||||
);
|
||||
|
||||
let estimated_pbs_noise_variance = pbs_variance_132_bits_security_gaussian(
|
||||
lwe_dimension,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
pbs_base_log,
|
||||
pbs_level,
|
||||
modulus,
|
||||
);
|
||||
|
||||
let estimated_ks_noise_variance = keyswitch_additive_variance_132_bits_security_gaussian(
|
||||
glwe_dimension.to_equivalent_lwe_dimension(polynomial_size),
|
||||
lwe_dimension,
|
||||
ks_base_log,
|
||||
ks_level,
|
||||
modulus,
|
||||
);
|
||||
|
||||
let ms_noise_variance =
|
||||
estimate_modulus_switching_noise_with_binary_key(lwe_dimension, polynomial_size, modulus);
|
||||
|
||||
let br_to_ms_noise_variance = Variance(
|
||||
estimated_pbs_noise_variance.0 + estimated_ks_noise_variance.0 + ms_noise_variance.0,
|
||||
);
|
||||
|
||||
NoiseVariances {
|
||||
lwe_noise_variance,
|
||||
glwe_noise_variance,
|
||||
estimated_pbs_noise_variance,
|
||||
estimated_ks_noise_variance,
|
||||
br_to_ms_noise_variance,
|
||||
}
|
||||
}
|
||||
|
||||
fn write_results_to_file(
|
||||
params: Params,
|
||||
perf_metrics_array: &[(usize, ThreadCount, usize, PerfMetrics)],
|
||||
out_dir: &Path,
|
||||
) {
|
||||
let exp_name = format!(
|
||||
"n={}_k={}_N={}_brl={}_brb={}_ksl={}_ksb={}",
|
||||
params.lwe_dimension.0,
|
||||
params.glwe_dimension.0,
|
||||
params.polynomial_size.0,
|
||||
params.pbs_level.0,
|
||||
params.pbs_base_log.0,
|
||||
params.ks_level.0,
|
||||
params.ks_base_log.0,
|
||||
);
|
||||
|
||||
let out_file_name = PathBuf::from(format!("{exp_name}.csv"));
|
||||
|
||||
let out_path = out_dir.join(out_file_name);
|
||||
|
||||
if out_path.exists() {
|
||||
std::fs::remove_file(&out_path).unwrap();
|
||||
}
|
||||
|
||||
let mut out = std::fs::File::options()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(&out_path)
|
||||
.unwrap();
|
||||
|
||||
// per_batch_runtime_s: f64,
|
||||
// pbs_per_s: f64,
|
||||
// pbs_per_s_per_thread: f64,
|
||||
// equivalent_monothread_pbs_runtime_s: f64,
|
||||
|
||||
writeln!(
|
||||
&mut out,
|
||||
"chunk_size,threads_used,batch_count,overall_runtime_s,\
|
||||
per_batch_runtime_s,pbs_per_s,pbs_per_s_per_thread,equivalent_monothread_pbs_runtime_s"
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
for (chunk_size, thread_count, batch_count, perf_metrics) in perf_metrics_array {
|
||||
let thread_count = thread_count.0;
|
||||
let PerfMetrics {
|
||||
overall_runtime_s,
|
||||
per_batch_runtime_s,
|
||||
pbs_per_s,
|
||||
pbs_per_s_per_thread,
|
||||
equivalent_monothread_pbs_runtime_s,
|
||||
} = perf_metrics;
|
||||
writeln!(
|
||||
&mut out,
|
||||
"{chunk_size},{thread_count},{batch_count},{overall_runtime_s},\
|
||||
{per_batch_runtime_s},{pbs_per_s},{pbs_per_s_per_thread},{equivalent_monothread_pbs_runtime_s}"
|
||||
).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_b_l_limited(
|
||||
bases: &[usize],
|
||||
levels: &[usize],
|
||||
preserved_mantissa: usize,
|
||||
) -> Vec<BaseLevel> {
|
||||
let mut bases_levels = vec![];
|
||||
for (b, l) in iproduct!(bases, levels) {
|
||||
if b * l <= preserved_mantissa {
|
||||
if *b == 1 {
|
||||
if (b * l) % 5 == 0 {
|
||||
bases_levels.push(BaseLevel {
|
||||
base: DecompositionBaseLog(*b),
|
||||
level: DecompositionLevelCount(*l),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
bases_levels.push(BaseLevel {
|
||||
base: DecompositionBaseLog(*b),
|
||||
level: DecompositionLevelCount(*l),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
bases_levels
|
||||
}
|
||||
|
||||
// preserved_mantissa = number of bits that are in the mantissa of the floating point numbers used
|
||||
pub fn timing_experiment(algorithm: &str, preserved_mantissa: usize, modulus: u128) {
|
||||
assert_eq!(algorithm, EXT_PROD_ALGO);
|
||||
|
||||
let out_dir = Path::new("exp");
|
||||
if !out_dir.exists() {
|
||||
std::fs::create_dir(out_dir).unwrap();
|
||||
}
|
||||
|
||||
let ciphertext_modulus: CiphertextModulus<u64> = match modulus {
|
||||
0 => CiphertextModulus::new_native(),
|
||||
_ => CiphertextModulus::try_new(modulus).unwrap(),
|
||||
};
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let lwe_dimension_search_space = (512..=1024).step_by(64).map(LweDimension);
|
||||
let glwe_dimension_search_space = (1..=5).map(GlweDimension);
|
||||
let polynomial_size_search_space = (8..=14).map(|poly_log2| PolynomialSize(1 << poly_log2));
|
||||
|
||||
let modulus_log2 = if ciphertext_modulus.is_native_modulus() {
|
||||
64usize
|
||||
} else {
|
||||
ciphertext_modulus.get_custom_modulus().ilog2() as usize
|
||||
};
|
||||
|
||||
// TODO: as discussed with Sam, limit to 40
|
||||
let max_base_level_product = 40;
|
||||
|
||||
let preserved_mantissa = preserved_mantissa.min(modulus_log2);
|
||||
|
||||
let (potential_base_logs, potential_levels) = (
|
||||
(1..=modulus_log2).collect::<Vec<_>>(),
|
||||
(1..=modulus_log2).collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let max_base_log_level_prod = preserved_mantissa
|
||||
.min(modulus_log2)
|
||||
.min(max_base_level_product);
|
||||
|
||||
// let base_log_level_pbs = filter_b_l(
|
||||
// &potential_base_logs,
|
||||
// &potential_levels,
|
||||
// max_base_log_level_prod,
|
||||
// );
|
||||
|
||||
let base_log_level_pbs = filter_b_l_limited(
|
||||
&potential_base_logs,
|
||||
&potential_levels,
|
||||
max_base_log_level_prod,
|
||||
);
|
||||
// Same for KS
|
||||
let base_log_level_ks = base_log_level_pbs.clone();
|
||||
|
||||
let hypercube = iproduct!(
|
||||
lwe_dimension_search_space,
|
||||
glwe_dimension_search_space,
|
||||
polynomial_size_search_space,
|
||||
base_log_level_pbs,
|
||||
base_log_level_ks
|
||||
);
|
||||
|
||||
let hypercube: Vec<_> = hypercube
|
||||
.map(
|
||||
|(
|
||||
lwe_dimension,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
pbs_base_log_level,
|
||||
ks_base_log_level,
|
||||
)| {
|
||||
let params = Params {
|
||||
lwe_dimension,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
pbs_base_log: pbs_base_log_level.base,
|
||||
pbs_level: pbs_base_log_level.level,
|
||||
ks_base_log: ks_base_log_level.base,
|
||||
ks_level: ks_base_log_level.level,
|
||||
};
|
||||
let variances =
|
||||
lwe_glwe_noise_ap_estimate(params, modulus_log2.try_into().unwrap());
|
||||
(params, variances)
|
||||
},
|
||||
)
|
||||
.filter(|(_params, variances)| {
|
||||
// let noise_ok = variances.all_noises_are_not_uniformly_random();
|
||||
// let base_logs_not_too_small = params.pbs_base_log.0 != 1 && params.ks_base_log.0 !=
|
||||
// 1;
|
||||
|
||||
// noise_ok && base_logs_not_too_small
|
||||
|
||||
// let glwe_poly_not_too_big = params.polynomial_size.0 < 2048
|
||||
// || (params.polynomial_size.0 >= 2048 && params.glwe_dimension.0 == 1);
|
||||
|
||||
// noise_ok && glwe_poly_not_too_big
|
||||
|
||||
// noise_ok
|
||||
|
||||
variances.all_noises_are_not_uniformly_random()
|
||||
})
|
||||
.collect();
|
||||
|
||||
println!("candidates {}", hypercube.len());
|
||||
|
||||
let mut hypercube: Vec<_> = hypercube
|
||||
.into_iter()
|
||||
.unique_by(|x| ParamsHash::from(x.0))
|
||||
.collect();
|
||||
|
||||
println!("candidates {}", hypercube.len());
|
||||
|
||||
// hypercube.sort_by(|a, b| {
|
||||
// let a = a.0;
|
||||
// let b = b.0;
|
||||
// let cost_a = ks_cost(
|
||||
// a.lwe_dimension,
|
||||
// a.glwe_dimension
|
||||
// .to_equivalent_lwe_dimension(a.polynomial_size),
|
||||
// a.ks_level,
|
||||
// ) + pbs_cost(
|
||||
// a.lwe_dimension,
|
||||
// a.glwe_dimension,
|
||||
// a.pbs_level,
|
||||
// a.polynomial_size,
|
||||
// );
|
||||
// let cost_b = ks_cost(
|
||||
// b.lwe_dimension,
|
||||
// b.glwe_dimension
|
||||
// .to_equivalent_lwe_dimension(b.polynomial_size),
|
||||
// b.ks_level,
|
||||
// ) + pbs_cost(
|
||||
// b.lwe_dimension,
|
||||
// b.glwe_dimension,
|
||||
// b.pbs_level,
|
||||
// b.polynomial_size,
|
||||
// );
|
||||
|
||||
// cost_a.cmp(&cost_b)
|
||||
// });
|
||||
|
||||
let seed = [0u8; 8 * 4];
|
||||
let mut rng = rand_chacha::ChaChaRng::from_seed(seed);
|
||||
hypercube.shuffle(&mut rng);
|
||||
|
||||
// // After the shuffle make the small levels pop first
|
||||
// hypercube.sort_by(|a, b| {
|
||||
// let a = a.0;
|
||||
// let b = b.0;
|
||||
|
||||
// let a_level_prod = a.ks_level.0 * a.pbs_level.0;
|
||||
// let b_level_prod = b.ks_level.0 * b.pbs_level.0;
|
||||
|
||||
// a_level_prod.cmp(&b_level_prod)
|
||||
|
||||
// // let a_size = a
|
||||
// // .glwe_dimension
|
||||
// // .to_equivalent_lwe_dimension(a.polynomial_size)
|
||||
// // .0
|
||||
// // * a.ks_level.0
|
||||
// // * a.lwe_dimension.to_lwe_size().0
|
||||
// // + a.lwe_dimension.0
|
||||
// // * (a.glwe_dimension.to_glwe_size().0.pow(2))
|
||||
// // * a.pbs_level.0
|
||||
// // * a.polynomial_size.0;
|
||||
|
||||
// // let b_size = b
|
||||
// // .glwe_dimension
|
||||
// // .to_equivalent_lwe_dimension(b.polynomial_size)
|
||||
// // .0
|
||||
// // * b.ks_level.0
|
||||
// // * b.lwe_dimension.to_lwe_size().0
|
||||
// // + b.lwe_dimension.0
|
||||
// // * (b.glwe_dimension.to_glwe_size().0.pow(2))
|
||||
// // * b.pbs_level.0
|
||||
// // * b.polynomial_size.0;
|
||||
|
||||
// // a_size.cmp(&b_size)
|
||||
// });
|
||||
|
||||
// {
|
||||
// let mut out = std::fs::File::options()
|
||||
// .create(true)
|
||||
// .truncate(true)
|
||||
// .write(true)
|
||||
// .open(&out_dir.join(&"params.log"))
|
||||
// .unwrap();
|
||||
|
||||
// for (param, _) in &hypercube {
|
||||
// writeln!(&mut out, "{param:?}").unwrap();
|
||||
// }
|
||||
// panic!("lol");
|
||||
// }
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
for (idx, (params, variances)) in hypercube.into_iter().enumerate() {
|
||||
let loop_start = std::time::Instant::now();
|
||||
println!("#{idx} start");
|
||||
println!("{params:#?}");
|
||||
let perf_metrics = run_timing_measurements(params, variances, ciphertext_modulus);
|
||||
println!("{perf_metrics:#?}");
|
||||
write_results_to_file(params, &perf_metrics, out_dir);
|
||||
let loop_elapsed = loop_start.elapsed();
|
||||
println!("#{idx} done in {loop_elapsed:?}");
|
||||
println!("overall runtime {:?}", start_time.elapsed());
|
||||
}
|
||||
}
|
||||
|
||||
pub const CHUNK_SIZE: [usize; 5] = [1, 32, 64, 128, 192];
|
||||
pub const BATCH_COUNT: usize = 100;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct PerfMetrics {
|
||||
overall_runtime_s: f64,
|
||||
per_batch_runtime_s: f64,
|
||||
pbs_per_s: f64,
|
||||
pbs_per_s_per_thread: f64,
|
||||
equivalent_monothread_pbs_runtime_s: f64,
|
||||
}
|
||||
|
||||
fn compute_perf_metrics(
|
||||
overall_runtime: std::time::Duration,
|
||||
batch_count: usize,
|
||||
pbs_per_batch: usize,
|
||||
thread_count: usize,
|
||||
) -> PerfMetrics {
|
||||
let per_batch_runtime = overall_runtime / batch_count.try_into().unwrap();
|
||||
let per_batch_runtime_s = per_batch_runtime.as_secs_f64();
|
||||
let batch_per_s = 1.0 / per_batch_runtime_s;
|
||||
let pbs_per_s = batch_per_s * pbs_per_batch as f64;
|
||||
let pbs_per_s_per_thread = pbs_per_s / thread_count as f64;
|
||||
let equivalent_monothread_pbs_runtime_s = 1.0 / pbs_per_s_per_thread;
|
||||
|
||||
PerfMetrics {
|
||||
overall_runtime_s: overall_runtime.as_secs_f64(),
|
||||
per_batch_runtime_s,
|
||||
pbs_per_s,
|
||||
pbs_per_s_per_thread,
|
||||
equivalent_monothread_pbs_runtime_s,
|
||||
}
|
||||
}
|
||||
|
||||
fn run_timing_measurements(
|
||||
params: Params,
|
||||
variances: NoiseVariances,
|
||||
ciphertext_modulus: CiphertextModulus<u64>,
|
||||
) -> Vec<(usize, ThreadCount, usize, PerfMetrics)> {
|
||||
// let params = Params {
|
||||
// lwe_dimension: LweDimension(742),
|
||||
// glwe_dimension: GlweDimension(1),
|
||||
// polynomial_size: PolynomialSize(2048),
|
||||
// pbs_base_log: DecompositionBaseLog(23),
|
||||
// pbs_level: DecompositionLevelCount(1),
|
||||
// ks_base_log: DecompositionBaseLog(3),
|
||||
// ks_level: DecompositionLevelCount(5),
|
||||
// };
|
||||
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
let mut secret_random_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
let mut encryption_random_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
|
||||
let lwe_noise_distribution =
|
||||
Gaussian::from_dispersion_parameter(variances.lwe_noise_variance, 0.0);
|
||||
let glwe_noise_distribution =
|
||||
Gaussian::from_dispersion_parameter(variances.glwe_noise_variance, 0.0);
|
||||
|
||||
let lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
params.lwe_dimension,
|
||||
&mut secret_random_generator,
|
||||
);
|
||||
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
params.glwe_dimension,
|
||||
params.polynomial_size,
|
||||
&mut secret_random_generator,
|
||||
);
|
||||
|
||||
let ksk = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
&glwe_secret_key.as_lwe_secret_key(),
|
||||
&lwe_secret_key,
|
||||
params.ks_base_log,
|
||||
params.ks_level,
|
||||
lwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let fbsk = {
|
||||
let bsk = allocate_and_generate_new_lwe_bootstrap_key(
|
||||
&lwe_secret_key,
|
||||
&glwe_secret_key,
|
||||
params.pbs_base_log,
|
||||
params.pbs_level,
|
||||
glwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fbsk = FourierLweBootstrapKey::new(
|
||||
bsk.input_lwe_dimension(),
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
);
|
||||
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier(&bsk, &mut fbsk);
|
||||
|
||||
fbsk
|
||||
};
|
||||
|
||||
let inputs: Vec<_> = (0..BATCH_COUNT * CHUNK_SIZE.last().unwrap())
|
||||
.map(|_| {
|
||||
allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&glwe_secret_key.as_lwe_secret_key(),
|
||||
Plaintext(0),
|
||||
glwe_noise_distribution,
|
||||
ciphertext_modulus,
|
||||
&mut encryption_random_generator,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut output = inputs.clone();
|
||||
|
||||
let fft = Fft::new(fbsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
|
||||
let mut buffers: Vec<_> = (0..*CHUNK_SIZE.last().unwrap())
|
||||
.map(|_| {
|
||||
let buffer_after_ks =
|
||||
LweCiphertext::new(0u64, ksk.output_lwe_size(), ciphertext_modulus);
|
||||
|
||||
let mut computations_buffers = ComputationBuffers::new();
|
||||
computations_buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
fbsk.glwe_size(),
|
||||
fbsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.try_unaligned_bytes_required()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
(buffer_after_ks, computations_buffers)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut accumulator = GlweCiphertext::new(
|
||||
0u64,
|
||||
fbsk.glwe_size(),
|
||||
fbsk.polynomial_size(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// Random values in the lut
|
||||
accumulator.as_mut().fill_with(|| rng.gen::<u64>());
|
||||
|
||||
let mut timings = vec![];
|
||||
|
||||
let current_thread_count = rayon::current_num_threads();
|
||||
|
||||
for chunk_size in CHUNK_SIZE {
|
||||
let effective_thread_count = ThreadCount(chunk_size.min(current_thread_count));
|
||||
|
||||
let ciphertext_to_process_count = chunk_size * BATCH_COUNT;
|
||||
|
||||
if chunk_size == 1 {
|
||||
assert_eq!(ciphertext_to_process_count, BATCH_COUNT);
|
||||
|
||||
let (after_ks_buffer, fft_buffer) = &mut buffers[0];
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for (input_lwe, output_lwe) in inputs[..ciphertext_to_process_count]
|
||||
.iter()
|
||||
.zip(output[..ciphertext_to_process_count].iter_mut())
|
||||
{
|
||||
keyswitch_lwe_ciphertext(&ksk, input_lwe, after_ks_buffer);
|
||||
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized(
|
||||
after_ks_buffer,
|
||||
output_lwe,
|
||||
&accumulator,
|
||||
&fbsk,
|
||||
fft,
|
||||
fft_buffer.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
let perf_metrics =
|
||||
compute_perf_metrics(elapsed, BATCH_COUNT, chunk_size, effective_thread_count.0);
|
||||
|
||||
timings.push((
|
||||
chunk_size,
|
||||
effective_thread_count,
|
||||
BATCH_COUNT,
|
||||
perf_metrics,
|
||||
));
|
||||
} else {
|
||||
let mut measurement_count = 0;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
for (input_lwe_chunk, output_lwe_chunk) in inputs[..ciphertext_to_process_count]
|
||||
.chunks_exact(chunk_size)
|
||||
.zip(output[..ciphertext_to_process_count].chunks_exact_mut(chunk_size))
|
||||
{
|
||||
measurement_count += 1;
|
||||
assert_eq!(input_lwe_chunk.len(), chunk_size);
|
||||
assert_eq!(output_lwe_chunk.len(), chunk_size);
|
||||
|
||||
input_lwe_chunk
|
||||
.par_iter()
|
||||
.zip(output_lwe_chunk.par_iter_mut())
|
||||
.zip(buffers.par_iter_mut())
|
||||
.for_each(|((input_lwe, output_lwe), (after_ks_buffer, fft_buffer))| {
|
||||
keyswitch_lwe_ciphertext(&ksk, input_lwe, after_ks_buffer);
|
||||
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized(
|
||||
after_ks_buffer,
|
||||
output_lwe,
|
||||
&accumulator,
|
||||
&fbsk,
|
||||
fft,
|
||||
fft_buffer.stack(),
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
assert_eq!(measurement_count, BATCH_COUNT);
|
||||
|
||||
let perf_metrics =
|
||||
compute_perf_metrics(elapsed, BATCH_COUNT, chunk_size, effective_thread_count.0);
|
||||
|
||||
timings.push((
|
||||
chunk_size,
|
||||
effective_thread_count,
|
||||
BATCH_COUNT,
|
||||
perf_metrics,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
timings
|
||||
}
|
||||
6225
tfhe-rs-cost-model/src/log-real-to-pred-2.dat
Normal file
6225
tfhe-rs-cost-model/src/log-real-to-pred-2.dat
Normal file
File diff suppressed because it is too large
Load Diff
6171
tfhe-rs-cost-model/src/log-real-to-pred-3.dat
Normal file
6171
tfhe-rs-cost-model/src/log-real-to-pred-3.dat
Normal file
File diff suppressed because it is too large
Load Diff
6076
tfhe-rs-cost-model/src/log-real-to-pred-4.dat
Normal file
6076
tfhe-rs-cost-model/src/log-real-to-pred-4.dat
Normal file
File diff suppressed because it is too large
Load Diff
801
tfhe-rs-cost-model/src/main.rs
Normal file
801
tfhe-rs-cost-model/src/main.rs
Normal file
@@ -0,0 +1,801 @@
|
||||
mod ks_pbs_timing;
|
||||
mod operators;
|
||||
|
||||
use crate::operators::classic_pbs::{
|
||||
classic_pbs_external_product, classic_pbs_external_product_u128,
|
||||
classic_pbs_external_product_u128_split,
|
||||
};
|
||||
use crate::operators::multi_bit_pbs::{
|
||||
multi_bit_pbs_external_product, std_multi_bit_pbs_external_product,
|
||||
};
|
||||
use clap::Parser;
|
||||
use itertools::iproduct;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write;
|
||||
use tfhe::core_crypto::algorithms::misc::torus_modular_diff;
|
||||
use tfhe::core_crypto::commons::noise_formulas::external_product_no_fft::external_product_no_fft_additive_variance132_bits_security_gaussian;
|
||||
use tfhe::core_crypto::commons::noise_formulas::multi_bit_external_product_no_fft::multi_bit_external_product_no_fft_additive_variance_132_bits_security_gaussian;
|
||||
use tfhe::core_crypto::commons::noise_formulas::secure_noise::minimal_glwe_variance_for_132_bits_security_gaussian;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
|
||||
pub const DEBUG: bool = false;
|
||||
pub const EXT_PROD_ALGO: &str = "ext-prod";
|
||||
pub const MULTI_BIT_EXT_PROD_ALGO: &str = "multi-bit-ext-prod";
|
||||
pub const STD_MULTI_BIT_EXT_PROD_ALGO: &str = "std-multi-bit-ext-prod";
|
||||
pub const EXT_PROD_U128_SPLIT_ALGO: &str = "ext-prod-u128-split";
|
||||
pub const EXT_PROD_U128_ALGO: &str = "ext-prod-u128";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GlweCiphertextGgswCiphertextExternalProductParameters<Scalar: UnsignedInteger> {
|
||||
pub ggsw_noise: Gaussian<f64>,
|
||||
pub glwe_noise: Gaussian<f64>,
|
||||
pub glwe_dimension: GlweDimension,
|
||||
pub ggsw_encrypted_value: Scalar,
|
||||
pub polynomial_size: PolynomialSize,
|
||||
pub decomposition_base_log: DecompositionBaseLog,
|
||||
pub decomposition_level_count: DecompositionLevelCount,
|
||||
pub ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[clap(author,version,about,long_about = None)]
|
||||
struct Args {
|
||||
/// Total number of threads.
|
||||
#[clap(long, short)]
|
||||
tot: usize,
|
||||
/// Current Thread ID
|
||||
#[clap(long, short)]
|
||||
id: usize,
|
||||
/// Number of time a test is repeated for a single set of parameter.
|
||||
/// This indicates the number of different keys since,at each repetition,we re-sample
|
||||
/// everything
|
||||
#[clap(long, short, default_value_t = 10)]
|
||||
repetitions: usize,
|
||||
/// The size of the sample per key
|
||||
#[clap(long, short = 'S', default_value_t = 10)]
|
||||
sample_size: usize,
|
||||
/// Step used for testing levels beyond 20in hypercube.
|
||||
/// Example: with a step of 3,tested levels tested would be: 1 through 20 then 21,24,27,etc
|
||||
#[clap(long, short = 's', default_value_t = 1)]
|
||||
steps: usize,
|
||||
/// Which algorithm to measure fft noise for
|
||||
#[clap(long,short = 'a',value_parser = [
|
||||
EXT_PROD_ALGO,
|
||||
MULTI_BIT_EXT_PROD_ALGO,
|
||||
STD_MULTI_BIT_EXT_PROD_ALGO,
|
||||
EXT_PROD_U128_SPLIT_ALGO,
|
||||
EXT_PROD_U128_ALGO
|
||||
],default_value = "")]
|
||||
algorithm: String,
|
||||
#[clap(long)]
|
||||
use_fft: bool,
|
||||
#[clap(long)]
|
||||
multi_bit_grouping_factor: Option<usize>,
|
||||
#[clap(long, short = 'q')]
|
||||
modulus_log2: Option<u32>,
|
||||
#[clap(long, short = 'd', default_value = ".")]
|
||||
dir: String,
|
||||
#[clap(long, action)]
|
||||
timing_only: bool,
|
||||
}
|
||||
|
||||
fn variance_to_stddev(var: Variance) -> StandardDev {
|
||||
StandardDev::from_standard_dev(var.get_standard_dev())
|
||||
}
|
||||
|
||||
fn get_analysis_output_file(dir: &str, id: usize) -> std::fs::File {
|
||||
match OpenOptions::new()
|
||||
.read(true)
|
||||
.write(true)
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(format!("{dir}/{id}.algo_sample_acquistion"))
|
||||
{
|
||||
Err(why) => panic!("{why}"),
|
||||
Ok(file) => file,
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_output_file_header(dir: &str, id: usize) {
|
||||
let mut file = get_analysis_output_file(dir, id);
|
||||
let header =
|
||||
"polynomial_size,glwe_dimension,decomposition_level_count,decomposition_base_log,\
|
||||
ggsw_encrypted_value,input_variance,output_variance,predicted_variance,single_ggsw_variance,mean_runtime_ns,\
|
||||
prep_time_ns\n";
|
||||
let _ = file.write(header.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn write_to_file<Scalar: UnsignedInteger + std::fmt::Display>(
|
||||
params: &GlweCiphertextGgswCiphertextExternalProductParameters<Scalar>,
|
||||
input_stddev: StandardDev,
|
||||
output_stddev: StandardDev,
|
||||
single_ggsw_stddev: StandardDev,
|
||||
pred_stddev: StandardDev,
|
||||
mean_runtime_ns: u128,
|
||||
mean_prep_time_ns: u128,
|
||||
dir: &str,
|
||||
id: usize,
|
||||
) {
|
||||
let data_to_save = format!(
|
||||
"{},{},{},{},{},{},{},{},{},{},{}\n",
|
||||
params.polynomial_size.0,
|
||||
params.glwe_dimension.0,
|
||||
params.decomposition_level_count.0,
|
||||
params.decomposition_base_log.0,
|
||||
params.ggsw_encrypted_value,
|
||||
input_stddev.get_variance(),
|
||||
output_stddev.get_variance(),
|
||||
pred_stddev.get_variance(),
|
||||
single_ggsw_stddev.get_variance(),
|
||||
mean_runtime_ns,
|
||||
mean_prep_time_ns,
|
||||
);
|
||||
|
||||
let mut file = get_analysis_output_file(dir, id);
|
||||
|
||||
let _ = file.write(data_to_save.as_bytes()).unwrap();
|
||||
}
|
||||
|
||||
fn minimal_variance_for_security(
|
||||
k: GlweDimension,
|
||||
size: PolynomialSize,
|
||||
modulus_log2: u32,
|
||||
) -> Variance {
|
||||
let modulus = 2.0f64.powi(modulus_log2 as i32);
|
||||
minimal_glwe_variance_for_132_bits_security_gaussian(k, size, modulus)
|
||||
}
|
||||
|
||||
fn mean(data: &[f64]) -> Option<f64> {
|
||||
// adapted from https://rust-lang-nursery.github.io/rust-cookbook/science/mathematics/statistics.html
|
||||
let sum: f64 = data.iter().sum();
|
||||
let count = data.len();
|
||||
|
||||
match count {
|
||||
positive if positive > 0 => Some(sum / count as f64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn std_deviation(data: &[f64]) -> Option<StandardDev> {
|
||||
// from https://rust-lang-nursery.github.io/rust-cookbook/science/mathematics/statistics.html
|
||||
// replacing the mean by 0. as we theoretically know it
|
||||
match (mean(data), data.len()) {
|
||||
(Some(_data_mean), count) if count > 0 => {
|
||||
let variance = data
|
||||
.iter()
|
||||
.map(|&value| {
|
||||
let diff = 0. - value;
|
||||
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f64>()
|
||||
/ count as f64;
|
||||
|
||||
Some(StandardDev::from_standard_dev(variance.sqrt()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_torus_diff<Scalar: UnsignedInteger>(
|
||||
errs: &mut [f64],
|
||||
output: Vec<Scalar>,
|
||||
input: Vec<Scalar>,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
bit: Scalar,
|
||||
) {
|
||||
if bit == Scalar::ONE {
|
||||
for (&out, (&inp, err)) in output.iter().zip(input.iter().zip(errs.iter_mut())) {
|
||||
*err = torus_modular_diff(out, inp, ciphertext_modulus);
|
||||
}
|
||||
} else if bit == Scalar::ZERO {
|
||||
for (&out, err) in output.iter().zip(errs.iter_mut()) {
|
||||
*err = torus_modular_diff(out, Scalar::ZERO, ciphertext_modulus);
|
||||
}
|
||||
} else {
|
||||
panic!("Not a bit: {:?}", bit);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
struct BaseLevel {
|
||||
base: DecompositionBaseLog,
|
||||
level: DecompositionLevelCount,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
struct HyperCubeParams {
|
||||
glwe_dimension: GlweDimension,
|
||||
base_level: BaseLevel,
|
||||
polynomial_size: PolynomialSize,
|
||||
}
|
||||
|
||||
fn filter_b_l(bases: &[usize], levels: &[usize], preserved_mantissa: usize) -> Vec<BaseLevel> {
|
||||
let mut bases_levels = vec![];
|
||||
for (b, l) in iproduct!(bases, levels) {
|
||||
if b * l <= preserved_mantissa {
|
||||
bases_levels.push(BaseLevel {
|
||||
base: DecompositionBaseLog(*b),
|
||||
level: DecompositionLevelCount(*l),
|
||||
});
|
||||
}
|
||||
}
|
||||
bases_levels
|
||||
}
|
||||
|
||||
fn ggsw_scalar_size(k: GlweDimension, l: DecompositionLevelCount, n: PolynomialSize) -> usize {
|
||||
let (k, l, n) = (k.0, l.0, n.0);
|
||||
(k + 1).pow(2) * l * n
|
||||
}
|
||||
|
||||
fn scalar_muls_per_ext_prod(
|
||||
k: GlweDimension,
|
||||
l: DecompositionLevelCount,
|
||||
n: PolynomialSize,
|
||||
) -> usize {
|
||||
// Each coefficient of the ggsw is involved once in an fmadd operation
|
||||
ggsw_scalar_size(k, l, n)
|
||||
}
|
||||
|
||||
fn ext_prod_cost(k: GlweDimension, l: DecompositionLevelCount, n: PolynomialSize) -> usize {
|
||||
// Conversions going from integer to float and from float to integer
|
||||
let conversion_cost = 2 * k.to_glwe_size().0 * n.0;
|
||||
// Fwd and back
|
||||
let fft_cost = 2 * l.0 * k.to_glwe_size().0 * n.0 * n.0.ilog2() as usize;
|
||||
scalar_muls_per_ext_prod(k, l, n) + conversion_cost + fft_cost
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn ks_cost(
|
||||
input_lwe_dimenion: LweDimension,
|
||||
output_lwe_dimension: LweDimension,
|
||||
ks_level_count: DecompositionLevelCount,
|
||||
) -> usize {
|
||||
// times 2 as it's multiply and add
|
||||
2 * input_lwe_dimenion.0 * ks_level_count.0 * output_lwe_dimension.0
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn pbs_cost(
|
||||
w: LweDimension,
|
||||
k: GlweDimension,
|
||||
l: DecompositionLevelCount,
|
||||
n: PolynomialSize,
|
||||
) -> usize {
|
||||
w.0 * ext_prod_cost(k, l, n)
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = Args::parse();
|
||||
let tot = args.tot;
|
||||
let id = args.id;
|
||||
let total_repetitions = args.repetitions;
|
||||
let base_sample_size = args.sample_size;
|
||||
let algo = args.algorithm;
|
||||
let dir = &args.dir;
|
||||
let timing_only = args.timing_only;
|
||||
let use_fft = args.use_fft;
|
||||
|
||||
if algo.is_empty() {
|
||||
panic!("No algorithm provided")
|
||||
}
|
||||
|
||||
let grouping_factor = match algo.as_str() {
|
||||
MULTI_BIT_EXT_PROD_ALGO | STD_MULTI_BIT_EXT_PROD_ALGO => Some(LweBskGroupingFactor(
|
||||
args.multi_bit_grouping_factor
|
||||
.expect("Required multi_bit_grouping_factor when sampling multi bit alogrithms"),
|
||||
)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let modulus: u128 = match args.modulus_log2 {
|
||||
Some(modulus_log2) => {
|
||||
if modulus_log2 > 128 {
|
||||
panic!("Got modulus_log2 > 128,this is not supported");
|
||||
}
|
||||
|
||||
match algo.as_str() {
|
||||
EXT_PROD_ALGO | MULTI_BIT_EXT_PROD_ALGO | STD_MULTI_BIT_EXT_PROD_ALGO => {
|
||||
if modulus_log2 > 64 {
|
||||
panic!("Got modulus_log2 > 64,for 64 bits scalars");
|
||||
}
|
||||
|
||||
1u128 << modulus_log2
|
||||
}
|
||||
EXT_PROD_U128_SPLIT_ALGO | EXT_PROD_U128_ALGO => {
|
||||
if modulus_log2 == 128 {
|
||||
// native
|
||||
0
|
||||
} else {
|
||||
1u128 << modulus_log2
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
// Native
|
||||
None => 0,
|
||||
};
|
||||
|
||||
// TODO manage moduli < 2^53
|
||||
let (stepped_levels_cutoff, max_base_log_inclusive, preserved_mantissa) = match algo.as_str() {
|
||||
EXT_PROD_U128_ALGO | EXT_PROD_U128_SPLIT_ALGO => (41, 128, 106),
|
||||
_ => (21, 64, 53),
|
||||
};
|
||||
|
||||
if timing_only {
|
||||
return ks_pbs_timing::timing_experiment(&algo, preserved_mantissa, modulus);
|
||||
}
|
||||
|
||||
assert_ne!(
|
||||
tot, 0,
|
||||
"Got tot = 0 for noise sampling experiment,unsupported"
|
||||
);
|
||||
|
||||
// Parameter Grid
|
||||
let polynomial_sizes = vec![
|
||||
PolynomialSize(1 << 8),
|
||||
PolynomialSize(1 << 9),
|
||||
PolynomialSize(1 << 10),
|
||||
PolynomialSize(1 << 11),
|
||||
PolynomialSize(1 << 12),
|
||||
PolynomialSize(1 << 13),
|
||||
PolynomialSize(1 << 14),
|
||||
];
|
||||
let max_polynomial_size = polynomial_sizes.iter().copied().max().unwrap();
|
||||
let glwe_dimensions = vec![
|
||||
GlweDimension(1),
|
||||
GlweDimension(2),
|
||||
GlweDimension(3),
|
||||
GlweDimension(4),
|
||||
GlweDimension(5),
|
||||
];
|
||||
|
||||
let base_logs: Vec<_> = (1..=max_base_log_inclusive).collect();
|
||||
let mut levels = (1..stepped_levels_cutoff).collect::<Vec<_>>();
|
||||
let mut stepped_levels = (stepped_levels_cutoff..=max_base_log_inclusive)
|
||||
.step_by(args.steps)
|
||||
.collect::<Vec<_>>();
|
||||
levels.append(&mut stepped_levels);
|
||||
let bases_levels = filter_b_l(&base_logs, &levels, preserved_mantissa);
|
||||
|
||||
let hypercube = iproduct!(glwe_dimensions, bases_levels, polynomial_sizes);
|
||||
let mut hypercube: Vec<HyperCubeParams> = hypercube
|
||||
.map(
|
||||
|(glwe_dimension, base_level, polynomial_size)| HyperCubeParams {
|
||||
glwe_dimension,
|
||||
base_level,
|
||||
polynomial_size,
|
||||
},
|
||||
)
|
||||
.collect();
|
||||
|
||||
hypercube.sort_by(|a, b| {
|
||||
let k_a = a.glwe_dimension;
|
||||
let l_a = a.base_level.level;
|
||||
let n_a = a.polynomial_size;
|
||||
|
||||
let k_b = b.glwe_dimension;
|
||||
let l_b = b.base_level.level;
|
||||
let n_b = b.polynomial_size;
|
||||
|
||||
let muls_a = ext_prod_cost(k_a, l_a, n_a);
|
||||
let muls_b = ext_prod_cost(k_b, l_b, n_b);
|
||||
|
||||
muls_a.cmp(&muls_b)
|
||||
});
|
||||
|
||||
// Pick elements of increasing complexity stepping by the number of threads to balance the
|
||||
// computation cost among threads
|
||||
let chunk: Vec<_> = hypercube.iter().skip(id).step_by(tot).collect();
|
||||
let chunk_size = chunk.len();
|
||||
|
||||
println!(
|
||||
"-> Thread #{id} computing chunk #{id} of length {chunk_size} \
|
||||
(processing elements #{id} + k * {tot})",
|
||||
);
|
||||
|
||||
prepare_output_file_header(dir, id);
|
||||
|
||||
let mut seeder = new_seeder();
|
||||
let seeder = seeder.as_mut();
|
||||
|
||||
let mut secret_random_generator =
|
||||
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
|
||||
let mut encryption_random_generator =
|
||||
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
|
||||
|
||||
let u64_tool =
|
||||
|secret_rng: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encrypt_rng: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>| {
|
||||
for (
|
||||
curr_idx,
|
||||
HyperCubeParams {
|
||||
glwe_dimension,
|
||||
base_level:
|
||||
BaseLevel {
|
||||
base: decomposition_base_log,
|
||||
level: decomposition_level_count,
|
||||
},
|
||||
polynomial_size,
|
||||
},
|
||||
) in chunk.iter().enumerate()
|
||||
{
|
||||
let glwe_dimension = *glwe_dimension;
|
||||
let decomposition_base_log = *decomposition_base_log;
|
||||
let decomposition_level_count = *decomposition_level_count;
|
||||
let polynomial_size = *polynomial_size;
|
||||
let ciphertext_modulus = CiphertextModulus::try_new(modulus).unwrap();
|
||||
|
||||
let modulus_log2 = if ciphertext_modulus.is_native_modulus() {
|
||||
u64::BITS
|
||||
} else if ciphertext_modulus.is_power_of_two() {
|
||||
ciphertext_modulus.get_custom_modulus().ilog2()
|
||||
} else {
|
||||
todo!("Non power of 2 moduli are currently not supported")
|
||||
};
|
||||
|
||||
println!("Chunk part: {:?}/{chunk_size:?} done", curr_idx + 1);
|
||||
let sample_size = base_sample_size * max_polynomial_size.0 / polynomial_size.0;
|
||||
let ggsw_noise = Gaussian::from_dispersion_parameter(
|
||||
minimal_variance_for_security(glwe_dimension, polynomial_size, modulus_log2),
|
||||
0.0,
|
||||
);
|
||||
// We measure the noise added to a GLWE ciphertext,here we can choose to have no
|
||||
// input noise
|
||||
// It also avoid potential cases where the noise is so big it gets decomposed
|
||||
// during computations,it's an assumption we apparently already make ("small noise
|
||||
// regime")
|
||||
let glwe_noise = Gaussian::from_dispersion_parameter(Variance(0.0), 0.0);
|
||||
// minimal_variance_for_security_64(glwe_dimension, poly_size);
|
||||
|
||||
let parameters = GlweCiphertextGgswCiphertextExternalProductParameters::<u64> {
|
||||
ggsw_noise,
|
||||
glwe_noise,
|
||||
glwe_dimension,
|
||||
ggsw_encrypted_value: 1,
|
||||
polynomial_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
ciphertext_modulus,
|
||||
};
|
||||
|
||||
println!("params: {parameters:?}");
|
||||
|
||||
let noise_prediction = match algo.as_str() {
|
||||
EXT_PROD_ALGO => {
|
||||
external_product_no_fft_additive_variance132_bits_security_gaussian(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
2.0f64.powi(modulus_log2 as i32),
|
||||
)
|
||||
}
|
||||
MULTI_BIT_EXT_PROD_ALGO =>multi_bit_external_product_no_fft_additive_variance_132_bits_security_gaussian(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
grouping_factor.unwrap().0 as f64,
|
||||
2.0f64.powi(modulus_log2 as i32),
|
||||
),
|
||||
STD_MULTI_BIT_EXT_PROD_ALGO => multi_bit_external_product_no_fft_additive_variance_132_bits_security_gaussian(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
grouping_factor.unwrap().0 as f64,
|
||||
2.0f64.powi(modulus_log2 as i32),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let fft = Fft::new(parameters.polynomial_size);
|
||||
let mut computation_buffers = ComputationBuffers::new();
|
||||
computation_buffers.resize(
|
||||
add_external_product_assign_mem_optimized_requirement::<u64>(
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
fft.as_view(),
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required()
|
||||
.max(
|
||||
fft.as_view()
|
||||
.forward_scratch()
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut errors = vec![0.; sample_size * polynomial_size.0 * total_repetitions];
|
||||
|
||||
if noise_prediction.get_variance() < 1. / 12. {
|
||||
let mut total_runtime_ns = 0u128;
|
||||
let mut total_prep_time_ns = 0u128;
|
||||
|
||||
for (_, errs) in (0..total_repetitions)
|
||||
.zip(errors.chunks_mut(sample_size * polynomial_size.0))
|
||||
{
|
||||
let mut raw_inputs = Vec::with_capacity(sample_size);
|
||||
let mut outputs = Vec::with_capacity(sample_size);
|
||||
|
||||
let (sample_runtime_ns, prep_time_ns) = match algo.as_str() {
|
||||
EXT_PROD_ALGO => classic_pbs_external_product(
|
||||
¶meters,
|
||||
&mut raw_inputs,
|
||||
&mut outputs,
|
||||
sample_size,
|
||||
secret_rng,
|
||||
encrypt_rng,
|
||||
use_fft,
|
||||
fft.as_view(),
|
||||
&mut computation_buffers,
|
||||
),
|
||||
MULTI_BIT_EXT_PROD_ALGO => multi_bit_pbs_external_product(
|
||||
¶meters,
|
||||
&mut raw_inputs,
|
||||
&mut outputs,
|
||||
sample_size,
|
||||
secret_rng,
|
||||
encrypt_rng,
|
||||
use_fft,
|
||||
fft.as_view(),
|
||||
&mut computation_buffers,
|
||||
grouping_factor.unwrap(),
|
||||
),
|
||||
STD_MULTI_BIT_EXT_PROD_ALGO => std_multi_bit_pbs_external_product(
|
||||
¶meters,
|
||||
&mut raw_inputs,
|
||||
&mut outputs,
|
||||
sample_size,
|
||||
secret_rng,
|
||||
encrypt_rng,
|
||||
use_fft,
|
||||
fft.as_view(),
|
||||
&mut computation_buffers,
|
||||
grouping_factor.unwrap(),
|
||||
),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
total_runtime_ns += sample_runtime_ns;
|
||||
total_prep_time_ns += prep_time_ns;
|
||||
|
||||
let raw_input_plaintext_vector =
|
||||
raw_inputs.into_iter().flatten().collect::<Vec<_>>();
|
||||
let output_plaintext_vector =
|
||||
outputs.into_iter().flatten().collect::<Vec<_>>();
|
||||
|
||||
compute_torus_diff(
|
||||
errs,
|
||||
output_plaintext_vector,
|
||||
raw_input_plaintext_vector,
|
||||
parameters.ciphertext_modulus,
|
||||
parameters.ggsw_encrypted_value,
|
||||
);
|
||||
}
|
||||
let _mean_err = mean(&errors).unwrap();
|
||||
let std_err = std_deviation(&errors).unwrap();
|
||||
let mean_runtime_ns =
|
||||
total_runtime_ns / ((total_repetitions * sample_size) as u128);
|
||||
// GGSW is prepared only once per sample
|
||||
let mean_prep_time_ns = total_prep_time_ns / (total_repetitions as u128);
|
||||
write_to_file(
|
||||
¶meters,
|
||||
parameters.glwe_noise.standard_dev(),
|
||||
std_err,
|
||||
ggsw_noise.standard_dev(),
|
||||
variance_to_stddev(noise_prediction),
|
||||
mean_runtime_ns,
|
||||
mean_prep_time_ns,
|
||||
dir,
|
||||
id,
|
||||
);
|
||||
|
||||
// TODO output raw data
|
||||
} else {
|
||||
write_to_file(
|
||||
¶meters,
|
||||
parameters.glwe_noise.standard_dev(),
|
||||
variance_to_stddev(Variance::from_variance(1. / 12.)),
|
||||
ggsw_noise.standard_dev(),
|
||||
variance_to_stddev(Variance::from_variance(1. / 12.)),
|
||||
0,
|
||||
0,
|
||||
dir,
|
||||
id,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let u128_tool =
|
||||
|secret_rng: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encrypt_rng: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>| {
|
||||
for (
|
||||
curr_idx,
|
||||
HyperCubeParams {
|
||||
glwe_dimension,
|
||||
base_level:
|
||||
BaseLevel {
|
||||
base: decomposition_base_log,
|
||||
level: decomposition_level_count,
|
||||
},
|
||||
polynomial_size,
|
||||
},
|
||||
) in chunk.iter().enumerate()
|
||||
{
|
||||
let glwe_dimension = *glwe_dimension;
|
||||
let decomposition_base_log = *decomposition_base_log;
|
||||
let decomposition_level_count = *decomposition_level_count;
|
||||
let polynomial_size = *polynomial_size;
|
||||
let ciphertext_modulus = CiphertextModulus::try_new(modulus).unwrap();
|
||||
|
||||
let modulus_log2 = if ciphertext_modulus.is_native_modulus() {
|
||||
u128::BITS
|
||||
} else if ciphertext_modulus.is_power_of_two() {
|
||||
ciphertext_modulus.get_custom_modulus().ilog2()
|
||||
} else {
|
||||
todo!("Non power of 2 moduli are currently not supported")
|
||||
};
|
||||
|
||||
println!("Chunk part: {:?}/{chunk_size:?} done", curr_idx + 1);
|
||||
let sample_size = base_sample_size * max_polynomial_size.0 / polynomial_size.0;
|
||||
let ggsw_noise = Gaussian::from_dispersion_parameter(
|
||||
minimal_variance_for_security(glwe_dimension, polynomial_size, modulus_log2),
|
||||
0.0,
|
||||
);
|
||||
// We measure the noise added to a GLWE ciphertext,here we can choose to have no
|
||||
// input noise
|
||||
// It also avoid potential cases where the noise is so big it gets decomposed
|
||||
// during computations,it's an assumption we apparently already make ("small noise
|
||||
// regime")
|
||||
let glwe_noise = Gaussian::from_dispersion_parameter(Variance(0.0), 0.0);
|
||||
// minimal_variance_for_security_64(glwe_dimension, poly_size));
|
||||
|
||||
let parameters = GlweCiphertextGgswCiphertextExternalProductParameters::<u128> {
|
||||
ggsw_noise,
|
||||
glwe_noise,
|
||||
glwe_dimension,
|
||||
ggsw_encrypted_value: 1,
|
||||
polynomial_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
ciphertext_modulus,
|
||||
};
|
||||
|
||||
println!("params: {parameters:?}");
|
||||
|
||||
let noise_prediction = match algo.as_str() {
|
||||
EXT_PROD_U128_SPLIT_ALGO | EXT_PROD_U128_ALGO => {
|
||||
external_product_no_fft_additive_variance132_bits_security_gaussian(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
2.0f64.powi(modulus_log2 as i32),
|
||||
)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let fft = Fft128::new(parameters.polynomial_size);
|
||||
let mut computation_buffers = ComputationBuffers::new();
|
||||
computation_buffers.resize(
|
||||
programmable_bootstrap_f128_lwe_ciphertext_mem_optimized_requirement::<u128>(
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
fft.as_view(),
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required()
|
||||
.max(
|
||||
fft.as_view()
|
||||
.backward_scratch()
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut errors = vec![0.; sample_size * polynomial_size.0 * total_repetitions];
|
||||
|
||||
if noise_prediction.get_variance() < 1. / 12. {
|
||||
let mut total_runtime_ns = 0u128;
|
||||
let mut total_prep_time_ns = 0u128;
|
||||
|
||||
for (_, errs) in (0..total_repetitions)
|
||||
.zip(errors.chunks_mut(sample_size * polynomial_size.0))
|
||||
{
|
||||
let mut raw_inputs = Vec::with_capacity(sample_size);
|
||||
let mut outputs = Vec::with_capacity(sample_size);
|
||||
|
||||
let (sample_runtime_ns, prep_time_ns) = match algo.as_str() {
|
||||
EXT_PROD_U128_SPLIT_ALGO => classic_pbs_external_product_u128_split(
|
||||
¶meters,
|
||||
&mut raw_inputs,
|
||||
&mut outputs,
|
||||
sample_size,
|
||||
secret_rng,
|
||||
encrypt_rng,
|
||||
fft.as_view(),
|
||||
&mut computation_buffers,
|
||||
),
|
||||
EXT_PROD_U128_ALGO => classic_pbs_external_product_u128(
|
||||
¶meters,
|
||||
&mut raw_inputs,
|
||||
&mut outputs,
|
||||
sample_size,
|
||||
secret_rng,
|
||||
encrypt_rng,
|
||||
fft.as_view(),
|
||||
&mut computation_buffers,
|
||||
),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
total_runtime_ns += sample_runtime_ns;
|
||||
total_prep_time_ns += prep_time_ns;
|
||||
|
||||
let raw_input_plaintext_vector =
|
||||
raw_inputs.into_iter().flatten().collect::<Vec<_>>();
|
||||
let output_plaintext_vector =
|
||||
outputs.into_iter().flatten().collect::<Vec<_>>();
|
||||
|
||||
compute_torus_diff(
|
||||
errs,
|
||||
output_plaintext_vector,
|
||||
raw_input_plaintext_vector,
|
||||
parameters.ciphertext_modulus,
|
||||
parameters.ggsw_encrypted_value,
|
||||
);
|
||||
}
|
||||
let _mean_err = mean(&errors).unwrap();
|
||||
let std_err = std_deviation(&errors).unwrap();
|
||||
let mean_runtime_ns =
|
||||
total_runtime_ns / ((total_repetitions * sample_size) as u128);
|
||||
// GGSW is prepared only once per sample
|
||||
let mean_prep_time_ns = total_prep_time_ns / (total_repetitions as u128);
|
||||
write_to_file(
|
||||
¶meters,
|
||||
parameters.glwe_noise.standard_dev(),
|
||||
std_err,
|
||||
ggsw_noise.standard_dev(),
|
||||
variance_to_stddev(noise_prediction),
|
||||
mean_runtime_ns,
|
||||
mean_prep_time_ns,
|
||||
dir,
|
||||
id,
|
||||
);
|
||||
|
||||
// TODO output raw data
|
||||
} else {
|
||||
write_to_file(
|
||||
¶meters,
|
||||
parameters.glwe_noise.standard_dev(),
|
||||
variance_to_stddev(Variance::from_variance(1. / 12.)),
|
||||
ggsw_noise.standard_dev(),
|
||||
variance_to_stddev(Variance::from_variance(1. / 12.)),
|
||||
0,
|
||||
0,
|
||||
dir,
|
||||
id,
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match algo.as_str() {
|
||||
EXT_PROD_ALGO | MULTI_BIT_EXT_PROD_ALGO | STD_MULTI_BIT_EXT_PROD_ALGO => u64_tool(
|
||||
&mut secret_random_generator,
|
||||
&mut encryption_random_generator,
|
||||
),
|
||||
EXT_PROD_U128_ALGO | EXT_PROD_U128_SPLIT_ALGO => u128_tool(
|
||||
&mut secret_random_generator,
|
||||
&mut encryption_random_generator,
|
||||
),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
}
|
||||
30
tfhe-rs-cost-model/src/noise_estimation.rs
Normal file
30
tfhe-rs-cost-model/src/noise_estimation.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
|
||||
pub fn classic_pbs_estimate_external_product_noise_with_binary_ggsw_and_glwe<D1>(
|
||||
_polynomial_size: PolynomialSize,
|
||||
_glwe_dimension: GlweDimension,
|
||||
_ggsw_noise: D1,
|
||||
_base_log: DecompositionBaseLog,
|
||||
_level: DecompositionLevelCount,
|
||||
_log2_modulus: u32,
|
||||
) -> Variance
|
||||
where
|
||||
D1: DispersionParameter,
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
|
||||
pub fn multi_bit_pbs_estimate_external_product_noise_with_binary_ggsw_and_glwe<D1>(
|
||||
_polynomial_size: PolynomialSize,
|
||||
_glwe_dimension: GlweDimension,
|
||||
_ggsw_noise: D1,
|
||||
_base_log: DecompositionBaseLog,
|
||||
_level: DecompositionLevelCount,
|
||||
_log2_modulus: u32,
|
||||
_grouping_factor: LweBskGroupingFactor,
|
||||
) -> Variance
|
||||
where
|
||||
D1: DispersionParameter,
|
||||
{
|
||||
todo!()
|
||||
}
|
||||
533
tfhe-rs-cost-model/src/operators/classic_pbs.rs
Normal file
533
tfhe-rs-cost-model/src/operators/classic_pbs.rs
Normal file
@@ -0,0 +1,533 @@
|
||||
use crate::GlweCiphertextGgswCiphertextExternalProductParameters;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use tfhe::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
use tfhe::core_crypto::fft_impl::fft128::crypto::ggsw::{
|
||||
add_external_product_assign, Fourier128GgswCiphertext,
|
||||
};
|
||||
use tfhe::core_crypto::fft_impl::fft128_u128::crypto::ggsw::add_external_product_assign_split;
|
||||
use tfhe::core_crypto::fft_impl::fft128_u128::math::fft::Fft128View;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::ggsw::FourierGgswCiphertext;
|
||||
use tfhe::core_crypto::fft_impl::fft64::math::fft::FftView;
|
||||
use tfhe::core_crypto::prelude::{
|
||||
add_external_product_assign_mem_optimized, allocate_and_generate_new_binary_glwe_secret_key,
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized, decrypt_glwe_ciphertext,
|
||||
encrypt_constant_ggsw_ciphertext, encrypt_glwe_ciphertext,
|
||||
karatsuba_add_external_product_assign_mem_optimized, ActivatedRandomGenerator,
|
||||
CiphertextModulus, Cleartext, ComputationBuffers, EncryptionRandomGenerator, GgswCiphertext,
|
||||
GlweCiphertext, GlweCiphertextMutView, GlweCiphertextView, Numeric, PlaintextCount,
|
||||
PlaintextList, SecretRandomGenerator,
|
||||
};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn classic_pbs_external_product(
|
||||
parameters: &GlweCiphertextGgswCiphertextExternalProductParameters<u64>,
|
||||
raw_inputs: &mut Vec<Vec<u64>>,
|
||||
outputs: &mut Vec<Vec<u64>>,
|
||||
sample_size: usize,
|
||||
secret_random_generator: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encryption_random_generator: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>,
|
||||
use_fft: bool,
|
||||
fft: FftView,
|
||||
computation_buffers: &mut ComputationBuffers,
|
||||
) -> (u128, u128) {
|
||||
let ciphertext_modulus = parameters.ciphertext_modulus;
|
||||
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
parameters.glwe_dimension,
|
||||
parameters.polynomial_size,
|
||||
secret_random_generator,
|
||||
);
|
||||
|
||||
let mut std_ggsw = GgswCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.decomposition_base_log,
|
||||
parameters.decomposition_level_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_constant_ggsw_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut std_ggsw,
|
||||
Cleartext(parameters.ggsw_encrypted_value),
|
||||
parameters.ggsw_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fourier_ggsw = FourierGgswCiphertext::new(
|
||||
std_ggsw.glwe_size(),
|
||||
std_ggsw.polynomial_size(),
|
||||
std_ggsw.decomposition_base_log(),
|
||||
std_ggsw.decomposition_level_count(),
|
||||
);
|
||||
|
||||
if use_fft {
|
||||
convert_standard_ggsw_ciphertext_to_fourier_mem_optimized(
|
||||
&std_ggsw,
|
||||
&mut fourier_ggsw,
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut sample_runtime_ns = 0u128;
|
||||
|
||||
for _ in 0..sample_size {
|
||||
let input_plaintext_list =
|
||||
PlaintextList::new(0u64, PlaintextCount(parameters.polynomial_size.0));
|
||||
// encryption_random_generator
|
||||
// .fill_slice_with_random_uniform_mask(input_plaintext_list.as_mut());
|
||||
// let scaling_to_native_torus = parameters
|
||||
// .ciphertext_modulus
|
||||
// .get_power_of_two_scaling_to_native_torus();
|
||||
// // Shift to match the behavior of the previous concrete-core fixtures
|
||||
// // Divide as encryption will encode the power of two in the MSBs
|
||||
// input_plaintext_list.as_mut().iter_mut().for_each(|x| {
|
||||
// *x = (*x << (<u64 as Numeric>::BITS - parameters.decomposition_base_log.0))
|
||||
// / scaling_to_native_torus
|
||||
// });
|
||||
|
||||
// Sanity check
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let modulus: u64 = ciphertext_modulus.get_custom_modulus() as u64;
|
||||
assert!(input_plaintext_list.as_ref().iter().all(|x| *x < modulus));
|
||||
}
|
||||
|
||||
let mut input_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut input_glwe_ciphertext,
|
||||
&input_plaintext_list,
|
||||
parameters.glwe_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if use_fft {
|
||||
add_external_product_assign_mem_optimized(
|
||||
&mut output_glwe_ciphertext,
|
||||
&fourier_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
} else {
|
||||
karatsuba_add_external_product_assign_mem_optimized(
|
||||
&mut output_glwe_ciphertext,
|
||||
&std_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
output_glwe_ciphertext
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed().as_nanos();
|
||||
sample_runtime_ns += elapsed;
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
decrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&output_glwe_ciphertext,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
// Sanity check
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let modulus: u64 = ciphertext_modulus.get_custom_modulus() as u64;
|
||||
assert!(output_plaintext_list.as_ref().iter().all(|x| *x < modulus));
|
||||
}
|
||||
|
||||
raw_inputs.push(input_plaintext_list.into_container());
|
||||
outputs.push(output_plaintext_list.into_container());
|
||||
}
|
||||
|
||||
// No prep time in this case
|
||||
(sample_runtime_ns, 0)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn classic_pbs_external_product_u128_split(
|
||||
parameters: &GlweCiphertextGgswCiphertextExternalProductParameters<u128>,
|
||||
raw_inputs: &mut Vec<Vec<u128>>,
|
||||
outputs: &mut Vec<Vec<u128>>,
|
||||
sample_size: usize,
|
||||
secret_random_generator: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encryption_random_generator: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>,
|
||||
fft: Fft128View,
|
||||
computation_buffers: &mut ComputationBuffers,
|
||||
) -> (u128, u128) {
|
||||
let ciphertext_modulus = parameters.ciphertext_modulus;
|
||||
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
parameters.glwe_dimension,
|
||||
parameters.polynomial_size,
|
||||
secret_random_generator,
|
||||
);
|
||||
|
||||
let mut std_ggsw = GgswCiphertext::new(
|
||||
0u128,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.decomposition_base_log,
|
||||
parameters.decomposition_level_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_constant_ggsw_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut std_ggsw,
|
||||
Cleartext(parameters.ggsw_encrypted_value),
|
||||
parameters.ggsw_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fourier_ggsw = Fourier128GgswCiphertext::new(
|
||||
std_ggsw.glwe_size(),
|
||||
std_ggsw.polynomial_size(),
|
||||
std_ggsw.decomposition_base_log(),
|
||||
std_ggsw.decomposition_level_count(),
|
||||
);
|
||||
|
||||
fourier_ggsw
|
||||
.as_mut_view()
|
||||
.fill_with_forward_fourier(&std_ggsw, fft);
|
||||
|
||||
let mut sample_runtime_ns = 0u128;
|
||||
|
||||
for _ in 0..sample_size {
|
||||
let mut input_plaintext_list =
|
||||
PlaintextList::new(0u128, PlaintextCount(parameters.polynomial_size.0));
|
||||
encryption_random_generator
|
||||
.fill_slice_with_random_uniform_mask(input_plaintext_list.as_mut());
|
||||
let scaling_to_native_torus = parameters
|
||||
.ciphertext_modulus
|
||||
.get_power_of_two_scaling_to_native_torus();
|
||||
// Shift to match the behavior of the previous concrete-core fixtures
|
||||
// Divide as encryption will encode the power of two in the MSBs
|
||||
input_plaintext_list.as_mut().iter_mut().for_each(|x| {
|
||||
*x = (*x << (<u128 as Numeric>::BITS - parameters.decomposition_base_log.0))
|
||||
/ scaling_to_native_torus
|
||||
});
|
||||
|
||||
// Sanity check
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let modulus = ciphertext_modulus.get_custom_modulus();
|
||||
assert!(input_plaintext_list.as_ref().iter().all(|x| *x < modulus));
|
||||
}
|
||||
|
||||
let mut input_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u128,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut input_glwe_ciphertext,
|
||||
&input_plaintext_list,
|
||||
parameters.glwe_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u128,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let stack = computation_buffers.stack();
|
||||
|
||||
let align = CACHELINE_ALIGN;
|
||||
|
||||
let (input_glwe_lo, stack) = stack.collect_aligned(
|
||||
align,
|
||||
input_glwe_ciphertext.as_ref().iter().map(|i| *i as u64),
|
||||
);
|
||||
let (input_glwe_hi, stack) = stack.collect_aligned(
|
||||
align,
|
||||
input_glwe_ciphertext
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|i| (*i >> 64) as u64),
|
||||
);
|
||||
|
||||
let input_glwe_lo = GlweCiphertextView::from_container(
|
||||
&*input_glwe_lo,
|
||||
input_glwe_ciphertext.polynomial_size(),
|
||||
// Here we split a u128 to two u64 containers and the ciphertext modulus does not
|
||||
// match anymore in terms of the underlying Scalar type, so we'll provide a dummy
|
||||
// native modulus
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
let input_glwe_hi = GlweCiphertextView::from_container(
|
||||
&*input_glwe_hi,
|
||||
input_glwe_ciphertext.polynomial_size(),
|
||||
// Here we split a u128 to two u64 containers and the ciphertext modulus does not
|
||||
// match anymore in terms of the underlying Scalar type, so we'll provide a dummy
|
||||
// native modulus
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
|
||||
let (output_glwe_lo, stack) = stack.collect_aligned(
|
||||
align,
|
||||
output_glwe_ciphertext.as_ref().iter().map(|i| *i as u64),
|
||||
);
|
||||
let (output_glwe_hi, stack) = stack.collect_aligned(
|
||||
align,
|
||||
output_glwe_ciphertext
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|i| (*i >> 64) as u64),
|
||||
);
|
||||
|
||||
let mut output_glwe_lo = GlweCiphertextMutView::from_container(
|
||||
&mut *output_glwe_lo,
|
||||
output_glwe_ciphertext.polynomial_size(),
|
||||
// Here we split a u128 to two u64 containers and the ciphertext modulus does not
|
||||
// match anymore in terms of the underlying Scalar type, so we'll provide a dummy
|
||||
// native modulus
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
let mut output_glwe_hi = GlweCiphertextMutView::from_container(
|
||||
&mut *output_glwe_hi,
|
||||
output_glwe_ciphertext.polynomial_size(),
|
||||
// Here we split a u128 to two u64 containers and the ciphertext modulus does not
|
||||
// match anymore in terms of the underlying Scalar type, so we'll provide a dummy
|
||||
// native modulus
|
||||
CiphertextModulus::new_native(),
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
add_external_product_assign_split(
|
||||
&mut output_glwe_lo,
|
||||
&mut output_glwe_hi,
|
||||
&fourier_ggsw,
|
||||
&input_glwe_lo,
|
||||
&input_glwe_hi,
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
|
||||
let elapsed = start.elapsed().as_nanos();
|
||||
sample_runtime_ns += elapsed;
|
||||
|
||||
output_glwe_ciphertext
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.zip(
|
||||
output_glwe_lo
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(output_glwe_hi.as_ref().iter()),
|
||||
)
|
||||
.for_each(|(out, (&lo, &hi))| *out = lo as u128 | ((hi as u128) << 64));
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
output_glwe_ciphertext
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
decrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&output_glwe_ciphertext,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
// Sanity check
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let modulus = ciphertext_modulus.get_custom_modulus();
|
||||
assert!(output_plaintext_list.as_ref().iter().all(|x| *x < modulus));
|
||||
}
|
||||
|
||||
raw_inputs.push(input_plaintext_list.into_container());
|
||||
outputs.push(output_plaintext_list.into_container());
|
||||
}
|
||||
|
||||
// No prep time in this case
|
||||
(sample_runtime_ns, 0)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn classic_pbs_external_product_u128(
|
||||
parameters: &GlweCiphertextGgswCiphertextExternalProductParameters<u128>,
|
||||
raw_inputs: &mut Vec<Vec<u128>>,
|
||||
outputs: &mut Vec<Vec<u128>>,
|
||||
sample_size: usize,
|
||||
secret_random_generator: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encryption_random_generator: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>,
|
||||
fft: Fft128View,
|
||||
computation_buffers: &mut ComputationBuffers,
|
||||
) -> (u128, u128) {
|
||||
let ciphertext_modulus = parameters.ciphertext_modulus;
|
||||
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
parameters.glwe_dimension,
|
||||
parameters.polynomial_size,
|
||||
secret_random_generator,
|
||||
);
|
||||
|
||||
let mut std_ggsw = GgswCiphertext::new(
|
||||
0u128,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.decomposition_base_log,
|
||||
parameters.decomposition_level_count,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_constant_ggsw_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut std_ggsw,
|
||||
Cleartext(parameters.ggsw_encrypted_value),
|
||||
parameters.ggsw_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fourier_ggsw = Fourier128GgswCiphertext::new(
|
||||
std_ggsw.glwe_size(),
|
||||
std_ggsw.polynomial_size(),
|
||||
std_ggsw.decomposition_base_log(),
|
||||
std_ggsw.decomposition_level_count(),
|
||||
);
|
||||
|
||||
fourier_ggsw
|
||||
.as_mut_view()
|
||||
.fill_with_forward_fourier(&std_ggsw, fft);
|
||||
|
||||
let mut sample_runtime_ns = 0u128;
|
||||
|
||||
for _ in 0..sample_size {
|
||||
let mut input_plaintext_list =
|
||||
PlaintextList::new(0u128, PlaintextCount(parameters.polynomial_size.0));
|
||||
encryption_random_generator
|
||||
.fill_slice_with_random_uniform_mask(input_plaintext_list.as_mut());
|
||||
let scaling_to_native_torus = parameters
|
||||
.ciphertext_modulus
|
||||
.get_power_of_two_scaling_to_native_torus();
|
||||
// Shift to match the behavior of the previous concrete-core fixtures
|
||||
// Divide as encryption will encode the power of two in the MSBs
|
||||
input_plaintext_list.as_mut().iter_mut().for_each(|x| {
|
||||
*x = (*x << (<u128 as Numeric>::BITS - parameters.decomposition_base_log.0))
|
||||
/ scaling_to_native_torus
|
||||
});
|
||||
|
||||
// Sanity check
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let modulus = ciphertext_modulus.get_custom_modulus();
|
||||
assert!(input_plaintext_list.as_ref().iter().all(|x| *x < modulus));
|
||||
}
|
||||
|
||||
let mut input_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u128,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut input_glwe_ciphertext,
|
||||
&input_plaintext_list,
|
||||
parameters.glwe_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u128,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
add_external_product_assign(
|
||||
&mut output_glwe_ciphertext,
|
||||
&fourier_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
output_glwe_ciphertext
|
||||
.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed().as_nanos();
|
||||
sample_runtime_ns += elapsed;
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
decrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&output_glwe_ciphertext,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
// Sanity check
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let modulus = ciphertext_modulus.get_custom_modulus();
|
||||
assert!(output_plaintext_list.as_ref().iter().all(|x| *x < modulus));
|
||||
}
|
||||
|
||||
raw_inputs.push(input_plaintext_list.into_container());
|
||||
outputs.push(output_plaintext_list.into_container());
|
||||
}
|
||||
|
||||
// No prep time in this case
|
||||
(sample_runtime_ns, 0)
|
||||
}
|
||||
2
tfhe-rs-cost-model/src/operators/mod.rs
Normal file
2
tfhe-rs-cost-model/src/operators/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod classic_pbs;
|
||||
pub mod multi_bit_pbs;
|
||||
393
tfhe-rs-cost-model/src/operators/multi_bit_pbs.rs
Normal file
393
tfhe-rs-cost-model/src/operators/multi_bit_pbs.rs
Normal file
@@ -0,0 +1,393 @@
|
||||
use crate::GlweCiphertextGgswCiphertextExternalProductParameters;
|
||||
use tfhe::core_crypto::algorithms::polynomial_algorithms;
|
||||
use tfhe::core_crypto::fft_impl::common::pbs_modulus_switch;
|
||||
use tfhe::core_crypto::fft_impl::fft64::crypto::ggsw::FourierGgswCiphertext;
|
||||
use tfhe::core_crypto::fft_impl::fft64::math::fft::FftView;
|
||||
use tfhe::core_crypto::fft_impl::fft64::math::polynomial::FourierPolynomial;
|
||||
use tfhe::core_crypto::prelude::{
|
||||
add_external_product_assign_mem_optimized, allocate_and_generate_new_binary_glwe_secret_key,
|
||||
allocate_and_generate_new_lwe_multi_bit_bootstrap_key,
|
||||
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized, decrypt_glwe_ciphertext,
|
||||
encrypt_glwe_ciphertext, karatsuba_add_external_product_assign_mem_optimized,
|
||||
modulus_switch_multi_bit, prepare_multi_bit_ggsw_mem_optimized, std_prepare_multi_bit_ggsw,
|
||||
ActivatedRandomGenerator, ComputationBuffers, ContiguousEntityContainer,
|
||||
EncryptionRandomGenerator, FourierLweMultiBitBootstrapKey, GgswCiphertext, GlweCiphertext,
|
||||
LweBskGroupingFactor, LweSecretKey, MonomialDegree, Numeric, PlaintextCount, PlaintextList,
|
||||
SecretRandomGenerator,
|
||||
};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn multi_bit_pbs_external_product(
|
||||
parameters: &GlweCiphertextGgswCiphertextExternalProductParameters<u64>,
|
||||
raw_inputs: &mut Vec<Vec<u64>>,
|
||||
outputs: &mut Vec<Vec<u64>>,
|
||||
sample_size: usize,
|
||||
secret_random_generator: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encryption_random_generator: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>,
|
||||
use_fft: bool,
|
||||
fft: FftView,
|
||||
computation_buffers: &mut ComputationBuffers,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
) -> (u128, u128) {
|
||||
let lwe_sk = LweSecretKey::from_container(vec![1u64; grouping_factor.0]);
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
parameters.glwe_dimension,
|
||||
parameters.polynomial_size,
|
||||
secret_random_generator,
|
||||
);
|
||||
|
||||
let bsk = allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_secret_key,
|
||||
parameters.decomposition_base_log,
|
||||
parameters.decomposition_level_count,
|
||||
grouping_factor,
|
||||
parameters.ggsw_noise,
|
||||
parameters.ciphertext_modulus,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut fbsk = FourierLweMultiBitBootstrapKey::new(
|
||||
bsk.input_lwe_dimension(),
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
bsk.grouping_factor(),
|
||||
);
|
||||
|
||||
if use_fft {
|
||||
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized(
|
||||
&bsk,
|
||||
&mut fbsk,
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
let std_ggsw_vec: Vec<_> = bsk.iter().collect();
|
||||
let ggsw_vec: Vec<_> = fbsk.ggsw_iter().collect();
|
||||
|
||||
let grouping_factor = fbsk.grouping_factor();
|
||||
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
|
||||
|
||||
assert_eq!(ggsw_vec.len(), ggsw_per_multi_bit_element.0);
|
||||
|
||||
let mut random_mask = vec![0u64; grouping_factor.0];
|
||||
encryption_random_generator.fill_slice_with_random_uniform_mask(&mut random_mask);
|
||||
|
||||
// Recompute it here to rotate and negate the input or output vector to compute errors that make
|
||||
// sense, this corresponds to all key bits == 1, which is a worse case on a single ext prod
|
||||
let equivalent_monomial_degree = MonomialDegree(pbs_modulus_switch(
|
||||
random_mask.iter().sum::<u64>(),
|
||||
parameters.polynomial_size,
|
||||
));
|
||||
|
||||
let mut fourier_a_monomial = FourierPolynomial::new(fbsk.polynomial_size());
|
||||
|
||||
let mut std_ggsw = GgswCiphertext::new(
|
||||
0u64,
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut tmp_std_ggsw = GgswCiphertext::new(
|
||||
0u64,
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut fourier_ggsw = FourierGgswCiphertext::new(
|
||||
fbsk.glwe_size(),
|
||||
fbsk.polynomial_size(),
|
||||
fbsk.decomposition_base_log(),
|
||||
fbsk.decomposition_level_count(),
|
||||
);
|
||||
|
||||
let prep_start = std::time::Instant::now();
|
||||
if use_fft {
|
||||
prepare_multi_bit_ggsw_mem_optimized(
|
||||
&mut fourier_ggsw,
|
||||
&ggsw_vec,
|
||||
modulus_switch_multi_bit(
|
||||
fbsk.polynomial_size().to_blind_rotation_input_modulus_log(),
|
||||
grouping_factor,
|
||||
&random_mask,
|
||||
),
|
||||
&mut fourier_a_monomial,
|
||||
fft,
|
||||
);
|
||||
} else {
|
||||
std_prepare_multi_bit_ggsw(
|
||||
&mut std_ggsw,
|
||||
&mut tmp_std_ggsw,
|
||||
&std_ggsw_vec,
|
||||
modulus_switch_multi_bit(
|
||||
bsk.polynomial_size().to_blind_rotation_input_modulus_log(),
|
||||
grouping_factor,
|
||||
&random_mask,
|
||||
),
|
||||
);
|
||||
}
|
||||
let prep_time_ns = prep_start.elapsed().as_nanos();
|
||||
|
||||
let mut sample_runtime_ns = 0u128;
|
||||
|
||||
for _ in 0..sample_size {
|
||||
let input_plaintext_list =
|
||||
PlaintextList::new(0u64, PlaintextCount(parameters.polynomial_size.0));
|
||||
// encryption_random_generator
|
||||
// .fill_slice_with_random_uniform_mask(input_plaintext_list.as_mut());
|
||||
// // Shift to match the behavior of the previous concrete-core fixtures
|
||||
// input_plaintext_list
|
||||
// .as_mut()
|
||||
// .iter_mut()
|
||||
// .for_each(|x| *x <<= <u64 as Numeric>::BITS - parameters.decomposition_base_log.0);
|
||||
|
||||
let mut input_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut input_glwe_ciphertext,
|
||||
&input_plaintext_list,
|
||||
parameters.glwe_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.ciphertext_modulus,
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if use_fft {
|
||||
add_external_product_assign_mem_optimized(
|
||||
&mut output_glwe_ciphertext,
|
||||
&fourier_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
} else {
|
||||
karatsuba_add_external_product_assign_mem_optimized(
|
||||
&mut output_glwe_ciphertext,
|
||||
&std_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed().as_nanos();
|
||||
sample_runtime_ns += elapsed;
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
decrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&output_glwe_ciphertext,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
let mut output_pt_list_as_polynomial = output_plaintext_list.as_mut_polynomial();
|
||||
|
||||
// As we performed a monomial multiplication, we need to apply a monomial div to get outputs
|
||||
// in the right order
|
||||
polynomial_algorithms::polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut output_pt_list_as_polynomial,
|
||||
equivalent_monomial_degree,
|
||||
);
|
||||
|
||||
raw_inputs.push(input_plaintext_list.into_container());
|
||||
outputs.push(output_plaintext_list.into_container());
|
||||
}
|
||||
|
||||
(sample_runtime_ns, prep_time_ns)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn std_multi_bit_pbs_external_product(
|
||||
parameters: &GlweCiphertextGgswCiphertextExternalProductParameters<u64>,
|
||||
raw_inputs: &mut Vec<Vec<u64>>,
|
||||
outputs: &mut Vec<Vec<u64>>,
|
||||
sample_size: usize,
|
||||
secret_random_generator: &mut SecretRandomGenerator<ActivatedRandomGenerator>,
|
||||
encryption_random_generator: &mut EncryptionRandomGenerator<ActivatedRandomGenerator>,
|
||||
use_fft: bool,
|
||||
fft: FftView,
|
||||
computation_buffers: &mut ComputationBuffers,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
) -> (u128, u128) {
|
||||
let lwe_sk = LweSecretKey::from_container(vec![1u64; grouping_factor.0]);
|
||||
let glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
parameters.glwe_dimension,
|
||||
parameters.polynomial_size,
|
||||
secret_random_generator,
|
||||
);
|
||||
|
||||
let bsk = allocate_and_generate_new_lwe_multi_bit_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_secret_key,
|
||||
parameters.decomposition_base_log,
|
||||
parameters.decomposition_level_count,
|
||||
grouping_factor,
|
||||
parameters.ggsw_noise,
|
||||
parameters.ciphertext_modulus,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let ggsw_vec: Vec<_> = bsk.iter().collect();
|
||||
|
||||
let grouping_factor = bsk.grouping_factor();
|
||||
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
|
||||
|
||||
assert_eq!(ggsw_vec.len(), ggsw_per_multi_bit_element.0);
|
||||
|
||||
let mut random_mask = vec![0u64; grouping_factor.0];
|
||||
encryption_random_generator.fill_slice_with_random_uniform_mask(&mut random_mask);
|
||||
|
||||
// Recompute it here to rotate and negate the input or output vector to compute errors that make
|
||||
// sense
|
||||
let equivalent_monomial_degree = MonomialDegree(pbs_modulus_switch(
|
||||
random_mask.iter().sum::<u64>(),
|
||||
parameters.polynomial_size,
|
||||
));
|
||||
|
||||
let mut fourier_ggsw = FourierGgswCiphertext::new(
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
);
|
||||
|
||||
let mut std_ggsw = GgswCiphertext::new(
|
||||
0u64,
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let mut tmp_std_ggsw = GgswCiphertext::new(
|
||||
0u64,
|
||||
bsk.glwe_size(),
|
||||
bsk.polynomial_size(),
|
||||
bsk.decomposition_base_log(),
|
||||
bsk.decomposition_level_count(),
|
||||
bsk.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let prep_start = std::time::Instant::now();
|
||||
std_prepare_multi_bit_ggsw(
|
||||
&mut std_ggsw,
|
||||
&mut tmp_std_ggsw,
|
||||
&ggsw_vec,
|
||||
modulus_switch_multi_bit(
|
||||
bsk.polynomial_size().to_blind_rotation_input_modulus_log(),
|
||||
grouping_factor,
|
||||
&random_mask,
|
||||
),
|
||||
);
|
||||
|
||||
if use_fft {
|
||||
fourier_ggsw.as_mut_view().fill_with_forward_fourier(
|
||||
std_ggsw.as_view(),
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
let prep_time_ns = prep_start.elapsed().as_nanos();
|
||||
|
||||
let mut sample_runtime_ns = 0u128;
|
||||
|
||||
for _ in 0..sample_size {
|
||||
let input_plaintext_list =
|
||||
PlaintextList::new(0u64, PlaintextCount(parameters.polynomial_size.0));
|
||||
// encryption_random_generator
|
||||
// .fill_slice_with_random_uniform_mask(input_plaintext_list.as_mut());
|
||||
// // Shift to match the behavior of the previous concrete-core fixtures
|
||||
// input_plaintext_list
|
||||
// .as_mut()
|
||||
// .iter_mut()
|
||||
// .for_each(|x| *x <<= <u64 as Numeric>::BITS - parameters.decomposition_base_log.0);
|
||||
|
||||
let mut input_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.ciphertext_modulus,
|
||||
);
|
||||
|
||||
encrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&mut input_glwe_ciphertext,
|
||||
&input_plaintext_list,
|
||||
parameters.glwe_noise,
|
||||
encryption_random_generator,
|
||||
);
|
||||
|
||||
let mut output_glwe_ciphertext = GlweCiphertext::new(
|
||||
0u64,
|
||||
parameters.glwe_dimension.to_glwe_size(),
|
||||
parameters.polynomial_size,
|
||||
parameters.ciphertext_modulus,
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
if use_fft {
|
||||
add_external_product_assign_mem_optimized(
|
||||
&mut output_glwe_ciphertext,
|
||||
&fourier_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
fft,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
} else {
|
||||
karatsuba_add_external_product_assign_mem_optimized(
|
||||
&mut output_glwe_ciphertext,
|
||||
&std_ggsw,
|
||||
&input_glwe_ciphertext,
|
||||
computation_buffers.stack(),
|
||||
);
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed().as_nanos();
|
||||
sample_runtime_ns += elapsed;
|
||||
|
||||
let mut output_plaintext_list = input_plaintext_list.clone();
|
||||
decrypt_glwe_ciphertext(
|
||||
&glwe_secret_key,
|
||||
&output_glwe_ciphertext,
|
||||
&mut output_plaintext_list,
|
||||
);
|
||||
|
||||
let mut output_pt_list_as_polynomial = output_plaintext_list.as_mut_polynomial();
|
||||
|
||||
// As we performed a monomial multiplication, we need to apply a monomial div to get outputs
|
||||
// in the right order
|
||||
polynomial_algorithms::polynomial_wrapping_monic_monomial_div_assign(
|
||||
&mut output_pt_list_as_polynomial,
|
||||
equivalent_monomial_degree,
|
||||
);
|
||||
|
||||
raw_inputs.push(input_plaintext_list.into_container());
|
||||
outputs.push(output_plaintext_list.into_container());
|
||||
}
|
||||
|
||||
(sample_runtime_ns, prep_time_ns)
|
||||
}
|
||||
3
tfhe-rs-cost-model/src/requirements.txt
Normal file
3
tfhe-rs-cost-model/src/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
numpy
|
||||
scipy
|
||||
scikit-learn
|
||||
30
tfhe-rs-cost-model/src/tanh-plot.sh
Executable file
30
tfhe-rs-cost-model/src/tanh-plot.sh
Executable file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env gnuplot
|
||||
|
||||
GF = 2
|
||||
SFX = "-tanh" # -tanh
|
||||
# sort by 4-th column
|
||||
DATAFILE = "< sort -nk4 log-real-to-pred-".GF.SFX.".dat"
|
||||
|
||||
set term pngcairo size 1800,600 linewidth 2
|
||||
|
||||
set grid
|
||||
set xtics 2
|
||||
|
||||
do for [nu=9:14] {
|
||||
do for [k=1:2] {
|
||||
N = 2**nu
|
||||
set out "logratio-".GF."-k=".k."-N=".N.SFX.".png"
|
||||
|
||||
x0 = y0 = NaN
|
||||
plot \
|
||||
DATAFILE u (($2 == N && $3 == k && $5 == 1) ? (y0=$1,x0=$4) : x0):(y0) w lp lt 1 t 'B = 2^1', \
|
||||
'' u (x0 = NaN):(y0 = NaN) notitle, \
|
||||
'' u (($2 == N && $3 == k && $5 == 2) ? (y0=$1,x0=$4) : x0):(y0) w lp lt 2 t 'B = 2^2', \
|
||||
'' u (x0 = NaN):(y0 = NaN) notitle, \
|
||||
'' u (($2 == N && $3 == k && $5 == 3) ? (y0=$1,x0=$4) : x0):(y0) w lp lt 3 t 'B = 2^3', \
|
||||
'' u (x0 = NaN):(y0 = NaN) notitle, \
|
||||
'' u (($2 == N && $3 == k && $5 == 4) ? (y0=$1,x0=$4) : x0):(y0) w lp lt 4 t 'B = 2^4', \
|
||||
'' u (x0 = NaN):(y0 = NaN) notitle, \
|
||||
'' u (($2 == N && $3 == k && $5 == 5) ? (y0=$1,x0=$4) : x0):(y0) w lp lt 5 t 'B = 2^5'
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ use crate::core_crypto::fft_impl::fft64::crypto::ggsw::{
|
||||
add_external_product_assign as impl_add_external_product_assign,
|
||||
add_external_product_assign_scratch as impl_add_external_product_assign_scratch, cmux,
|
||||
cmux_scratch,
|
||||
karatsuba_add_external_product_assign as impl_karatsuba_add_external_product_assign,
|
||||
};
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
use concrete_fft::c64;
|
||||
@@ -489,6 +490,48 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
|
||||
}
|
||||
}
|
||||
|
||||
pub fn karatsuba_add_external_product_assign_mem_optimized<
|
||||
Scalar,
|
||||
OutputGlweCont,
|
||||
InputGlweCont,
|
||||
GgswCont,
|
||||
>(
|
||||
out: &mut GlweCiphertext<OutputGlweCont>,
|
||||
ggsw: &GgswCiphertext<GgswCont>,
|
||||
glwe: &GlweCiphertext<InputGlweCont>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
OutputGlweCont: ContainerMut<Element = Scalar>,
|
||||
GgswCont: Container<Element = Scalar>,
|
||||
InputGlweCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(out.ciphertext_modulus(), glwe.ciphertext_modulus());
|
||||
let ciphertext_modulus = out.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
impl_karatsuba_add_external_product_assign(
|
||||
out.as_mut_view(),
|
||||
ggsw.as_view(),
|
||||
glwe.as_view(),
|
||||
stack,
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
// fill the MSBs and leave the LSBs empty, this usage of the signed decomposer allows to
|
||||
// round while keeping the data in the MSBs
|
||||
let signed_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(ciphertext_modulus.get_custom_modulus().ilog2() as usize),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
out.as_mut()
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = signed_decomposer.closest_representable(*x));
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the required memory for [`add_external_product_assign_mem_optimized`].
|
||||
pub fn add_external_product_assign_mem_optimized_requirement<Scalar>(
|
||||
glwe_size: GlweSize,
|
||||
|
||||
@@ -131,7 +131,7 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
|
||||
}
|
||||
|
||||
// Fills the slice with random uniform values, using the mask generator.
|
||||
pub(crate) fn fill_slice_with_random_uniform_mask<Scalar>(&mut self, output: &mut [Scalar])
|
||||
pub fn fill_slice_with_random_uniform_mask<Scalar>(&mut self, output: &mut [Scalar])
|
||||
where
|
||||
Scalar: RandomGenerable<Uniform>,
|
||||
{
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
// This file was autogenerated, do not modify by hand.
|
||||
use crate::core_crypto::commons::dispersion::Variance;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn external_product_fft_additive_variance132_bits_security_gaussian(
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
external_product_fft_additive_variance132_bits_security_gaussian_impl(
|
||||
glwe_dimension.0 as f64,
|
||||
polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn external_product_fft_additive_variance132_bits_security_gaussian_impl(
|
||||
glwe_dimension: f64,
|
||||
polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
2.06537277069845e-33
|
||||
* decomposition_base.powf(2.0)
|
||||
* decomposition_level_count
|
||||
* polynomial_size.powf(2.0)
|
||||
* (glwe_dimension + 1.0)
|
||||
+ decomposition_level_count
|
||||
* polynomial_size
|
||||
* ((-0.0497829131652661 * glwe_dimension * polynomial_size + 5.31469187675068).exp2()
|
||||
+ 16.0 * modulus.powf(-2.0))
|
||||
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
|
||||
* (glwe_dimension + 1.0)
|
||||
+ (1_f64 / 2.0)
|
||||
* glwe_dimension
|
||||
* polynomial_size
|
||||
* (0.0208333333333333 * modulus.powf(-2.0)
|
||||
+ 0.0416666666666667 * decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
+ (1_f64 / 12.0) * modulus.powf(-2.0)
|
||||
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count)
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
// This file was autogenerated, do not modify by hand.
|
||||
use crate::core_crypto::commons::dispersion::Variance;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn external_product_no_fft_additive_variance132_bits_security_gaussian(
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
external_product_no_fft_additive_variance132_bits_security_gaussian_impl(
|
||||
glwe_dimension.0 as f64,
|
||||
polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn external_product_no_fft_additive_variance132_bits_security_gaussian_impl(
|
||||
glwe_dimension: f64,
|
||||
polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
decomposition_level_count
|
||||
* polynomial_size
|
||||
* ((-0.0497829131652661 * glwe_dimension * polynomial_size + 5.31469187675068).exp2()
|
||||
+ 16.0 * modulus.powf(-2.0))
|
||||
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
|
||||
* (glwe_dimension + 1.0)
|
||||
+ (1_f64 / 2.0)
|
||||
* glwe_dimension
|
||||
* polynomial_size
|
||||
* (0.0208333333333333 * modulus.powf(-2.0)
|
||||
+ 0.0416666666666667 * decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
+ (1_f64 / 12.0) * modulus.powf(-2.0)
|
||||
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count)
|
||||
}
|
||||
@@ -1,4 +1,7 @@
|
||||
// This file was autogenerated, do not modify by hand.
|
||||
pub mod external_product_fft;
|
||||
pub mod external_product_no_fft;
|
||||
pub mod lwe_keyswitch;
|
||||
pub mod lwe_programmable_bootstrap;
|
||||
pub mod multi_bit_external_product_no_fft;
|
||||
pub mod secure_noise;
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
// This file was autogenerated, do not modify by hand.
|
||||
use crate::core_crypto::commons::dispersion::Variance;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn multi_bit_external_product_no_fft_additive_variance_132_bits_security_gaussian(
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
grouping_factor: f64,
|
||||
modulus: f64,
|
||||
) -> Variance {
|
||||
Variance(
|
||||
multi_bit_external_product_no_fft_additive_variance_132_bits_security_gaussian_impl(
|
||||
glwe_dimension.0 as f64,
|
||||
polynomial_size.0 as f64,
|
||||
2.0f64.powi(decomposition_base_log.0 as i32),
|
||||
decomposition_level_count.0 as f64,
|
||||
grouping_factor,
|
||||
modulus,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/// This formula is only valid if the proper noise distributions are used and
|
||||
/// if the keys used are encrypted using secure noise given by the
|
||||
/// [`minimal_glwe_variance`](`super::secure_noise`)
|
||||
/// and [`minimal_lwe_variance`](`super::secure_noise`) family of functions.
|
||||
pub fn multi_bit_external_product_no_fft_additive_variance_132_bits_security_gaussian_impl(
|
||||
glwe_dimension: f64,
|
||||
polynomial_size: f64,
|
||||
decomposition_base: f64,
|
||||
decomposition_level_count: f64,
|
||||
grouping_factor: f64,
|
||||
modulus: f64,
|
||||
) -> f64 {
|
||||
grouping_factor.exp2()
|
||||
* decomposition_level_count
|
||||
* polynomial_size
|
||||
* ((-0.0497829131652661 * glwe_dimension * polynomial_size + 5.31469187675068).exp2()
|
||||
+ 16.0 * modulus.powf(-2.0))
|
||||
* ((1_f64 / 12.0) * decomposition_base.powf(2.0) + 0.166666666666667)
|
||||
* (glwe_dimension + 1.0)
|
||||
+ (1_f64 / 2.0)
|
||||
* glwe_dimension
|
||||
* polynomial_size
|
||||
* (0.0208333333333333 * modulus.powf(-2.0)
|
||||
+ 0.0416666666666667 * decomposition_base.powf(-2.0 * decomposition_level_count))
|
||||
+ (1_f64 / 12.0) * modulus.powf(-2.0)
|
||||
+ (1_f64 / 24.0) * decomposition_base.powf(-2.0 * decomposition_level_count)
|
||||
}
|
||||
@@ -15,7 +15,7 @@ use crate::core_crypto::entities::ggsw_ciphertext::{
|
||||
use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertext, GlweCiphertextView};
|
||||
use crate::core_crypto::fft_impl::fft64::math::decomposition::TensorSignedDecompositionLendingIter;
|
||||
use crate::core_crypto::prelude::ContainerMut;
|
||||
use aligned_vec::CACHELINE_ALIGN;
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use concrete_fft::fft128::f128;
|
||||
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
|
||||
use tfhe_versionable::Versionize;
|
||||
@@ -169,6 +169,38 @@ impl<C: Container<Element = f64>> Fourier128GgswCiphertext<C> {
|
||||
}
|
||||
}
|
||||
|
||||
pub type Fourier128GgswCiphertextOwned = Fourier128GgswCiphertext<ABox<[f64]>>;
|
||||
|
||||
impl Fourier128GgswCiphertext<ABox<[f64]>> {
|
||||
pub fn new(
|
||||
glwe_size: GlweSize,
|
||||
polynomial_size: PolynomialSize,
|
||||
decomposition_base_log: DecompositionBaseLog,
|
||||
decomposition_level_count: DecompositionLevelCount,
|
||||
) -> Self {
|
||||
let container_len = polynomial_size.to_fourier_polynomial_size().0
|
||||
* decomposition_level_count.0
|
||||
* glwe_size.0
|
||||
* glwe_size.0;
|
||||
|
||||
let boxed_re0 = avec![0.0f64; container_len].into_boxed_slice();
|
||||
let boxed_re1 = avec![0.0f64; container_len].into_boxed_slice();
|
||||
let boxed_im0 = avec![0.0f64; container_len].into_boxed_slice();
|
||||
let boxed_im1 = avec![0.0f64; container_len].into_boxed_slice();
|
||||
|
||||
Fourier128GgswCiphertext::from_container(
|
||||
boxed_re0,
|
||||
boxed_re1,
|
||||
boxed_im0,
|
||||
boxed_im1,
|
||||
polynomial_size,
|
||||
glwe_size,
|
||||
decomposition_base_log,
|
||||
decomposition_level_count,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Container<Element = f64>> Fourier128GgswLevelMatrix<C> {
|
||||
pub fn from_container(
|
||||
data_re0: C,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use super::super::super::{fft128, fft128_u128};
|
||||
use super::super::math::fft::Fft128View;
|
||||
use crate::core_crypto::fft_impl::common::tests::{
|
||||
gen_keys_or_get_from_cache_if_enabled, generate_keys,
|
||||
};
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use super::super::math::decomposition::TensorSignedDecompositionLendingIter;
|
||||
use super::super::math::fft::{FftView, FourierPolynomialList};
|
||||
use super::super::math::polynomial::FourierPolynomialMutView;
|
||||
use crate::core_crypto::algorithms::polynomial_algorithms::{
|
||||
polynomial_wrapping_add_assign, polynomial_wrapping_add_mul_assign, polynomial_wrapping_mul,
|
||||
};
|
||||
use crate::core_crypto::backward_compatibility::fft_impl::FourierGgswCiphertextVersions;
|
||||
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
@@ -15,6 +18,7 @@ use crate::core_crypto::entities::ggsw_ciphertext::{
|
||||
fourier_ggsw_level_matrix_size, GgswCiphertextView,
|
||||
};
|
||||
use crate::core_crypto::entities::glwe_ciphertext::{GlweCiphertextMutView, GlweCiphertextView};
|
||||
use crate::core_crypto::entities::polynomial::Polynomial;
|
||||
use aligned_vec::{avec, ABox, CACHELINE_ALIGN};
|
||||
use concrete_fft::c64;
|
||||
use dyn_stack::{PodStack, ReborrowMut, SizeOverflow, StackReq};
|
||||
@@ -600,6 +604,145 @@ pub fn add_external_product_assign<Scalar>(
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform the external product of `ggsw` and `glwe`, and adds the result to `out`.
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub fn karatsuba_add_external_product_assign<Scalar>(
|
||||
mut out: GlweCiphertextMutView<'_, Scalar>,
|
||||
ggsw: GgswCiphertextView<Scalar>,
|
||||
glwe: GlweCiphertextView<Scalar>,
|
||||
stack: PodStack<'_>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
{
|
||||
// we check that the polynomial sizes match
|
||||
debug_assert_eq!(ggsw.polynomial_size(), glwe.polynomial_size());
|
||||
debug_assert_eq!(ggsw.polynomial_size(), out.polynomial_size());
|
||||
// we check that the glwe sizes match
|
||||
debug_assert_eq!(ggsw.glwe_size(), glwe.glwe_size());
|
||||
debug_assert_eq!(ggsw.glwe_size(), out.glwe_size());
|
||||
|
||||
let align = CACHELINE_ALIGN;
|
||||
let poly_size = ggsw.polynomial_size().0;
|
||||
|
||||
// we round the input mask and body
|
||||
let decomposer = SignedDecomposer::<Scalar>::new(
|
||||
ggsw.decomposition_base_log(),
|
||||
ggsw.decomposition_level_count(),
|
||||
);
|
||||
|
||||
let (output_buffer, mut substack0) =
|
||||
stack.make_aligned_raw::<Scalar>(poly_size * ggsw.glwe_size().0, align);
|
||||
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
|
||||
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
|
||||
// it has been fully initialized for the first time.
|
||||
let output_buffer = &mut *output_buffer;
|
||||
let mut is_output_uninit = true;
|
||||
|
||||
{
|
||||
// ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER DOMAIN
|
||||
// In this section, we perform the external product in the fourier domain, and accumulate
|
||||
// the result in the output_fft_buffer variable.
|
||||
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
|
||||
glwe.as_ref()
|
||||
.iter()
|
||||
.map(|s| decomposer.init_decomposer_state(*s)),
|
||||
DecompositionBaseLog(decomposer.base_log),
|
||||
DecompositionLevelCount(decomposer.level_count),
|
||||
substack0.rb_mut(),
|
||||
);
|
||||
|
||||
// We loop through the levels (we reverse to match the order of the decomposition iterator.)
|
||||
ggsw.iter().rev().for_each(|ggsw_decomp_matrix| {
|
||||
// We retrieve the decomposition of this level.
|
||||
let (_glwe_level, glwe_decomp_term, _substack2) =
|
||||
collect_next_term(&mut decomposition, &mut substack1, align);
|
||||
let glwe_decomp_term = GlweCiphertextView::from_container(
|
||||
&*glwe_decomp_term,
|
||||
ggsw.polynomial_size(),
|
||||
out.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
// For each level we have to add the result of the vector-matrix product between the
|
||||
// decomposition of the glwe, and the ggsw level matrix to the output. To do so, we
|
||||
// iteratively add to the output, the product between every line of the matrix, and
|
||||
// the corresponding (scalar) polynomial in the glwe decomposition:
|
||||
//
|
||||
// ggsw_mat ggsw_mat
|
||||
// glwe_dec | - - - - | < glwe_dec | - - - - |
|
||||
// | - - - | x | - - - - | | - - - | x | - - - - | <
|
||||
// ^ | - - - - | ^ | - - - - |
|
||||
//
|
||||
// t = 1 t = 2 ...
|
||||
|
||||
izip!(
|
||||
ggsw_decomp_matrix.as_glwe_list().iter(),
|
||||
glwe_decomp_term.as_polynomial_list().iter()
|
||||
)
|
||||
.for_each(|(ggsw_row, glwe_poly)| {
|
||||
// let (fourier, substack3) =
|
||||
// substack2.rb_mut().make_aligned_raw::<c64>(poly_size, align);
|
||||
// // We perform the forward fft transform for the glwe polynomial
|
||||
// let fourier = fft
|
||||
// .forward_as_integer(
|
||||
// FourierPolynomialMutView { data: fourier },
|
||||
// glwe_poly,
|
||||
// substack3,
|
||||
// )
|
||||
// .data;
|
||||
// // Now we loop through the polynomials of the output, and add the
|
||||
// // corresponding product of polynomials.
|
||||
|
||||
// update_with_fmadd(
|
||||
// output_buffer,
|
||||
// ggsw_row.data(),
|
||||
// fourier,
|
||||
// is_output_uninit,
|
||||
// poly_size,
|
||||
// );
|
||||
|
||||
// // we initialized `output_fft_buffer, so we can set this to false
|
||||
// is_output_uninit = false;
|
||||
|
||||
let row_as_poly_list = ggsw_row.as_polynomial_list();
|
||||
if is_output_uninit {
|
||||
for (mut output_poly, row_poly) in output_buffer
|
||||
.chunks_exact_mut(poly_size)
|
||||
.map(Polynomial::from_container)
|
||||
.zip(row_as_poly_list.iter())
|
||||
{
|
||||
polynomial_wrapping_mul(&mut output_poly, &row_poly, &glwe_poly);
|
||||
}
|
||||
} else {
|
||||
for (mut output_poly, row_poly) in output_buffer
|
||||
.chunks_exact_mut(poly_size)
|
||||
.map(Polynomial::from_container)
|
||||
.zip(row_as_poly_list.iter())
|
||||
{
|
||||
polynomial_wrapping_add_mul_assign(&mut output_poly, &row_poly, &glwe_poly);
|
||||
}
|
||||
}
|
||||
|
||||
is_output_uninit = false;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// -------------------------------------------- TRANSFORMATION OF RESULT TO STANDARD DOMAIN
|
||||
// In this section, we bring the result from the fourier domain, back to the standard
|
||||
// domain, and add it to the output.
|
||||
//
|
||||
// We iterate over the polynomials in the output.
|
||||
if !is_output_uninit {
|
||||
izip!(
|
||||
out.as_mut_polynomial_list().iter_mut(),
|
||||
output_buffer
|
||||
.into_chunks(poly_size)
|
||||
.map(Polynomial::from_container),
|
||||
)
|
||||
.for_each(|(mut out, res)| polynomial_wrapping_add_assign(&mut out, &res));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "__profiling", inline(never))]
|
||||
pub(crate) fn collect_next_term<'a, Scalar: UnsignedTorus>(
|
||||
decomposition: &mut TensorSignedDecompositionLendingIter<'_, Scalar>,
|
||||
|
||||
@@ -6,4 +6,4 @@ pub mod common;
|
||||
pub mod fft64;
|
||||
|
||||
pub mod fft128;
|
||||
mod fft128_u128;
|
||||
pub mod fft128_u128;
|
||||
|
||||
@@ -14,6 +14,6 @@ pub use super::commons::math::random::{ActivatedRandomGenerator, Gaussian, TUnif
|
||||
pub use super::commons::parameters::*;
|
||||
pub use super::commons::traits::*;
|
||||
pub use super::entities::*;
|
||||
pub use super::fft_impl::fft128::math::fft::Fft128;
|
||||
pub use super::fft_impl::fft64::math::fft::Fft;
|
||||
pub use super::fft_impl::fft128::math::fft::{Fft128, Fft128View};
|
||||
pub use super::fft_impl::fft64::math::fft::{Fft, FftView};
|
||||
pub use super::seeders::*;
|
||||
|
||||
Reference in New Issue
Block a user