Sunscreen's TFHE implementation (#349)

Co-authored-by: Sam Tay <samctay@pm.me>
This commit is contained in:
Ryan Orendorff
2024-02-16 15:29:35 -07:00
committed by GitHub
parent ab6a01e0b2
commit dc8fdeab81
81 changed files with 14732 additions and 10 deletions

View File

@@ -60,14 +60,14 @@ jobs:
# Build package in prep for user docs
- name: Build sunscreen and bincode
run: cargo build --release --features bulletproofs,linkedproofs --package sunscreen --package bincode
run: cargo build --profile mdbook --features bulletproofs,linkedproofs --package sunscreen --package bincode
# Build mdbook
- name: Build mdBook
run: cargo build --release
working-directory: ./mdBook
# Build user documentation
- name: Test docs
run: ../mdBook/target/release/mdbook test -L dependency=../target/release/deps --extern sunscreen=../target/release/libsunscreen.rlib --extern bincode=../target/release/libbincode.rlib
run: ../mdBook/target/release/mdbook test -L dependency=../target/mdbook/deps --extern sunscreen=../target/mdbook/libsunscreen.rlib --extern bincode=../target/mdbook/libbincode.rlib
working-directory: ./sunscreen_docs
lint:

188
.gitignore vendored
View File

@@ -1,5 +1,189 @@
sunscreen_docs/book
/target
.DS_Store
__bindgen.ii
__bindgen.cpp
# Ignore output from the proptest crate
proptest-regressions/
# Created by https://www.toptal.com/developers/gitignore/api/emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust
# Edit at https://www.toptal.com/developers/gitignore?templates=emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust
### direnv ###
.direnv
.envrc
### Emacs ###
# -*- mode: gitignore; -*-
*~
\#*\#
/.emacs.desktop
/.emacs.desktop.lock
*.elc
auto-save-list
tramp
.\#*
# Org-mode
.org-id-locations
*_archive
# flymake-mode
*_flymake.*
# eshell files
/eshell/history
/eshell/lastdir
# elpa packages
/elpa/
# reftex files
*.rel
# AUCTeX auto folder
/auto/
# cask packages
.cask/
dist/
# Flycheck
flycheck_*.el
# server auth directory
/server/
# projectiles files
.projectile
# directory configuration
.dir-locals.el
# network security
/network-security.data
### Linux ###
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### macOS Patch ###
# iCloud generated files
*.icloud
### Rust ###
# Generated by Cargo
# will have compiled files and executables
debug/
target/
# These are backup files generated by rustfmt
**/*.rs.bk
# MSVC Windows builds of rustc generate these, which store debugging information
*.pdb
### Vim ###
# Swap
[._]*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
# Session
Session.vim
Sessionx.vim
# Temporary
.netrwhist
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
!.vscode/*.code-snippets
# Local History for Visual Studio Code
.history/
# Built Visual Studio Code Extensions
*.vsix
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk

88
Cargo.lock generated
View File

@@ -643,6 +643,8 @@ dependencies = [
"num-traits 0.2.17",
"once_cell",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
@@ -1593,6 +1595,12 @@ dependencies = [
"cc",
]
[[package]]
name = "linked-list"
version = "0.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4dacf969043dc69f1f731b5042eb05e030d264bcf34f2242889fcbdc7a65f06"
[[package]]
name = "linux-raw-sys"
version = "0.4.10"
@@ -2196,6 +2204,15 @@ dependencies = [
"syn 2.0.38",
]
[[package]]
name = "primal-check"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9df7f93fd637f083201473dab4fee2db4c429d32e55e3299980ab3957ab916a0"
dependencies = [
"num-integer",
]
[[package]]
name = "private_tx_linkedproof"
version = "0.1.0"
@@ -2220,9 +2237,9 @@ checksum = "f89dff0959d98c9758c88826cc002e2c3d0b9dfac4139711d1f30de442f1139b"
[[package]]
name = "proptest"
version = "1.3.1"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e"
checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf"
dependencies = [
"bit-set",
"bit-vec",
@@ -2232,7 +2249,7 @@ dependencies = [
"rand",
"rand_chacha",
"rand_xorshift",
"regex-syntax 0.7.5",
"regex-syntax 0.8.2",
"rusty-fork",
"tempfile",
"unarray",
@@ -2356,6 +2373,15 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "realfft"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953d9f7e5cdd80963547b456251296efc2626ed4e3cbf36c869d9564e0220571"
dependencies = [
"rustfft",
]
[[package]]
name = "redox_syscall"
version = "0.2.11"
@@ -2393,9 +2419,9 @@ checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "regex-syntax"
version = "0.7.5"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
[[package]]
name = "remove_dir_all"
@@ -2485,6 +2511,21 @@ dependencies = [
"semver",
]
[[package]]
name = "rustfft"
version = "6.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43806561bc506d0c5d160643ad742e3161049ac01027b5e6d7524091fd401d86"
dependencies = [
"num-complex",
"num-integer",
"num-traits 0.2.17",
"primal-check",
"strength_reduce",
"transpose",
"version_check",
]
[[package]]
name = "rustix"
version = "0.38.20"
@@ -2773,6 +2814,12 @@ dependencies = [
"rand",
]
[[package]]
name = "strength_reduce"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
[[package]]
name = "strsim"
version = "0.10.0"
@@ -2994,6 +3041,27 @@ dependencies = [
"thiserror",
]
[[package]]
name = "sunscreen_tfhe"
version = "0.1.0"
dependencies = [
"bytemuck",
"criterion 0.5.1",
"linked-list",
"logproof",
"merlin",
"num",
"paste",
"proptest",
"rand",
"rand_distr",
"realfft",
"rustfft",
"serde",
"sunscreen_math",
"thiserror",
]
[[package]]
name = "sunscreen_zkp_backend"
version = "0.8.1"
@@ -3195,6 +3263,16 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "transpose"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6522d49d03727ffb138ae4cbc1283d3774f0d10aa7f9bf52e6784c45daf9b23"
dependencies = [
"num-integer",
"strength_reduce",
]
[[package]]
name = "try-lock"
version = "0.2.3"

View File

@@ -17,7 +17,7 @@ members = [
"sunscreen_math",
"sunscreen_math_macros",
"sunscreen_runtime",
"sunscreen_compiler_common",
"sunscreen_tfhe",
"sunscreen_zkp_backend",
]
exclude = ["mdBook", "rust-playground"]
@@ -25,6 +25,17 @@ exclude = ["mdBook", "rust-playground"]
[profile.release]
split-debuginfo = "packed"
debug = true
lto = "fat"
codegen-units = 1
[profile.bench]
lto = "fat"
codegen-units = 1
[profile.mdbook]
inherits = "release"
lto = false
codegen-units = 16
[workspace.dependencies]
bytemuck = "1.13.0"

81
sunscreen_tfhe/.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,81 @@
{
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'sunscreen_tfhe'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=sunscreen_tfhe"
],
"filter": {
"name": "sunscreen_tfhe",
"kind": "lib"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug benchmark 'tfhe_proof'",
"cargo": {
"args": [
"test",
"--no-run",
"--bench=tfhe_proof",
"--package=sunscreen_tfhe"
],
"filter": {
"name": "tfhe_proof",
"kind": "bench"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug benchmark 'fft'",
"cargo": {
"args": [
"test",
"--no-run",
"--bench=fft",
"--package=sunscreen_tfhe"
],
"filter": {
"name": "fft",
"kind": "bench"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug benchmark 'ops'",
"cargo": {
"args": [
"test",
"--no-run",
"--bench=ops",
"--package=sunscreen_tfhe"
],
"filter": {
"name": "ops",
"kind": "bench"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

3
sunscreen_tfhe/.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"rust-analyzer.showUnlinkedFileNotification": false
}

52
sunscreen_tfhe/Cargo.toml Normal file
View File

@@ -0,0 +1,52 @@
[package]
name = "sunscreen_tfhe"
version = "0.1.0"
edition = "2021"
authors = ["Sunscreen"]
rust-version = "1.56.0"
license = "AGPL-3.0-only"
description = "This crate contains the Sunscreen Torus FHE (TFHE) implementation"
homepage = "https://sunscreen.tech"
repository = "https://github.com/Sunscreen-tech/Sunscreen"
documentation = "https://docs.sunscreen.tech"
keywords = ["FHE", "TFHE", "lattice", "cryptography"]
categories = ["cryptography"]
readme = "crates-io.md"
[dependencies]
bytemuck = { workspace = true }
# TODO: Remove when Rust stabilizes Cursor API
linked-list = "0.0.3"
logproof = { workspace = true, optional = true }
num = { workspace = true }
paste = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true }
realfft = "3.3.0"
rustfft = "6.1.0"
serde = { workspace = true }
sunscreen_math = { workspace = true }
thiserror = { workspace = true }
[dev-dependencies]
criterion = "0.5.1"
merlin = "3.0.0"
proptest = "1.4.0"
[features]
logproof = ["dep:logproof"]
[[bench]]
name = "tfhe_proof"
harness = false
required-features= ["logproof"]
[[bench]]
name = "fft"
harness = false
[[bench]]
name = "ops"
harness = false

35
sunscreen_tfhe/barrett.py Normal file
View File

@@ -0,0 +1,35 @@
import math
import sys
radix = 10
if len(sys.argv) < 3:
print("Usage: barrett <number of 64-bit limbs> <modulus> [<modulus radix>]")
n = int(sys.argv[1])
if len(sys.argv) == 4:
radix = int(sys.argv[3])
p = int(sys.argv[2], radix)
def compute_vals(n, p):
r = math.floor(2**(64 * n) // p)
s = math.floor(2**(128 * n) // p) - 2**(64 * n) * r
t = 2**(64*n) - r * p
return (n, r, s, t)
(_, r, s, t) = compute_vals(n, p)
def print_value(n, name, x):
print(name + " = [")
for i in range(n):
print(" " + str((x >> (64 * i)) & 0xFFFFFFFFFFFFFFFF) + ",")
print("]")
print_value(n, "r", r)
print_value(n, "s", s)
print_value(n, "t", t)

View File

@@ -0,0 +1,47 @@
use criterion::{criterion_group, criterion_main, Criterion};
use num::Complex;
use sunscreen_tfhe::{math::fft::negacyclic::TwistedFft, FrequencyTransform};
fn negacyclic_fft(c: &mut Criterion) {
let n = 2048;
let plan = TwistedFft::<f64>::new(n);
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut y = vec![Complex::from(0.0); x.len() / 2];
c.bench_function("FFT 2048", |s| {
s.iter(|| {
plan.forward(&x, &mut y);
});
});
let n = 1024;
let plan = TwistedFft::<f64>::new(n);
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut y = vec![Complex::from(0.0); x.len() / 2];
c.bench_function("FFT 1024", |s| {
s.iter(|| {
plan.forward(&x, &mut y);
});
});
let n: usize = 256;
let plan = TwistedFft::<f64>::new(n);
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut y = vec![Complex::from(0.0); x.len() / 2];
c.bench_function("FFT 256", |s| {
s.iter(|| {
plan.forward(&x, &mut y);
});
});
}
criterion_group!(benches, negacyclic_fft);
criterion_main!(benches);

View File

@@ -0,0 +1,235 @@
use criterion::{
criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion,
};
use sunscreen_tfhe::{
entities::{
GgswCiphertext, GgswCiphertextFft, GlweCiphertext, Polynomial, UnivariateLookupTable,
},
high_level::*,
ops::bootstrapping::circuit_bootstrap,
rand::Stddev,
GlweDef, GlweDimension, GlweSize, LweDef, LweDimension, PlaintextBits, PolynomialDegree,
RadixCount, RadixDecomposition, RadixLog, GLWE_1_1024_80, GLWE_5_256_80, LWE_512_80,
};
fn cmux(c: &mut Criterion) {
struct CmuxParams {
gsw_radix: RadixDecomposition,
glwe: GlweDef,
}
fn cmux_params(params: &CmuxParams, c: &mut Criterion) {
let sk = keygen::generate_binary_glwe_sk(&params.glwe);
let bits = PlaintextBits(1);
let msg = (0..params.glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>();
let msg = Polynomial::new(&msg);
let a = encryption::encrypt_glwe(&msg, &sk, &params.glwe, bits);
let b = a.clone();
let sel = encryption::encrypt_ggsw(1, &sk, &params.glwe, &params.gsw_radix, bits);
let mut sel_fft = GgswCiphertextFft::new(&params.glwe, &params.gsw_radix);
sel.fft(&mut sel_fft, &params.glwe, &params.gsw_radix);
let name = format!(
"cmux N={} k={} l={}",
params.glwe.dim.polynomial_degree.0, params.glwe.dim.size.0, params.gsw_radix.count.0
);
let mut result = GlweCiphertext::new(&params.glwe);
c.bench_function(&name, |bench| {
bench.iter(|| {
sunscreen_tfhe::ops::fft_ops::cmux(
&mut result,
&a,
&b,
&sel_fft,
&params.glwe,
&params.gsw_radix,
);
});
});
}
let params = CmuxParams {
gsw_radix: RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(10),
},
glwe: GLWE_5_256_80,
};
cmux_params(&params, c);
let params = CmuxParams {
gsw_radix: RadixDecomposition {
count: RadixCount(1),
radix_log: RadixLog(11),
},
glwe: GLWE_1_1024_80,
};
cmux_params(&params, c);
}
fn programmable_bootstrapping(c: &mut Criterion) {
fn run_bench(
name: &str,
g: &mut BenchmarkGroup<WallTime>,
lwe: &LweDef,
glwe: &GlweDef,
bs_radix: &RadixDecomposition,
) {
let lwe_sk = keygen::generate_binary_lwe_sk(lwe);
let glwe_sk = keygen::generate_binary_glwe_sk(glwe);
let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, lwe, glwe, bs_radix);
let bsk = fft::fft_bootstrap_key(&bsk, lwe, glwe, bs_radix);
let ct = lwe_sk.encrypt(1, &glwe.as_lwe_def(), PlaintextBits(1)).0;
let lut = UnivariateLookupTable::trivial_from_fn(|x| x, glwe, PlaintextBits(1));
g.bench_function(name, |b| {
b.iter(|| {
evaluation::univariate_programmable_bootstrap(&ct, &lut, &bsk, lwe, glwe, bs_radix);
});
});
}
let mut g = c.benchmark_group("Bootstrapping");
// CBS parameters
let radix = RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(16),
};
run_bench(
"CBS parameters",
&mut g,
&LWE_512_80,
&GLWE_5_256_80,
&radix,
);
// Binary PBS parameters
let bs_radix = RadixDecomposition {
count: RadixCount(3),
radix_log: RadixLog(6),
};
run_bench(
"boolean PBS parameters",
&mut g,
&LweDef {
dim: LweDimension(722),
std: Stddev(0.000013071021089943935),
},
&GlweDef {
dim: GlweDimension {
size: GlweSize(2),
polynomial_degree: PolynomialDegree(512),
},
std: Stddev(0.00000004990272175010415),
},
&bs_radix,
);
// 3-bit message 1-bit carry PBS parameters
let bs_radix = RadixDecomposition {
count: RadixCount(1),
radix_log: RadixLog(23),
};
run_bench(
"3+1 message PBS parameters",
&mut g,
&LweDef {
dim: LweDimension(742),
std: Stddev(0.000007069849454709433),
},
&GlweDef {
dim: GlweDimension {
size: GlweSize(1),
polynomial_degree: PolynomialDegree(2048),
},
std: Stddev(0.00000000000000029403601535432533),
},
&bs_radix,
);
}
fn circuit_bootstrapping(c: &mut Criterion) {
let pbs_radix = RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(16),
};
let cbs_radix = RadixDecomposition {
count: RadixCount(1),
radix_log: RadixLog(11),
};
let pfks_radix = RadixDecomposition {
count: RadixCount(3),
radix_log: RadixLog(11),
};
let level_2_params = GLWE_5_256_80;
let level_1_params = GLWE_1_1024_80;
let level_0_params = LWE_512_80;
let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params);
let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params);
let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params);
let bsk = keygen::generate_bootstrapping_key(
&sk_0,
&sk_2,
&level_0_params,
&level_2_params,
&pbs_radix,
);
let bsk = fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix);
let cbsksk = keygen::generate_cbs_ksk(
sk_2.to_lwe_secret_key(),
&sk_1,
&level_2_params.as_lwe_def(),
&level_1_params,
&pfks_radix,
);
let val = 0;
let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1));
let mut actual = GgswCiphertext::new(&level_1_params, &cbs_radix);
c.bench_function("Circuit bootstrap", |b| {
b.iter(|| {
circuit_bootstrap(
&mut actual,
&ct,
&bsk,
&cbsksk,
&level_0_params,
&level_1_params,
&level_2_params,
&pbs_radix,
&cbs_radix,
&pfks_radix,
);
});
});
}
criterion_group!(
benches,
cmux,
programmable_bootstrapping,
circuit_bootstrapping
);
criterion_main!(benches);

View File

@@ -0,0 +1,128 @@
use criterion::{criterion_group, criterion_main, Criterion};
use logproof::{
InnerProductVerifierKnowledge, LogProof, LogProofGenerators, LogProofProverKnowledge,
};
use merlin::Transcript;
use sunscreen_tfhe::{
high_level::*,
zkp::{generate_tfhe_sdlp_prover_knowledge, ProofStatement, TorusZq, Witness},
PlaintextBits, Torus, TorusOps, LWE_512_80,
};
fn make_proof<S: TorusOps + TorusZq>(pk: &LogProofProverKnowledge<S::Zq>) -> LogProof {
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
let u = InnerProductVerifierKnowledge::get_u();
let mut p_t = Transcript::new(b"test");
LogProof::create(&mut p_t, pk, &gen.g, &gen.h, &u)
}
fn tfhe_secret_proof(c: &mut Criterion) {
let params = LWE_512_80;
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_lwe_sk(&params);
let enc_data = (0..32)
.map(|_| encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, &params, bits))
.collect::<Vec<_>>();
let msg = vec![Torus::from(1u64); 32];
let statement = enc_data
.iter()
.enumerate()
.map(|(i, d)| ProofStatement::PrivateKeyEncryption {
message_id: i,
ciphertext: &d.0,
})
.collect::<Vec<_>>();
let witness = enc_data
.iter()
.enumerate()
.map(|(_i, d)| Witness::PrivateKeyEncryption {
randomness: d.1,
private_key: &sk,
})
.collect::<Vec<_>>();
let pk = generate_tfhe_sdlp_prover_knowledge(&statement, &msg, &witness, &params, bits);
let p = make_proof::<u64>(&pk);
let mut g = c.benchmark_group("Secret key encryption");
g.sample_size(10);
g.bench_function("Prove 32-bit secret encryption", |b| {
b.iter(|| {
let _ = make_proof::<u64>(&pk);
});
});
g.bench_function("Verify 32-bit secret encryption", |b| {
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
let u = InnerProductVerifierKnowledge::get_u();
b.iter(|| {
let mut t = Transcript::new(b"test");
p.verify(&mut t, &pk.vk, &gen.g, &gen.h, &u).unwrap();
});
});
}
fn tfhe_public_proof(c: &mut Criterion) {
let params = LWE_512_80;
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_lwe_sk(&params);
let public = keygen::generate_lwe_pk(&sk, &params);
let enc_data = (0..32)
.map(|_| encryption::encrypt_lwe_and_return_randomness(1, &public, &params, bits))
.collect::<Vec<_>>();
let msg = vec![Torus::from(1u64); 32];
let statement = enc_data
.iter()
.enumerate()
.map(|(i, d)| ProofStatement::PublicKeyEncryption {
public_key: &public,
message_id: i,
ciphertext: &d.0,
})
.collect::<Vec<_>>();
let witness = enc_data
.iter()
.enumerate()
.map(|(_i, d)| Witness::PublicKeyEncryption { randomness: &d.1 })
.collect::<Vec<_>>();
let pk = generate_tfhe_sdlp_prover_knowledge(&statement, &msg, &witness, &params, bits);
let p = make_proof::<u64>(&pk);
let mut g = c.benchmark_group("Public key encryption");
g.sample_size(10);
g.bench_function("Prove 32-bit public encryption", |b| {
b.iter(|| {
let _ = make_proof::<u64>(&pk);
});
});
g.bench_function("Verify 32-bit public encryption", |b| {
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
let u = InnerProductVerifierKnowledge::get_u();
b.iter(|| {
let mut t = Transcript::new(b"test");
p.verify(&mut t, &pk.vk, &gen.g, &gen.h, &u).unwrap();
});
});
}
criterion_group!(benches, tfhe_secret_proof, tfhe_public_proof,);
criterion_main!(benches);

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

34
sunscreen_tfhe/mont.py Normal file
View File

@@ -0,0 +1,34 @@
p = 13
r = 16
rinv = (r**(p-2)) % p
p_inv = (p**(p-2)) % r
p_prime = (p_inv * (r - 1)) % r
print((p_inv * p) % r)
print((r * rinv) % p)
def to_mont(x):
return (r * x) % p
def from_mont(x):
return (x * rinv) % p
def mont_add(x, y):
return (x + y) % p
def mont_mul(x, y):
return (x * y * rinv) % p
def redc():
pass
a = to_mont(5)
b = to_mont(6)
c = mont_mul(a, b)
d = mont_add(a, b)
print(from_mont(c), from_mont(d))

293
sunscreen_tfhe/src/dst.rs Normal file
View File

@@ -0,0 +1,293 @@
use crate::scratch::Pod;
macro_rules! dst {
($(#[$meta:meta])* $t:ty, $ref_t:ty, $wrapper:ty, ($($derive:ident),* $(,)? ), ($($t_bounds:ty),* $(,)? )) => {
paste::paste! {
$(#[$meta])*
#[derive($($derive,)*)]
pub struct $t<T> where T: Clone $(+ $t_bounds)* {
data: Vec<$wrapper<T>>
}
/// A reference to the data structure.
#[repr(transparent)]
pub struct $ref_t<T> where T: Clone $(+ $t_bounds)* {
data: [$wrapper<T>],
}
impl<T> $ref_t<T> where T: Clone $(+ $t_bounds)* {
/// Clones the contents of rhs into self
pub fn clone_from_ref(&mut self, rhs: &$ref_t<T>) {
for (l, r) in self.data.iter_mut().zip(rhs.data.iter()) {
*l = r.clone();
}
}
/// Returns a slice view of the data representing a $t.
pub fn as_slice(&self) -> &[$wrapper<T>] {
&self.data
}
/// Returns a mutable slice view of the data representing a $t.
pub fn as_mut_slice(&mut self) -> &mut [$wrapper<T>] {
&mut self.data
}
/// Move the contents of rhs into self.
pub fn move_from(&mut self, rhs: $t<T>) {
for (l, r) in self.data.iter_mut().zip(rhs.data.into_iter()) {
*l = r;
}
}
}
impl<T> crate::dst::FromSlice<$wrapper<T>> for $ref_t<T> where T: Clone $(+ $t_bounds)* {
fn from_slice(s: &[$wrapper<T>]) -> &$ref_t<T> {
unsafe { &*(s as *const [$wrapper<T>] as *const $ref_t<T>) }
}
}
impl<T> crate::dst::FromMutSlice<$wrapper<T>> for $ref_t<T> where T: Clone $(+ $t_bounds)* {
fn from_mut_slice(s: &mut [$wrapper<T>]) -> &mut $ref_t<T> {
unsafe { &mut *(s as *mut [$wrapper<T>] as *mut $ref_t<T>) }
}
}
impl<T> $ref_t<T> where T: Clone $(+ $t_bounds)*, $wrapper<T>: num::Zero {
/// Clears the contents of self to contain zero
pub fn clear(&mut self) {
for x in self.as_mut_slice() {
*x = <$wrapper<T> as num::Zero>::zero();
}
}
}
impl<T> std::borrow::Borrow< $ref_t <T>> for $t<T> where T: Clone $(+ $t_bounds)* {
fn borrow(&self) -> &$ref_t<T> {
let ptr = self.data.as_slice() as *const [$wrapper<T>] as *const $ref_t<T>;
unsafe { &*ptr }
}
}
impl<T> std::convert::AsRef< $ref_t <T>> for $t<T> where T: Clone $(+ $t_bounds)*
{
fn as_ref(&self) -> &$ref_t<T> {
<Self as std::borrow::Borrow<$ref_t <T>>>::borrow(self)
}
}
impl<T> std::borrow::BorrowMut< $ref_t<T>> for $t<T> where T: Clone $(+ $t_bounds)* {
fn borrow_mut(&mut self) -> &mut $ref_t<T> {
let ptr = self.data.as_mut_slice() as *mut [$wrapper<T>] as *mut $ref_t<T>;
unsafe { &mut *ptr }
}
}
impl<T> std::borrow::ToOwned for $ref_t<T> where T: Clone $(+ $t_bounds)* {
type Owned = $t<T>;
fn to_owned(&self) -> Self::Owned {
$t { data: self.data.to_owned() }
}
}
impl<T> std::ops::Deref for $t<T> where T: Clone $(+ $t_bounds)* {
type Target = $ref_t<T>;
fn deref(&self) -> &Self::Target {
<Self as std::borrow::Borrow::<$ref_t<T>>>::borrow(&self)
}
}
impl<T> std::ops::DerefMut for $t<T> where T: Clone $(+ $t_bounds)* {
fn deref_mut(&mut self) -> &mut Self::Target {
<Self as std::borrow::BorrowMut::<$ref_t<T>>>::borrow_mut(self)
}
}
}
};
}
macro_rules! dst_iter {
($t:ty, $t_mut:ty, $wrapper_type: ty, $item_ref:ty, ($($t_bounds:ty,)*)) => {
paste::paste!{
/// An iterator to access references to an underlying type.
pub struct $t<'a, T> where T: Clone $(+ $t_bounds)* {
data: &'a [$wrapper_type<T>],
stride: usize,
front_idx: usize,
back_idx: i64
}
impl<'a, T> $t<'a, T> where T: Clone $(+ $t_bounds)* {
/// Create a new iterator that emits references to the contained type
/// by striding over the underlying data.
pub fn new(data: &'a [$wrapper_type<T>], stride: usize) -> Self {
assert_eq!(data.len() % stride, 0);
Self {
data,
stride,
front_idx: 0,
back_idx: (data.len() as i64) - (stride as i64)
}
}
#[inline]
/// The total number of items this iterator will emit and has emitted.
///
/// # Remarks
/// This method returns the same value regardless of how many times
/// `next` has been called.
///
/// This operation does not consume the iterator.
pub fn len(&self) -> usize {
self.data.len() / self.stride
}
/// Returns true if the iterator is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
impl<'a, T> std::iter::Iterator for $t<'a, T> where T: Clone $(+ $t_bounds)* {
type Item = &'a $item_ref<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.front_idx >= self.data.len() {
return None;
}
if (self.front_idx as i64) == self.back_idx + self.stride as i64 {
return None;
}
let data = <$item_ref<T> as crate::dst::FromSlice<$wrapper_type<T>>>::from_slice(
&self.data[self.front_idx..self.front_idx + self.stride]
);
self.front_idx += self.stride;
Some(data)
}
}
impl<'a, T> std::iter::DoubleEndedIterator for $t<'a, T> where T: Clone $(+ $t_bounds)* {
fn next_back(&mut self) -> Option<Self::Item> {
if self.back_idx < 0 {
return None;
}
if (self.front_idx as i64) == self.back_idx + self.stride as i64 {
return None;
}
let start = self.back_idx as usize;
let data = <$item_ref<T> as crate::dst::FromSlice<$wrapper_type<T>>>::from_slice(
&self.data[start..start + self.stride]
);
self.back_idx -= (self.stride as i64);
Some(data)
}
}
/// A mutable iterator to access references to an underlying type.
pub struct $t_mut<'a, T> where T: Clone $(+ $t_bounds)* {
data: *mut $wrapper_type<T>,
len: usize,
stride: usize,
idx: usize,
_phantom: std::marker::PhantomData<&'a T>,
}
impl<'a, T> $t_mut<'a, T> where T: Clone $(+ $t_bounds)* {
/// Create a new iterator that emits references to the contained type
/// by striding over the underlying data, mutably.
pub fn new(data: &'a mut [$wrapper_type<T>], stride: usize) -> Self {
assert_eq!(data.len() % stride, 0);
Self {
idx: 0,
stride,
data: data.as_mut_ptr(),
len: data.len(),
_phantom: std::marker::PhantomData
}
}
#[inline]
/// The total number of items this iterator will emit and has emitted.
///
/// # Remarks
/// This method returns the same value regardless of how many times
/// `next` has been called.
///
/// This operation does not consume the iterator.
pub fn len(&self) -> usize {
self.len / self.stride
}
/// Returns true if the iterator is empty.
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<'a, T> std::iter::Iterator for $t_mut<'a, T> where T: Clone $(+ $t_bounds)* {
type Item = &'a mut $item_ref<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.idx == self.len {
return None;
}
// Since the slices emitted by this iterator don't overlap, this is sound.
let data = unsafe {
let slice = self.data.add(self.idx);
std::slice::from_raw_parts_mut(slice, self.stride)
};
self.idx += self.stride;
Some(<$item_ref<T> as crate::dst::FromMutSlice<$wrapper_type<T>>>::from_mut_slice(data))
}
}
}
};
}
pub type NoWrapper<T> = T;
pub trait OverlaySize {
type Inputs: Copy + Clone;
fn size(t: Self::Inputs) -> usize;
}
impl<S: Pod> OverlaySize for [S] {
type Inputs = usize;
fn size(t: Self::Inputs) -> usize {
t
}
}
pub trait FromSlice<T> {
fn from_slice(data: &[T]) -> &Self;
}
pub trait FromMutSlice<T> {
fn from_mut_slice(data: &mut [T]) -> &mut Self;
}

View File

@@ -0,0 +1,98 @@
use serde::{Deserialize, Serialize};
use sunscreen_math::Zero;
use crate::{
dst::{FromMutSlice, FromSlice, OverlaySize},
entities::PolynomialRef,
ops::{bootstrapping::generate_bivariate_lut, encryption::trivially_encrypt_glwe_ciphertext},
scratch::allocate_scratch_ref,
CarryBits, GlweDef, GlweDimension, PlaintextBits, Torus, TorusOps,
};
use super::{GlweCiphertextRef, UnivariateLookupTableRef};
dst! {
/// Lookup table for a bivariate function used during
/// [`programmable_bootstrap_bivariate`](crate::ops::bootstrapping::programmable_bootstrap_bivariate)
/// See that function for more details.
BivariateLookupTable,
BivariateLookupTableRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for BivariateLookupTableRef<S> {
type Inputs = GlweDimension;
fn size(t: Self::Inputs) -> usize {
GlweCiphertextRef::<S>::size(t)
}
}
impl<S: TorusOps> BivariateLookupTable<S> {
/// Creates a [BivariateLookupTable] filled with the result of
/// a function applied to every possible pair of plaintext inputs.
pub fn trivial_from_fn<F>(
map: F,
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
carry_bits: CarryBits,
) -> Self
where
F: Fn(u64, u64) -> u64,
{
let mut lut = BivariateLookupTable {
data: vec![Torus::zero(); BivariateLookupTableRef::<S>::size(glwe.dim)],
};
lut.fill_trivial_from_fn(map, glwe, plaintext_bits, carry_bits);
lut
}
}
impl<S: TorusOps> BivariateLookupTableRef<S> {
/// Convert a [BivariateLookupTableRef] to a [UnivariateLookupTableRef].
pub fn as_univariate(&self) -> &UnivariateLookupTableRef<S> {
// This works because a bivariate lookup table is just a univariate
// lookup table.
UnivariateLookupTableRef::from_slice(&self.data)
}
/// Gets a copy of the underlying [GlweCiphertextRef] from the
/// [BivariateLookupTableRef].
pub fn glwe(&self) -> &GlweCiphertextRef<S> {
GlweCiphertextRef::from_slice(&self.data)
}
/// Gets a mutable copy of the underlying [GlweCiphertextRef] from the
/// [BivariateLookupTableRef].
pub fn glwe_mut(&mut self) -> &mut GlweCiphertextRef<S> {
GlweCiphertextRef::from_mut_slice(&mut self.data)
}
/// Fills the [BivariateLookupTableRef] with the result of a bivariate
/// function.
pub fn fill_trivial_from_fn<F: Fn(u64, u64) -> u64>(
&mut self,
map: F,
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
carry_bits: CarryBits,
) {
allocate_scratch_ref!(poly, PolynomialRef<Torus<S>>, (glwe.dim.polynomial_degree));
generate_bivariate_lut(poly, map, glwe, plaintext_bits, carry_bits);
trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe);
}
/// Creates a lookup table filled with the same value at every entry.
pub fn fill_with_constant(&mut self, val: S, glwe: &GlweDef, plaintext_bits: PlaintextBits) {
self.clear();
for o in self.glwe_mut().b_mut(glwe).coeffs_mut() {
*o = Torus::encode(val, plaintext_bits);
}
}
}

View File

@@ -0,0 +1,137 @@
use num::{Complex, Zero};
use serde::{Deserialize, Serialize};
use crate::{
dst::{NoWrapper, OverlaySize},
entities::{
GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, GgswCiphertextFftRef,
GgswCiphertextIterator, GgswCiphertextIteratorMut, GgswCiphertextRef,
},
GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
};
dst! {
/// An encrypted amount to rotate the polynomials in a GLWE ciphertext by.
/// The [BlindRotationShiftFft] variant of this type is used by the
/// [`blind_rotate`](crate::ops::bootstrapping::blind_rotation) function.
BlindRotationShift,
BlindRotationShiftRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for BlindRotationShiftRef<S> {
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
let n_bits = (t.0.polynomial_degree.0 as u64).ilog2() as usize;
GgswCiphertextRef::<S>::size(t) * n_bits
}
}
impl<S: TorusOps> BlindRotationShift<S> {
/// Create a new zero [BlindRotationShift] with the given parameters.
///
/// A blind rotation shift is a collection of GGSW ciphertexts, each of
/// which encrypts a single bit in position `i` representing how much to
/// shift the input by as `2^i`.
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
let len = BlindRotationShiftRef::<S>::size((params.dim, radix.count));
Self {
data: vec![Torus::zero(); len],
}
}
}
impl<S: TorusOps> BlindRotationShiftRef<S> {
/// Iterate over the rows of the [BlindRotationShift].
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextIterator<S> {
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
GgswCiphertextIterator::new(self.as_slice(), stride)
}
/// Iterate over the rows of the [BlindRotationShift] mutably.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextIteratorMut<S> {
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
GgswCiphertextIteratorMut::new(self.as_mut_slice(), stride)
}
}
dst! {
/// An encrypted amount to rotate the polynomials in a GLWE ciphertext by.
/// Used by the
/// [`blind_rotate`](crate::ops::bootstrapping::blind_rotation) function.
/// The non-FFT version of this type is [BlindRotationShift].
BlindRotationShiftFft,
BlindRotationShiftFftRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
()
}
impl OverlaySize for BlindRotationShiftFftRef<Complex<f64>> {
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
let n_bits = (t.0.polynomial_degree.0 as u64).ilog2() as usize;
GgswCiphertextFftRef::<Complex<f64>>::size(t) * n_bits
}
}
impl BlindRotationShiftFft<Complex<f64>> {
/// Create a new zero [BlindRotationShiftFft] with the given parameters.
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
let len = BlindRotationShiftFftRef::size((params.dim, radix.count));
Self {
data: vec![Complex::zero(); len],
}
}
}
impl BlindRotationShiftFftRef<Complex<f64>> {
/// Iterate over the rows of the [BlindRotationShiftFft].
pub fn rows(
&self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextFftIterator<Complex<f64>> {
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
GgswCiphertextFftIterator::new(self.as_slice(), stride)
}
/// Iterate over the rows of the [BlindRotationShiftFft] mutably.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextFftIteratorMut<Complex<f64>> {
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
GgswCiphertextFftIteratorMut::new(self.as_mut_slice(), stride)
}
/// Compute the inverse FFT of the [BlindRotationShiftFft] and store the
/// result in the result [BlindRotationShift].
pub fn ifft<S: TorusOps>(
&self,
result: &mut BlindRotationShiftRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) {
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
s.ifft(r, params, radix);
}
}
}

View File

@@ -0,0 +1,160 @@
use num::{Complex, Zero};
use serde::{Deserialize, Serialize};
use crate::{
dst::{NoWrapper, OverlaySize},
entities::{
GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, GgswCiphertextFftRef,
GgswCiphertextIterator, GgswCiphertextIteratorMut, GgswCiphertextRef,
},
GlweDef, GlweDimension, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
};
dst! {
/// Keys used for bootstrapping. The [BootstrapKeyFft] variant of this type
/// is used by the bootstrapping functions such as
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
BootstrapKey,
BootstrapKeyRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for BootstrapKeyRef<S> {
type Inputs = (LweDimension, GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GgswCiphertextRef::<S>::size((t.1, t.2)) * t.0 .0
}
}
impl<S: TorusOps> BootstrapKey<S> {
/// Create a new zero [BootstrapKey] with the given parameters.
///
/// A bootstrapping key is a collection of GGSW ciphertexts, each of which
/// encrypts a single bit of an LWE secret key. This representation cannot
/// be directly with the bootstrapping functions, but the FFT version of the
/// bootstrapping key that can be used with the bootstrapping functions can
/// be created by calling the [BootstrapKeyRef::fft] method
pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self {
let len = BootstrapKeyRef::<S>::size((lwe_params.dim, glwe_params.dim, radix.count));
Self {
data: vec![Torus::zero(); len],
}
}
}
impl<S: TorusOps> BootstrapKeyRef<S> {
/// Iterate over the rows of the [BootstrapKey].
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextIterator<S> {
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
GgswCiphertextIterator::new(self.as_slice(), stride)
}
/// Iterate over the rows of the [BootstrapKey] mutably.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextIteratorMut<S> {
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
GgswCiphertextIteratorMut::new(self.as_mut_slice(), stride)
}
/// Perform an FFT on the [BootstrapKey] to obtain a [BootstrapKeyFft].
pub fn fft(
&self,
result: &mut BootstrapKeyFftRef<Complex<f64>>,
params: &GlweDef,
radix: &RadixDecomposition,
) {
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
s.fft(r, params, radix);
}
}
}
dst! {
/// Keys used for bootstrapping. Used by the bootstrapping functions such as
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
/// The non-FFT variant of this type is [BootstrapKey].
BootstrapKeyFft,
BootstrapKeyFftRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
()
}
impl OverlaySize for BootstrapKeyFftRef<Complex<f64>> {
type Inputs = (LweDimension, GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GgswCiphertextFftRef::<Complex<f64>>::size((t.1, t.2)) * t.0 .0
}
}
impl BootstrapKeyFft<Complex<f64>> {
/// Create a new zero [BootstrapKeyFft] with the given parameters.
///
/// A bootstrapping key is a collection of GGSW ciphertexts, each of which
/// encrypts a single bit of an LWE secret key. In this representation, the
/// GGSW ciphertexts are in the frequency domain and can be used directly by
/// the bootstrapping functions such as
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self {
let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count));
Self {
data: vec![Complex::zero(); len],
}
}
}
impl BootstrapKeyFftRef<Complex<f64>> {
/// Iterate over the rows of the [BootstrapKeyFft].
pub fn rows(
&self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextFftIterator<Complex<f64>> {
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
GgswCiphertextFftIterator::new(self.as_slice(), stride)
}
/// Iterate over the rows of the [BootstrapKeyFft] mutably.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextFftIteratorMut<Complex<f64>> {
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
GgswCiphertextFftIteratorMut::new(self.as_mut_slice(), stride)
}
/// Perform an IFFT on the [BootstrapKeyFft] to obtain a [BootstrapKey].
pub fn ifft<S: TorusOps>(
&self,
result: &mut BootstrapKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) {
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
s.ifft(r, params, radix);
}
}
/// Asserts that the [BootstrapKeyFft] is valid for the given parameters.
#[inline(always)]
pub(crate) fn assert_valid(&self, lwe: &LweDef, glwe: &GlweDef, radix: &RadixDecomposition) {
assert_eq!(
self.as_slice().len(),
BootstrapKeyFftRef::size((lwe.dim, glwe.dim, radix.count))
);
}
}

View File

@@ -0,0 +1,98 @@
use serde::{Deserialize, Serialize};
use sunscreen_math::Zero;
use crate::{
dst::OverlaySize, GlweDef, GlweDimension, LweDef, LweDimension,
PrivateFunctionalKeyswitchLweCount, RadixCount, RadixDecomposition, Torus, TorusOps,
};
use super::{
PrivateFunctionalKeyswitchKeyIter, PrivateFunctionalKeyswitchKeyIterMut,
PrivateFunctionalKeyswitchKeyRef,
};
dst! {
/// Key for Circuit Bootstrapping Key Switching.
CircuitBootstrappingKeyswitchKeys,
CircuitBootstrappingKeyswitchKeysRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for CircuitBootstrappingKeyswitchKeysRef<S> {
type Inputs = (LweDimension, GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
PrivateFunctionalKeyswitchKeyRef::<S>::size((
t.0,
t.1,
t.2,
PrivateFunctionalKeyswitchLweCount(1),
)) * (t.1.size.0 + 1)
}
}
impl<S: TorusOps> CircuitBootstrappingKeyswitchKeys<S> {
/// Allocate a new [`CircuitBootstrappingKeyswitchKeys`] for the given parameters.
pub fn new(from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) -> Self {
let len = CircuitBootstrappingKeyswitchKeysRef::<S>::size((
from_lwe.dim,
to_glwe.dim,
radix.count,
));
Self {
data: vec![Torus::zero(); len],
}
}
}
impl<S: TorusOps> CircuitBootstrappingKeyswitchKeysRef<S> {
/// Get an iterator over the contained [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s.
pub fn keys(
&self,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
) -> PrivateFunctionalKeyswitchKeyIter<S> {
let stride = PrivateFunctionalKeyswitchKeyRef::<S>::size((
lwe.dim,
glwe.dim,
radix.count,
PrivateFunctionalKeyswitchLweCount(1),
));
PrivateFunctionalKeyswitchKeyIter::new(self.as_slice(), stride)
}
/// Get a mutable iterator over the contained [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s.
pub fn keys_mut(
&mut self,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
) -> PrivateFunctionalKeyswitchKeyIterMut<S> {
let stride = PrivateFunctionalKeyswitchKeyRef::<S>::size((
lwe.dim,
glwe.dim,
radix.count,
PrivateFunctionalKeyswitchLweCount(1),
));
PrivateFunctionalKeyswitchKeyIterMut::new(self.as_mut_slice(), stride)
}
#[inline(always)]
/// Assert these keys are valid under the given parameters.
pub fn assert_valid(&self, from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) {
assert_eq!(
self.as_slice().len(),
CircuitBootstrappingKeyswitchKeysRef::<S>::size((
from_lwe.dim,
to_glwe.dim,
radix.count,
))
);
}
}

View File

@@ -0,0 +1,114 @@
use num::{Complex, Zero};
use serde::{Deserialize, Serialize};
use crate::{
dst::OverlaySize, ops::ciphertext::external_product_ggsw_glwe, GlweDef, GlweDimension,
RadixCount, RadixDecomposition, Torus, TorusOps,
};
use super::{
GgswCiphertextFftRef, GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef,
GlweCiphertext, GlweCiphertextRef,
};
dst! {
/// A GGSW ciphertext. For the FFT variant, see
/// [`GgswCiphertextFft`](crate::entities::GgswCiphertextFft).
GgswCiphertext,
GgswCiphertextRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
dst_iter! { GgswCiphertextIterator, GgswCiphertextIteratorMut, Torus, GgswCiphertextRef, (TorusOps,)}
impl<S> OverlaySize for GgswCiphertextRef<S>
where
S: TorusOps,
{
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GlevCiphertextRef::<S>::size(t) * (t.0.size.0 + 1)
}
}
impl<S> GgswCiphertext<S>
where
S: TorusOps,
{
/// Create a new zero GGSW ciphertext with the given parameters.
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
let elems = GgswCiphertextRef::<S>::size((params.dim, radix.count));
Self {
data: vec![Torus::zero(); elems],
}
}
/// Create a new GGSW ciphertext from a slice of Torus elements.
pub fn from_slice(data: &[Torus<S>], params: &GlweDef, radix: &RadixDecomposition) -> Self {
let elems = GgswCiphertextRef::<S>::size((params.dim, radix.count));
assert_eq!(data.len(), elems);
Self {
data: data.to_vec(),
}
}
/// Computes the external product between a GGSW ciphertext and a GLWE ciphertext.
/// GGSW ⊡ GLWE -> GLWE
pub fn external_product(
&self,
glwe: &GlweCiphertextRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlweCiphertext<S> {
external_product_ggsw_glwe(self, glwe, params, radix)
}
}
impl<S> GgswCiphertextRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the rows of the GGSW ciphertext, which are
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GlevCiphertextIterator<S> {
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
GlevCiphertextIterator::new(&self.data, stride)
}
/// Returns a mutable iterator over the rows of the GGSW ciphertext, which are
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextIteratorMut<S> {
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
GlevCiphertextIteratorMut::new(&mut self.data, stride)
}
/// Compute the FFT of each of the GLWE ciphertexts in the GGSW ciphertext.
/// The result is stored in `result`.
pub fn fft(
&self,
result: &mut GgswCiphertextFftRef<Complex<f64>>,
params: &GlweDef,
radix: &RadixDecomposition,
) {
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
s.fft(r, params);
}
}
/// Assert that the GGSW ciphertext is valid for the given parameters.
#[inline(always)]
pub(crate) fn assert_valid(&self, glwe: &GlweDef, radix: &RadixDecomposition) {
assert_eq!(self.as_slice().len(), Self::size((glwe.dim, radix.count)));
}
}

View File

@@ -0,0 +1,79 @@
use num::{Complex, Zero};
use serde::{Deserialize, Serialize};
use crate::{
dst::{NoWrapper, OverlaySize},
entities::GgswCiphertextRef,
GlweDef, GlweDimension, RadixCount, RadixDecomposition, TorusOps,
};
use super::{GlevCiphertextFftIterator, GlevCiphertextFftIteratorMut, GlevCiphertextFftRef};
dst! {
/// The FFT variant of a GGSW ciphertext. See
/// [`GgswCiphertext`](crate::entities::GgswCiphertext) for more details.
GgswCiphertextFft,
GgswCiphertextFftRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
()
}
dst_iter! { GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, NoWrapper, GgswCiphertextFftRef, ()}
impl OverlaySize for GgswCiphertextFftRef<Complex<f64>> {
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GlevCiphertextFftRef::<Complex<f64>>::size(t) * (t.0.size.0 + 1)
}
}
impl GgswCiphertextFft<Complex<f64>> {
/// Creates a new GGSW ciphertext with FFT representation.
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextFft<Complex<f64>> {
let len = GgswCiphertextFftRef::size((params.dim, radix.count));
GgswCiphertextFft {
data: vec![Complex::zero(); len],
}
}
}
impl GgswCiphertextFftRef<Complex<f64>> {
/// Returns an iterator over the rows of the GGSW ciphertext, which are
/// [GlevCiphertextFft](crate::entities::GlevCiphertextFft)s.
pub fn rows(
&self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextFftIterator<Complex<f64>> {
let stride = GlevCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
GlevCiphertextFftIterator::new(self.as_slice(), stride)
}
/// Returns a mutable iterator over the rows of the GGSW ciphertext, which are
/// [GlevCiphertextFft](crate::entities::GlevCiphertextFft)s.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextFftIteratorMut<Complex<f64>> {
let stride = GlevCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
GlevCiphertextFftIteratorMut::new(self.as_mut_slice(), stride)
}
/// Computes the inverse FFT of the GGSW ciphertexts and stores computation
/// in `result`.
pub fn ifft<S: TorusOps>(
&self,
result: &mut GgswCiphertextRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) {
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
s.ifft(r, params);
}
}
}

View File

@@ -0,0 +1,58 @@
use num::Complex;
use serde::{Deserialize, Serialize};
use crate::{dst::OverlaySize, GlweDef, GlweDimension, RadixCount, Torus, TorusOps};
use super::{
GlevCiphertextFftRef, GlweCiphertextIterator, GlweCiphertextIteratorMut, GlweCiphertextRef,
};
dst! {
/// A GLEV ciphertext. For the FFT variant, see
/// [`GlevCiphertextFft`](crate::entities::GlevCiphertextFft).
GlevCiphertext,
GlevCiphertextRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps,)
}
dst_iter! { GlevCiphertextIterator, GlevCiphertextIteratorMut, Torus, GlevCiphertextRef, (TorusOps,)}
impl<S> OverlaySize for GlevCiphertextRef<S>
where
S: TorusOps,
{
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GlweCiphertextRef::<S>::size(t.0) * t.1 .0
}
}
impl<S> GlevCiphertextRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the rows of the GLEV ciphertext, which are
/// [`GlweCiphertext`](crate::entities::GlweCiphertext)s.
pub fn glwe_ciphertexts(&self, params: &GlweDef) -> GlweCiphertextIterator<S> {
GlweCiphertextIterator::new(&self.data, GlweCiphertextRef::<S>::size(params.dim))
}
/// Returns a mutable iterator over the rows of the GLEV ciphertext, which are
/// [`GlweCiphertext](crate::entities::GlweCiphertext)s.
pub fn glwe_ciphertexts_mut(&mut self, params: &GlweDef) -> GlweCiphertextIteratorMut<S> {
GlweCiphertextIteratorMut::new(&mut self.data, GlweCiphertextRef::<S>::size(params.dim))
}
/// Compute the FFT of each of the GLWE ciphertexts in the GLEV ciphertext.
/// The result is stored in `result`.
pub fn fft(&self, result: &mut GlevCiphertextFftRef<Complex<f64>>, params: &GlweDef) {
for (i, fft) in self
.glwe_ciphertexts(params)
.zip(result.glwe_ciphertexts_mut(params))
{
i.fft(fft, params);
}
}
}

View File

@@ -0,0 +1,65 @@
use num::Complex;
use serde::{Deserialize, Serialize};
use crate::{
dst::{NoWrapper, OverlaySize},
GlweDef, GlweDimension, RadixCount, TorusOps,
};
use super::{
GlevCiphertextRef, GlweCiphertextFftIterator, GlweCiphertextFftIteratorMut,
GlweCiphertextFftRef,
};
dst! {
/// The FFT variant of a GLEV ciphertext. See
/// [GlevCiphertext](crate::entities::GlevCiphertext) for more details.
GlevCiphertextFft,
GlevCiphertextFftRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
()
}
dst_iter! { GlevCiphertextFftIterator, GlevCiphertextFftIteratorMut, NoWrapper, GlevCiphertextFftRef, ()}
impl OverlaySize for GlevCiphertextFftRef<Complex<f64>> {
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GlweCiphertextFftRef::<Complex<f64>>::size(t.0) * t.1 .0
}
}
impl GlevCiphertextFftRef<Complex<f64>> {
/// Returns an iterator over the rows of the GLEV ciphertext, which are
/// [`GlweCiphertextFft`](crate::entities::GlweCiphertextFft)s.
pub fn glwe_ciphertexts(&self, params: &GlweDef) -> GlweCiphertextFftIterator<Complex<f64>> {
GlweCiphertextFftIterator::new(
&self.data,
GlweCiphertextFftRef::<Complex<f64>>::size(params.dim),
)
}
/// Returns a mutable iterator over the rows of the GLEV ciphertext, which are
/// [`GlweCiphertextFft`](crate::entities::GlweCiphertextFft)s.
pub fn glwe_ciphertexts_mut(
&mut self,
params: &GlweDef,
) -> GlweCiphertextFftIteratorMut<Complex<f64>> {
GlweCiphertextFftIteratorMut::new(
&mut self.data,
GlweCiphertextFftRef::<Complex<f64>>::size(params.dim),
)
}
/// Computes the inverse FFT of the GLEV ciphertexts and stores computation
/// in `result`.
pub fn ifft<S: TorusOps>(&self, result: &mut GlevCiphertextRef<S>, params: &GlweDef) {
for (i, ifft) in self
.glwe_ciphertexts(params)
.zip(result.glwe_ciphertexts_mut(params))
{
i.ifft(ifft, params);
}
}
}

View File

@@ -0,0 +1,159 @@
use num::{Complex, Zero};
use serde::{Deserialize, Serialize};
use crate::{
dst::{FromMutSlice, FromSlice, OverlaySize},
entities::GgswCiphertextRef,
macros::{impl_binary_op, impl_unary_op},
ops::ciphertext::external_product_ggsw_glwe,
GlweDef, GlweDimension, RadixDecomposition, Torus, TorusOps,
};
use super::{
GlweCiphertextFftRef, GlweSecretKeyRef, PolynomialIterator, PolynomialIteratorMut,
PolynomialRef,
};
dst! {
/// A GLWE ciphertext.
GlweCiphertext,
GlweCiphertextRef,
Torus,
(Debug, Clone, Serialize, Deserialize),
(TorusOps)
}
dst_iter! { GlweCiphertextIterator, GlweCiphertextIteratorMut, Torus, GlweCiphertextRef, (TorusOps,) }
// Also implements the assign operators.
impl_binary_op!(Add, GlweCiphertext, (TorusOps,));
impl_binary_op!(Sub, GlweCiphertext, (TorusOps,));
impl_unary_op!(Neg, GlweCiphertext);
impl<S> OverlaySize for GlweCiphertextRef<S>
where
S: TorusOps,
{
type Inputs = GlweDimension;
fn size(t: Self::Inputs) -> usize {
// We have `n` a polynomials plus 1 b polynomial each of degree d.
GlweSecretKeyRef::<S>::size(t) + t.polynomial_degree.0
}
}
impl<S> GlweCiphertext<S>
where
S: TorusOps,
{
/// Initialize an empty (zero) GLWE ciphertext
pub fn new(params: &GlweDef) -> GlweCiphertext<S> {
params.dim.assert_valid();
let len = GlweCiphertextRef::<S>::size(params.dim);
let data = (0..len).map(|_| Torus::<S>::zero()).collect::<Vec<_>>();
GlweCiphertext { data }
}
/// Computes the external product of a GLWE ciphertext and a GGSW ciphertext.
/// GGSW ⊡ GLWE -> GLWE
pub fn external_product(
&self,
ggsw: &GgswCiphertextRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlweCiphertext<S> {
external_product_ggsw_glwe(ggsw, self, params, radix)
}
/// Generate a GLWE ciphertext from a slice of [crate::TorusOps] elements.
pub fn from_slice(data: &[S], params: &GlweDef) -> GlweCiphertext<S> {
assert_eq!(data.len(), GlweCiphertextRef::<S>::size(params.dim));
GlweCiphertext {
data: data
.iter()
.map(|x| Torus::from(*x))
.collect::<Vec<Torus<S>>>(),
}
}
}
impl<S> GlweCiphertextRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
pub fn a_b(
&self,
params: &GlweDef,
) -> (PolynomialIterator<Torus<S>>, &PolynomialRef<Torus<S>>) {
let (a, b) = self.data.as_ref().split_at(self.split_idx(params));
(
PolynomialIterator::new(a, params.dim.polynomial_degree.0),
PolynomialRef::from_slice(b),
)
}
/// Returns an interator over the a polynomials in a GLWE ciphertext.
pub fn a(&self, params: &GlweDef) -> PolynomialIterator<Torus<S>> {
self.a_b(params).0
}
/// Returns a reference to the b polynomial in a GLWE ciphertext.
pub fn b(&self, params: &GlweDef) -> &PolynomialRef<Torus<S>> {
self.a_b(params).1
}
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
pub fn a_b_mut(
&mut self,
params: &GlweDef,
) -> (
PolynomialIteratorMut<Torus<S>>,
&mut PolynomialRef<Torus<S>>,
) {
let polynomial_degree = params.dim.polynomial_degree;
let split_idx = self.split_idx(params);
let (a, b) = self.data.as_mut().split_at_mut(split_idx);
(
PolynomialIteratorMut::new(a, polynomial_degree.0),
PolynomialRef::from_mut_slice(b),
)
}
/// Returns a mutable iterator over the a polynomials in a GLWE ciphertext.
pub fn a_mut(&mut self, params: &GlweDef) -> PolynomialIteratorMut<Torus<S>> {
self.a_b_mut(params).0
}
/// Returns a mutable reference to the b polynomial in a GLWE ciphertext.
pub fn b_mut(&mut self, params: &GlweDef) -> &mut PolynomialRef<Torus<S>> {
self.a_b_mut(params).1
}
fn split_idx(&self, params: &GlweDef) -> usize {
params.dim.size.0 * params.dim.polynomial_degree.0
}
/// Create an FFT transformed version of `self` stored to result.
pub fn fft(&self, result: &mut GlweCiphertextFftRef<Complex<f64>>, params: &GlweDef) {
for (a, fft) in self.a(params).zip(result.a_mut(params)) {
a.fft(fft);
}
self.b(params).fft(result.b_mut(params));
}
#[inline(always)]
pub(crate) fn assert_valid(&self, params: &GlweDef) {
assert_eq!(
self.as_slice().len(),
GlweCiphertextRef::<S>::size(params.dim)
)
}
}

View File

@@ -0,0 +1,139 @@
use num::{complex::Complex64, Complex, Zero};
use serde::{Deserialize, Serialize};
use crate::{
dst::{FromMutSlice, FromSlice, NoWrapper, OverlaySize},
GlweDef, GlweDimension, TorusOps,
};
use super::{GlweCiphertextRef, PolynomialFftIterator, PolynomialFftIteratorMut, PolynomialFftRef};
dst! {
/// The FFT variant of a GLWE ciphertext. See
/// [`GlweCiphertext`](crate::entities::GlweCiphertext) for more details.
GlweCiphertextFft,
GlweCiphertextFftRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
()
}
dst_iter! { GlweCiphertextFftIterator, GlweCiphertextFftIteratorMut, NoWrapper, GlweCiphertextFftRef, ()}
impl OverlaySize for GlweCiphertextFftRef<Complex<f64>> {
type Inputs = GlweDimension;
fn size(t: Self::Inputs) -> usize {
// FFT polynomials are half the length of their standard counterparts.
PolynomialFftRef::<Complex<f64>>::size(t.polynomial_degree) * (t.size.0 + 1)
}
}
impl GlweCiphertextFft<Complex<f64>> {
/// Creates a new zero GLWE ciphertext in the frequency domain.
pub fn new(params: &GlweDef) -> Self {
let len = GlweCiphertextFftRef::size(params.dim);
Self {
data: vec![Complex::zero(); len],
}
}
}
impl GlweCiphertextFftRef<Complex<f64>> {
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
pub fn a_b(
&self,
params: &GlweDef,
) -> (
PolynomialFftIterator<Complex<f64>>,
&PolynomialFftRef<Complex<f64>>,
) {
let (a, b) = self.as_slice().split_at(self.split_idx(params));
(
PolynomialFftIterator::new(a, params.dim.polynomial_degree.0 / 2),
PolynomialFftRef::from_slice(b),
)
}
/// Returns an interator over the a polynomials in a GLWE ciphertext.
pub fn a(&self, params: &GlweDef) -> PolynomialFftIterator<Complex<f64>> {
self.a_b(params).0
}
/// Returns a reference to the b polynomial in a GLWE ciphertext.
pub fn b(&self, params: &GlweDef) -> &PolynomialFftRef<Complex64> {
self.a_b(params).1
}
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
pub fn a_b_mut(
&mut self,
params: &GlweDef,
) -> (
PolynomialFftIteratorMut<Complex<f64>>,
&mut PolynomialFftRef<Complex<f64>>,
) {
let polynomial_degree = params.dim.polynomial_degree;
let split_idx = self.split_idx(params);
let (a, b) = self.as_mut_slice().split_at_mut(split_idx);
(
PolynomialFftIteratorMut::new(a, polynomial_degree.0 / 2),
PolynomialFftRef::from_mut_slice(b),
)
}
/// Returns a mutable iterator over the a polynomials in a GLWE ciphertext.
pub fn a_mut(&mut self, params: &GlweDef) -> PolynomialFftIteratorMut<Complex<f64>> {
self.a_b_mut(params).0
}
/// Returns a mutable reference to the b polynomial in a GLWE ciphertext.
pub fn b_mut(&mut self, params: &GlweDef) -> &mut PolynomialFftRef<Complex<f64>> {
self.a_b_mut(params).1
}
#[inline(always)]
fn split_idx(&self, params: &GlweDef) -> usize {
params.dim.size.0 * params.dim.polynomial_degree.0 / 2
}
/// Computes the inverse FFT of the GLWE ciphertext and stores the
/// computation in `result`.
pub fn ifft<T: TorusOps>(&self, result: &mut GlweCiphertextRef<T>, params: &GlweDef) {
for (a, fft) in self.a(params).zip(result.a_mut(params)) {
a.ifft(fft);
}
self.b(params).ifft(result.b_mut(params));
}
}
#[cfg(test)]
mod tests {
use crate::{entities::Polynomial, high_level::*, PlaintextBits, GLWE_1_1024_80};
#[test]
fn can_decrypt_glwe_after_fft_roundtrip() {
let params = GLWE_1_1024_80;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_glwe_sk(&params);
let pt = (0..params.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>();
let pt = Polynomial::new(&pt);
let mut ct = encryption::encrypt_glwe(&pt, &sk, &params, bits);
let fft = fft::fft_glwe(&ct, &params);
fft.ifft(&mut ct, &params);
let actual = encryption::decrypt_glwe(&ct, &sk, &params, bits);
assert_eq!(actual, pt);
}
}

View File

@@ -0,0 +1,75 @@
use num::Zero;
use serde::{Deserialize, Serialize};
use crate::{
dst::OverlaySize, GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
};
use super::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef};
// TODO: This GLWE keyswitch only works for switching to a new key with the same
// parameter. Copy what is above but changed for polynomials to enable
// converting to a different key parameter set.
dst! {
/// A GLWE keyswitch key used to switch a ciphertext from one key to another.
/// See [`module`](crate::ops::keyswitch) documentation for more details.
GlweKeyswitchKey,
GlweKeyswitchKeyRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps,)
}
impl<S> OverlaySize for GlweKeyswitchKeyRef<S>
where
S: TorusOps,
{
type Inputs = (GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GlevCiphertextRef::<S>::size(t) * (t.0.size.0)
}
}
impl<S> GlweKeyswitchKey<S>
where
S: TorusOps,
{
/// Creates a new GLWE keyswitch key. This enables switching to a new key as
/// well as switching from the `original_params` that define the first key
/// to the `new_params` that define the second key.
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
// TODO: Shouldn't this function take 2 GlweDefs?
// Ryan: to whoever wrote this, yes, see the above todo next to the dst.
let elems = GlweKeyswitchKeyRef::<S>::size((params.dim, radix.count));
Self {
data: vec![Torus::zero(); elems],
}
}
}
impl<S> GlweKeyswitchKeyRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the rows of the GLWE keyswitch key, which are
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GlevCiphertextIterator<S> {
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
GlevCiphertextIterator::new(&self.data, stride)
}
/// Returns a mutable iterator over the rows of the GLWE keyswitch key, which are
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
pub fn rows_mut(
&mut self,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextIteratorMut<S> {
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
GlevCiphertextIteratorMut::new(&mut self.data, stride)
}
}

View File

@@ -0,0 +1,385 @@
use serde::{Deserialize, Serialize};
use crate::{
dst::{FromSlice, NoWrapper, OverlaySize},
entities::GgswCiphertext,
macros::{impl_binary_op, impl_unary_op},
ops::encryption::{
decrypt_glwe_ciphertext, encrypt_ggsw_ciphertext, encrypt_glwe_ciphertext_secret,
},
rand::{binary, uniform_torus},
GlweDef, GlweDimension, PlaintextBits, RadixDecomposition, Torus, TorusOps,
};
use super::{
GlweCiphertext, GlweCiphertextRef, LweSecretKeyRef, Polynomial, PolynomialIterator,
PolynomialIteratorMut, PolynomialRef,
};
dst! {
/// A GLWE secret key. This is a list of `s` polynomials of degree `n` where
/// `n` is the polynomial degree. The length of the list is `k` where `k` is
/// the size of the GLWE secret key.
GlweSecretKey,
GlweSecretKeyRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
(TorusOps,)
}
// We can use these macros (which does operations on the underlying data vector)
// because addition, subtraction, and negation of polynomials is just element
// wise addition, subtraction, and negation of the underlying data vector.
impl_binary_op!(Add, GlweSecretKey, (TorusOps,));
impl_binary_op!(Sub, GlweSecretKey, (TorusOps,));
impl_unary_op!(Neg, GlweSecretKey);
impl<S> OverlaySize for GlweSecretKeyRef<S>
where
S: TorusOps,
{
type Inputs = GlweDimension;
fn size(t: Self::Inputs) -> usize {
PolynomialRef::<S>::size(t.polynomial_degree) * t.size.0
}
}
impl<S> GlweSecretKey<S>
where
S: TorusOps,
{
fn generate(params: &GlweDef, torus_element_generator: impl Fn() -> S) -> GlweSecretKey<S> {
params.dim.assert_valid();
let len = GlweSecretKeyRef::<S>::size(params.dim);
GlweSecretKey {
data: (0..len)
.map(|_| torus_element_generator())
.collect::<Vec<_>>(),
}
}
/// Generate a random binary GLWE secret key.
pub fn generate_binary(params: &GlweDef) -> GlweSecretKey<S> {
Self::generate(params, binary)
}
/// Generate a secret key with uniformly random coefficients. This can be
/// used when performing threshold decryption, which needs random secret
/// keys that are uniform over the entire ciphertext modulus. Uniform
/// secret keys are also valid keys for encryption/decryption but are not
/// widely used.
pub fn generate_uniform(params: &GlweDef) -> GlweSecretKey<S> {
Self::generate(params, || uniform_torus::<S>().inner())
}
}
impl<S> GlweSecretKeyRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the `s` polynomials in a GLWE secret key.
pub fn s(&self, params: &GlweDef) -> PolynomialIterator<S> {
PolynomialIterator::new(&self.data, params.dim.polynomial_degree.0)
}
/// Decrypts and decodes a GLWE ciphertext into a polynomial.
pub fn decrypt_decode_glwe(
&self,
ct: &GlweCiphertextRef<S>,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) -> Polynomial<S>
where
S: TorusOps,
{
let mut result = Polynomial::zero(ct.a_b(params).1.len());
decrypt_glwe_ciphertext(&mut result, ct, self, params);
result.map(|x| x.decode(plaintext_bits))
}
/// Encodes and encrypts a message as a GLWE ciphertext using a secret key.
pub fn encode_encrypt_glwe(
&self,
plaintext: &PolynomialRef<S>,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) -> GlweCiphertext<S>
where
S: TorusOps,
{
let plaintext = plaintext.map(|x| Torus::encode(*x, plaintext_bits));
let mut ct = GlweCiphertext::new(params);
encrypt_glwe_ciphertext_secret(&mut ct, &plaintext, self, params);
ct
}
/// Encodes and encrypts a message as a GGSW ciphertext using a secret key.
pub fn encode_encrypt_ggsw(
&self,
msg: &PolynomialRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) -> GgswCiphertext<S>
where
S: TorusOps,
{
let mut ggsw = GgswCiphertext::new(params, radix);
encrypt_ggsw_ciphertext(&mut ggsw, msg, self, params, radix, plaintext_bits);
ggsw
}
/// Returns a representation of a GLWE secret key as an LWE secret key.
/// This is a LWE secret key with dimension `N * k` where `N` is the
/// polynomial degree and `k` is the size of the GLWE secret key.
pub fn to_lwe_secret_key(&self) -> &LweSecretKeyRef<S> {
LweSecretKeyRef::from_slice(&self.data)
}
#[inline(always)]
pub(crate) fn assert_valid(&self, params: &GlweDef) {
assert_eq!(
self.as_slice().len(),
GlweSecretKeyRef::<S>::size(params.dim)
);
}
}
impl<S> GlweSecretKeyRef<S>
where
S: TorusOps,
{
/// Returns an mutable iterator over the `s` polynomials in a GLWE secret
/// key.
pub fn s_mut(&mut self, params: &GlweDef) -> PolynomialIteratorMut<S> {
PolynomialIteratorMut::new(&mut self.data, params.dim.polynomial_degree.0)
}
}
#[cfg(test)]
mod tests {
use crate::{high_level::*, GLWE_1_1024_80};
use num::traits::{WrappingAdd, WrappingNeg, WrappingSub};
#[test]
fn secret_key_dimensions() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_binary_glwe_sk(&params);
assert_eq!(sk.s(&params).count(), params.dim.size.0);
for s_i in sk.s(&params) {
assert_eq!(s_i.len(), params.dim.polynomial_degree.0);
for s in s_i.coeffs() {
assert!(*s == 0 || *s == 1);
}
}
}
// Addition
#[test]
fn add_secret_keys() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
let sk3 = sk + sk2;
assert_eq!(sk3_expected, sk3.data)
}
#[test]
fn add_assign_secret_keys() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let mut sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
sk2 += sk;
assert_eq!(sk2_expected, sk2.data)
}
#[test]
fn add_secret_key_refs() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
let sk3 = sk.as_ref() + sk2.as_ref();
assert_eq!(sk3_expected, sk3.data)
}
#[test]
fn wrapping_add_secret_keys() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
let sk3 = sk.wrapping_add(&sk2);
assert_eq!(sk3_expected, sk3.data)
}
// Subtraction
#[test]
fn sub_secret_keys() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
let sk3 = sk - sk2;
assert_eq!(sk3_expected, sk3.data)
}
#[test]
fn sub_assign_secret_keys() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let mut sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk2
.data
.iter()
.zip(sk.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
sk2 -= sk;
assert_eq!(sk2_expected, sk2.data)
}
#[test]
fn sub_secret_key_refs() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
let sk3 = sk.as_ref() - sk2.as_ref();
assert_eq!(sk3_expected, sk3.data)
}
#[test]
fn wrapping_sub_secret_keys() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
let sk3 = sk.wrapping_sub(&sk2);
assert_eq!(sk3_expected, sk3.data)
}
// Negation
#[test]
fn neg_secret_key() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2 = -sk;
assert_eq!(sk2_expected, sk2.data)
}
#[test]
fn neg_secret_key_ref() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2 = -sk.as_ref();
assert_eq!(sk2_expected, sk2.data)
}
#[test]
fn wrapping_neg_secret_key() {
let params = GLWE_1_1024_80;
let sk = keygen::generate_binary_glwe_sk(&params);
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2 = sk.wrapping_neg();
assert_eq!(sk2_expected, sk2.data)
}
}

View File

@@ -0,0 +1,44 @@
use serde::{Deserialize, Serialize};
use crate::{dst::OverlaySize, LweDef, LweDimension, RadixCount, Torus, TorusOps};
use super::{LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef};
// Iteration over LWE ciphertexts
dst! {
/// A Lev Ciphertext is a collection of LWE ciphertexts.
LevCiphertext,
LevCiphertextRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps,)
}
dst_iter! { LevCiphertextIterator, LevCiphertextIteratorMut, Torus, LevCiphertextRef, (TorusOps,)}
impl<S> OverlaySize for LevCiphertextRef<S>
where
S: TorusOps,
{
type Inputs = (LweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
LweCiphertextRef::<S>::size(t.0) * t.1 .0
}
}
impl<S> LevCiphertextRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the rows of the Lev ciphertext, which are
/// [`LweCiphertext`](crate::entities::LweCiphertext)s.
pub fn lwe_ciphertexts(&self, params: &LweDef) -> LweCiphertextIterator<S> {
LweCiphertextIterator::new(&self.data, LweCiphertextRef::<S>::size(params.dim))
}
/// Returns a mutable iterator over the rows of the Lev ciphertext, which are
/// [`LweCiphertext`](crate::entities::LweCiphertext)s.
pub fn lwe_ciphertexts_mut(&mut self, params: &LweDef) -> LweCiphertextIteratorMut<S> {
LweCiphertextIteratorMut::new(&mut self.data, LweCiphertextRef::<S>::size(params.dim))
}
}

View File

@@ -0,0 +1,192 @@
use num::Zero;
use serde::{Deserialize, Serialize};
use crate::{
dst::OverlaySize,
macros::{impl_binary_op, impl_unary_op},
LweDef, LweDimension, Torus, TorusOps,
};
dst! {
/// An LWE ciphertext.
LweCiphertext,
LweCiphertextRef,
Torus,
(Clone, Debug,Serialize, Deserialize),
(TorusOps)
}
dst_iter! { LweCiphertextIterator, LweCiphertextIteratorMut, Torus, LweCiphertextRef, (TorusOps,) }
impl_binary_op!(Add, LweCiphertext, (TorusOps,));
impl_binary_op!(Sub, LweCiphertext, (TorusOps,));
impl_unary_op!(Neg, LweCiphertext);
impl<S> OverlaySize for LweCiphertextRef<S>
where
S: TorusOps,
{
type Inputs = LweDimension;
fn size(t: Self::Inputs) -> usize {
t.0 + 1
}
}
impl<S: TorusOps> LweCiphertext<S> {
/// Create a new LWE ciphertext with all coefficients set to zero.
pub fn new(params: &LweDef) -> Self {
Self::zero(params)
}
/// Create a new LWE ciphertext with all coefficients set to zero.
pub fn zero(params: &LweDef) -> Self {
let data = vec![Torus::zero(); LweCiphertextRef::<S>::size(params.dim)];
Self { data }
}
}
impl<S: TorusOps> LweCiphertextRef<S> {
fn split_at_idx(&self, params: &LweDef) -> usize {
params.dim.0
}
/// Returns a reference to the mask A and body B in an LWE ciphertext.
pub fn a_b(&self, params: &LweDef) -> (&[Torus<S>], &Torus<S>) {
let (a, b) = self.data.split_at(self.split_at_idx(params));
(a, &b[0])
}
/// Returns a reference to the mask A in an LWE ciphertext.
pub fn a(&self, params: &LweDef) -> &[Torus<S>] {
let (a, _) = self.a_b(params);
a
}
/// Returns a reference to the body B in an LWE ciphertext.
pub fn b(&self, params: &LweDef) -> &Torus<S> {
let (_, b) = self.a_b(params);
b
}
/// Returns a mutable reference to the mask A and body B in an LWE ciphertext.
pub fn a_b_mut(&mut self, params: &LweDef) -> (&mut [Torus<S>], &mut Torus<S>) {
let (a, b) = self.data.split_at_mut(self.split_at_idx(params));
(a, &mut b[0])
}
/// Returns a mutable reference to the mask A in an LWE ciphertext.
pub fn a_mut(&mut self, params: &LweDef) -> &mut [Torus<S>] {
let (a, _) = self.a_b_mut(params);
a
}
/// Returns a mutable reference to the body B in an LWE ciphertext.
pub fn b_mut(&mut self, params: &LweDef) -> &mut Torus<S> {
let (_, b) = self.a_b_mut(params);
b
}
/// Asserts that the LWE ciphertext is valid for a given LWE dimension.
#[inline(always)]
pub(crate) fn assert_valid(&self, params: &LweDef) {
assert_eq!(
self.as_slice().len(),
LweCiphertextRef::<S>::size(params.dim)
);
}
}
#[cfg(test)]
mod tests {
use crate::{high_level::*, PlaintextBits, LWE_512_80};
use proptest::prelude::*;
// Test that the negation of a ciphertext is the same as the negation of the
// plaintext.
proptest! {
#[test]
fn negation_homomorphism(a in any::<u64>()) {
let bits = PlaintextBits(4);
let params = LWE_512_80;
let sk = keygen::generate_binary_lwe_sk(&params);
let a_enc = encryption::encrypt_lwe_secret(a, &sk, &params, bits);
let a_enc_neg = -a_enc;
prop_assert_eq!(encryption::decrypt_lwe(&a_enc_neg, &sk, &params, bits), a.wrapping_neg() % (0x1 << bits.0 as u64));
}
}
// Test that the addition of ciphertexts is the same as the addition of the
// plaintexts.
proptest! {
#[test]
fn additive_homomorphism(a in any::<u64>(), b in any::<u64>()) {
let params = LWE_512_80;
let sk = keygen::generate_binary_lwe_sk(&params);
let bits = PlaintextBits(4);
let a_enc = encryption::encrypt_lwe_secret(a, &sk, &params, bits);
let b_enc = encryption::encrypt_lwe_secret(b, &sk, &params, bits);
let c_enc = a_enc + b_enc;
prop_assert_eq!(encryption::decrypt_lwe(&c_enc, &sk, &params, bits), a.wrapping_add(b) % (0x1 << bits.0 as u64));
}
}
// Test that the subtraction of ciphertexts is the same as the subtraction
// of the plaintexts.
proptest! {
#[test]
fn subtraction_homomorphism(a in any::<u64>(), b in any::<u64>()) {
let params = LWE_512_80;
let sk = keygen::generate_binary_lwe_sk(&params);
let bits = PlaintextBits(4);
let a_enc = encryption::encrypt_lwe_secret(a, &sk, &params, bits);
let b_enc = encryption::encrypt_lwe_secret(b, &sk, &params, bits);
let c_enc = a_enc - b_enc;
prop_assert_eq!(encryption::decrypt_lwe(&c_enc, &sk, &params, bits), a.wrapping_sub(b) % (0x1 << bits.0 as u64));
}
}
// Testing that the addition of a ciphertext and a negated ciphertext is the
// same as the subtraction of the ciphertexts.
proptest! {
#[test]
fn add_negative_is_subtraction(a in any::<u64>(), b in any::<u64>()) {
let params = LWE_512_80;
let sk = keygen::generate_binary_lwe_sk(&params);
let bits = PlaintextBits(4);
let a_enc = encryption::encrypt_lwe_secret(a, &sk, &params, bits);
let b_enc = encryption::encrypt_lwe_secret(b, &sk, &params, bits);
let c_enc_by_add_neg = a_enc.as_ref() + (-(b_enc.as_ref())).as_ref();
let c_enc_by_sub = a_enc.as_ref() - b_enc.as_ref();
// Test that the a values are the same
for (a_enc_by_add_neg_i, a_enc_by_sub_i) in c_enc_by_add_neg.a(&params).iter().zip(c_enc_by_sub.a(&params).iter()) {
assert_eq!(a_enc_by_add_neg_i, a_enc_by_sub_i);
}
// Test that the b values are the same
assert_eq!(c_enc_by_add_neg.b(&params), c_enc_by_sub.b(&params));
}
}
}

View File

@@ -0,0 +1,49 @@
use serde::{Deserialize, Serialize};
use sunscreen_math::Zero;
use crate::{dst::OverlaySize, LweDef, LweDimension, Torus, TorusOps};
use super::{LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef};
dst! {
/// A list of LWE ciphertexts. Used during
/// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
LweCiphertextList,
LweCiphertextListRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for LweCiphertextListRef<S> {
type Inputs = (LweDimension, usize);
#[inline(always)]
fn size(t: Self::Inputs) -> usize {
LweCiphertextRef::<S>::size(t.0) * t.1
}
}
impl<S: TorusOps> LweCiphertextList<S> {
/// Create a new zero [LweCiphertextList] with the given parameters.
///
/// This data structure represents is a list of LWE ciphertexts, used for
/// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
pub fn new(lwe: &LweDef, count: usize) -> Self {
Self {
data: vec![Torus::zero(); LweCiphertextListRef::<S>::size((lwe.dim, count))],
}
}
}
impl<S: TorusOps> LweCiphertextListRef<S> {
/// Iterate over the LWE ciphertexts in the list.
pub fn ciphertexts(&self, lwe: &LweDef) -> LweCiphertextIterator<S> {
LweCiphertextIterator::new(self.as_slice(), LweCiphertextRef::<S>::size(lwe.dim))
}
/// Iterate over the LWE ciphertexts in the list mutably.
pub fn ciphertexts_mut(&mut self, lwe: &LweDef) -> LweCiphertextIteratorMut<S> {
LweCiphertextIteratorMut::new(self.as_mut_slice(), LweCiphertextRef::<S>::size(lwe.dim))
}
}

View File

@@ -0,0 +1,170 @@
use num::Zero;
use serde::{Deserialize, Serialize};
use crate::{
dst::OverlaySize, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
};
use super::{LevCiphertextIterator, LevCiphertextIteratorMut, LevCiphertextRef};
dst! {
/// A LWE keyswitch key used to switch a ciphertext from one key to another.
/// See [`module`](crate::ops::keyswitch) documentation for more details.
LweKeyswitchKey,
LweKeyswitchKeyRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps,)
}
impl<S> OverlaySize for LweKeyswitchKeyRef<S>
where
S: TorusOps,
{
// Old LWE dimension, new LWE dimension, radix count
type Inputs = (LweDimension, LweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
// Number of rows should be equal to the number of elements in the original key
let num_rows = t.0 .0;
// Each row is made up of encryptions under the new key
let len_row = LevCiphertextRef::<S>::size((t.1, t.2));
// Encrypt the secret key s_i in each row
len_row * (num_rows)
}
}
impl<S> LweKeyswitchKey<S>
where
S: TorusOps,
{
/// Creates a new LWE keyswitch key. This enables switching to a new key as
/// well as switching from the `original_params` that define the first key
/// to the `new_params` that define the second key.
pub fn new(original_params: &LweDef, new_params: &LweDef, radix: &RadixDecomposition) -> Self {
let elems =
LweKeyswitchKeyRef::<S>::size((original_params.dim, new_params.dim, radix.count));
Self {
data: vec![Torus::zero(); elems],
}
}
}
impl<S> LweKeyswitchKeyRef<S>
where
S: TorusOps,
{
/// Returns an iterator over the rows of the LWE keyswitch key, which are
/// [`LevCiphertext`](crate::entities::LevCiphertext)s.
pub fn rows(
&self,
new_params: &LweDef,
radix: &RadixDecomposition,
) -> LevCiphertextIterator<S> {
let stride = LevCiphertextRef::<S>::size((new_params.dim, radix.count));
LevCiphertextIterator::new(&self.data, stride)
}
/// Returns a mutable iterator over the rows of the LWE keyswitch key, which are
/// [`LevCiphertext`](crate::entities::LevCiphertext)s.
pub fn rows_mut(
&mut self,
new_params: &LweDef,
radix: &RadixDecomposition,
) -> LevCiphertextIteratorMut<S> {
let stride = LevCiphertextRef::<S>::size((new_params.dim, radix.count));
LevCiphertextIteratorMut::new(&mut self.data, stride)
}
/// Asserts that the keyswitch key is valid for the given parameters.
#[inline(always)]
pub(crate) fn assert_valid(
&self,
original_params: &LweDef,
new_params: &LweDef,
radix: &RadixDecomposition,
) {
assert_eq!(
self.as_slice().len(),
LweKeyswitchKeyRef::<S>::size((original_params.dim, new_params.dim, radix.count))
);
}
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{
entities::{LweCiphertext, LweKeyswitchKey},
high_level::*,
high_level::{TEST_LWE_DEF_1, TEST_LWE_DEF_2, TEST_RADIX},
ops::keyswitch::{
lwe_keyswitch::keyswitch_lwe_to_lwe, lwe_keyswitch_key::generate_keyswitch_key_lwe,
},
PlaintextBits,
};
#[test]
fn keyswitch_lwe() {
let bits = PlaintextBits(4);
let from_lwe = TEST_LWE_DEF_1;
let to_lwe = TEST_LWE_DEF_2;
for _ in 0..50 {
let original_sk = keygen::generate_binary_lwe_sk(&from_lwe);
let new_sk = keygen::generate_binary_lwe_sk(&to_lwe);
let mut ksk = LweKeyswitchKey::<u64>::new(&from_lwe, &to_lwe, &TEST_RADIX);
generate_keyswitch_key_lwe(&mut ksk, &original_sk, &new_sk, &to_lwe, &TEST_RADIX);
let msg = thread_rng().next_u64() % (1 << bits.0);
let original_ct = original_sk.encrypt(msg, &from_lwe, bits).0;
let mut new_ct = LweCiphertext::new(&to_lwe);
keyswitch_lwe_to_lwe(
&mut new_ct,
&original_ct,
&ksk,
&from_lwe,
&to_lwe,
&TEST_RADIX,
);
let new_decrypted = new_sk.decrypt(&new_ct, &to_lwe, bits);
assert_eq!(new_decrypted, msg);
}
}
#[test]
fn lwe_keyswitch_keygen() {
let from_lwe = TEST_LWE_DEF_1;
let to_lwe = TEST_LWE_DEF_2;
for _ in 0..10 {
let sk_1 = keygen::generate_binary_lwe_sk(&from_lwe);
let sk_2 = keygen::generate_binary_lwe_sk(&to_lwe);
let mut ksk = LweKeyswitchKey::<u64>::new(&from_lwe, &to_lwe, &TEST_RADIX);
generate_keyswitch_key_lwe(&mut ksk, &sk_1, &sk_2, &to_lwe, &TEST_RADIX);
for (i, r) in ksk.rows(&to_lwe, &TEST_RADIX).enumerate() {
for (j, l) in r.lwe_ciphertexts(&to_lwe).enumerate() {
let decomp = (j + 1) * TEST_RADIX.radix_log.0;
let res = sk_2.decrypt(l, &to_lwe, PlaintextBits(decomp as u32));
assert_eq!(res, sk_1.s()[i]);
}
}
}
}
}

View File

@@ -0,0 +1,158 @@
use num::Zero;
use serde::{Deserialize, Serialize};
use crate::{
dst::OverlaySize,
ops::encryption::encode_and_encrypt_lwe_ciphertext,
rand::{binary, normal_torus},
LweDef, LweDimension, PlaintextBits, Torus, TorusOps,
};
use super::{
LweCiphertext, LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef,
LweSecretKeyRef,
};
/// Randomness used to encrypt a message with a public key.
#[derive(Debug)]
pub struct TlwePublicEncRandomness<S: TorusOps> {
/// The binary selectors of the encryptions of zero in the public key.
pub r: Vec<S>,
/// The gaussian noise added to make the LWE problem.
pub e: LweCiphertext<S>,
}
dst! {
/// An LWE public key.
LwePublicKey,
LwePublicKeyRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S> OverlaySize for LwePublicKeyRef<S>
where
S: TorusOps,
{
type Inputs = LweDimension;
fn size(t: Self::Inputs) -> usize {
LweCiphertextRef::<S>::size(t) * t.0
}
}
impl<S> LwePublicKey<S>
where
S: TorusOps,
{
/// Generate an LWE public key from a given secret key. This is done by
/// encrypting the LWE dimension number of zeros under the secret key, and
/// then using the resulting ciphertext as the public key.
pub fn generate(sk: &LweSecretKeyRef<S>, params: &LweDef) -> Self {
let mut pk = LwePublicKey {
data: vec![Torus::zero(); LwePublicKeyRef::<S>::size(params.dim)],
};
let enc_zeros = pk.enc_zeros_mut(params);
for z in enc_zeros {
encode_and_encrypt_lwe_ciphertext(z, sk, <S as Zero>::zero(), params, PlaintextBits(1));
}
pk
}
}
impl<S> LwePublicKeyRef<S>
where
S: TorusOps,
{
/// Get the public key data as an iterator.
pub fn enc_zeros(&self, params: &LweDef) -> LweCiphertextIterator<S> {
LweCiphertextIterator::new(&self.data, LweCiphertextRef::<S>::size(params.dim))
}
/// Get the public key data as a mutable iterator.
pub fn enc_zeros_mut(&mut self, params: &LweDef) -> LweCiphertextIteratorMut<S> {
LweCiphertextIteratorMut::new(&mut self.data, LweCiphertextRef::<S>::size(params.dim))
}
/// Encrypt a message as an LWE ciphertext using a public key, returning the
/// encrypted message and the randomness used.
pub fn encrypt(
&self,
msg: S,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> (LweCiphertext<S>, TlwePublicEncRandomness<S>) {
let msg = Torus::<S>::encode(msg, plaintext_bits);
let lwe_dimension = params.dim.0;
let mut acc = LweCiphertext::zero(params);
let (acc_a, acc_b) = acc.a_b_mut(params);
let mut r_noise = vec![];
let mut e = LweCiphertext::zero(params);
let (e_a, e_b) = e.a_b_mut(params);
for z in self.enc_zeros(params) {
let (a, b) = z.a_b(params);
let r = binary::<S>();
r_noise.push(r);
for i in 0..lwe_dimension {
acc_a[i] += a[i] * r;
}
*acc_b += *b * r;
}
for i in 0..lwe_dimension {
let a_noise = normal_torus(params.std);
e_a[i] = a_noise;
acc_a[i] += a_noise;
}
*acc_b += msg;
*e_b = normal_torus(params.std);
*acc_b += *e_b;
let noise = TlwePublicEncRandomness { r: r_noise, e };
(acc, noise)
}
}
#[cfg(test)]
mod tests {
use crate::{
high_level::{encryption, keygen, TEST_LWE_DEF_1},
PlaintextBits,
};
#[test]
fn public_key_is_zeros() {
let params = TEST_LWE_DEF_1;
let sk = keygen::generate_binary_lwe_sk(&params);
let pk = keygen::generate_lwe_pk(&sk, &params);
for ct in pk.enc_zeros(&params) {
let pt = encryption::decrypt_lwe(ct, &sk, &params, PlaintextBits(1));
assert_eq!(pt, 0);
}
}
#[test]
fn can_public_key_encrypt() {
let params = TEST_LWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_lwe_sk(&params);
let pk = keygen::generate_lwe_pk(&sk, &params);
let ct = encryption::encrypt_lwe(5, &pk, &params, bits);
assert_eq!(encryption::decrypt_lwe(&ct, &sk, &params, bits), 5);
}
}

View File

@@ -0,0 +1,337 @@
use num::Zero;
use serde::{Deserialize, Serialize};
use crate::{
dst::{NoWrapper, OverlaySize},
macros::{impl_binary_op, impl_unary_op},
ops::encryption::encode_and_encrypt_lwe_ciphertext,
rand::{binary, uniform_torus},
LweDef, LweDimension, PlaintextBits, Torus, TorusOps,
};
use super::{LweCiphertext, LweCiphertextRef};
dst! {
/// An LWE secret key.
LweSecretKey,
LweSecretKeyRef,
NoWrapper,
(Clone, Debug, Serialize, Deserialize),
()
}
impl_binary_op!(Add, LweSecretKey, (TorusOps,));
impl_binary_op!(Sub, LweSecretKey, (TorusOps,));
impl_unary_op!(Neg, LweSecretKey);
impl<S> OverlaySize for LweSecretKeyRef<S>
where
S: TorusOps,
{
type Inputs = LweDimension;
fn size(t: Self::Inputs) -> usize {
t.0
}
}
impl<S> LweSecretKey<S>
where
S: TorusOps,
{
fn generate(params: &LweDef, torus_element_generator: fn() -> S) -> Self {
let len = LweSecretKeyRef::<S>::size(params.dim);
LweSecretKey {
data: (0..len)
.map(|_| torus_element_generator())
.collect::<Vec<_>>(),
}
}
/// Generate a random binary LWE secret key
pub fn generate_binary(params: &LweDef) -> Self {
Self::generate(params, binary)
}
/// Generate a secret key with uniformly random coefficients. This can be
/// used when performing threshold decryption, which needs random secret
/// keys that are uniform over the entire ciphertext modulus. Uniform
/// secret keys are also valid keys for encryption/decryption but are not
/// widely used.
pub fn generate_uniform(params: &LweDef) -> Self {
Self::generate(params, || uniform_torus::<S>().inner())
}
}
impl<S> LweSecretKeyRef<S>
where
S: TorusOps,
{
/// Create an LWE ciphertext from a given message with a private key. The
/// message should be in the plaintext space, and will be encoded onto the
/// Torus automatically.
pub fn encrypt(
&self,
msg: S,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> (LweCiphertext<S>, Torus<S>) {
let mut ct = LweCiphertext::<S>::zero(params);
let e = encode_and_encrypt_lwe_ciphertext(&mut ct, self, msg, params, plaintext_bits);
(ct, e)
}
/// Decrypts the given ciphertext, returning the message. The message will
/// not be decoded into the plaintext space; the caller is responsible for
/// performing operations like shifting by delta and rounding. See
/// [Self::decrypt] for a function that performs the decoding automatically.
pub fn decrypt_without_decode(&self, ct: &LweCiphertextRef<S>, params: &LweDef) -> Torus<S> {
ct.assert_valid(params);
let (a, b) = ct.a_b(params);
let mut dot = Torus::<S>::zero();
for (a_i, d_i) in a.iter().zip(self.data.iter()) {
dot += a_i * d_i
}
b - dot
}
/// Decrypts and decodes a ciphertext, returning the message. The message
/// will be decoded into the plaintext space. See
/// [Self::decrypt_without_decode] for a function that does not perform the
/// decoding.
pub fn decrypt(
&self,
ct: &LweCiphertextRef<S>,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> S {
let msg = self.decrypt_without_decode(ct, params);
msg.decode(plaintext_bits)
}
/// Asserts that a given secret key is valid for a given LWE dimension.
pub fn assert_valid(&self, params: &LweDef) {
assert_eq!(
self.as_slice().len(),
LweSecretKeyRef::<S>::size(params.dim)
);
}
}
impl<S> LweSecretKeyRef<S>
where
S: TorusOps,
{
/// Returns the secret key data as a slice.
pub fn s(&self) -> &[S] {
&self.data
}
}
#[cfg(test)]
mod tests {
use crate::high_level::*;
use num::traits::{WrappingAdd, WrappingNeg, WrappingSub};
// Addition
#[test]
fn add_secret_keys() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2 = keygen::generate_uniform_lwe_sk(params);
let sk3_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
let sk3 = sk + sk2;
assert_eq!(sk3_expected, sk3.s())
}
#[test]
fn add_assign_secret_keys() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let mut sk2 = keygen::generate_uniform_lwe_sk(params);
let sk2_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
sk2 += sk;
assert_eq!(sk2_expected, sk2.s())
}
#[test]
fn add_secret_key_refs() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2 = keygen::generate_uniform_lwe_sk(params);
let sk3_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
let sk3 = sk.as_ref() + sk2.as_ref();
assert_eq!(sk3_expected, sk3.s())
}
#[test]
fn wrapping_add_secret_keys() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2 = keygen::generate_uniform_lwe_sk(params);
let sk3_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
let sk3 = sk.wrapping_add(&sk2);
assert_eq!(sk3_expected, sk3.s())
}
// Subtraction
#[test]
fn sub_secret_keys() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2 = keygen::generate_uniform_lwe_sk(params);
let sk3_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
let sk3 = sk - sk2;
assert_eq!(sk3_expected, sk3.s())
}
#[test]
fn sub_assign_secret_keys() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let mut sk2 = keygen::generate_uniform_lwe_sk(params);
let sk2_expected = sk2
.s()
.iter()
.zip(sk.s().iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
sk2 -= sk;
assert_eq!(sk2_expected, sk2.s())
}
#[test]
fn sub_secret_key_refs() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2 = keygen::generate_uniform_lwe_sk(params);
let sk3_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
let sk3 = sk.as_ref() - sk2.as_ref();
assert_eq!(sk3_expected, sk3.s())
}
#[test]
fn wrapping_sub_secret_keys() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2 = keygen::generate_uniform_lwe_sk(params);
let sk3_expected = sk
.s()
.iter()
.zip(sk2.s().iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
let sk3 = sk.wrapping_sub(&sk2);
assert_eq!(sk3_expected, sk3.s())
}
// Negation
#[test]
fn neg_secret_key() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2 = -sk;
assert_eq!(sk2_expected, sk2.s())
}
#[test]
fn neg_secret_key_ref() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2 = -sk.as_ref();
assert_eq!(sk2_expected, sk2.s())
}
#[test]
fn wrapping_neg_secret_key() {
let params = &TEST_LWE_DEF_1;
let sk = keygen::generate_uniform_lwe_sk(params);
let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2 = sk.wrapping_neg();
assert_eq!(sk2_expected, sk2.s())
}
}

View File

@@ -0,0 +1,71 @@
mod public_functional_keyswitch_key;
pub use public_functional_keyswitch_key::*;
mod blind_rotation_shift;
pub use blind_rotation_shift::*;
mod lwe_ciphertext_list;
pub use lwe_ciphertext_list::*;
mod private_functional_keyswitch_key;
pub use private_functional_keyswitch_key::*;
mod circuit_bootstrapping_private_keyswitch_keys;
pub use circuit_bootstrapping_private_keyswitch_keys::*;
mod bootstrap_key;
pub use bootstrap_key::*;
mod univariate_lookup_table;
pub use univariate_lookup_table::*;
mod bivariate_lookup_table;
pub use bivariate_lookup_table::*;
mod glwe_secret_key;
pub use glwe_secret_key::*;
mod glwe_ciphertext;
pub use glwe_ciphertext::*;
mod glwe_ciphertext_fft;
pub use glwe_ciphertext_fft::*;
mod lwe_ciphertext;
pub use lwe_ciphertext::*;
mod lwe_secret_key;
pub use lwe_secret_key::*;
mod lwe_public_key;
pub use lwe_public_key::*;
mod glev_ciphertext;
pub use glev_ciphertext::*;
mod glev_ciphertext_fft;
pub use glev_ciphertext_fft::*;
mod ggsw_ciphertext;
pub use ggsw_ciphertext::*;
mod ggsw_ciphertext_fft;
pub use ggsw_ciphertext_fft::*;
mod lev_ciphertext;
pub use lev_ciphertext::*;
mod lwe_keyswitch_key;
pub use lwe_keyswitch_key::*;
mod glwe_keyswitch_key;
pub use glwe_keyswitch_key::*;
mod polynomial;
pub use polynomial::*;
mod polynomial_fft;
pub use polynomial_fft::*;
mod polynomial_list;
pub use polynomial_list::*;

View File

@@ -0,0 +1,588 @@
use std::{
num::Wrapping,
ops::{Add, AddAssign, Mul, Sub, SubAssign},
};
use num::{Complex, Zero};
use crate::{
dst::{FromMutSlice, FromSlice, NoWrapper, OverlaySize},
fft::negacyclic::get_fft,
polynomial::{polynomial_add_assign, polynomial_external_mad, polynomial_sub_assign},
scratch::allocate_scratch,
FrequencyTransform, PolynomialDegree, ReinterpretAsSigned, ToF64, Torus, TorusOps,
};
use super::PolynomialFftRef;
dst! {
/// A type representing a polynomial.
Polynomial,
PolynomialRef,
NoWrapper,
(Debug, Clone, PartialEq, Eq,),
()
}
dst_iter! { PolynomialIterator, PolynomialIteratorMut, NoWrapper, PolynomialRef, () }
impl<T> OverlaySize for PolynomialRef<T>
where
T: Clone,
{
type Inputs = PolynomialDegree;
fn size(t: Self::Inputs) -> usize {
t.0
}
}
impl<T> Polynomial<T>
where
T: Clone,
{
/// Create a new polynomial from a slice of coefficients.
pub fn new(data: &[T]) -> Polynomial<T> {
Polynomial {
data: data.to_owned(),
}
}
/// Create a new polynomial filled with zeros of a specified length.
pub fn zero(len: usize) -> Polynomial<T>
where
T: Zero,
{
Polynomial {
data: vec![T::zero(); len],
}
}
}
impl<T> FromIterator<T> for Polynomial<T>
where
T: Clone,
{
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Self {
data: iter.into_iter().collect::<Vec<_>>(),
}
}
}
impl<T> PolynomialRef<T>
where
T: Clone,
{
/// Returns the coefficients of the polynomial in ascending order.
pub fn coeffs(&self) -> &[T] {
&self.data
}
/// Returns the mutable coefficients of the polynomial in ascending order.
pub fn coeffs_mut(&mut self) -> &mut [T] {
&mut self.data
}
/// Apply a function to each coefficient of the polynomial and return a new
/// polynomial.
pub fn map<F, U>(&self, f: F) -> Polynomial<U>
where
F: Fn(&T) -> U,
U: Clone,
{
Polynomial {
data: self.data.iter().map(f).collect::<Vec<_>>(),
}
}
/// Maps this polynomial using f into the dst [`PolynomialRef`].
///
/// # Panics
/// If `dst.len() != self.len()`
pub fn map_into<F, U>(&self, dst: &mut PolynomialRef<U>, f: F)
where
F: Fn(&T) -> U,
U: Clone,
{
assert_eq!(dst.len(), self.len());
dst.coeffs_mut()
.iter_mut()
.zip(self.coeffs().iter())
.for_each(|(d, s)| *d = f(s));
}
/// Returns the number of coefficients in the polynomial.
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
/// Returns true if the polynomial has no coefficients.
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
impl<T> PolynomialRef<T>
where
T: TorusOps,
{
/// Reinterpret the this polynomial as a polynomial of torus elements.
pub fn as_torus(&self) -> &PolynomialRef<Torus<T>> {
let as_torus = bytemuck::cast_slice(&self.data);
PolynomialRef::from_slice(as_torus)
}
/// Reinterpret this polynomial of integers as having wrapping semantics.
pub fn as_wrapping(&self) -> &PolynomialRef<Wrapping<T>> {
let as_wrapping = bytemuck::cast_slice(&self.data);
PolynomialRef::from_slice(as_wrapping)
}
/// Reinterpret this polynomial of integers as having wrapping semantics.
pub fn as_wrapping_mut(&mut self) -> &mut PolynomialRef<Wrapping<T>> {
let as_wrapping = bytemuck::cast_slice_mut(&mut self.data);
PolynomialRef::from_mut_slice(as_wrapping)
}
}
impl<T> PolynomialRef<Torus<T>>
where
T: TorusOps,
{
/**
* Multiply by a monomial X^-degree, returning a new Polynomial.
*/
pub fn mul_by_negative_monomial_negacyclic(&mut self, degree: usize) {
let len = self.len();
// The behavior of the rotation is the same for every degree + k*2N for
// k >= 0.
let degree = degree % (2 * len);
// If the degree is 0 (or a multiple of 2N), the polynomial is unchanged.
if degree == 0 {
return;
}
// If the degree is N (or of the form N + k*2N), the polynomial is negated.
if degree == len {
self.data
.iter_mut()
.for_each(|x| *x = num::traits::WrappingNeg::wrapping_neg(x));
return;
}
let shift = degree % len;
self.data.rotate_left(shift);
let negate_segment = if degree < len {
(len - shift)..len
} else {
0..(len - shift)
};
for i in negate_segment {
self.data[i] = num::traits::WrappingNeg::wrapping_neg(&self.data[i]);
}
}
/**
* Multiply by a monomial X^degree, returning a new Polynomial.
*/
pub fn mul_by_positive_monomial_negacyclic(&mut self, degree: usize) {
let len = self.len();
// The behavior of the rotation is the same for every degree + k*2N for
// k >= 0.
let degree = degree % (2 * len);
// If the degree is 0 (or a multiple of 2N), the polynomial is unchanged.
if degree == 0 {
return;
}
// If the degree is N (or of the form N + k*2N), the polynomial is negated.
if degree == len {
self.data
.iter_mut()
.for_each(|x| *x = num::traits::WrappingNeg::wrapping_neg(x));
return;
}
let shift = degree % len;
self.data.rotate_right(shift);
let negate_segment = if degree < len { 0..degree } else { shift..len };
for i in negate_segment {
self.data[i] = num::traits::WrappingNeg::wrapping_neg(&self.data[i]);
}
}
/**
* Multiply by a monomial X^degree, returning a new Polynomial. The degree
* can be either positive or negative.
*/
pub fn mul_by_monomial_negacyclic(&mut self, degree: isize) {
if degree < 0 {
self.mul_by_negative_monomial_negacyclic(degree.unsigned_abs());
} else {
self.mul_by_positive_monomial_negacyclic(degree as usize);
}
}
}
impl<T, U> PolynomialRef<T>
where
U: ToF64,
T: Clone + Copy + ReinterpretAsSigned<Output = U>,
{
/// Compute the FFT of the polynomial.
pub fn fft(&self, out: &mut PolynomialFftRef<Complex<f64>>) {
assert!(self.len().is_power_of_two());
assert_eq!(self.len(), out.len() * 2);
let mut self_f64 = allocate_scratch::<f64>(self.len());
let self_f64 = self_f64.as_mut_slice();
for (o, i) in self_f64.iter_mut().zip(self.coeffs().iter()) {
// Reinterperet [0, 1) to [-q/2, q/2) to slightly increase
// precision.
*o = (*i).reinterpret_as_signed().to_f64();
}
let log_n = self.len().ilog2() as usize;
let fft = get_fft(log_n);
fft.forward(self_f64, out.as_mut_slice());
}
}
impl<S> Add<Polynomial<S>> for Polynomial<S>
where
S: Add<S, Output = S> + Copy,
{
type Output = Polynomial<S>;
fn add(self, rhs: Polynomial<S>) -> Self::Output {
self.as_ref().add(rhs.as_ref())
}
}
impl<S> Add<&PolynomialRef<S>> for &PolynomialRef<S>
where
S: Add<S, Output = S> + Copy,
{
type Output = Polynomial<S>;
fn add(self, rhs: &PolynomialRef<S>) -> Self::Output {
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
let coeffs = self
.coeffs()
.as_ref()
.iter()
.zip(rhs.coeffs().as_ref().iter())
.map(|(a, b)| *a + *b)
.collect::<Vec<_>>();
Polynomial { data: coeffs }
}
}
impl<S> AddAssign<&PolynomialRef<S>> for PolynomialRef<S>
where
S: AddAssign<S> + Copy,
{
fn add_assign(&mut self, rhs: &PolynomialRef<S>) {
polynomial_add_assign(self, rhs)
}
}
impl<S> Sub<Polynomial<S>> for Polynomial<S>
where
S: Sub<S, Output = S> + Copy,
{
type Output = Polynomial<S>;
fn sub(self, rhs: Polynomial<S>) -> Self::Output {
self.as_ref().sub(rhs.as_ref())
}
}
impl<S> Sub<&PolynomialRef<S>> for &PolynomialRef<S>
where
S: Sub<S, Output = S> + Copy,
{
type Output = Polynomial<S>;
fn sub(self, rhs: &PolynomialRef<S>) -> Self::Output {
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
let coeffs = self
.coeffs()
.as_ref()
.iter()
.zip(rhs.coeffs().as_ref().iter())
.map(|(a, b)| *a - *b)
.collect::<Vec<_>>();
Polynomial { data: coeffs }
}
}
impl<S> SubAssign<&PolynomialRef<S>> for PolynomialRef<S>
where
S: SubAssign<S> + Copy,
{
fn sub_assign(&mut self, rhs: &PolynomialRef<S>) {
polynomial_sub_assign(self, rhs)
}
}
impl<S> Mul<&PolynomialRef<S>> for &PolynomialRef<Torus<S>>
where
S: TorusOps,
{
type Output = Polynomial<Torus<S>>;
/// External product of T\[X\]/f * Z\[X\]/f
/// TODO: use NTT to do in nlog(n) time.
fn mul(self, rhs: &PolynomialRef<S>) -> Self::Output {
assert_eq!(rhs.len(), self.len());
let mut c = Polynomial {
data: vec![Torus::zero(); rhs.len()],
};
polynomial_external_mad(&mut c, self, rhs);
c
}
}
#[cfg(test)]
mod tests {
use std::ops::Deref;
use crate::{entities::Polynomial, Torus};
#[test]
fn can_add_polynomials() {
let a = Polynomial::new(&[1, 2, 3]);
let b = Polynomial::new(&[4, 5, 6]);
let expected = Polynomial::new(&[5, 7, 9]);
let c = a.deref() + b.deref();
assert_eq!(c, expected);
}
// A golden test but easier than recoding the logic again.
#[test]
fn can_multiply_by_positive_monomial_negacyclic() {
let original = Polynomial::new(&[1, 2, 3, 4].map(Torus::<u64>::from));
let mut shift_0 = original.clone();
shift_0.mul_by_positive_monomial_negacyclic(0);
let expected_0 = original.clone();
assert_eq!(shift_0, expected_0);
let mut shift_1 = original.clone();
shift_1.mul_by_positive_monomial_negacyclic(1);
let expected_1 = Polynomial::new(&[
Torus::from(4u64.wrapping_neg()),
Torus::from(1),
Torus::from(2),
Torus::from(3),
]);
assert_eq!(shift_1, expected_1);
let mut shift_2 = original.clone();
shift_2.mul_by_positive_monomial_negacyclic(2);
let expected_2 = Polynomial::new(&[
Torus::from(3u64.wrapping_neg()),
Torus::from(4u64.wrapping_neg()),
Torus::from(1),
Torus::from(2),
]);
assert_eq!(shift_2, expected_2);
let mut shift_3 = original.clone();
shift_3.mul_by_positive_monomial_negacyclic(3);
let expected_3 = Polynomial::new(&[
Torus::from(2u64.wrapping_neg()),
Torus::from(3u64.wrapping_neg()),
Torus::from(4u64.wrapping_neg()),
Torus::from(1),
]);
assert_eq!(shift_3, expected_3);
let mut shift_4 = original.clone();
shift_4.mul_by_positive_monomial_negacyclic(4);
let expected_4 = Polynomial::new(&[
Torus::from(1u64.wrapping_neg()),
Torus::from(2u64.wrapping_neg()),
Torus::from(3u64.wrapping_neg()),
Torus::from(4u64.wrapping_neg()),
]);
assert_eq!(shift_4, expected_4);
let mut shift_5 = original.clone();
shift_5.mul_by_positive_monomial_negacyclic(5);
let expected_5 = Polynomial::new(&[
Torus::from(4u64),
Torus::from(1u64.wrapping_neg()),
Torus::from(2u64.wrapping_neg()),
Torus::from(3u64.wrapping_neg()),
]);
assert_eq!(shift_5, expected_5);
let mut shift_6 = original.clone();
shift_6.mul_by_positive_monomial_negacyclic(6);
let expected_6 = Polynomial::new(&[
Torus::from(3u64),
Torus::from(4u64),
Torus::from(1u64.wrapping_neg()),
Torus::from(2u64.wrapping_neg()),
]);
assert_eq!(shift_6, expected_6);
let mut shift_7 = original.clone();
shift_7.mul_by_positive_monomial_negacyclic(7);
let expected_7 = Polynomial::new(&[
Torus::from(2u64),
Torus::from(3u64),
Torus::from(4u64),
Torus::from(1u64.wrapping_neg()),
]);
assert_eq!(shift_7, expected_7);
let mut shift_8 = original.clone();
shift_8.mul_by_positive_monomial_negacyclic(8);
let expected_8 = original.clone();
assert_eq!(shift_8, expected_8);
let mut shift_9 = original.clone();
shift_9.mul_by_positive_monomial_negacyclic(9);
let expected_9 = shift_1.clone();
assert_eq!(shift_9, expected_9);
}
#[test]
fn can_multiply_by_negative_monomial_negacyclic() {
let original = Polynomial::new(&[1, 2, 3, 4].map(Torus::<u64>::from));
let mut shift_0 = original.clone();
shift_0.mul_by_negative_monomial_negacyclic(0);
let expected_0 = original.clone();
assert_eq!(shift_0, expected_0);
let mut shift_1 = original.clone();
shift_1.mul_by_negative_monomial_negacyclic(1);
let expected_1 = Polynomial::new(&[
Torus::from(2),
Torus::from(3),
Torus::from(4),
Torus::from(1u64.wrapping_neg()),
]);
assert_eq!(shift_1, expected_1);
let mut shift_2 = original.clone();
shift_2.mul_by_negative_monomial_negacyclic(2);
let expected_2 = Polynomial::new(&[
Torus::from(3),
Torus::from(4),
Torus::from(1u64.wrapping_neg()),
Torus::from(2u64.wrapping_neg()),
]);
assert_eq!(shift_2, expected_2);
let mut shift_3 = original.clone();
shift_3.mul_by_negative_monomial_negacyclic(3);
let expected_3 = Polynomial::new(&[
Torus::from(4),
Torus::from(1u64.wrapping_neg()),
Torus::from(2u64.wrapping_neg()),
Torus::from(3u64.wrapping_neg()),
]);
assert_eq!(shift_3, expected_3);
let mut shift_4 = original.clone();
shift_4.mul_by_negative_monomial_negacyclic(4);
let expected_4 = Polynomial::new(&[
Torus::from(1u64.wrapping_neg()),
Torus::from(2u64.wrapping_neg()),
Torus::from(3u64.wrapping_neg()),
Torus::from(4u64.wrapping_neg()),
]);
assert_eq!(shift_4, expected_4);
let mut shift_5 = original.clone();
shift_5.mul_by_negative_monomial_negacyclic(5);
let expected_5 = Polynomial::new(&[
Torus::from(2u64.wrapping_neg()),
Torus::from(3u64.wrapping_neg()),
Torus::from(4u64.wrapping_neg()),
Torus::from(1),
]);
assert_eq!(shift_5, expected_5);
let mut shift_6 = original.clone();
shift_6.mul_by_negative_monomial_negacyclic(6);
let expected_6 = Polynomial::new(&[
Torus::from(3u64.wrapping_neg()),
Torus::from(4u64.wrapping_neg()),
Torus::from(1),
Torus::from(2),
]);
assert_eq!(shift_6, expected_6);
let mut shift_7 = original.clone();
shift_7.mul_by_negative_monomial_negacyclic(7);
let expected_7 = Polynomial::new(&[
Torus::from(4u64.wrapping_neg()),
Torus::from(1),
Torus::from(2),
Torus::from(3),
]);
assert_eq!(shift_7, expected_7);
let mut shift_8 = original.clone();
shift_8.mul_by_negative_monomial_negacyclic(8);
let expected_8 = original.clone();
assert_eq!(shift_8, expected_8);
let mut shift_9 = original.clone();
shift_9.mul_by_negative_monomial_negacyclic(9);
let expected_9 = shift_1.clone();
assert_eq!(shift_9, expected_9);
}
}

View File

@@ -0,0 +1,158 @@
use num::{traits::MulAdd, Complex};
use crate::{
dst::{NoWrapper, OverlaySize},
fft::negacyclic::get_fft,
scratch::allocate_scratch,
FrequencyTransform, FromF64, NumBits, PolynomialDegree,
};
use super::PolynomialRef;
dst! {
/// The FFT of a polynomial. See [`Polynomial`](crate::entities::Polynomial)
/// for the non-FFT variant.
PolynomialFft,
PolynomialFftRef,
NoWrapper,
(Debug, Clone, PartialEq, Eq,),
()
}
dst_iter!(
PolynomialFftIterator,
PolynomialFftIteratorMut,
NoWrapper,
PolynomialFftRef,
()
);
impl<T> OverlaySize for PolynomialFftRef<T>
where
T: Clone,
{
type Inputs = PolynomialDegree;
fn size(t: Self::Inputs) -> usize {
t.0 / 2
}
}
impl<T> PolynomialFft<T>
where
T: Clone,
{
/// Create a new polynomial with the given length in the fourier domain.
pub fn new(data: &[T]) -> Self {
Self {
data: data.to_owned(),
}
}
}
impl<T> PolynomialFftRef<T>
where
T: Clone,
{
/// Returns the coefficients of the polynomial in the fourier domain.
pub fn coeffs(&self) -> &[T] {
&self.data
}
/// Returns the mutable coefficients of the polynomial in the fourier domain.
pub fn coeffs_mut(&mut self) -> &mut [T] {
&mut self.data
}
/// Returns the number of coefficients in the polynomial.
pub fn len(&self) -> usize {
self.data.len()
}
/// Returns true if the polynomial has no coefficients.
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
}
impl PolynomialFftRef<Complex<f64>> {
/// Compute the inverse FFT of the polynomial.
pub fn ifft<T>(&self, poly: &mut PolynomialRef<T>)
where
T: Clone + FromF64 + NumBits,
{
assert!(self.len().is_power_of_two());
assert_eq!(self.len() * 2, poly.len());
let log_n = poly.len().ilog2() as usize;
let fft = get_fft(log_n);
let mut ifft = allocate_scratch::<f64>(poly.len());
let ifft = ifft.as_mut_slice();
fft.reverse(&self.data, ifft);
// When the exponent != 0 && exponent != 1024,
// IEEE-754 doubles are represented as -1**s * 1.m * 2**(e - 1023).
//
// m is 52 bits, e is 11 bits, and s is 1 bit.
//
// Thus, to compute 2**x, we set e = 1023 + x, m=0, and s = 0. So, we just
// need to fill in EXP and shift it up 52 places.
//
// We first reduce modulo q
let exp: u64 = 1023 + T::BITS as u64;
let q: f64 = f64::from_bits(exp << 52);
let exp_div_2 = exp - 1;
let q_div_2 = f64::from_bits(exp_div_2 << 52);
// Exploit the fact that q is a power of 2 when performing the modulo
// reduction. Could possibly be even faster by masking and shifting
// the mantissa and tweaking the exponent. However, profiling on ARM
// indicates this is no longer a bottleneck with the code below.
//
// See https://stackoverflow.com/questions/49139283/are-there-any-numbers-that-enable-fast-modulo-calculation-on-floats
//
// Don't know why Rust decides not to inline this. Inlining allows
// the below loop to get unrolled, vectorized, and division gets
// replaced with multiplication since q is a known constant.
#[inline(always)]
fn mod_q(val: f64, q: f64) -> f64 {
f64::mul_add(-(val / q).trunc(), q, val)
}
for (o, ifft) in poly.coeffs_mut().iter_mut().zip(ifft.iter()) {
let mut ifft = mod_q(*ifft, q);
// Next, we need to adjust x outside [-q/2, q/2) to wrap to the correct torus
// point.
if ifft >= q_div_2 {
ifft -= q;
} else if ifft <= -q_div_2 {
ifft += q;
}
*o = T::from_f64(ifft);
}
}
/// Computes the multiplication of two polynomials as `c += a * b`. This is
/// more efficient than the naive method, and has a runtime of O(N). Note
/// that performing the FFT and IFFT to get in and out of the fourier domain
/// costs O(N log N).
pub fn multiply_add(
&mut self,
a: &PolynomialFftRef<Complex<f64>>,
b: &PolynomialFftRef<Complex<f64>>,
) {
for ((c, a), b) in self
.coeffs_mut()
.iter_mut()
.zip(a.coeffs().iter())
.zip(b.coeffs().iter())
{
*c = a.mul_add(b, c);
}
}
}

View File

@@ -0,0 +1,52 @@
use num::Zero;
use crate::{
dst::{NoWrapper, OverlaySize},
PolynomialDegree,
};
use super::{PolynomialIterator, PolynomialIteratorMut, PolynomialRef};
dst! {
/// A list of polynomials.
PolynomialList,
PolynomialListRef,
NoWrapper,
(Debug, Clone, PartialEq, Eq,),
()
}
impl<S: Clone> OverlaySize for PolynomialListRef<S> {
type Inputs = (PolynomialDegree, usize);
fn size(t: Self::Inputs) -> usize {
PolynomialRef::<S>::size(t.0) * t.1
}
}
impl<S> PolynomialList<S>
where
S: Clone + Zero,
{
/// Create a new polynomial list, where each polynomial has the same degree.
pub fn new(degree: PolynomialDegree, count: usize) -> Self {
Self {
data: vec![S::zero(); degree.0 * count],
}
}
}
impl<S> PolynomialListRef<S>
where
S: Clone + Zero,
{
/// Iterate over the polynomials in the list.
pub fn iter(&self, degree: PolynomialDegree) -> PolynomialIterator<S> {
PolynomialIterator::new(&self.data, PolynomialRef::<S>::size(degree))
}
/// Iterate over the polynomials in the list mutably.
pub fn iter_mut(&mut self, degree: PolynomialDegree) -> PolynomialIteratorMut<S> {
PolynomialIteratorMut::new(&mut self.data, PolynomialRef::<S>::size(degree))
}
}

View File

@@ -0,0 +1,134 @@
use serde::{Deserialize, Serialize};
use sunscreen_math::Zero;
use crate::{
dst::OverlaySize,
entities::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef},
GlweDef, GlweDimension, LweDef, LweDimension, PrivateFunctionalKeyswitchLweCount, RadixCount,
RadixDecomposition, Torus, TorusOps,
};
use super::LweSecretKeyRef;
dst! {
/// Key for Private Functional Key Switching. See
/// [`module`](crate::ops::keyswitch::private_functional_keyswitch)
/// documentation for more details.
PrivateFunctionalKeyswitchKey,
PrivateFunctionalKeyswitchKeyRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
dst_iter!(
PrivateFunctionalKeyswitchKeyIter,
PrivateFunctionalKeyswitchKeyIterMut,
Torus,
PrivateFunctionalKeyswitchKeyRef,
(TorusOps,)
);
impl<S: TorusOps> OverlaySize for PrivateFunctionalKeyswitchKeyRef<S> {
type Inputs = (
LweDimension,
GlweDimension,
RadixCount,
PrivateFunctionalKeyswitchLweCount,
);
fn size(t: Self::Inputs) -> usize {
GlevCiphertextRef::<S>::size((t.1, t.2)) * (LweSecretKeyRef::<S>::size(t.0) + 1) * t.3 .0
}
}
impl<S: TorusOps> PrivateFunctionalKeyswitchKey<S> {
/// Construct a new uninitialized [`PrivateFunctionalKeyswitchKey`]. This key is used
/// compute a secret function mapping `lwe_count`
/// [`LweCiphertext`](crate::entities::LweCiphertext)s to a
/// [`GlweCiphertext`](crate::entities::GlweCiphertext).
///
/// # Remarks
/// The key is composed of a `(from_lwe.dim + 1) * lwe_count.0` matrix of
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. The leading
/// dimension contains radix-scaled encryptions of the bits in a
/// [`LweSecretKey`](crate::entities::LweSecretKey) followed by a scaled
/// encryption of -1.
///
/// Plaintexts in the GLEV are scaled by the usual `q/beta^(j+1)`.
///
/// The trailing dimension iterates over `0..lwe_count.0`.
pub fn new(
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
lwe_count: &PrivateFunctionalKeyswitchLweCount,
) -> Self {
Self {
data: vec![
Torus::zero();
PrivateFunctionalKeyswitchKeyRef::<S>::size((
from_lwe.dim,
to_glwe.dim,
radix.count,
*lwe_count
))
],
}
}
}
impl<S: TorusOps> PrivateFunctionalKeyswitchKeyRef<S> {
/// Returns an iterator over the
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s that
/// compose this key.
///
/// # See also
/// To make sense of the layout, see also [`PrivateFunctionalKeyswitchKey::new()`](./struct.PrivateFunctionalKeyswitchKey.html#remarks).
pub fn glevs(
&self,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextIterator<S> {
GlevCiphertextIterator::new(
self.as_slice(),
GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count)),
)
}
/// Returns a muitable iterator over the
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s that compose this
/// key.
///
/// # See also
/// To make sense of the layout, see also [`PrivateFunctionalKeyswitchKey::new()`](./struct.PrivateFunctionalKeyswitchKey.html#remarks).
pub fn glevs_mut(
&mut self,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextIteratorMut<S> {
GlevCiphertextIteratorMut::new(
self.as_mut_slice(),
GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count)),
)
}
#[inline(always)]
/// Assert this value is correct for the given parameters.
pub fn assert_valid(
&self,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
lwe_count: &PrivateFunctionalKeyswitchLweCount,
) {
assert_eq!(
self.as_slice().len(),
PrivateFunctionalKeyswitchKeyRef::<S>::size((
from_lwe.dim,
to_glwe.dim,
radix.count,
*lwe_count
))
)
}
}

View File

@@ -0,0 +1,79 @@
use serde::{Deserialize, Serialize};
use sunscreen_math::Zero;
use crate::{
dst::OverlaySize,
entities::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef},
GlweDef, GlweDimension, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
};
use super::GlweCiphertextRef;
dst! {
/// Public Functional Key Switching Key. See
/// [`module`](crate::ops::keyswitch::public_functional_keyswitch)
/// documentation for more details.
PublicFunctionalKeyswitchKey,
PublicFunctionalKeyswitchKeyRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for PublicFunctionalKeyswitchKeyRef<S> {
type Inputs = (LweDimension, GlweDimension, RadixCount);
fn size(t: Self::Inputs) -> usize {
GlweCiphertextRef::<S>::size(t.1) * t.0 .0 * t.2 .0
}
}
impl<S: TorusOps> PublicFunctionalKeyswitchKey<S> {
/// Construct a new uninitialized [`PublicFunctionalKeyswitchKey`]. This key is used
/// when performing a [`public_functional_keyswitch`](crate::ops::keyswitch::public_functional_keyswitch).
pub fn new(from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) -> Self {
let len =
PublicFunctionalKeyswitchKeyRef::<S>::size((from_lwe.dim, to_glwe.dim, radix.count));
Self {
data: vec![Torus::zero(); len],
}
}
}
impl<S: TorusOps> PublicFunctionalKeyswitchKeyRef<S> {
/// Iterate over the rows of the [`PublicFunctionalKeyswitchKey`].
pub fn glevs(
&self,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextIterator<S> {
let stride = GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count));
GlevCiphertextIterator::new(self.as_slice(), stride)
}
/// Iterate over the rows of the [`PublicFunctionalKeyswitchKey`] mutably.
pub fn glevs_mut(
&mut self,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) -> GlevCiphertextIteratorMut<S> {
let stride = GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count));
GlevCiphertextIteratorMut::new(self.as_mut_slice(), stride)
}
/// Asserts that the key is valid for the given parameters.
pub(crate) fn assert_valid(
&self,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) {
assert_eq!(
self.as_slice().len(),
PublicFunctionalKeyswitchKeyRef::<S>::size((from_lwe.dim, to_glwe.dim, radix.count))
);
}
}

View File

@@ -0,0 +1,83 @@
use serde::{Deserialize, Serialize};
use sunscreen_math::Zero;
use crate::{
dst::{FromMutSlice, FromSlice, OverlaySize},
entities::PolynomialRef,
ops::{bootstrapping::generate_lut, encryption::trivially_encrypt_glwe_ciphertext},
scratch::allocate_scratch_ref,
GlweDef, GlweDimension, PlaintextBits, Torus, TorusOps,
};
use super::GlweCiphertextRef;
dst! {
/// Lookup table for a univariate function used during
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap)
/// and [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
UnivariateLookupTable,
UnivariateLookupTableRef,
Torus,
(Clone, Debug, Serialize, Deserialize),
(TorusOps)
}
impl<S: TorusOps> OverlaySize for UnivariateLookupTableRef<S> {
type Inputs = GlweDimension;
fn size(t: Self::Inputs) -> usize {
GlweCiphertextRef::<S>::size(t)
}
}
impl<S: TorusOps> UnivariateLookupTable<S> {
/// Creates a lookup table that is trivially encrypted.
pub fn trivial_from_fn<F>(map: F, glwe: &GlweDef, plaintext_bits: PlaintextBits) -> Self
where
F: Fn(u64) -> u64,
{
let mut lut = UnivariateLookupTable {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};
lut.fill_trivial_from_fn(map, glwe, plaintext_bits);
lut
}
}
impl<S: TorusOps> UnivariateLookupTableRef<S> {
/// Return the underlying GLWE representation of a lookup table.
pub fn glwe(&self) -> &GlweCiphertextRef<S> {
GlweCiphertextRef::from_slice(&self.data)
}
/// Return a mutable representation of the underlying GLWE representation of
/// a lookup table.
pub fn glwe_mut(&mut self) -> &mut GlweCiphertextRef<S> {
GlweCiphertextRef::from_mut_slice(&mut self.data)
}
/// Generates a look up table filled with the values from the provided map,
/// and trivially encrypts the lookup table.
pub fn fill_trivial_from_fn<F: Fn(u64) -> u64>(
&mut self,
map: F,
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
) {
allocate_scratch_ref!(poly, PolynomialRef<Torus<S>>, (glwe.dim.polynomial_degree));
generate_lut(poly, map, glwe, plaintext_bits);
trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe);
}
/// Creates a lookup table filled with the same value at every entry.
pub fn fill_with_constant(&mut self, val: S, glwe: &GlweDef, plaintext_bits: PlaintextBits) {
self.clear();
for o in self.glwe_mut().b_mut(glwe).coeffs_mut() {
*o = Torus::encode(val, plaintext_bits);
}
}
}

View File

@@ -0,0 +1,4 @@
#[derive(thiserror::Error)]
pub enum Error {
OutOfRange
}

View File

@@ -0,0 +1,939 @@
#![allow(dead_code)]
use crate::{
rand::Stddev, GlweDef, GlweDimension, GlweSize, LweDef, LweDimension, PolynomialDegree,
RadixCount, RadixDecomposition, RadixLog,
};
#[doc(hidden)]
pub const TEST_RADIX: RadixDecomposition = RadixDecomposition {
count: RadixCount(3),
radix_log: RadixLog(4),
};
#[doc(hidden)]
pub const TEST_GLWE_DEF_1: GlweDef = GlweDef {
dim: GlweDimension {
polynomial_degree: PolynomialDegree(128),
size: GlweSize(2),
},
std: Stddev(1e-16),
};
#[doc(hidden)]
pub const TEST_GLWE_DEF_2: GlweDef = GlweDef {
dim: GlweDimension {
polynomial_degree: PolynomialDegree(256),
size: GlweSize(3),
},
std: Stddev(1e-16),
};
#[doc(hidden)]
pub const TEST_LWE_DEF_1: LweDef = LweDef {
dim: LweDimension(128),
std: Stddev(1e-16),
};
#[doc(hidden)]
pub const TEST_LWE_DEF_2: LweDef = LweDef {
dim: LweDimension(256),
std: Stddev(1e-16),
};
#[doc(hidden)]
pub const TEST_LWE_DEF_3: LweDef = LweDef {
dim: LweDimension(128),
std: Stddev(0.0),
};
/// TFHE functionality related to key generation.
pub mod keygen {
use crate::{
entities::{
BootstrapKey, CircuitBootstrappingKeyswitchKeys, GlweSecretKey, GlweSecretKeyRef,
LweKeyswitchKey, LwePublicKey, LweSecretKey, LweSecretKeyRef,
},
ops::{
bootstrapping::generate_bootstrap_key,
keyswitch::{
lwe_keyswitch_key::generate_keyswitch_key_lwe,
private_functional_keyswitch::generate_circuit_bootstrapping_pfks_keys,
},
},
GlweDef, LweDef, RadixDecomposition,
};
/// Generate a new binary [`LweSecretKey`] under the given LWE parameters.
///
/// # Remarks
/// Any functions that use this key will need the same [`LweDef`].
///
/// These keys may be used to create bootstrapping keys.
///
/// # Panics
/// If [`LweDef`] is invalid.
///
/// # Security
/// These keys are *not* secure under some threshold cryptography settings.
/// Under those settings, you should use [`generate_uniform_lwe_sk`].
///
/// This key is secret and care should be taken as to which parties
/// possess it. Anyone who possesses the returned [`LweSecretKey`]
/// can decrypt any messages encrypted under it.
pub fn generate_binary_lwe_sk(params: &LweDef) -> LweSecretKey<u64> {
LweSecretKey::generate_binary(params)
}
/// Generate a new binary [`LweSecretKey`] under the given LWE parameters.
///
/// # Remarks
/// Any functions that use this key will need the same [`LweDef`].
///
/// These keys may *not* directly be used to create bootstrapping keys.
/// However, in threshold schemes that use them, usually you derive
/// a binary key from uniform key shares.
///
/// # Panics
/// If [`LweDef`] is invalid.
///
/// # Security
/// This key is secret and care should be taken as to which parties
/// possess it. Anyone who possesses the returned [`LweSecretKey`]
/// can decrypt any messages encrypted under it.
pub fn generate_uniform_lwe_sk(params: &LweDef) -> LweSecretKey<u64> {
LweSecretKey::generate_uniform(params)
}
/// Generate a new [`LwePublicKey`] under the given parameters. This
/// public key is paired with `sk` - that is messages encrypted under
/// this public key can be decrypted with `sk`.
///
/// # Remarks
/// Any functions that use this key will need to use the same `params`.
///
/// # Panics
/// If `params` is invalid
/// If `params` doesn't correspond with `sk`.
///
/// # Security
/// This key is public and sharing it does not compromise semantic
/// security.
pub fn generate_lwe_pk(sk: &LweSecretKeyRef<u64>, params: &LweDef) -> LwePublicKey<u64> {
LwePublicKey::generate(sk, params)
}
/// Generate a new GLWE secret key under the given GLWE parameters.
/// The key will consist of {0,1}-valued coefficients.
///
/// # Remarks
/// Any functions that use this key will need the same [`GlweDef`].
///
/// # Panics
/// If [`GlweDef`] is invalid.
///
/// # Security
/// Binary [GlweSecretKey]s are insecure in some threshold cryptography
/// settings. Under those settings, you should use
/// [`generate_uniform_glwe_sk`].
///
/// This key is secret and care should be taken as to which parties
/// possess it. Anyone who possesses the returned [`GlweSecretKey`]
/// can decrypt any messages encrypted under it.
pub fn generate_binary_glwe_sk(params: &GlweDef) -> GlweSecretKey<u64> {
GlweSecretKey::generate_binary(params)
}
/// Generate a new GLWE secret key under the given GLWE parameters.
/// The key will consist of uniform coefficients over the Torus's
/// isomorphic ring.
///
/// # Remarks
/// Any functions that use this key will need the same [`GlweDef`].
///
/// This is used over binary in threshold cryptography settings.
///
/// # Panics
/// If [`GlweDef`] is invalid.
///
/// # Security
/// This key is secret and care should be taken as to which parties
/// possess it. Anyone who possesses the returned [`GlweSecretKey`]
/// can decrypt any messages encrypted under it.
pub fn generate_uniform_glwe_sk(params: &GlweDef) -> GlweSecretKey<u64> {
GlweSecretKey::generate_uniform(params)
}
/// Generate a new bootstrapping key, which is used in bootstrapping operations.
///
/// See also
/// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap)
/// and [`circuit_bootstrap`](super::evaluation::circuit_bootstrap).
///
/// # Remarks
/// A bootstrapping key is an encryption of an LWE secret key under a different
/// GLWE secret key reinterpreted as an LWE key.
///
/// `lwe` and `glwe` are the LWE and GLWE parameters used when you generated the
/// [`LweSecretKey`] and [`GlweSecretKey`], respectively.
///
/// `radix` specifies the decomposition to use during bootstrapping.
///
/// You should use the same `lwe`, `glwe`, `radix` values here as when you
/// call
/// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap).
///
/// The returned bootstrapping key is not immediately useful outside of serialization.
/// You need to FFT transform is first (see [fft_bootstrap_key](super::fft::fft_bootstrap_key)).
///
/// ## Circuit bootstrapping
/// When using [`circuit_bootstrap`](super::evaluation::circuit_bootstrap), `pbs_radix`
/// must match this `radix`, `lwe_0` must match `lwe`, and `glwe_2` must match `glwe`.
///
/// # Panics
/// If `lwe`, `glwe`, or `radix` are invalid.
/// If `glwe_key` isn't valid under `glwe`.
/// If `sk` isn't valid under `lwe`.
///
/// # Security
/// The returned key is public and does not compromise semantic security.
/// However, anyone who possesses `glwe_key` can easily use the returned
/// [`BootstrapKey`] to recover `sk`.
pub fn generate_bootstrapping_key(
sk: &LweSecretKey<u64>,
glwe_key: &GlweSecretKey<u64>,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
) -> BootstrapKey<u64> {
let mut bsk = BootstrapKey::new(lwe, glwe, radix);
generate_bootstrap_key(&mut bsk, sk, glwe_key, glwe, radix);
bsk
}
/// Generate an LWE keyswitch key. LWE keyswitching allows you take an encryption of `m`
/// under [LWESecretKey](crate::entities::LweSecretKey) `from_sk` and turn it into an
/// encryption of `m` under `to_sk`.
///
/// # Remarks
/// `from_lwe` and `to_lwe` are the parameters under which you generated `from_sk` and
/// `to_sk`, respectively.
///
/// `radix` specifies the decomposition to use during LWE keyswitching.
///
/// TODO: mention keyswitch operation.
///
/// # Panics
/// If the `from_lwe` parameters aren't valid for `from_sk`.
/// If the `to_lwe` parameters aren't valid for `to_sk`.
/// If `from_lwe`, `to_lwe`, or `radix` parameters are invalid.
///
/// # Security
/// The returned key is public and sharing it does not compromise semantic
/// security. However, anyone who possesses `to_lwe` will effectively
/// be able to decrypt any message encrypted under `from_lwe` with the
/// returned [`LweKeyswitchKey`].
pub fn generate_ksk(
from_sk: &LweSecretKeyRef<u64>,
to_sk: &LweSecretKeyRef<u64>,
from_lwe: &LweDef,
to_lwe: &LweDef,
radix: &RadixDecomposition,
) -> LweKeyswitchKey<u64> {
let mut ksk = LweKeyswitchKey::new(from_lwe, to_lwe, radix);
generate_keyswitch_key_lwe(&mut ksk, from_sk, to_sk, to_lwe, radix);
ksk
}
/// Generate a set of [`CircuitBootstrappingKeyswitchKeys`] to use during
/// [circuit_bootstrap](super::evaluation::circuit_bootstrap) operations.
///
/// # Remarks
/// Internally, [`CircuitBootstrappingKeyswitchKeys`] is a list of
/// [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s.
///
/// During [circuit_bootstrap](super::evaluation::circuit_bootstrap) operations,
/// these keys are used to convert [`LweCiphertext`](crate::entities::LweCiphertext)s
/// encrypted under `from_sk` into [`GlweCiphertext`](crate::entities::GlweCiphertext)s
/// encrypted under `to_sk`. These [`GlweCiphertext`](crate::entities::GlweCiphertext)s
/// together form a [`GgswCiphertext`](crate::entities::GgswCiphertext).
///
/// The `from_lwe` and `to_glwe` parameters correspond to those used when you generated
/// `from_sk` and `to_sk`, respectively.
///
/// The `radix` parameter describes the [`RadixDecomposition`] to use during the private
/// functional keyswitch operation. When performing a
/// [circuit_bootstrap](super::evaluation::circuit_bootstrap), these same parameters should
/// be passed as `pfks_radix`.
///
/// # Panics
/// If `from_lwe`, `to_glwe`, or `radix` are invalid.
/// If `from_lwe` or `to_glwe` parameters don't correspond with `to_sk` or `to_glwe`, respectively.
///
/// # Security
/// The returned [`CircuitBootstrappingKeyswitchKeys`] are public and
/// don't in of themselves compromise semantic security. However,
/// anyone who possesses `to_sk` can easily recover `from_sk` using this
/// information.
pub fn generate_cbs_ksk(
from_sk: &LweSecretKeyRef<u64>,
to_sk: &GlweSecretKeyRef<u64>,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) -> CircuitBootstrappingKeyswitchKeys<u64> {
let mut cbs_ksk = CircuitBootstrappingKeyswitchKeys::new(from_lwe, to_glwe, radix);
generate_circuit_bootstrapping_pfks_keys(
&mut cbs_ksk,
from_sk,
to_sk,
from_lwe,
to_glwe,
radix,
);
cbs_ksk
}
}
/// TFHE functionality related to encryption.
pub mod encryption {
use crate::{
entities::{
GgswCiphertext, GgswCiphertextRef, GlweCiphertext, GlweCiphertextRef, GlweSecretKeyRef,
LweCiphertext, LweCiphertextRef, LwePublicKeyRef, LweSecretKeyRef, Polynomial,
PolynomialRef, TlwePublicEncRandomness,
},
ops::encryption::{encrypt_ggsw_ciphertext_scalar, trivially_encrypt_lwe_ciphertext},
CarryBits, GlweDef, LweDef, PlaintextBits, RadixDecomposition, Torus,
};
/// Create an [`LweCiphertext`] encryption of `val` under
/// [LweSecretKey](crate::entities::LweSecretKey) `sk`.
///
/// # Remarks
/// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk) to
/// generate an [`LweSecretKey`](crate::entities::LweSecretKey).
///
/// `params` should be the same parameters used when generating the secret key.
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`.
///
/// # Panics
/// If `params` is invalid.
/// If `params` don't correspond with `sk`.
pub fn encrypt_lwe_secret(
val: u64,
sk: &LweSecretKeyRef<u64>,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> LweCiphertext<u64> {
sk.encrypt(val, params, plaintext_bits).0
}
/// Create a tuple containing an [`LweCiphertext`] encryption of `val`
/// under [LweSecretKey](crate::entities::LweSecretKey) `sk` and the
/// randomness used to generate it.
///
/// This randomness can be used to produce zero-knowledge proofs.
///
/// # Remarks
/// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk)
/// to generate an [`LweSecretKey`](crate::entities::LweSecretKey).
///
/// `params` should be the same parameters used when generating the secret key.
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`.
///
/// # Panics
/// If `params` is invalid.
/// If `params` don't correspond with `sk`.
///
/// # Security
/// Revealing the returned randomness compromises the confidentiality
/// of the returned [`LweCiphertext`].
pub fn encrypt_lwe_secret_and_return_randomness(
val: u64,
sk: &LweSecretKeyRef<u64>,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> (LweCiphertext<u64>, Torus<u64>) {
sk.encrypt(val, params, plaintext_bits)
}
/// Create an [LweCiphertext] encryption of `val` under the secret
/// key that pairs with `pk`.
///
/// This function uses the [`LwePublicKey`](crate::entities::LwePublicKey),
/// which may be freely distributed without compromising security.
///
/// # Remarks
/// Use [`generate_lwe_pk`](super::keygen::generate_lwe_pk) to generate
/// a [`LwePublicKey`](crate::entities::LwePublicKey).
///
/// # Panics
/// If `params` is invalid.
/// If `params` doesn't correspond with `pk`.
pub fn encrypt_lwe(
val: u64,
pk: &LwePublicKeyRef<u64>,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> LweCiphertext<u64> {
pk.encrypt(val, params, plaintext_bits).0
}
/// Create a tuple containing an [`LweCiphertext`] encryption of `val`
/// under the [LweSecretKey](crate::entities::LweSecretKey) paired with
/// `pk` and the randomness used to generate it.
///
/// This randomness can be used to produce zero-knowledge proofs.
///
/// # Remarks
/// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk)
/// to generate an [`LweSecretKey`](crate::entities::LweSecretKey).
///
/// `params` should be the same parameters used when generating the secret key.
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`.
///
/// # Panics
/// If `params` is invalid.
/// If `params` don't correspond with `pk`.
///
/// # Security
/// Revealing the returned randomness compromises the confidentiality
/// of the returned [`LweCiphertext`].
pub fn encrypt_lwe_and_return_randomness(
val: u64,
pk: &LwePublicKeyRef<u64>,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> (LweCiphertext<u64>, TlwePublicEncRandomness<u64>) {
pk.encrypt(val, params, plaintext_bits)
}
/// Create a [`GlweCiphertext`] encryption of `pt` under `sk`.
///
/// # Remarks
/// `params` should be a same as those used when creating `sk`.
/// `pt.len()` should equal `sk.dim.polynomial_degree.0`.
///
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// polynomial coefficients will contain the message. No coefficient in
/// `pt` should exceed `2^plaintext_bits.0`.
///
/// # Panics
/// If `params` is invalid.
/// If `params` doesn't correspond with `sk`
/// If `pt` doesn't have the same number of coefficients as
/// `params.dim.polynomial_degree.0`.
pub fn encrypt_glwe(
pt: &PolynomialRef<u64>,
sk: &GlweSecretKeyRef<u64>,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) -> GlweCiphertext<u64> {
sk.encode_encrypt_glwe(pt, params, plaintext_bits)
}
/// Create a trivial LWE encryption. Trivial encryptions have no noise and are thus
/// insecure. However, they are useful for creating public constants in TFHE computations.
///
/// Trivial encryptions are valid encryptions of `val` under *every* LWE secret key
/// using the same `params`.
///
/// # Remarks
/// `params` are the parameters under which to create this encryption. Note that homomorphic
/// operations require every operand use the same [`LweDef`] parameters.
///
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`. Truncation
/// of `val`'s most-significant bits will occur otherwise.
///
/// # Panics
/// If `plaintext_bits > 63`.
/// If `params` are invalid.
///
/// # Security
/// Trivial encryptions provide no cryptographic security.
pub fn trivial_lwe(
val: u64,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> LweCiphertext<u64> {
let mut ct = LweCiphertext::new(params);
trivially_encrypt_lwe_ciphertext(&mut ct, &Torus::encode(val, plaintext_bits), params);
ct
}
/// Decrypt [LweCiphertext] `ct` encrypted under [LweSecretKey](crate::entities::LweSecretKey)
/// `sk`. Decode this decrypted value and return it.
///
/// # Remarks
/// `params` must correspond with `ct` and `sk`.
///
/// If a different `sk` is used than the one that produced `ct`, the result will be
/// garbage.
///
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// will contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally,
/// you should use the same `plaintext_bits` that were used during encryption.
///
/// # Panics
/// If `params` doesn't correspond with either `ct` or `sk`.
/// If `params` is invalid.
pub fn decrypt_lwe(
ct: &LweCiphertextRef<u64>,
sk: &LweSecretKeyRef<u64>,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> u64 {
sk.decrypt(ct, params, plaintext_bits)
}
/// Decrypt [LweCiphertext] `ct` encrypted under [LweSecretKey](crate::entities::LweSecretKey)
/// `sk` and decode it, handling carry bits.
///
/// # Remarks
/// `params` must correspond with `ct` and `sk`.
///
/// If a different `sk` is used than the one that produced `ct`, the result will be
/// garbage.
///
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// will contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally,
/// you should use the same `plaintext_bits` that were used during encryption.
///
/// `carry_bits` describes how many of the most-significant bits of the
/// [`Torus`] will contain the carry.
///
/// # Panics
/// If `params` doesn't correspond with either `ct` or `sk`.
/// If `params` is invalid.
pub fn decrypt_lwe_with_carry(
ct: &LweCiphertextRef<u64>,
sk: &LweSecretKeyRef<u64>,
params: &LweDef,
plaintext_bits: PlaintextBits,
carry_bits: CarryBits,
) -> u64 {
let decrypted = sk.decrypt_without_decode(ct, params);
// We manually decode here because the padding bit.
let plain_bits = plaintext_bits;
let round_bit = decrypted
.inner()
.wrapping_shr(64 - plain_bits.0 - carry_bits.0 - 1)
& 0x1;
let mask = (0x1 << plain_bits.0) - 1;
(decrypted
.inner()
.wrapping_shr(64 - plain_bits.0 - carry_bits.0)
+ round_bit)
& mask
}
/// Decrypt [GlweCiphertext] `ct` encrypted under [GlweSecretKey](crate::entities::GlweSecretKey)
/// `sk`. Decode this decrypted value and return it.
///
/// # Remarks
/// `params` must correspond with `ct` and `sk`.
///
/// If a different `sk` is used than the one that produced `ct`, the result will be
/// garbage.
///
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] polynomial's
/// coefficients contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally,
/// you should use the same `plaintext_bits` that were used during encryption.
///
/// # Panics
/// If `params` doesn't correspond with either `ct` or `sk`.
/// If `params` is invalid.
pub fn decrypt_glwe(
ct: &GlweCiphertextRef<u64>,
sk: &GlweSecretKeyRef<u64>,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) -> Polynomial<u64> {
sk.decrypt_decode_glwe(ct, params, plaintext_bits)
}
/// Create a trivial encryption of `pt` as a [GlweCiphertext]. Trivial encryptions contain
/// no noise and are thus insecure. However, they are useful as public constants in
/// a TFHE computation.
///
/// # Remarks
/// Trivial encryption are valid under *every* [GlweSecretKey](crate::entities::GlweSecretKey)
/// using `params`. Note that homomorphic operations require every operand to use the same
/// `params` and secret key.
///
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
/// polynomial coefficients will contain the message. `val` should not exceed
/// `2^plaintext_bits.0`. Truncation of `val`'s most-significant bits will occur otherwise.
///
/// # Panics
/// If `params` are invalid.
/// If `plaintext_bits >= 64`.
///
/// # Security
/// Trivial encryptions provide no cryptographic security.
pub fn trivial_glwe(
pt: &PolynomialRef<u64>,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) -> GlweCiphertext<u64> {
let mut result = GlweCiphertext::new(params);
for (b_out, b_in) in result
.b_mut(params)
.coeffs_mut()
.iter_mut()
.zip(pt.coeffs().iter())
{
*b_out = Torus::encode(*b_in, plaintext_bits);
}
result
}
/// Create a [`GgswCiphertext`] encrypting `msg` under
/// [GlweSecretKey](crate::entities::GlweSecretKey) `sk`.
///
/// # Remarks
/// This encrypts `msg` as a constant coefficient polynomial. While [`GgswCiphertext`]s
/// theoretically support encrypting arbitrary polynomial messages, such ciphertexts have
/// no known uses in TFHE.
///
/// Typically, you'll want to set `plaintext_bits` to 1 and encrypt a binary
/// `msg`. This allows you to use the [`GgswCiphertext`] as the select input to
/// a [`cmux`](super::evaluation::cmux) operation.
///
/// `params` should match those under which you generated `sk`.
///
/// `radix` defines how many decompositions to include in the result. Subsequent
/// [`cmux`](super::evaluation::cmux) operations using this result must use the
/// same `radix`.
///
/// [`GgswCiphertext`]s are not immediate useful outside of serialization. You must
/// first take its Fourier transform using [`fft_ggsw`](super::fft::fft_ggsw) before
/// using it in a [`cmux`](super::evaluation::cmux) operation.
///
/// # Panics
/// If `params` or `radix` are invalid.
/// If `params` doesn't correspond to `sk`.
/// If `plaintext_bits >= 64`.
pub fn encrypt_ggsw(
msg: u64,
sk: &GlweSecretKeyRef<u64>,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) -> GgswCiphertext<u64> {
let mut result = GgswCiphertext::new(params, radix);
encrypt_ggsw_ciphertext_scalar(&mut result, msg, sk, params, radix, plaintext_bits);
result
}
/// Decrypt a [`GgswCiphertext`] encrypted under `sk`.
/// Since GGSW ciphertexts generally contain binary, you should
/// usually set `plaintext_bits` to 1.
///
/// # Remarks
/// `params` should correspond with `ct` and `sk`.
/// `radix` should correspond with `ct`.
///
/// # Panics
/// If `params` or `radix` are invalid.
/// If `params` or `radix` don't correspond to `ct`
/// If `params` don't correspond to `sk`.
pub fn decrypt_ggsw(
ct: &GgswCiphertextRef<u64>,
sk: &GlweSecretKeyRef<u64>,
params: &GlweDef,
radix: &RadixDecomposition,
_plaintext_bits: PlaintextBits,
) -> Polynomial<u64> {
let mut msg = Polynomial::zero(params.dim.polynomial_degree.0);
crate::ops::encryption::decrypt_ggsw_ciphertext(&mut msg, ct, sk, params, radix);
msg.map(|x| x.inner())
}
}
/// Operations for producing Fourier-transformed versions of entities.
pub mod fft {
use num::Complex;
use crate::{
entities::{
BootstrapKeyFft, BootstrapKeyRef, GgswCiphertextFft, GgswCiphertextRef,
GlweCiphertextFft, GlweCiphertextRef,
},
GlweDef, LweDef, RadixDecomposition,
};
/// Take the fourier transform of a [`GlweCiphertext`](crate::entities::GlweCiphertext).
///
/// # Remarks
/// `params` must be the same parameters that produce `ct`.
///
/// # Panics
/// `params` is invalid.
/// `params` doesn't correspond with `ct`.
pub fn fft_glwe(
ct: &GlweCiphertextRef<u64>,
params: &GlweDef,
) -> GlweCiphertextFft<Complex<f64>> {
let mut fft = GlweCiphertextFft::new(params);
ct.fft(&mut fft, params);
fft
}
/// Take the fourier transform of a [`GgswCiphertext`](crate::entities::GgswCiphertext).
///
/// # Remarks
/// `glwe` and `radix` must be the same parameters that produced `ggsw`.
///
/// For [`GgswCiphertext`](crate::entities::GgswCiphertext)s that result from a
/// [`circuit_bootstrap`](super::evaluation::circuit_bootstrap) operation, these
/// must match `glwe_1` and `cbs_radix` respectively.
///
/// # Panics
/// If `glwe` and `radix` don't correspond with `ggsw`.
/// If `glwe` or `radix` are invalid.
pub fn fft_ggsw(
ggsw: &GgswCiphertextRef<u64>,
glwe: &GlweDef,
radix: &RadixDecomposition,
) -> GgswCiphertextFft<Complex<f64>> {
let mut fft = GgswCiphertextFft::new(glwe, radix);
ggsw.fft(&mut fft, glwe, radix);
fft
}
/// Take the fourier transform of a [BootstrapKey](crate::entities::BootstrapKey).
/// The resulting [`BootstrapKeyFft`] may be used in
/// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap)
/// and [`circuit_bootstrap`](super::evaluation::circuit_bootstrap)
/// operations.
///
/// # Remarks
/// `glwe` and `radix` must be the same parameters that produced `bsk`.
///
/// # Panics
/// If `glwe` and `radix` don't correspond with `bsk`.
/// If `glwe` or `radix` are invalid.
pub fn fft_bootstrap_key(
bsk: &BootstrapKeyRef<u64>,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
) -> BootstrapKeyFft<Complex<f64>> {
let mut bsk_fft = BootstrapKeyFft::new(lwe, glwe, radix);
bsk.fft(&mut bsk_fft, glwe, radix);
bsk_fft
}
}
/// TFHE operations for performing computation.
pub mod evaluation {
use num::Complex;
use crate::{
entities::{
BootstrapKeyFft, BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef,
GgswCiphertext, GgswCiphertextFftRef, GlweCiphertext, GlweCiphertextRef, LweCiphertext,
LweCiphertextRef, LweKeyswitchKeyRef, UnivariateLookupTableRef,
},
GlweDef, LweDef, RadixDecomposition,
};
/// Perform a multiplexing operation. When `b_fft` encrypts a zero polynomial,
/// the resulting [`GlweCiphertext`] will the same message as `d_0`. When `b_fft`
/// encrypts the 1 polynomial, the result will contain the same message as `d_1`.
///
/// # Remarks
/// `b_fft`, `d_0`, and `d_1` must all be encrypted under the same
/// [`GlweSecretKey`](crate::entities::GlweSecretKey). This implies `params` must
/// correspond with all three values.
///
/// Additionally, `radix` must correspond to `b_fft`.
///
/// For
/// [`GgswCiphertext`] resulting from [`circuit_bootstrap`] operations,
/// `radix` must be the same as `cbs_radix` and `params` must be the same as
/// `glwe_1`.
///
/// # Panics
/// If `params` doesn't correspond with `b_fft`, `d_0`, `d_1`.
/// If `radix` doesn't correspond with `b_fft`.
/// If `radix` or `params` are invalid.
pub fn cmux(
b_fft: &GgswCiphertextFftRef<Complex<f64>>,
d_0: &GlweCiphertextRef<u64>,
d_1: &GlweCiphertextRef<u64>,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlweCiphertext<u64> {
let mut result = GlweCiphertext::new(params);
crate::ops::fft_ops::cmux(&mut result, d_0, d_1, b_fft, params, radix);
result
}
#[allow(clippy::too_many_arguments)]
/// Perform a programmable bootstrapping operation. Bootstrapping takes
/// `input` and produces a new ciphertext with a fixed noise level, applying
/// a univariate function defined by `lut` in the process.
///
/// This new ciphertext is encrypted under a (usually) different
/// [`LweSecretKey`](crate::entities::LweSecretKey) defined by `glwe` interpreted
/// as an [`LweDef`].
///
/// See also [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable).
///
/// To switch the message back to the original key, you need to perform an LWE
/// keyswitch operation. TODO: hotlink
///
/// # Remarks
/// `input` must be valid under the `lwe` parameters.
/// `lwe`, `glwe`, and `radix` parameters must be the same as those used when
/// first creating the `bsk`.
/// `lut` must be valid under `glwe` parameters.
///
/// # Panics
/// If `lwe`, `glwe`, or `radix` parameters are invalid.
/// If `input` doesn't correspond to `lwe` parameters.
/// If `bsk` doesn't correspond to `lwe`, `glwe`, `radix` parameters.
/// If `lut` doesn't correspond to `glwe` parameters.
pub fn univariate_programmable_bootstrap(
input: &LweCiphertextRef<u64>,
lut: &UnivariateLookupTableRef<u64>,
bsk: &BootstrapKeyFft<Complex<f64>>,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
) -> LweCiphertext<u64> {
let mut out = LweCiphertext::new(&glwe.as_lwe_def());
crate::ops::bootstrapping::programmable_bootstrap(
&mut out, input, lut, bsk, lwe, glwe, radix,
);
out
}
#[allow(clippy::too_many_arguments)]
/// Perform a circuit bootstrapping operation. Circuit bootstrapping takes
/// `input` [LweCiphertext] encrypted under a [LweSecretKey](crate::entities::LweSecretKey)
/// constructed with `lwe_0` parameters and produces a [GgswCiphertext] encrypted
/// under a [GlweSecretKey](crate::entities::GlweSecretKey) constructed with `glwe_1`
/// parameters.
///
/// See also [generate_bootstrapping_key](super::keygen::generate_bootstrapping_key) and
/// [generate_cbs_pfks](super::keygen::generate_cbs_ksk) for how to generate the required
/// keys.
///
/// # Remarks
/// Internally, circuit bootstrapping occurs in 2 steps. First we perform programmable
/// bootstraps (PBS) from `lwe_0` parameters to `glwe_2` reinterpreted as LWE parameters. We
/// do this `cbs_radix.count` times using univariate functions that map the message in
/// input to its corresponding radix decomposition.
///
/// For step 2, we use private functional keyswitching (PFKS) to transform the
/// `cbs_radix.count` [LweCiphertext]s encrypted under `glwe_2` into
/// `cbs_radix.count * glwe_2.size + 1` [GlweCiphertext]s. The PFKS operations multiply each
/// [GlevCiphertext](crate::entities::GlevCiphertext) by the corresponding polynomial in
/// the `glwe_1` [GlweSecretKey](crate::entities::GlweSecretKey) to create a valid
/// [GgswCiphertext].
///
/// To summarize, we use PBS to turn an `lwe_0` LWE ciphertext into `glwe_2` LWE ciphertexts.
/// We then use PFKS to turn the `glwe_2` LWE ciphertexts into `glwe_1` GLWE ciphertexts
/// arranged as a valid [GgswCiphertext] encrypting the same value as `input` as a constant
/// coefficient polynomial.
///
/// See [crate::ops::bootstrapping::circuit_bootstrap] for more details.
///
/// `pbs_radix` parameterizes the bootstrapping operation (step 1).
///
/// `pfks_radix` parameterizes the PFKS operation (step 2). These should
///
/// `cbs_radix` parameterizes the final decomposition of the resulting [`GgswCiphertext`]. This
/// should match the radix used when creating the
/// [CircuitBootstrappingKeyswitchKeys](crate::entities::CircuitBootstrappingKeyswitchKeys)
///
///
/// # Panics
/// If `pbs_radix`, `cbs_radix`, `pfksk_radix`, `lwe_0`, `glwe_2`, or `glwe_1` are invalid.
/// If `pbs_radix`, `lwe_0`, `glwe_2` don't match the `radix`, `lwe`, `glwe` (respectively) used to generate `bsk`.
/// If `pfks_radix`, `glwe_2.as_lwe_def()`, `glwe_1` parameters don't match the `radix`, `from_lwe`, `to_lwe`
/// (respectively) used to generate `cbsksk`.
pub fn circuit_bootstrap(
input: &LweCiphertextRef<u64>,
bsk: &BootstrapKeyFftRef<Complex<f64>>,
cbsksk: &CircuitBootstrappingKeyswitchKeysRef<u64>,
lwe_0: &LweDef,
glwe_1: &GlweDef,
glwe_2: &GlweDef,
pbs_radix: &RadixDecomposition,
cbs_radix: &RadixDecomposition,
pfks_radix: &RadixDecomposition,
) -> GgswCiphertext<u64> {
let mut out = GgswCiphertext::new(glwe_1, cbs_radix);
crate::ops::bootstrapping::circuit_bootstrap(
&mut out, input, bsk, cbsksk, lwe_0, glwe_1, glwe_2, pbs_radix, cbs_radix, pfks_radix,
);
out
}
/// Perform LWE keyswitching to produce a new [`LweCiphertext`] encrypted
/// under a different [`LweSecretKey`](crate::entities::LweSecretKey).
///
/// # Remarks
/// When creating `ksk` with [`generate_ksk`](super::keygen::generate_ksk),
/// you pass 2 different [`LweSecretKey`](crate::entities::LweSecretKey)
/// values: a `from_sk` and `to_sk`. `ct` should be encrypted under the
/// same `from_sk` and ciphertext this function returns will be encrypted
/// under `to_sk`.
pub fn keyswitch_lwe_to_lwe(
ct: &LweCiphertextRef<u64>,
ksk: &LweKeyswitchKeyRef<u64>,
from_lwe: &LweDef,
to_lwe: &LweDef,
radix: &RadixDecomposition,
) -> LweCiphertext<u64> {
let mut new_ct = LweCiphertext::new(to_lwe);
crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe(
&mut new_ct,
ct,
ksk,
from_lwe,
to_lwe,
radix,
);
new_ct
}
}

35
sunscreen_tfhe/src/lib.rs Normal file
View File

@@ -0,0 +1,35 @@
//! TFHE low-level library
#![deny(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]
#[macro_use]
mod dst;
/// The entities module contains the main data structures used in the library.
pub mod entities;
/// Higher level operations in TFHE such as programmable bootstrapping, keyswitching, etc.
pub mod ops;
/// Parameters that define a TFHE scheme.
pub mod params;
pub use params::*;
/// Math operations on various math primitives such as polynomials.
pub mod math;
pub use math::*;
mod macros;
/// Random number generation.
pub mod rand;
mod scratch;
/// A high-level API for interfacing with TFHE. Allocates, computes with and returns
/// objects as you would expect from a Rust API.
pub mod high_level;
/// Zero Knowledge proofs for TFHE.
#[cfg(feature = "logproof")]
pub mod zkp;

View File

@@ -0,0 +1,108 @@
macro_rules! impl_binary_op {
($op:ident, $type:ty, ($($t_bounds:ty),* $(,)? )) => {
paste::paste! {
// Ex: AddAssign for LweSecretKey
impl<S> std::ops::[<$op Assign>] for $type<S>
where
S: $($t_bounds)*
{
fn [<$op:lower _assign>](&mut self, rhs: Self) {
self.data.iter_mut().zip(rhs.data.iter()).for_each(|(a, b)| {
*a = num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a, b);
});
}
}
// Ex: Add for LweSecretKey
// Calls Add for &LweSecretKeyRef
impl<S> std::ops::$op for $type<S>
where
S: TorusOps,
{
type Output = Self;
fn [<$op:lower >](self, rhs: Self) -> Self::Output {
std::ops::$op::[< $op:lower >](self.as_ref(), rhs.as_ref())
}
}
// Ex: WrappingAdd for LweSecretKey
// Calls Add for &LweSecretKeyRef
impl<S> num::traits::[<Wrapping $op>] for $type<S>
where
S: TorusOps,
{
fn [<wrapping_ $op:lower>](&self, rhs: &Self) -> Self {
std::ops::$op::[< $op:lower >](self.as_ref(), rhs.as_ref())
}
}
// Ex: Add for &LweSecretKeyRef
// Calls AddAssign for LweSecretKey
impl<S> std::ops::$op for &[<$type Ref>]<S>
where
S: TorusOps,
{
type Output = $type<S>;
fn [< $op:lower >](self, rhs: Self) -> Self::Output {
let mut a = self.to_owned();
std::ops::[< $op Assign >]::[< $op:lower _assign>](&mut a, rhs.to_owned());
a
}
}
}
};
}
macro_rules! impl_unary_op {
($op:ident, $type:ty) => {
paste::paste! {
// Ex: Neg for LweSecretKey
// Calls Neg for &LweSecretKeyRef
impl<S> std::ops::$op for $type<S>
where
S: TorusOps,
{
type Output = Self;
fn [<$op:lower>](self) -> Self::Output {
std::ops::$op::[<$op:lower>](self.as_ref())
}
}
// Ex: Neg for &LweSecretKeyRef
impl<S> std::ops::$op for &[<$type Ref>]<S>
where
S: TorusOps,
{
type Output = $type<S>;
fn [<$op:lower>](self) -> Self::Output {
// We call the wrapping trait instead of using the dot
// syntax because the dot syntax can dereference the value
// and can cause problems with Deref.
let data = self.data.iter().map(|a| num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a)).collect();
$type { data }
}
}
// Ex: WrappingNeg for LweSecretKey
// Calls Neg for &LweSecretKeyRef
impl<S> num::traits::[<Wrapping $op>] for $type<S>
where
S: TorusOps,
{
fn [<wrapping_ $op:lower>](&self) -> Self {
std::ops::$op::[<$op:lower>](self.as_ref())
}
}
}
};
}
pub(crate) use impl_binary_op;
pub(crate) use impl_unary_op;

View File

@@ -0,0 +1,64 @@
/// An integer type that supports rounding division.
pub trait RoundedDiv {
/// Divides two numbers and rounds the result to the nearest integer.
fn div_rounded(&self, divisor: Self) -> Self;
}
macro_rules! div_rounded {
($t:ty) => {
impl RoundedDiv for $t {
#[inline(always)]
fn div_rounded(&self, divisor: $t) -> $t {
// There are a few ways to do this, but we chose the following
// because it allows the entire range of a type to be used. The
// other common method is to add half the divisor to the
// numerator, but that effectively cuts the size of the possible
// inputs in half before overflow.
let q = self / divisor;
let r = self % divisor;
if r >= divisor / 2 {
q + 1
} else {
q
}
}
}
};
}
div_rounded!(u8);
div_rounded!(u16);
div_rounded!(u32);
div_rounded!(u64);
div_rounded!(u128);
div_rounded!(usize);
div_rounded!(i8);
div_rounded!(i16);
div_rounded!(i32);
div_rounded!(i64);
div_rounded!(i128);
div_rounded!(isize);
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use super::*;
#[test]
fn test_div_rounded() {
for _ in 0..1_000 {
let a = thread_rng().next_u64();
let mut b = thread_rng().next_u64();
if b == 0 {
b = 1;
}
let expected = ((a as f64) / (b as f64)).round() as u64;
let actual = a.div_rounded(b);
assert_eq!(expected, actual);
}
}
}

View File

@@ -0,0 +1,184 @@
use std::ops::{Add, Mul, Neg};
use std::sync::Arc;
use num::{complex::Complex, Float};
use realfft::{ComplexToReal, FftNum, RealFftPlanner, RealToComplex};
use sunscreen_math::{One, Zero as SunscreenZero};
use crate::{FrequencyTransform, Inverse, Pow, RootOfUnity};
/// A struct that can perform a real FFT.
pub struct RealFft<T>
where
T: FftNum,
{
pub(crate) fplan: Arc<dyn RealToComplex<T>>,
pub(crate) rplan: Arc<dyn ComplexToReal<T>>,
pub(crate) scale: T,
}
impl<T> RealFft<T>
where
T: FftNum + Float,
{
/// Create a new [RealFft] with the given size.
pub fn new(n: usize) -> Self {
let mut plan = RealFftPlanner::<T>::new();
let fplan = plan.plan_fft_forward(n);
let rplan = plan.plan_fft_inverse(n);
let scale = T::from(1.0).unwrap() / T::from(n).unwrap();
Self {
fplan,
rplan,
scale,
}
}
}
impl<T> FrequencyTransform for RealFft<T>
where
T: FftNum + Float,
{
type BaseRepr = T;
type FrequencyRepr = Complex<T>;
fn forward(&self, data: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]) {
assert_eq!(data.len() / 2 + 1, output.len());
self.fplan.process(&mut data.to_owned(), output).unwrap();
}
fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]) {
self.rplan.process(&mut data.to_owned(), output).unwrap();
output.iter_mut().for_each(|x| {
*x = *x * self.scale;
});
}
}
/// A struct that can perform a NTT using a naive algorithm.
#[allow(unused)]
pub struct NaiveNtt<T> {
twiddle: Vec<T>,
inv_twiddle: Vec<T>,
n_inv: T,
}
impl<T> NaiveNtt<T>
where
T: RootOfUnity
+ Copy
+ SunscreenZero
+ One
+ Mul<T, Output = T>
+ Add<T, Output = T>
+ Neg<Output = T>
+ Pow<u64>
+ From<u64>
+ Inverse,
{
/// Create a new [NaiveNtt] with the given size.
#[allow(unused)]
pub fn new(n: usize) -> Self {
let mut twiddle = vec![];
let mut inv_twiddle = vec![];
let root = T::nth_root_of_unity(n as u64);
let inv_root = root.inverse();
for i in 0..n as u64 {
twiddle.push(root.pow(i));
inv_twiddle.push(inv_root.pow(i));
}
Self {
twiddle,
inv_twiddle,
n_inv: T::from(n as u64).inverse(),
}
}
}
#[cfg(test)]
mod tests {
use num::complex::ComplexFloat;
use crate::FrequencyTransform;
use super::*;
#[test]
fn can_roundtrip_real_fft() {
let n = 256;
let fft = RealFft::<f64>::new(n);
let input = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut points = vec![Complex::from(0.0); input.len() / 2 + 1];
let mut result = vec![0.0; input.len()];
fft.forward(&input, &mut points);
fft.reverse(&points, &mut result);
for (l, r) in input.iter().zip(result.iter()) {
assert!((l - r).abs() < 1e-12);
}
}
#[test]
fn negacyclic_gives_odd_harmonics() {
let n = 512;
let nega_len = n / 2;
let fft = RealFft::<f64>::new(n);
let input = (0..n)
.enumerate()
.map(|(i, x)| {
let val = x % nega_len;
if i < nega_len {
val as f64
} else {
-(val as f64)
}
})
.collect::<Vec<_>>();
let mut points = vec![Complex::from(0.0); input.len() / 2 + 1];
fft.forward(&input, &mut points);
for (i, x) in points.iter().enumerate() {
if i % 2 == 0 {
assert!(x.re().abs() < 1e-12);
assert!(x.im().abs() < 1e-12);
} else {
assert!(x.re().abs() > 1.0);
assert!(x.im().abs() > 1.0);
}
}
}
#[test]
fn can_cyclic_convolution() {
let n = 4;
let fft = RealFft::<f64>::new(n);
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut actual = vec![0.0; n];
let mut y = vec![Complex::from(0.0); n / 2 + 1];
fft.forward(&x, &mut y);
let z = y.iter().map(|y| y * y).collect::<Vec<_>>();
fft.reverse(&z, &mut actual);
let expected = [10.0, 12.0, 10.0, 4.0];
for (e, a) in expected.iter().zip(actual.iter()) {
assert!((e - a).abs() < 1e-12);
}
}
}

View File

@@ -0,0 +1,5 @@
/// FFT based operations over real numbers.
pub mod cyclic;
/// FFT based operations over twisted cyclotomics.
pub mod negacyclic;

View File

@@ -0,0 +1,176 @@
use std::{
f64::consts::PI,
sync::{Arc, Once},
};
use num::{Complex, Float, One};
use realfft::FftNum;
use rustfft::{Fft, FftPlanner};
use crate::{scratch::allocate_scratch, FrequencyTransform};
static FFT_CACHE_INIT: Once = Once::new();
static mut FFT_CACHE: Vec<TwistedFft<f64>> = vec![];
/// Get a [TwistedFft] for a given log N.
pub fn get_fft(log_n: usize) -> &'static TwistedFft<f64> {
// Can FFT powers of 2 from N=1 up to 4096.
assert!(log_n < 13);
FFT_CACHE_INIT.call_once(|| {
for i in 0..13 {
unsafe { FFT_CACHE.push(TwistedFft::new(0x1 << i)) };
}
});
unsafe { &FFT_CACHE[log_n] }
}
/// Perform FFT with a twist so points can be used for
/// negacyclic convolution.
///
/// # Remarks
/// See `<https://jeremykun.com/2022/12/09/negacyclic-polynomial-multiplication/>` for algorithm.
pub struct TwistedFft<T>
where
T: FftNum,
{
fwd: Arc<dyn Fft<T>>,
rev: Arc<dyn Fft<T>>,
twist: Vec<Complex<T>>,
twist_inv: Vec<Complex<T>>,
}
impl<T> TwistedFft<T>
where
T: FftNum + Float,
{
/// Create a new [TwistedFft] with the given size.
pub fn new(n: usize) -> Self {
assert!(n.is_power_of_two());
let n_2 = T::from(n * 2).unwrap();
let k = n / 2;
// The true length of the negacyclic sequence is 2N
let mut plan = FftPlanner::new();
let fwd = plan.plan_fft_forward(k);
let rev = plan.plan_fft_inverse(k);
let two_pi = T::from(PI).unwrap() * T::from(2.0).unwrap();
let w_2n = (Complex::from(two_pi / n_2) * Complex::i()).exp();
let twist = (0..k)
.map(|x| w_2n.powf(T::from(x).unwrap()))
.collect::<Vec<_>>();
let twist_inv = twist
.iter()
.copied()
.map(|t| t.powf(-T::one()))
.collect::<Vec<_>>();
debug_assert!(twist.iter().zip(twist_inv.iter()).all(|(a, b)| {
let a_a_inv = a * b;
let err = a_a_inv - Complex::one();
err.re.abs() < T::from(1e-12).unwrap() && err.im.abs() < T::from(1e-12).unwrap()
}));
Self {
fwd,
rev,
twist,
twist_inv,
}
}
}
impl<T> FrequencyTransform for TwistedFft<T>
where
T: FftNum + Float,
{
type BaseRepr = T;
type FrequencyRepr = Complex<T>;
fn forward(&self, x: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]) {
assert_eq!(x.len(), self.fwd.len() * 2);
let n_div_2 = x.len() / 2;
for i in 0..n_div_2 {
output[i] = Complex::new(x[i], x[i + n_div_2]) * self.twist[i];
}
let mut scratch = allocate_scratch(self.fwd.get_inplace_scratch_len());
let scratch_slice = scratch.as_mut_slice();
self.fwd.process_with_scratch(output, scratch_slice);
}
fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]) {
assert_eq!(data.len(), self.rev.len());
let mut ifft = allocate_scratch(data.len());
let ifft_slice = ifft.as_mut_slice();
ifft_slice.copy_from_slice(data);
let mut scratch = allocate_scratch(self.rev.get_inplace_scratch_len());
let scratch_slice = scratch.as_mut_slice();
self.rev.process_with_scratch(ifft_slice, scratch_slice);
let n_inv = T::one() / T::from(data.len()).unwrap();
for (i, x) in ifft_slice.iter().enumerate() {
let tmp = *x * n_inv * self.twist_inv[i];
output[i] = tmp.re.round();
output[i + data.len()] = tmp.im.round();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_roundtrip_negacyclic_fft() {
let n = 8;
let plan = TwistedFft::<f64>::new(n);
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut y = vec![Complex::from(0.0); x.len() / 2];
let mut actual = vec![0.0; x.len()];
plan.forward(&x, &mut y);
plan.reverse(&y, &mut actual);
for (l, r) in actual.iter().zip(x.iter()) {
assert!((l - r).abs() < 1e-12);
}
}
#[test]
fn can_negacyclic_conv() {
let n = 4;
let plan = TwistedFft::<f64>::new(n);
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
let mut y = vec![Complex::from(0.0); x.len() / 2];
let mut actual = vec![0.0; x.len()];
plan.forward(&x, &mut y);
let z = y.iter().map(|y| y * y).collect::<Vec<_>>();
plan.reverse(&z, &mut actual);
assert_eq!(actual, vec![-10.0, -12.0, -8.0, 4.0]);
}
}

View File

@@ -0,0 +1,359 @@
use std::ops::{Add, Mul, Neg, Sub};
use num::traits::{WrappingAdd, WrappingMul, WrappingNeg, WrappingSub};
use sunscreen_math::{refify_binary_op, One, Zero};
use crate::{Inverse, Pow, RootOfUnity};
/// 2^64 - 2^32 + 1
pub const GOLDILOCKS_PRIME: u64 = 0xFFFFFFFF00000001;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[repr(transparent)]
/// A value in the Goldilocks field (F_p where p = 2^64 - 2^32 + 1).
/// See
/// https://cp4space.hatsya.com/2021/09/01/an-efficient-prime-for-number-theoretic-transforms/
/// for why this field is so magical.
pub struct Fg(u64);
impl RootOfUnity for Fg {
fn nth_root_of_unity(n: u64) -> Self {
assert!(
n.is_power_of_two(),
"Goldilocks prime requires power of 2 < 2^32 for n when computing nth root of unity."
);
// See https://nufhe.readthedocs.io/en/latest/implementation_details.html for where this constant comes
// from.
const C: Fg = Fg(12037493425763644479);
let exp = 0x1_0000_0000u64 / n;
C.pow(exp)
}
}
impl Fg {
/// Returns `x % GOLDILOCKS_PRIME`
pub fn new(x: u64) -> Self {
if x > GOLDILOCKS_PRIME {
Self(x - GOLDILOCKS_PRIME)
} else {
Self(x)
}
}
#[inline]
pub fn unreduced_add(self, rhs: Self) -> Fg96 {
let (c, carry) = self.0.overflowing_add(rhs.0);
Fg96 {
lo: c,
hi: carry.into(),
}
}
#[inline]
pub fn unreduced_sub(self, rhs: Self) -> Fg96 {
self.unreduced_add(Fg(GOLDILOCKS_PRIME - rhs.0))
}
#[inline]
pub fn unreduced_mul(self, rhs: Self) -> Fg159 {
let res = self.0 as u128 * rhs.0 as u128;
res.into()
}
#[inline]
/// Compute `self * b + c` and don't reduce the result.
pub fn unreduced_mad(self, b: Self, c: Self) -> Fg159 {
let res = self.0 as u128 * b.0 as u128 + c.0 as u128;
res.into()
}
#[inline]
/// Compute `self * b + c`.
pub fn mad(self, b: Self, c: Self) -> Self {
self.unreduced_mad(b, c).reduce()
}
}
impl From<u64> for Fg {
fn from(value: u64) -> Self {
Self::new(value)
}
}
#[refify_binary_op]
impl Add<&Fg> for &Fg {
type Output = Fg;
#[inline]
fn add(self, rhs: &Fg) -> Self::Output {
let (res, c) = self.0.overflowing_add(rhs.0);
// If overflow occurred, then we have 1 carry bit. This gives us
// s = 1 * 2^64 + res = (2^32 - 1) + res.
//
// If overflow didn't occur, but res >= g, then we need to compute
// s - g = s + (2^32 + 1). Note that -g (mod 2^64) = 2^32 - 1 (mod 2^64).
// Thus, we can add (2^32 + 1) in both cases!
if c || res >= GOLDILOCKS_PRIME {
Fg(res.wrapping_add(0xFFFFFFFFu64))
} else {
Fg(res)
}
}
}
#[refify_binary_op]
impl Mul<&Fg> for &Fg {
type Output = Fg;
#[inline]
fn mul(self, rhs: &Fg) -> Self::Output {
self.unreduced_mul(*rhs).reduce()
}
}
#[refify_binary_op]
impl Sub<&Fg> for &Fg {
type Output = Fg;
#[inline]
fn sub(self, rhs: &Fg) -> Self::Output {
let (res, c) = self.0.overflowing_sub(rhs.0);
let offset = 0u32.wrapping_sub(u32::from(c)) as u64;
// If we underflow, then we need to add g.
// Note that g (mod 2^64) = -(2^32 + 1) (mod 32)
// Thus, this is equivalent to subtracting -(2^32 + 1).
Fg(res.wrapping_sub(offset))
}
}
impl Neg for Fg {
type Output = Fg;
fn neg(self) -> Self::Output {
Self(GOLDILOCKS_PRIME - self.0)
}
}
impl WrappingAdd for Fg {
#[inline]
fn wrapping_add(&self, v: &Self) -> Self {
self + v
}
}
impl WrappingMul for Fg {
#[inline]
fn wrapping_mul(&self, v: &Self) -> Self {
self * v
}
}
impl WrappingSub for Fg {
#[inline]
fn wrapping_sub(&self, v: &Self) -> Self {
self - v
}
}
impl WrappingNeg for Fg {
#[inline]
fn wrapping_neg(&self) -> Self {
-*self
}
}
impl Zero for Fg {
#[inline]
fn vartime_is_zero(&self) -> bool {
self.0 == 0
}
#[inline]
fn zero() -> Self {
Fg(0)
}
}
impl One for Fg {
#[inline]
fn one() -> Self {
Fg(1)
}
}
impl Inverse for Fg {
fn inverse(&self) -> Self {
<Fg as Pow<u64>>::pow(self, GOLDILOCKS_PRIME - 2)
}
}
#[derive(Copy, Clone, Debug)]
/// A number of the form hi << 64 + C, where B is 32-bit and C = 64-bit.
pub struct Fg96 {
lo: u64,
hi: u64,
}
impl Fg96 {
#[inline]
/// Reduce the number mod [`GOLDILOCKS_PRIME`].
///
/// /// # Remarks
/// Note that 2^64 = 2^32 - 1 (mod g)
///
/// Thus, we can rewrite as (2^32 - 1) * B + C.
pub fn reduce(self) -> Fg {
let prod = (self.hi << 32) - self.hi;
let mut res = prod.wrapping_add(self.lo);
if res < prod || res >= GOLDILOCKS_PRIME {
res = res.wrapping_sub(GOLDILOCKS_PRIME);
}
Fg(res)
}
}
/// An unreduced value of 159 or fewer bits, where lo is 64-bit, mid is 32-bit and hi is 63-bit.
/// We can represent this value as `lo + mid << 64 + hi << 96`.
pub struct Fg159 {
lo: u64,
mid: u64,
hi: u64,
}
impl Fg159 {
#[inline]
/// Reduce the number mod [`GOLDILOCKS_PRIME`].
///
/// /// # Remarks
/// Note that
/// * 2^64 = 2^32 - 1 (mod g)
/// * 2^96 = -1
///
/// Thus, we can rewrite as (2^32 - 1) * B + (C - A).
pub fn reduce(self) -> Fg {
let mut lo2 = self.lo.wrapping_sub(self.hi);
if self.hi > self.lo {
lo2 = lo2.wrapping_add(GOLDILOCKS_PRIME);
}
let prod = (self.mid << 32) - self.mid;
let mut res = lo2.wrapping_add(prod);
if res < prod || res >= GOLDILOCKS_PRIME {
res = res.wrapping_sub(GOLDILOCKS_PRIME);
}
Fg(res)
}
}
impl From<u128> for Fg159 {
#[inline(always)]
fn from(res: u128) -> Self {
Fg159 {
lo: res as u64,
mid: (res >> 64) as u32 as u64,
hi: (res >> 96) as u64,
}
}
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use super::*;
#[test]
fn can_add_fg() {
for _ in 0..1000 {
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let c = a + b;
let expected = ((a.0 as u128 + b.0 as u128) % GOLDILOCKS_PRIME as u128) as u64;
assert_eq!(c, Fg(expected));
}
}
#[test]
fn can_sub_fg() {
for _ in 0..1000 {
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let c = a - b;
let expected = ((a.0 as u128 + (GOLDILOCKS_PRIME - b.0) as u128)
% GOLDILOCKS_PRIME as u128) as u64;
assert_eq!(c, Fg(expected));
}
}
#[test]
fn can_neg_fg() {
for _ in 0..1000 {
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let c = -a;
let expected = Fg(0) - a;
assert_eq!(c, expected);
}
}
#[test]
fn can_mul_fg() {
fn test_case(a: u64, b: u64) {
let c = Fg(a) * Fg(b);
let expected = ((a as u128 * b as u128) % GOLDILOCKS_PRIME as u128) as u64;
assert_eq!(c, Fg(expected));
}
test_case(GOLDILOCKS_PRIME - 1, GOLDILOCKS_PRIME - 1);
for _ in 0..1000 {
test_case(
thread_rng().next_u64() % GOLDILOCKS_PRIME,
thread_rng().next_u64() % GOLDILOCKS_PRIME,
);
}
}
#[test]
fn can_mad_fg() {
for _ in 0..1000 {
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let c = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
let acutal = a.mad(b, c);
let expected =
((a.0 as u128 * b.0 as u128 + c.0 as u128) % GOLDILOCKS_PRIME as u128) as u64;
assert_eq!(acutal, Fg(expected));
}
}
#[test]
fn nth_root_of_unity() {
for i in 1..16u64 {
let root = Fg::nth_root_of_unity(0x1 << i);
assert_eq!(root.pow(0x1u64 << i), Fg::one());
}
}
}

View File

@@ -0,0 +1,165 @@
use std::ops::{Add, BitAnd, Mul, Shr};
pub use sunscreen_math::{One, Zero};
/// FFT based operations.
pub mod fft;
mod goldilocks_field;
/// Math operations on polynomials.
pub mod polynomial;
/// Operations for performing radix decompositions.
pub mod radix;
mod torus;
pub use torus::*;
mod basic;
pub use basic::*;
/// Types where the roots of unity in the given field can be found.
pub trait RootOfUnity
where
Self: Sized,
{
/// Find the n-th root of unity for the given field.
///
/// # Remarks
/// `w` is an n-th root of unity if `w^n = 1`.
///
/// # Panics
/// Implementers may choose to panic if no n-th root of unity
/// exists. You should take care in choosing your field modulus
/// to ensure the root does exist.
fn nth_root_of_unity(n: u64) -> Self;
}
/// Numbers that have an inverse element.
pub trait Inverse
where
Self: Sized,
{
/// Find the inverse of the given number.
fn inverse(&self) -> Self;
}
/// Numbers that can be raised to a power.
pub trait Pow<T>
where
Self: Sized,
T: Shr<u32, Output = T> + BitAnd<T, Output = T> + One + Eq,
{
/// Raise the number to the given power.
fn pow(&self, exp: T) -> Self;
}
impl<T, U> Pow<T> for U
where
T: Shr<u32, Output = T> + BitAnd<T, Output = T> + One + Eq + Copy,
U: sunscreen_math::One + Copy + Add<U, Output = U> + Mul<U, Output = U>,
{
fn pow(&self, exp: T) -> U {
let mut result = U::one();
let mut power = *self;
for i in 0..64 {
if (exp >> i) & T::one() == T::one() {
result = result * power
}
power = power * power;
}
result
}
}
/// Reinterpret the bits of an unsigned value as signed.
///
/// # Remarks
/// This is different than a cast. For example, UINT_MAX becomes -1
/// when interpreted as a 2's complement signed value.
pub trait ReinterpretAsSigned {
/// The output type of the reinterpretation.
type Output: ToF64;
/// Reinterpret the bits of an unsigned value as signed.
fn reinterpret_as_signed(self) -> Self::Output;
}
macro_rules! impl_reinterpret_signed {
($ut:ty, $st:ty) => {
impl ReinterpretAsSigned for $ut {
type Output = $st;
#[inline(always)]
fn reinterpret_as_signed(self) -> Self::Output {
unsafe { std::mem::transmute(self) }
}
}
};
}
impl_reinterpret_signed!(u8, i8);
impl_reinterpret_signed!(u16, i16);
impl_reinterpret_signed!(u32, i32);
impl_reinterpret_signed!(u64, i64);
impl_reinterpret_signed!(u128, i128);
/// Reinterpret the bits of a signed value as unsigned.
///
/// # Remarks
/// This is different than a cast. For example, -1 becomes UINT_MAX
/// when interpreted as an unsigned integer.
pub trait ReinterpretAsUnsigned {
/// The output type of the reinterpretation.
type Output;
/// Reinterpret the bits of a signed value as unsigned.
fn reinterpret_as_unsigned(self) -> Self::Output;
}
macro_rules! impl_reinterpret_unsigned {
($st:ty, $ut:ty) => {
impl ReinterpretAsUnsigned for $st {
type Output = $ut;
#[inline(always)]
fn reinterpret_as_unsigned(self) -> Self::Output {
unsafe { std::mem::transmute(self) }
}
}
};
}
impl_reinterpret_unsigned!(i8, u8);
impl_reinterpret_unsigned!(i16, u16);
impl_reinterpret_unsigned!(i32, u32);
impl_reinterpret_unsigned!(i64, u64);
impl_reinterpret_unsigned!(i128, u128);
/// A type that can be converted from and to the fourier domain.
pub trait FrequencyTransform {
/// Original domain representation.
type BaseRepr;
/// Fourier domain representation.
type FrequencyRepr;
/// Perform a fourier transform.
fn forward(&self, data: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]);
/// Perform an inverse fourier transform.
fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]);
}
/// A trait that allows types to specify how many leading zeros or ones are in
/// the value.
pub trait LeadingBits {
/// Count the number of leading zeros in the value.
fn leading_zeros(self) -> u32;
/// Count the number of leading ones in the value.
fn leading_ones(self) -> u32;
}

View File

@@ -0,0 +1,353 @@
use std::{
num::Wrapping,
ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign},
};
use num::traits::MulAdd;
use crate::{
dst::FromMutSlice, entities::PolynomialRef, scratch::allocate_scratch, ToF64, Torus, TorusOps,
};
/// Polynomial subtraction in place. This is equivalent to `a -= b` for each
/// coefficient in the polynomial.
pub fn polynomial_sub_assign<S>(lhs: &mut PolynomialRef<S>, rhs: &PolynomialRef<S>)
where
S: SubAssign + Copy,
{
for (a, b) in lhs
.coeffs_mut()
.iter_mut()
.zip(rhs.coeffs().iter().copied())
{
*a -= b;
}
}
/// Negate a polynomial in place. This is equivalent to `a = -a` for each
/// coefficient in the polynomial.
pub fn polynomial_negate<S>(c: &mut PolynomialRef<S>)
where
S: Clone + Copy + Neg<Output = S>,
{
for c in c.coeffs_mut().iter_mut() {
*c = -*c;
}
}
/// Compute `c = a * s` where `s` is scalar.
pub fn polynomial_scalar_mul<S, T, U>(c: &mut PolynomialRef<S>, a: &PolynomialRef<T>, s: U)
where
S: Clone,
T: Clone + Copy + Mul<U, Output = S>,
U: Clone + Copy,
{
for (c, a) in c.coeffs_mut().iter_mut().zip(a.coeffs().iter()) {
*c = *a * s
}
}
/// Compute `c += a * s`, where `s` is scalar.
pub fn polynomial_scalar_mad<S, T, U>(c: &mut PolynomialRef<S>, a: &PolynomialRef<T>, s: U)
where
S: Clone + Copy + Add<S, Output = S>,
T: Clone + Copy + MulAdd<U, S, Output = S>,
U: Clone + Copy,
{
for (c, a) in c.coeffs_mut().iter_mut().zip(a.coeffs().iter()) {
*c = a.mul_add(s, *c);
}
}
/// Compute `c = a + b` where a, b, and c are polynomials.
pub fn polynomial_add<S>(c: &mut PolynomialRef<S>, a: &PolynomialRef<S>, b: &PolynomialRef<S>)
where
S: Clone + Copy + Add<S, Output = S>,
{
assert_eq!(c.len(), a.len());
assert_eq!(c.len(), b.len());
for (c, (a, b)) in c
.as_mut_slice()
.iter_mut()
.zip(a.as_slice().iter().zip(b.as_slice().iter()))
{
*c = *a + *b;
}
}
/// Polynomial addition in place. This is equivalent to `a += b` for each
/// coefficient in the polynomial.
pub fn polynomial_add_assign<S>(lhs: &mut PolynomialRef<S>, rhs: &PolynomialRef<S>)
where
S: AddAssign + Copy,
{
for (a, b) in lhs
.coeffs_mut()
.iter_mut()
.zip(rhs.coeffs().iter().copied())
{
*a += b;
}
}
/// Compute `c = a - b` where a, b, and c are polynomials.
pub fn polynomial_sub<S>(c: &mut PolynomialRef<S>, a: &PolynomialRef<S>, b: &PolynomialRef<S>)
where
S: Clone + Copy + Sub<S, Output = S>,
{
assert_eq!(c.len(), a.len());
assert_eq!(c.len(), b.len());
for (c, (a, b)) in c
.as_mut_slice()
.iter_mut()
.zip(a.as_slice().iter().zip(b.as_slice().iter()))
{
*c = *a - *b;
}
}
/// Compute `c += a \[*\] b` where `a` in Z\[X\]/f and `c, b` in T\[X\]/f, and
/// \[*\] is the external product between these rings.
pub fn polynomial_external_mad<S>(
c: &mut PolynomialRef<Torus<S>>,
a: &PolynomialRef<Torus<S>>,
b: &PolynomialRef<S>,
) where
S: TorusOps,
{
polynomial_mad_impl(c, a, b);
}
/// Compute `c += a * b` where `*` the multiplication of the two polynomials of
/// degree (N - 1) modulo (X^N + 1). This is done with the naive algorithm, and
/// hence has O(N^2) time.
pub fn polynomial_mad<S>(
c: &mut PolynomialRef<Wrapping<S>>,
a: &PolynomialRef<Wrapping<S>>,
b: &PolynomialRef<Wrapping<S>>,
) where
S: TorusOps,
Wrapping<S>: Sub<Wrapping<S>, Output = Wrapping<S>>
+ Add<Wrapping<S>, Output = Wrapping<S>>
+ Mul<Wrapping<S>, Output = Wrapping<S>>,
{
polynomial_mad_impl(c, a, b);
}
/// Compute `c += a * b` where `*` the multiplication of the two polynomials of
/// degree (N - 1) modulo (X^N + 1). This is done with the naive algorithm, and
/// hence has O(N^2) time.
fn polynomial_mad_impl<S, T, U>(
c: &mut PolynomialRef<S>,
a: &PolynomialRef<T>,
b: &PolynomialRef<U>,
) where
U: Clone + Copy + ToF64,
T: Mul<U, Output = S> + Clone + Copy + ToF64,
S: Sub<S, Output = S> + Add<S, Output = S> + Clone + Copy,
{
assert!(a.len().is_power_of_two());
assert_eq!(a.len(), b.len());
assert_eq!(a.len(), c.len());
let len = a.len();
// Polynomial's length is a power of 2, so use mask to perform modulus.
let mask = len - 1;
let coeffs = c.coeffs_mut();
let mut poly_f64 = allocate_scratch::<f64>(a.len());
let poly_f64_ref = PolynomialRef::from_mut_slice(poly_f64.as_mut_slice());
a.map_into(poly_f64_ref, |x| x.to_f64());
for (i, l) in a.coeffs().iter().copied().enumerate() {
for (j, r) in b.coeffs().iter().copied().enumerate() {
if i + j >= len {
// Reduce mod len and subtract due to negacyclic
// polynomials.
let index = (i + j) & mask;
coeffs[index] = coeffs[index] - l * r;
} else {
let index = i + j;
coeffs[index] = coeffs[index] + l * r;
}
}
}
}
#[cfg(test)]
mod tests {
#[derive(BarrettConfig)]
#[barrett_config(modulus = "18446744073709551616", num_limbs = 2)]
pub struct U64Config;
pub type Zq64 = Zq<2, BarrettBackend<2, U64Config>>;
use crate::{
entities::{Polynomial, PolynomialFft},
normalized_torus_distance, ReinterpretAsSigned, ReinterpretAsUnsigned,
};
use super::*;
use num::{Complex, Zero};
use rand::{thread_rng, RngCore};
use sunscreen_math::{
poly::Polynomial as BasePolynomial,
ring::{BarrettBackend, Zq},
BarrettConfig, One,
};
#[test]
fn can_multiply_polynomials() {
// Compare our negacyclic polynomial implementation against
// sunscreen_math computing (a * b) mod (X^N + 1).
fn case(a: &PolynomialRef<Torus<u64>>, b: &PolynomialRef<u64>) {
let actual = a * b;
let mut f = vec![<Zq64 as sunscreen_math::Zero>::zero(); a.len() + 1];
let len = a.len();
f[0] = Zq64::one();
f[a.len()] = Zq64::one();
let f = BasePolynomial { coeffs: f };
let a = BasePolynomial {
coeffs: a
.coeffs()
.iter()
.map(|x| Zq64::from(x.inner()))
.collect::<Vec<_>>(),
};
let b = BasePolynomial {
coeffs: b.coeffs().iter().map(|x| Zq64::from(*x)).collect(),
};
let expected = a * b;
let (_, expected) = expected.vartime_div_rem_restricted_rhs(&f);
let expected = expected
.coeffs
.iter()
.take(len)
.map(|x| Torus::from(x.val.as_words()[0]))
.collect::<Vec<_>>();
let expected = Polynomial::new(&expected);
assert_eq!(actual, expected);
}
for _ in 0..50 {
let len = thread_rng().next_u64() % 8 + 1;
let len = 0x1 << len;
let a = (0..len)
.map(|_| Torus::from(thread_rng().next_u64()))
.collect::<Vec<_>>();
let a = Polynomial::new(&a);
let b = (0..len)
.map(|_| thread_rng().next_u64())
.collect::<Vec<_>>();
let b = Polynomial::new(&b);
case(&a, &b);
}
}
#[test]
fn can_roundtrip_polynomial() {
let poly = (0..1024u64).collect::<Vec<_>>();
let poly = Polynomial::new(&poly);
let mut actual = Polynomial::<u64>::zero(poly.len());
let mut out = PolynomialFft::new(&vec![Complex::zero(); poly.len() / 2]);
poly.fft(&mut out);
out.ifft(&mut actual);
assert_eq!(poly, actual);
}
#[test]
fn can_multiply_polynomials_fft() {
for _ in 0..100 {
let a = (0..1024u64)
.map(|x| Torus::from(x % 0x8000_0000))
.collect::<Vec<_>>();
let a = Polynomial::new(&a);
let b = (0..1024u64).map(|x| x % 16).collect::<Vec<_>>();
let b = Polynomial::new(&b);
let mut expected = Polynomial::<Torus<u64>>::zero(a.len());
polynomial_external_mad(&mut expected, &a, &b);
let mut a_fft = PolynomialFft::new(&vec![Complex::zero(); a.len() / 2]);
let mut b_fft = a_fft.clone();
let mut c_fft = a_fft.clone();
a.fft(&mut a_fft);
b.fft(&mut b_fft);
c_fft.multiply_add(&a_fft, &b_fft);
let mut actual = Polynomial::<Torus<u64>>::zero(a.len());
c_fft.ifft(&mut actual);
assert_eq!(actual, expected);
}
}
#[test]
fn can_approx_multiply_large_polynomials_fft() {
let n = 1024;
for _ in 0..100 {
// a is uniform torus elements, b is "small" as we'll encounter
// during radix decomposition.
let a = (0..n)
.map(|_| Torus::from(rand::thread_rng().next_u64()))
.collect::<Vec<_>>();
let a = Polynomial::new(&a);
let b = (0..n)
.map(|_| {
let signed = (rand::thread_rng().next_u64() % 32).reinterpret_as_signed() - 16;
signed.reinterpret_as_unsigned()
})
.collect::<Vec<_>>();
let b = Polynomial::new(&b);
let mut expected = Polynomial::<Torus<u64>>::zero(a.len());
polynomial_external_mad(&mut expected, &a, &b);
let mut a_fft = PolynomialFft::new(&vec![Complex::zero(); a.len() / 2]);
let mut b_fft = a_fft.clone();
let mut c_fft = a_fft.clone();
a.fft(&mut a_fft);
b.fft(&mut b_fft);
c_fft.multiply_add(&a_fft, &b_fft);
let mut actual = Polynomial::<Torus<u64>>::zero(a.len());
c_fft.ifft(&mut actual);
for (a, e) in actual.coeffs().iter().zip(expected.coeffs().iter()) {
let err = normalized_torus_distance(a, e).abs();
assert!(err < 1e-12);
}
}
}
}

View File

@@ -0,0 +1,399 @@
use crate::{
entities::{PolynomialIterator, PolynomialRef},
math::{Torus, TorusOps},
polynomial::polynomial_scalar_mad,
RadixCount, RadixDecomposition, RadixLog,
};
// Needed by allow_scratch_ref
#[allow(unused_imports)]
use crate::dst::FromMutSlice;
/// An iterator from least to most significant radix decomposition of a value.
pub struct ScalarRadixIterator<S>
where
S: TorusOps,
{
cur: S,
level: usize,
radix: RadixDecomposition,
}
impl<S: TorusOps> ScalarRadixIterator<S> {
/// Creates a new [`ScalarRadixIterator`] for the given [Torus] value.
#[inline(always)]
pub fn new(val: Torus<S>, radix: &RadixDecomposition) -> Self {
Self {
cur: round(val, radix),
level: 0,
radix: *radix,
}
}
}
#[inline(always)]
fn get_next_digit<S: TorusOps>(cur: &mut S, radix_log: usize) -> S {
let mask = S::from_u64((0x1u64 << radix_log) - 1);
// Interpreting the digits over [-B/2,B/2) reduces noise by half a bit on average.
let mut digit = *cur & mask;
*cur = *cur >> radix_log;
let carry = digit >> (radix_log - 1);
*cur = *cur + carry;
digit = digit.wrapping_sub(&(carry << radix_log));
digit
}
impl<S: TorusOps> Iterator for ScalarRadixIterator<S> {
type Item = S;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if self.level == self.radix.count.0 {
return None;
}
let digit = get_next_digit(&mut self.cur, self.radix.radix_log.0);
self.level += 1;
Some(digit)
}
}
/// An iterator from least to most significant radix decomposition of the coefficients
/// of a polynomial.
pub struct PolynomialRadixIterator<'a, S>
where
S: TorusOps,
{
scratch: &'a mut PolynomialRef<S>,
level: usize,
radix: RadixDecomposition,
}
impl<'a, S> PolynomialRadixIterator<'a, S>
where
S: TorusOps,
{
/// Creates a new [`PolynomialRadixIterator`] for the given polynomial.
pub fn new(
poly: &PolynomialRef<Torus<S>>,
scratch: &'a mut PolynomialRef<S>,
radix: &RadixDecomposition,
) -> Self {
assert!(radix.radix_log.0 * radix.count.0 < S::BITS as usize);
assert_ne!(radix.radix_log.0 * radix.count.0, 0);
poly.map_into(scratch, |x| round(*x, radix));
Self {
scratch,
level: 0,
radix: radix.to_owned(),
}
}
/// Writes the next polynomial decomposition to `dst` and returns `Some(())` if there is a next digit.
pub fn write_next(&mut self, dst: &mut PolynomialRef<S>) -> Option<()> {
if self.level == self.radix.count.0 {
return None;
}
self.level += 1;
for (s, r) in self
.scratch
.coeffs_mut()
.iter_mut()
.zip(dst.coeffs_mut().iter_mut())
{
*r = get_next_digit(s, self.radix.radix_log.0);
}
Some(())
}
}
/// Recomposes a polynomial from its `digits` decomposition and adds it to `dst`.
///
/// # Remarks
/// The digits should iterate from least to most significant.
pub fn recompose_and_add<S>(
dst: &mut PolynomialRef<Torus<S>>,
digits: &mut PolynomialIterator<S>,
radix: RadixLog,
count: RadixCount,
) where
S: TorusOps,
{
let shift_amount = S::BITS as usize - radix.0 * count.0;
let mut cur_radix = S::from_u64(0x1 << shift_amount);
let mut actual_count = 0;
for d in digits {
polynomial_scalar_mad(dst, d.as_torus(), cur_radix);
actual_count += 1;
cur_radix = cur_radix << radix.0;
}
assert_eq!(count.0, actual_count);
}
#[inline(always)]
/// Multiply `val` by q / B^(j + 1).
pub fn scale_by_decomposition_factor<S: TorusOps>(
val: S,
j: usize,
radix: &RadixDecomposition,
) -> S {
let shift = S::BITS as usize - radix.radix_log.0 * (j + 1);
let factor = S::one() << shift;
val.wrapping_mul(&factor)
}
#[inline(always)]
/// Rounds the given [`Torus`] element and returns the value interpreted as an integer.
fn round<S: TorusOps>(x: Torus<S>, radix: &RadixDecomposition) -> S {
let shift = S::BITS as usize - radix.radix_log.0 * radix.count.0;
let round_bit = (x.inner() >> (shift - 1)) & S::from_u64(0x1);
(x.inner() >> shift).wrapping_add(&round_bit)
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{
entities::{Polynomial, PolynomialList},
rand::uniform_torus,
scratch::allocate_scratch_ref,
PlaintextBits, PolynomialDegree,
};
use super::*;
#[test]
fn can_round_values() {
assert_eq!(
round(
Torus::from(0x12348FFF_FFFFFFFFu64),
&RadixDecomposition {
radix_log: RadixLog(4),
count: RadixCount(4)
}
),
0x1235
);
assert_eq!(
round(
Torus::from(0x12347FFF_FFFFFFFFu64),
&RadixDecomposition {
radix_log: RadixLog(4),
count: RadixCount(4)
}
),
0x1234
);
}
#[test]
fn can_decompose() {
let x = Polynomial::new(&[Torus::encode(7u64, PlaintextBits(4))]);
allocate_scratch_ref!(scratch, PolynomialRef<u64>, (PolynomialDegree(x.len())));
let mut radix_iter = PolynomialRadixIterator::new(
&x,
scratch,
&RadixDecomposition {
radix_log: RadixLog(2),
count: RadixCount(2),
},
);
let mut dst = Polynomial::zero(1);
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(1));
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(2));
assert_eq!(radix_iter.write_next(&mut dst), None);
let x = Polynomial::new(&[Torus::encode(1u64, PlaintextBits(1))]);
let radix_log = 4;
let mut radix_iter = PolynomialRadixIterator::new(
&x,
scratch,
&RadixDecomposition {
radix_log: RadixLog(radix_log),
count: RadixCount(3),
},
);
let mut dst = Polynomial::zero(1);
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
assert_eq!(dst.coeffs()[0], 0u64);
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
assert_eq!(dst.coeffs()[0], 0u64);
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
// 2^(beta - 1) - 2^beta
assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(1 << (radix_log - 1)));
assert_eq!(radix_iter.write_next(&mut dst), None);
}
#[test]
fn can_decompose_polynomial() {
let x = Polynomial::new(&[
// Decomposes to 1 [for 2^1] + 0 [for 2^2], or 1 + 0
Torus::encode(1u64, PlaintextBits(4)),
// Decomposes to -2 [for 2^1] + 1 [for 2^2], or -2 + 4
Torus::encode(2u64, PlaintextBits(4)),
// Decomposes to -1 [for 2^1] + 1 [for 2^2], or -1 + 4
Torus::encode(3u64, PlaintextBits(4)),
// Decomposes to 0 [for 2^1] + 1 [for 2^2], or 0 + 4
Torus::encode(4u64, PlaintextBits(4)),
]);
allocate_scratch_ref!(scratch, PolynomialRef<u64>, (PolynomialDegree(x.len())));
let mut radix_iter = PolynomialRadixIterator::new(
&x,
scratch,
&RadixDecomposition {
radix_log: RadixLog(2),
count: RadixCount(2),
},
);
let mut dst = Polynomial::zero(4);
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
assert_eq!(
dst.coeffs(),
[1u64, 0u64.wrapping_sub(2), 0u64.wrapping_sub(1), 0]
);
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
assert_eq!(dst.coeffs(), [0, 1, 1, 1]);
assert_eq!(radix_iter.write_next(&mut dst), None);
}
#[test]
fn can_decompose_recompose() {
let d = PolynomialDegree(8);
for _ in 0..50 {
let radix = (thread_rng().next_u32() % 7) + 1;
let radix = RadixLog(radix as usize);
let count = loop {
let count = (thread_rng().next_u32() % 7) + 1;
if (radix.0 * count as usize) < u64::BITS as usize {
break count;
}
};
let count = RadixCount(count as usize);
let x = (0..d.0).map(|_| uniform_torus::<u64>()).collect::<Vec<_>>();
let x = Polynomial::new(&x);
let expected = x.map(|c| {
let radix_bits = radix.0 * count.0;
let lsb = u64::BITS as usize - radix_bits;
let round_loc = lsb - 1;
let round = (c.inner() >> round_loc) & 0x1;
let mask = 0xFFFFFFFF_FFFFFFFFu64 << lsb;
Torus::from((c.inner() & mask).wrapping_add(round << lsb))
});
let mut digits = PolynomialList::new(d, count.0);
allocate_scratch_ref!(scratch, PolynomialRef<u64>, (PolynomialDegree(x.len())));
let mut radix_iter = PolynomialRadixIterator::new(
&x,
scratch,
&RadixDecomposition {
radix_log: radix,
count,
},
);
for d in digits.iter_mut(d) {
radix_iter.write_next(d);
}
let mut result = Polynomial::zero(x.len());
recompose_and_add(&mut result, &mut digits.iter(d), radix, count);
assert_eq!(expected, result);
}
}
fn random_radix() -> RadixDecomposition {
let radix = (thread_rng().next_u32() % 7) + 1;
let radix = RadixLog(radix as usize);
let count = loop {
let count = (thread_rng().next_u32() % 7) + 1;
if (radix.0 * count as usize) < u64::BITS as usize {
break count;
}
};
let count = RadixCount(count as usize);
RadixDecomposition {
radix_log: radix,
count,
}
}
#[test]
fn can_decompose_scalar() {
for _ in 0..50 {
let radix = random_radix();
let val = Torus::from(thread_rng().next_u64());
/*let radix = RadixDecomposition {
count: RadixCount(3),
radix_log: RadixLog(4),
};
let val = Torus::from(0xDEADBEEF_FEEDF00Du64);
*/
let decomp = ScalarRadixIterator::new(val, &radix);
let mut digits = vec![];
for digit in decomp {
digits.push(digit);
}
let actual = digits.iter().enumerate().fold(0u64, |s, (i, d)| {
let shift_amount = u64::BITS as usize - radix.radix_log.0 * (radix.count.0 - i);
let cur_radix = 0x1u64 << shift_amount;
(cur_radix.wrapping_mul(*d)).wrapping_add(s)
});
// Round shifts our value down to the LSB places, so move it back up to compare
// against a torus element.
let expected =
round(val, &radix) << (u64::BITS as usize - radix.radix_log.0 * (radix.count.0));
assert_eq!(actual, expected);
}
}
}

View File

@@ -0,0 +1,619 @@
use bytemuck::{Pod as BytemuckPod, Zeroable};
use num::traits::{
Bounded, MulAdd, Num, WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr,
WrappingSub,
};
use serde::{Deserialize, Serialize};
use std::{
fmt::{Binary, Debug, LowerHex, UpperHex},
num::Wrapping,
ops::{Add, AddAssign, BitAnd, Deref, Mul, Neg, Shl, Shr, Sub, SubAssign},
};
use sunscreen_math::{refify_binary_op, Zero};
use crate::{
math::{ReinterpretAsSigned, ReinterpretAsUnsigned},
scratch::Pod,
PlaintextBits,
};
/// Number of bits used in the representation of a type.
pub trait NumBits {
/// The number of bits used in the representation of a type.
const BITS: u32;
}
impl NumBits for u32 {
const BITS: u32 = u32::BITS;
}
impl NumBits for u64 {
const BITS: u32 = u64::BITS;
}
/// Convert a type into a 64-bit floating point number.
pub trait ToF64 {
/// Approximately convert the value to an f64.
fn to_f64(self) -> f64;
}
/// Convert a 64-bit floating point number into a type.
pub trait FromF64
where
Self: Sized,
{
/// Approximately convert an f64 into a value.
fn from_f64(x: f64) -> Self;
}
/// A type that supports operations on a Torus.
pub trait TorusOps:
BitAnd<Self, Output = Self>
+ WrappingAdd
+ WrappingSub
+ WrappingMul
+ WrappingShl
+ WrappingShr
+ WrappingNeg
+ BitAnd
+ ReinterpretAsSigned
+ Num
+ NumBits
+ From<u32>
+ TryFrom<u64>
+ FromU64
+ Clone
+ Copy
+ Binary
+ LowerHex
+ UpperHex
+ std::fmt::Debug
+ Ord
+ Zero
+ Pod
+ BytemuckPod
+ Bounded
+ ToF64
+ FromF64
+ ToU64
+ NumBits
{
}
// Sound since Torus is a transparent wrapper and `S` impl `Pod`
unsafe impl<S: TorusOps> Zeroable for Torus<S> {}
unsafe impl<S: TorusOps> BytemuckPod for Torus<S> {}
/// Convert a type into a 64-bit unsigned integer.
pub trait ToU64 {
/// Convert the value to a 64-bit unsigned integer.
fn to_u64(self) -> u64;
}
impl ToU64 for u32 {
fn to_u64(self) -> u64 {
self as u64
}
}
impl ToU64 for u64 {
fn to_u64(self) -> u64 {
self
}
}
impl<T> ToU64 for Torus<T>
where
T: TorusOps,
{
fn to_u64(self) -> u64 {
self.0.to_u64()
}
}
/// Convert a 64-bit unsigned integer into a type.
pub trait FromU64 {
/// For the given 64-bit value, take the N most significant bits, where
/// N is the bitlength of this type.
fn from_u64(val: u64) -> Self;
}
impl FromU64 for u32 {
fn from_u64(val: u64) -> Self {
(val & 0xFFFFFFFF) as Self
}
}
impl FromU64 for u64 {
fn from_u64(val: u64) -> Self {
val
}
}
macro_rules! impl_tof64 {
($t:ty) => {
impl ToF64 for $t {
fn to_f64(self) -> f64 {
self as f64
}
}
};
}
impl_tof64!(i8);
impl_tof64!(i16);
impl_tof64!(i32);
impl_tof64!(i64);
impl_tof64!(i128);
impl_tof64!(u8);
impl_tof64!(u16);
impl_tof64!(u32);
impl_tof64!(u64);
impl_tof64!(u128);
impl<T> ToF64 for Wrapping<T>
where
T: ToF64,
{
fn to_f64(self) -> f64 {
self.0.to_f64()
}
}
impl<T> ToF64 for Torus<T>
where
T: TorusOps,
{
fn to_f64(self) -> f64 {
self.0.to_f64()
}
}
macro_rules! impl_unsigned_fromf64 {
($t:ty,$st:ty) => {
impl FromF64 for $t {
#[inline(always)]
fn from_f64(x: f64) -> $t {
let x = x as $st;
x.reinterpret_as_unsigned()
}
}
};
}
impl_unsigned_fromf64!(u128, i128);
impl_unsigned_fromf64!(u64, i64);
impl_unsigned_fromf64!(u32, i32);
impl_unsigned_fromf64!(u16, i16);
impl<T> FromF64 for Torus<T>
where
T: TorusOps,
{
fn from_f64(x: f64) -> Self {
Self(T::from_f64(x))
}
}
impl<T> FromF64 for Wrapping<T>
where
T: FromF64,
{
fn from_f64(x: f64) -> Self {
Wrapping(T::from_f64(x))
}
}
impl TorusOps for u64 {}
impl TorusOps for u32 {}
/// A wrapper around a type that supports Torus operations.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Torus<S: TorusOps = u64>(S);
/// Compute the distance between two Torus values, normalized to the unit torus.
/// The first argument is taken as the reference point.
///
/// For example, if a is at a normalized position of 0.2, and b is at a
/// normalized position of 0.8, then there are two distances between them:
/// 0.6 and -0.4. This function will return the shorter of the two distances,
/// -0.4.
///
/// In other words:
///
/// ```text
/// b.normalized_torus() = a.normalized_torus() + normalized_torus_distance(a, b) (mod 1.0)
/// ```
pub fn normalized_torus_distance<S: TorusOps>(a: &Torus<S>, b: &Torus<S>) -> f64 {
let a_minus_b = a - b;
let b_minus_a = b - a;
let modulus = 2_f64.powi(S::BITS as i32);
let difference = if a_minus_b < b_minus_a {
-a_minus_b.to_f64()
} else {
b_minus_a.to_f64()
};
difference / modulus
}
impl<S: TorusOps> Deref for Torus<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<S: TorusOps> NumBits for Torus<S> {
const BITS: u32 = S::BITS;
}
impl<S: TorusOps> Torus<S> {
/// Compute the normalized position of this [Torus] value on the unit torus [0.0, 1.0).
pub fn normalized_torus(&self) -> f64 {
let modulus: f64 = 2_f64.powi(S::BITS as i32);
let val: f64 = self.0.to_f64();
val / modulus
}
/// Compute the distance between this [Torus] value and another, normalized to
/// the unit torus.
pub fn normalized_torus_distance(&self, other: &Self) -> f64 {
normalized_torus_distance(self, other)
}
/// Encode a value into a [Torus] that supports up to plain_bits of values.
/// This encodes the value on the equispaced `2^plaintext_bits` positions on
/// a larger torus.
pub fn encode(val: S, plain_bits: PlaintextBits) -> Self {
assert!(plain_bits.0 < S::BITS);
let encoded = val.wrapping_shl(S::BITS - plain_bits.0);
Self(encoded)
}
/// Decode a value from a [Torus] that supports up to plain_bits of values.
pub fn decode(&self, plain_bits: PlaintextBits) -> S {
assert!(plain_bits.0 < S::BITS);
let round_bit = self.0.wrapping_shr(S::BITS - plain_bits.0 - 1) & S::from(0x1);
let mask = S::from((0x1 << plain_bits.0) - 1);
(self.0.wrapping_shr(S::BITS - plain_bits.0) + round_bit) & mask
}
/// Scale a Torus element to a different modulus. Assumes that the two moduli are
/// powers of 2.
pub fn switch_modulus_smaller<T>(&self) -> Torus<T>
where
T: TorusOps + TryFrom<S>,
<T as TryFrom<S>>::Error: Debug,
{
let shift = (S::BITS - T::BITS) as usize;
// We don't wrap the bits we don't need
let y = self.0 >> shift;
Torus(T::try_from(y).expect("impossible error on switch_modulus_smaller"))
}
/// Return the underlying type of the Torus.
#[inline(always)]
pub fn inner(&self) -> S {
self.0
}
}
impl<S: TorusOps> From<S> for Torus<S> {
#[inline(always)]
fn from(value: S) -> Self {
Self(value)
}
}
impl<S: TorusOps> Zero for Torus<S> {
fn zero() -> Self {
Self(S::from(0))
}
fn vartime_is_zero(&self) -> bool {
self.inner() == <S as sunscreen_math::Zero>::zero()
}
}
impl<S: TorusOps> Neg for Torus<S> {
type Output = Self;
fn neg(self) -> Self::Output {
Self::Output::from(self.0.wrapping_neg())
}
}
impl<S: TorusOps> WrappingNeg for Torus<S> {
fn wrapping_neg(&self) -> Self {
Self::from(self.0.wrapping_neg())
}
}
#[refify_binary_op]
impl<S: TorusOps> Add<&Torus<S>> for &Torus<S> {
type Output = Torus<S>;
fn add(self, rhs: &Torus<S>) -> Self::Output {
Self::Output::from(self.0.wrapping_add(&rhs.0))
}
}
impl<S: TorusOps> WrappingAdd for Torus<S> {
fn wrapping_add(&self, rhs: &Self) -> Self {
self + rhs
}
}
#[refify_binary_op]
impl<S: TorusOps> Sub<&Torus<S>> for &Torus<S> {
type Output = Torus<S>;
fn sub(self, rhs: &Torus<S>) -> Self::Output {
Self::Output::from(self.0.wrapping_sub(&rhs.0))
}
}
impl<S: TorusOps> WrappingSub for Torus<S> {
fn wrapping_sub(&self, rhs: &Self) -> Self {
self - rhs
}
}
#[refify_binary_op]
impl<S: TorusOps> Mul<&S> for &Torus<S> {
type Output = Torus<S>;
fn mul(self, rhs: &S) -> Self::Output {
Self::Output::from(self.wrapping_mul(rhs))
}
}
#[refify_binary_op]
impl<S: TorusOps> BitAnd<&Torus<S>> for &Torus<S> {
type Output = Torus<S>;
fn bitand(self, rhs: &Torus<S>) -> Self::Output {
Torus::from(self.0 & rhs.0)
}
}
#[refify_binary_op]
impl<S: TorusOps> Shr<&usize> for &Torus<S> {
type Output = Torus<S>;
fn shr(self, rhs: &usize) -> Self::Output {
Torus::from(self.0 >> *rhs)
}
}
#[refify_binary_op]
impl<S: TorusOps> Shl<&usize> for &Torus<S> {
type Output = Torus<S>;
fn shl(self, rhs: &usize) -> Self::Output {
Torus::from(self.0 << *rhs)
}
}
impl<S: TorusOps> AddAssign<Self> for Torus<S> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
impl<S: TorusOps> AddAssign<&Self> for Torus<S> {
fn add_assign(&mut self, rhs: &Self) {
*self = *self + rhs
}
}
impl<S: TorusOps> SubAssign<Self> for Torus<S> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<S: TorusOps> SubAssign<&Self> for Torus<S> {
fn sub_assign(&mut self, rhs: &Self) {
*self = *self - rhs;
}
}
impl<S: TorusOps> std::iter::Sum for Torus<S> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Self::zero(), |acc, x| acc + x)
}
}
impl<S> ReinterpretAsSigned for Torus<S>
where
S: TorusOps,
{
type Output = <S as ReinterpretAsSigned>::Output;
#[inline(always)]
fn reinterpret_as_signed(self) -> Self::Output {
self.0.reinterpret_as_signed()
}
}
impl<S: TorusOps> num::Zero for Torus<S> {
fn zero() -> Self {
Self(<S as num::Zero>::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero()
}
}
impl<S: TorusOps> MulAdd<S, Self> for Torus<S> {
type Output = Self;
#[inline(always)]
fn mul_add(self, a: S, b: Self) -> Self::Output {
self * a + b
}
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use super::*;
#[test]
fn can_negate() {
let x = Torus::<u64>::from(0);
assert_eq!(-x, Torus::from(0));
let x = Torus::<u64>::from(1);
assert_eq!(-x, Torus::from(u64::MAX));
let x = Torus::<u64>::from(u64::MAX);
assert_eq!(-x, Torus::from(1));
}
#[test]
fn can_encode_decode() {
assert_eq!(
Torus::<u64>::encode(7, PlaintextBits(4)).0,
0x70000000_00000000
);
let x = Torus::<u64>::from(0x70000000_00000000);
assert_eq!(x.decode(PlaintextBits(4)), 7);
let x = Torus::<u64>::from(0x7FFFFFFF_FFFFFFFF);
assert_eq!(x.decode(PlaintextBits(4)), 8);
}
#[test]
fn can_decode_off_center() {
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.6) as u64);
let r = t.decode(PlaintextBits(1));
assert_eq!(r, 1);
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.3) as u64);
let r = t.decode(PlaintextBits(1));
assert_eq!(r, 1);
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.8) as u64);
let r = t.decode(PlaintextBits(1));
assert_eq!(r, 0);
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.2) as u64);
let r = t.decode(PlaintextBits(1));
assert_eq!(r, 0);
}
#[test]
fn can_normalize() {
let x = Torus::<u64>::from(0);
assert_eq!(x.normalized_torus(), 0.0);
let x = Torus::<u64>::from(u64::MAX / 4);
assert_eq!(x.normalized_torus(), 0.25);
let x = Torus::<u64>::from(u64::MAX / 2);
assert_eq!(x.normalized_torus(), 0.5);
let x = Torus::<u64>::from(u64::MAX / 4 * 3);
assert_eq!(x.normalized_torus(), 0.75);
let x = Torus::<u64>::from(u64::MAX / 8 * 7);
assert_eq!(x.normalized_torus(), 0.875);
}
#[test]
fn can_compute_distance() {
let a = Torus::<u64>::from(0);
let b = Torus::<u64>::from(u64::MAX / 4);
assert_eq!(normalized_torus_distance(&a, &b), 0.25);
assert_eq!(normalized_torus_distance(&b, &a), -0.25);
let a = Torus::<u64>::from(u64::MAX / 4);
let b = Torus::<u64>::from(u64::MAX / 2);
assert_eq!(normalized_torus_distance(&a, &b), 0.25);
assert_eq!(normalized_torus_distance(&b, &a), -0.25);
let a = Torus::<u64>::from(0);
let b = Torus::<u64>::from(u64::MAX / 4 * 3);
assert_eq!(normalized_torus_distance(&a, &b), -0.25);
assert_eq!(normalized_torus_distance(&b, &a), 0.25);
let a = Torus::<u64>::from(u64::MAX / 8);
let b = Torus::<u64>::from(u64::MAX / 4 * 3);
assert_eq!(normalized_torus_distance(&a, &b), -0.375);
assert_eq!(normalized_torus_distance(&b, &a), 0.375);
}
#[test]
fn test_normalized_relation() {
// Tests the relation:
// b.normalized_torus() = a.normalized_torus() + normalized_torus_distance(a, b)
for _ in 0..100 {
let a = Torus::<u64>::from(thread_rng().next_u64());
let b = Torus::<u64>::from(thread_rng().next_u64());
let a_norm = a.normalized_torus();
let b_norm = b.normalized_torus();
let dist = normalized_torus_distance(&a, &b);
let b_norm_from_dist = a_norm + dist;
let b_norm_from_dist = if b_norm_from_dist < 0.0 {
b_norm_from_dist + 1.0
} else if b_norm_from_dist >= 1.0 {
b_norm_from_dist - 1.0
} else {
b_norm_from_dist
};
let diff = (b_norm - b_norm_from_dist).abs();
assert!(
diff < 1e-12,
"Normalized torus relation test failed: a = {:?}, b = {:?}, a_norm = {}, b_norm = {}, dist: {}, b_norm_from_dist = {}, diff = {}",
a,
b,
a_norm,
b_norm,
dist,
b_norm_from_dist,
diff
);
}
}
#[test]
fn can_modulus_switch() {
let x = Torus::<u64>::from(0x12345678_9ABCDEF0);
let y = x.switch_modulus_smaller::<u32>();
assert_eq!(y.0, 0x12345678);
}
}

View File

@@ -0,0 +1,473 @@
use num::Complex;
use crate::{
dst::FromMutSlice,
entities::{BlindRotationShiftFftRef, GgswCiphertext, GlweCiphertextRef, GlweSecretKeyRef},
ops::{encryption::encrypt_ggsw_ciphertext_scalar, fft_ops::cmux},
scratch::allocate_scratch_ref,
GlweDef, PlaintextBits, RadixDecomposition, TorusOps,
};
/// Rotate the given ciphertext message polynomial by the given amount as if it
/// had been multiplied by a monomial with either positive or negative degree.
/// Since this is a negacyclic rotation, a rotation to the left negates the last
/// `rotation` coefficients, while a rotation to the right negates the first
/// `rotation` coefficients.
///
/// Mathematically, this is equivalent to multiplying the underlying polynomial
/// message by X^{rotation} mod (X^N + 1), where N is the polynomial degree.
/// There are some convenient relations to remember over any integer $k$ and
/// $i$:
///
/// - Rotations are modulus 2N: $X^{k*2N} = 1$ and $X^{k*2N ± i} = X^{± i}$.
/// - Rotations that are equivalent to a shift by -N are equivalent to negating
/// the polynomial: $X^{k*2N + N} = -1$.
/// - A negative rotation has an equivalent positive rotation: $X^{-i} = -X^{N - i}$.
///
/// # Example
///
/// ```
/// use sunscreen_tfhe::{
/// high_level::{keygen, encryption},
/// entities::{GlweCiphertext, Polynomial},
/// ops::bootstrapping::rotate_glwe_monomial_negacyclic,
/// params::{
/// GlweDef,
/// GlweSize,
/// GlweDimension,
/// PlaintextBits,
/// PolynomialDegree,
/// },
/// rand::Stddev,
/// };
///
/// // Define the GLWE parameters
/// let params = GlweDef {
/// dim: GlweDimension {
/// size: GlweSize(1),
/// polynomial_degree: PolynomialDegree(8),
/// },
/// std: Stddev(0.0000000444778278004718),
/// };
/// let plaintext_bits = PlaintextBits(4);
///
/// // Generate the GLWE secret key
/// let sk = keygen::generate_binary_glwe_sk(&params);
///
/// // Define and encrypt a message
/// let msg = Polynomial::new(&[1, 2, 3, 4, 5, 6, 7, 8]);
/// let ct = encryption::encrypt_glwe(&msg, &sk, &params, plaintext_bits);
///
/// // Rotate the message polynomial by 1 to the right
/// let mut rotated_ct = GlweCiphertext::new(&params);
/// rotate_glwe_monomial_negacyclic(&mut rotated_ct, &ct, 1, &params);
///
/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, &params, plaintext_bits);
///
/// assert_eq!(decrypted_msg, Polynomial::new(&[8, 1, 2, 3, 4, 5, 6, 7]));
///
/// // Rotate the message polynomial by 1 to the left
/// let mut rotated_ct = GlweCiphertext::new(&params);
/// rotate_glwe_monomial_negacyclic(&mut rotated_ct, &ct, -1, &params);
///
/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, &params, plaintext_bits);
///
/// // Since this is a negacyclic rotation, the element moved to the end is
/// // negated.
/// assert_eq!(decrypted_msg, Polynomial::new(&[2, 3, 4, 5, 6, 7, 8, 15]));
/// ```
pub fn rotate_glwe_monomial_negacyclic<S>(
output: &mut GlweCiphertextRef<S>,
ct: &GlweCiphertextRef<S>,
rotation: isize,
params: &GlweDef,
) where
S: TorusOps,
{
let (output_a, output_b) = output.a_b_mut(params);
let (ct_a, ct_b) = ct.a_b(params);
let output_all_coefficients = output_a.chain(std::iter::once(output_b));
let ct_all_coefficients = ct_a.chain(std::iter::once(ct_b));
for (o, a) in output_all_coefficients.zip(ct_all_coefficients) {
o.clone_from_ref(a);
o.mul_by_monomial_negacyclic(rotation);
}
}
/// Rotate the given ciphertext message polynomial by the given amount as if it
/// had been multiplied by a monomial. This is equivalent to shifting
/// all the coefficients left by `rotation` and negating the last `rotation`
/// coefficients. Mathematically, this is equivalent to multiplying by
/// X^{-rotation} mod (X^N + 1), or equivalently -X^{N - rotation}.
///
/// See [`rotate_glwe_monomial_negacyclic`] for the case that handles both
/// positive and negative rotations, plus an example.
pub fn rotate_glwe_negative_monomial_negacyclic<S>(
output: &mut GlweCiphertextRef<S>,
ct: &GlweCiphertextRef<S>,
rotation: usize,
params: &GlweDef,
) where
S: TorusOps,
{
rotate_glwe_monomial_negacyclic(output, ct, -(rotation as isize), params)
}
/// Rotate the given ciphertext message polynomial by the given amount as if it
/// had been multiplied by a positive monomial. This is equivalent to shifting
/// all the coefficients right by `rotation` and negating the first `rotation`
/// coefficients. Mathematically, this is equivalent to multiplying by
/// X^{rotation} mod (X^N + 1).
///
/// See [`rotate_glwe_monomial_negacyclic`] for the case that handles both
/// positive and negative rotations, plus an example.
pub fn rotate_glwe_positive_monomial_negacyclic<S>(
output: &mut GlweCiphertextRef<S>,
ct: &GlweCiphertextRef<S>,
rotation: usize,
params: &GlweDef,
) where
S: TorusOps,
{
rotate_glwe_monomial_negacyclic(output, ct, rotation as isize, params)
}
/// Rotate the given ciphertext message polynomial by negative encrypted shift.
/// In practice this is not often used on its own; bootstrapping performs a
/// different procedure.
///
/// See
/// - [`generate_blind_rotation_shift`] for a way to encrypt a rotation amount.
/// - [`rotate_glwe_monomial_negacyclic`] for how negacyclic rotation works
/// when the rotation amount is public.
///
/// # Example
///
/// ```
/// use sunscreen_tfhe::{
/// high_level::{keygen, encryption},
/// entities::{GlweCiphertext, Polynomial},
/// ops::bootstrapping::{blind_rotation, generate_blind_rotation_shift},
/// params::{
/// GlweDef,
/// GlweSize,
/// GlweDimension,
/// PlaintextBits,
/// PolynomialDegree,
/// RadixDecomposition,
/// RadixCount,
/// RadixLog,
/// },
/// rand::Stddev,
/// };
///
/// // Define the GLWE parameters
/// let params = GlweDef {
/// dim: GlweDimension {
/// size: GlweSize(1),
/// polynomial_degree: PolynomialDegree(8),
/// },
/// std: Stddev(0.0000000444778278004718),
/// };
/// let radix = RadixDecomposition {
/// count: RadixCount(3),
/// radix_log: RadixLog(4),
/// };
/// let plaintext_bits = PlaintextBits(4);
///
/// // Generate the GLWE secret key
/// let sk = keygen::generate_binary_glwe_sk(&params);
///
/// // Define and encrypt a message
/// let msg = Polynomial::new(&[1, 2, 3, 4, 5, 6, 7, 8]);
/// let ct = encryption::encrypt_glwe(&msg, &sk, &params, plaintext_bits);
///
/// // Generate a blind rotation amount
/// let mut blind_rotation_index = sunscreen_tfhe::entities::BlindRotationShiftFft::new(&params, &radix);
/// generate_blind_rotation_shift(&mut blind_rotation_index, 1, &sk, &params, &radix, plaintext_bits);
///
/// // Rotate the message polynomial by the blind rotation amount
/// let mut rotated_ct = GlweCiphertext::new(&params);
/// blind_rotation(&mut rotated_ct, &blind_rotation_index, &ct, &params, &radix);
///
/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, &params, plaintext_bits);
///
/// assert_eq!(decrypted_msg, Polynomial::new(&[2, 3, 4, 5, 6, 7, 8, 15]));
/// ```
pub fn blind_rotation<S>(
output: &mut GlweCiphertextRef<S>,
blind_rotation_index: &BlindRotationShiftFftRef<Complex<f64>>,
ct: &GlweCiphertextRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
// Initialize with the unrotated message m
output.clone_from_ref(ct);
allocate_scratch_ref!(rotated_ct, GlweCiphertextRef<S>, (params.dim));
for (i, index_select) in blind_rotation_index.rows(params, radix).enumerate() {
let rotation = 1 << i;
rotate_glwe_negative_monomial_negacyclic(rotated_ct, output, rotation, params);
let tmp = output.to_owned();
cmux(output, &tmp, rotated_ct, index_select, params, radix);
}
}
/// Encrypt an amount to rotate the message polynomial by.
///
/// This function is mostly provided as a convenience. Bootstrapping will rotate
/// a message without encrypting using a cmux tree, so this function is not
/// strictly necessary.
pub fn generate_blind_rotation_shift<S>(
bootstrap_key: &mut BlindRotationShiftFftRef<Complex<f64>>,
rotation: usize,
sk: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
{
let degree = params.dim.polynomial_degree.0;
assert!(rotation < degree);
for (i, ggsw_fft) in bootstrap_key.rows_mut(params, radix).enumerate() {
let bit = ((rotation >> i) & 1) as u64;
let mut ct = GgswCiphertext::new(params, radix);
encrypt_ggsw_ciphertext_scalar(
&mut ct,
S::from_u64(bit),
sk,
params,
radix,
plaintext_bits,
);
ct.fft(ggsw_fft, params, radix);
}
}
#[cfg(test)]
mod tests {
use blind_rotation::generate_blind_rotation_shift;
use crate::{
entities::{
BlindRotationShiftFft, GgswCiphertext, GlweCiphertext, GlweSecretKey, Polynomial,
},
high_level::{TEST_GLWE_DEF_1, TEST_RADIX},
ops::{
bootstrapping::{blind_rotation, rotate_glwe_monomial_negacyclic},
encryption::decrypt_ggsw_ciphertext,
},
polynomial::polynomial_external_mad,
GlweDef, GlweDimension, GlweSize, PlaintextBits, PolynomialDegree, Torus,
};
#[test]
fn can_rotate() {
let params = GlweDef {
dim: GlweDimension {
polynomial_degree: PolynomialDegree(8),
size: GlweSize(2),
},
..TEST_GLWE_DEF_1
};
let plaintext_bits = PlaintextBits(4);
let modulus = 1 << plaintext_bits.0;
let degree = params.dim.polynomial_degree.0;
let sk = GlweSecretKey::<u64>::generate_binary(&params);
let msg_coeffs = (0..degree)
.map(|i| (i % modulus) as u64)
.collect::<Vec<_>>();
let msg = Polynomial::new(&msg_coeffs);
let ct = sk.encode_encrypt_glwe(&msg, &params, plaintext_bits);
for rotation in (-2i64 * (degree as i64))..=(2i64 * (degree as i64)) {
println!("Rotation: {}", rotation);
let mut rotation_polynomial = vec![Torus::from(0u64); degree];
let direction = if rotation < 0 { -1 } else { 1 };
let original_rotation = rotation;
let rotation = rotation.unsigned_abs() as usize;
#[allow(clippy::collapsible_else_if)]
if direction == 1 {
// Positive rotation
if rotation == 0 || rotation == 2 * degree {
rotation_polynomial[0] = Torus::from(1);
} else if rotation < degree {
rotation_polynomial[rotation] = Torus::from(1);
} else if rotation == degree {
rotation_polynomial[0] = -Torus::from(1);
} else {
rotation_polynomial[rotation % degree] = -Torus::from(1);
}
} else {
// Negative rotation
if rotation == 0 || rotation == 2 * degree {
rotation_polynomial[0] = Torus::from(1);
} else if rotation < degree {
rotation_polynomial[(degree - rotation) % degree] = -Torus::from(1);
} else if rotation == degree {
rotation_polynomial[0] = -Torus::from(1);
} else {
rotation_polynomial[(2 * degree - rotation) % degree] = Torus::from(1);
}
}
let rotation_polynomial = Polynomial::new(&rotation_polynomial);
let mut expected = Polynomial::<Torus<u64>>::zero(degree);
polynomial_external_mad(&mut expected, &rotation_polynomial, &msg);
// Polynomial multiply doesn't reduce modulo the modulus, so we need to do it manually.
let expected = expected.map(|x| x.inner() % (modulus as u64));
// Perform encrypted rotation
let mut output_ct = GlweCiphertext::new(&params);
rotate_glwe_monomial_negacyclic(
&mut output_ct,
&ct,
original_rotation as isize,
&params,
);
let output_msg = sk.decrypt_decode_glwe(&output_ct, &params, plaintext_bits);
assert_eq!(output_msg, expected);
}
}
#[test]
fn rotation_shift_encrypted_properly() {
let params = GlweDef {
dim: GlweDimension {
polynomial_degree: PolynomialDegree(8),
size: GlweSize(2),
},
..TEST_GLWE_DEF_1
};
let radix = TEST_RADIX;
let degree = params.dim.polynomial_degree.0;
let sk = GlweSecretKey::<u64>::generate_binary(&params);
for rotation in 0..(degree - 1) {
let mut ggsw_index = BlindRotationShiftFft::new(&params, &radix);
generate_blind_rotation_shift(
&mut ggsw_index,
rotation,
&sk,
&params,
&radix,
PlaintextBits(4),
);
let mut encrypted_rotation = 0u64;
for (i, bit_fft) in ggsw_index.rows(&params, &radix).enumerate() {
let mut bit = GgswCiphertext::<u64>::new(&params, &radix);
bit_fft.ifft(&mut bit, &params, &radix);
let mut pt = Polynomial::zero(degree);
decrypt_ggsw_ciphertext(&mut pt, &bit, &sk, &params, &radix);
encrypted_rotation |= pt.coeffs()[0].inner() << i;
}
assert_eq!(encrypted_rotation, rotation as u64);
}
}
#[test]
fn can_blind_rotate() {
let params = GlweDef {
dim: GlweDimension {
polynomial_degree: PolynomialDegree(8),
size: GlweSize(2),
},
..TEST_GLWE_DEF_1
};
let radix = TEST_RADIX;
let plaintext_bits = PlaintextBits(4);
let modulus = 1 << plaintext_bits.0;
let degree = params.dim.polynomial_degree.0;
let num_bits = (degree as u64).ilog2() as usize;
let sk = GlweSecretKey::<u64>::generate_binary(&params);
let msg_coeffs = (0..degree)
.map(|i| (i % modulus) as u64)
.collect::<Vec<_>>();
let msg = Polynomial::new(&msg_coeffs);
let ct = sk.encode_encrypt_glwe(&msg, &params, plaintext_bits);
#[allow(clippy::needless_range_loop)]
for rotation in 0..=(degree - 1) {
let mut expected = Polynomial::<Torus<u64>>::new(msg.map(|x| Torus::from(*x)).coeffs());
for i in 0..num_bits {
let bit = ((rotation >> i) & 1) as u64;
// We don't perform this rotation
if bit == 0 {
continue;
}
let mut rotation_polynomial = vec![Torus::from(0u64); degree];
rotation_polynomial[degree - (1 << i)] = -Torus::<u64>::from(1);
let rotation_polynomial = Polynomial::new(&rotation_polynomial);
let tmp = expected.map(|x| x.inner());
expected = Polynomial::<Torus<u64>>::zero(degree);
polynomial_external_mad(&mut expected, &rotation_polynomial, &tmp);
}
// Polynomial multiply doesn't reduce modulo the modulus, so we need to do it manually.
let expected = expected.map(|x| x.inner() % (modulus as u64));
// Perform encrypted rotation
let mut ggsw_index = BlindRotationShiftFft::new(&params, &radix);
generate_blind_rotation_shift(
&mut ggsw_index,
rotation,
&sk,
&params,
&radix,
plaintext_bits,
);
let mut output_ct = GlweCiphertext::new(&params);
blind_rotation(&mut output_ct, &ggsw_index, &ct, &params, &radix);
let output_msg = sk.decrypt_decode_glwe(&output_ct, &params, plaintext_bits);
// Make sure the zero point is rotated the correct amount.
assert_eq!(output_msg.coeffs()[(degree - rotation) % degree], 0);
// Make sure we have moved the element in the rotation position to the zero position.
assert_eq!(output_msg.coeffs()[0], msg_coeffs[rotation]);
assert_eq!(
&output_msg, &expected,
"CT encrypted message: {:?}, expected message: {:?}",
&output_msg, &expected
);
}
}
}

View File

@@ -0,0 +1,469 @@
use num::Complex;
use crate::{
dst::FromMutSlice,
entities::{
BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, GgswCiphertextRef,
LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef,
},
ops::{
bootstrapping::programmable_bootstrap, homomorphisms::rotate,
keyswitch::private_functional_keyswitch::private_functional_keyswitch,
},
scratch::allocate_scratch_ref,
GlweDef, LweDef, PlaintextBits, PrivateFunctionalKeyswitchLweCount, RadixDecomposition, Torus,
TorusOps,
};
/// Bootstraps a LWE ciphertext to a GGSW ciphertext.
#[allow(clippy::too_many_arguments)]
/// Transform [`LweCiphertextRef`] `input` encrypted under parameters `lwe_0` into
/// the [`GgswCiphertextRef`] `output` encrypted under parameters `glwe_1` with
/// radix decomposition `cbs_radix`. This resets the noise in `output` in the
/// process.
///
/// [`GgswCiphertext`](crate::entities::GgswCiphertext)s can be used as select
/// inputs for [`cmux`](crate::ops::fft_ops::cmux) operations.
///
/// # Remarks
/// The following diagram illustrates how circuit bootstrapping works
///
/// ![Circuit Bootstrapping](LINK TO GITHUB)
///
/// We perform `cbs_radix.count` programmable bootstrapping (PBS) operations to
/// decompose the original message m under radix `2^cbs_radix.radix_log`. These PBS
/// operations use a bootstrapping key encrypting the level 0 LWE secret key under
/// the level 2 GLWE secret key and internally perform their own radix decomposition
/// parameterized by `pbs_radix`. After performing bootstrapping, we now have
/// `cbs_radix.count` LWE ciphertexts encrypted under the level 2 GLWE secret key
/// (reinterpreted as an LWE key).
///
/// Next, we take each of these `cbs_radix.count` level 2 LWE ciphertexts and
/// perform `glwe_1.dim.size + 1` private functional keyswitching operations (`
/// (glwe_1.dim.size + 1) * cbs_radix.count` in total). For the first `glwe_1.dim.
/// size` rows of the [`GgswCiphertextRef`] output, this multiplies the radix
/// decomposed message by the negative corresponding secret key. For the last
/// row, we simply multiply our radix decomposed messages by 1.
///
/// Recall that [`private_functional_keyswitch`] (PFKS) transforms a list of LWE
/// ciphertexts into a [`GlweCiphertext`](crate::entities::GlweCiphertext). In
/// our case, this list contains a single
/// [`LweCiphertext`](crate::entities::LweCiphertext) for each PFKS operation.
/// Each row of the output [`GgswCiphertext`](crate::entities::GgswCiphertext)
/// corresponds to a different PFKS key, encapsulated in `cbsksk`.
///
/// These PFKS operations switch from a key under parameters `glwe_2` (interpreted
/// as LWE) to `glwe_1` with [`RadixDecomposition`] `pfks_radix`.
///
/// # Panics
/// * If `bsk` is not valid for bootrapping from parameters `lwe_0` to `glwe_2`
/// (reinterpreted as LWE) with radix decomposition `pbs_radix`.
/// * If `cbsksk` is not a valid keyswitch key set for switching from `glwe_2`
/// (reintrerpreted as LWE) to `glwe_1` with `glwe_1.dim.size` entries and radix
/// decomposition `pfks_radix`.
/// * If `output` is not the correct length for a GGSW ciphertext under `glwe_1`
/// parameters with `cbs_radix` decomposition.
/// * If `input` is not a valid LWE ciphertext under `lwe_0` parameters.
/// * If `lwe_0`, `glwe_1`, `glwe_2`, `cbs_radix`, `pfks_radix`, `pbs_radix` are
/// invalid.
///
/// # Example
/// ```
/// use sunscreen_tfhe::{
/// high_level,
/// high_level::{keygen, encryption, fft},
/// entities::GgswCiphertext,
/// ops::bootstrapping::circuit_bootstrap,
/// params::{
/// GLWE_5_256_80,
/// GLWE_1_1024_80,
/// LWE_512_80,
/// PlaintextBits,
/// RadixDecomposition,
/// RadixCount,
/// RadixLog
/// }
/// };
///
/// let pbs_radix = RadixDecomposition {
/// count: RadixCount(2),
/// radix_log: RadixLog(16),
/// };
/// let cbs_radix = RadixDecomposition {
/// count: RadixCount(2),
/// radix_log: RadixLog(5),
/// };
/// let pfks_radix = RadixDecomposition {
/// count: RadixCount(3),
/// radix_log: RadixLog(11),
/// };
///
/// let level_2_params = GLWE_5_256_80;
/// let level_1_params = GLWE_1_1024_80;
/// let level_0_params = LWE_512_80;
///
/// let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params);
/// let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params);
/// let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params);
///
/// let bsk = keygen::generate_bootstrapping_key(
/// &sk_0,
/// &sk_2,
/// &level_0_params,
/// &level_2_params,
/// &pbs_radix,
/// );
/// let bsk =
/// high_level::fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix);
///
/// let cbsksk = keygen::generate_cbs_ksk(
/// sk_2.to_lwe_secret_key(),
/// &sk_1,
/// &level_2_params.as_lwe_def(),
/// &level_1_params,
/// &pfks_radix,
/// );
///
/// let val = 1;
/// let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1));
///
/// let mut ggsw = GgswCiphertext::new(&level_1_params, &cbs_radix);
///
/// // ggsw will contain `val`
/// circuit_bootstrap(
/// &mut ggsw,
/// &ct,
/// &bsk,
/// &cbsksk,
/// &level_0_params,
/// &level_1_params,
/// &level_2_params,
/// &pbs_radix,
/// &cbs_radix,
/// &pfks_radix,
/// );
/// ```
pub fn circuit_bootstrap<S: TorusOps>(
output: &mut GgswCiphertextRef<S>,
input: &LweCiphertextRef<S>,
bsk: &BootstrapKeyFftRef<Complex<f64>>,
cbsksk: &CircuitBootstrappingKeyswitchKeysRef<S>,
lwe_0: &LweDef,
glwe_1: &GlweDef,
glwe_2: &GlweDef,
pbs_radix: &RadixDecomposition,
cbs_radix: &RadixDecomposition,
pfks_radix: &RadixDecomposition,
) {
glwe_1.assert_valid();
glwe_2.assert_valid();
lwe_0.assert_valid();
pbs_radix.assert_valid::<S>();
cbs_radix.assert_valid::<S>();
pfks_radix.assert_valid::<S>();
cbsksk.assert_valid(&glwe_2.as_lwe_def(), glwe_1, pfks_radix);
bsk.assert_valid(lwe_0, glwe_2, pbs_radix);
output.assert_valid(glwe_1, cbs_radix);
input.assert_valid(lwe_0);
// Step 1, for each l in cbs_radix.count, use bootstrapping to base decompose the
// plaintext in input. We bootstrap from level 0 -> level 2.
allocate_scratch_ref!(
level_2_lwes,
LweCiphertextListRef<S>,
(glwe_2.as_lwe_def().dim, cbs_radix.count.0)
);
level_0_to_level_2(
level_2_lwes,
input,
bsk,
lwe_0,
glwe_2,
pbs_radix,
cbs_radix,
);
level_2_to_level1(
output,
level_2_lwes,
cbsksk,
glwe_2,
glwe_1,
pfks_radix,
cbs_radix,
);
}
#[allow(dead_code)]
#[inline(always)]
fn level_0_to_level_2<S: TorusOps>(
lwes_2: &mut LweCiphertextListRef<S>,
input: &LweCiphertextRef<S>,
bsk: &BootstrapKeyFftRef<Complex<f64>>,
lwe_0: &LweDef,
glwe_2: &GlweDef,
pbs_radix: &RadixDecomposition,
cbs_radix: &RadixDecomposition,
) {
allocate_scratch_ref!(lut, UnivariateLookupTableRef<S>, (glwe_2.dim));
allocate_scratch_ref!(lwe_rotated, LweCiphertextRef<S>, (lwe_0.dim));
allocate_scratch_ref!(
lwe_bootstrapped,
LweCiphertextRef<S>,
(glwe_2.as_lwe_def().dim)
);
// Rotate our input by q/4, putting 0 centered on q/4 and 1 centered on
// -q/4.
rotate(
lwe_rotated,
input,
Torus::encode(S::one(), PlaintextBits(2)),
lwe_0,
);
for (i, lwe_2) in lwes_2.ciphertexts_mut(&glwe_2.as_lwe_def()).enumerate() {
let cur_level = i + 1;
// Treat value as a T_{b^l+1} with one extra place for rounding as the last
// step.
let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level + 1) as u32);
// Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1}
// everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will
// give 1. Since we've shifted our input lwe by q/4, a 1 plaintext
// value will map to 1 and a 0 will map to -1.
let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one();
lut.fill_with_constant(minus_one, glwe_2, plaintext_bits);
programmable_bootstrap(
lwe_bootstrapped,
lwe_rotated,
lut,
bsk,
lwe_0,
glwe_2,
pbs_radix,
);
// Now we rotate our message containing -1 or 1 by 1 (wrt plaintext_bits).
// This will overflow -1 to 0 and cause 1 to wrap to 2.
rotate(
lwe_2,
lwe_bootstrapped,
Torus::encode(S::one(), plaintext_bits),
&glwe_2.as_lwe_def(),
);
}
}
/// Bootstraps a level 2 GLWE ciphertext to a level 1 GLWE ciphertext.
pub fn level_2_to_level1<S: TorusOps>(
result: &mut GgswCiphertextRef<S>,
lwes_2: &LweCiphertextListRef<S>,
cbsksk: &CircuitBootstrappingKeyswitchKeysRef<S>,
glwe_2: &GlweDef,
glwe_1: &GlweDef,
pfks_radix: &RadixDecomposition,
cbs_radix: &RadixDecomposition,
) {
for (glev, pfksk) in result.rows_mut(glwe_1, cbs_radix).zip(cbsksk.keys(
&glwe_2.as_lwe_def(),
glwe_1,
pfks_radix,
)) {
for (decomp, glwe) in lwes_2
.ciphertexts(&glwe_2.as_lwe_def())
.zip(glev.glwe_ciphertexts_mut(glwe_1))
{
private_functional_keyswitch(
glwe,
&[decomp],
pfksk,
&glwe_2.as_lwe_def(),
glwe_1,
pfks_radix,
&PrivateFunctionalKeyswitchLweCount(1),
);
}
}
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{
entities::{GgswCiphertext, LweCiphertextList},
high_level::{self, encryption, fft, keygen, TEST_LWE_DEF_1},
PlaintextBits, RadixCount, RadixDecomposition, RadixLog, GLWE_1_1024_80, GLWE_5_256_80,
LWE_512_80,
};
use super::{circuit_bootstrap, level_0_to_level_2};
#[test]
fn can_level_0_to_level_2() {
let pbs_radix = RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(16),
};
let cbs_radix = RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(5),
};
let glwe_params = GLWE_5_256_80;
let mut level_2 =
LweCiphertextList::<u64>::new(&glwe_params.as_lwe_def(), cbs_radix.count.0);
let sk = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
let bsk = keygen::generate_bootstrapping_key(
&sk,
&glwe_sk,
&TEST_LWE_DEF_1,
&glwe_params,
&pbs_radix,
);
let bsk = fft::fft_bootstrap_key(&bsk, &TEST_LWE_DEF_1, &glwe_params, &pbs_radix);
let lwe = sk.encrypt(0, &TEST_LWE_DEF_1, PlaintextBits(1)).0;
level_0_to_level_2(
&mut level_2,
&lwe,
&bsk,
&TEST_LWE_DEF_1,
&glwe_params,
&pbs_radix,
&cbs_radix,
);
for (i, lwe_2) in level_2.ciphertexts(&glwe_params.as_lwe_def()).enumerate() {
let cur_level = i + 1;
let bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level) as u32);
let actual =
glwe_sk
.to_lwe_secret_key()
.decrypt(lwe_2, &glwe_params.as_lwe_def(), bits);
assert_eq!(actual, 0);
}
let lwe = sk.encrypt(1, &TEST_LWE_DEF_1, PlaintextBits(1)).0;
level_0_to_level_2(
&mut level_2,
&lwe,
&bsk,
&TEST_LWE_DEF_1,
&glwe_params,
&pbs_radix,
&cbs_radix,
);
for (i, lwe_2) in level_2.ciphertexts(&glwe_params.as_lwe_def()).enumerate() {
let cur_level = i + 1;
let bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level) as u32);
let actual =
glwe_sk
.to_lwe_secret_key()
.decrypt(lwe_2, &glwe_params.as_lwe_def(), bits);
assert_eq!(actual, 1);
}
}
#[test]
fn can_circuit_bootstrap() {
let pbs_radix = RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(16),
};
let cbs_radix = RadixDecomposition {
count: RadixCount(2),
radix_log: RadixLog(5),
};
let pfks_radix = RadixDecomposition {
count: RadixCount(3),
radix_log: RadixLog(11),
};
let level_2_params = GLWE_5_256_80;
let level_1_params = GLWE_1_1024_80;
let level_0_params = LWE_512_80;
let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params);
let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params);
let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params);
let bsk = keygen::generate_bootstrapping_key(
&sk_0,
&sk_2,
&level_0_params,
&level_2_params,
&pbs_radix,
);
let bsk =
high_level::fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix);
let cbsksk = keygen::generate_cbs_ksk(
sk_2.to_lwe_secret_key(),
&sk_1,
&level_2_params.as_lwe_def(),
&level_1_params,
&pfks_radix,
);
for _ in 0..1 {
let val = thread_rng().next_u64() % 2;
let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1));
let mut actual = GgswCiphertext::new(&level_1_params, &cbs_radix);
circuit_bootstrap(
&mut actual,
&ct,
&bsk,
&cbsksk,
&level_0_params,
&level_1_params,
&level_2_params,
&pbs_radix,
&cbs_radix,
&pfks_radix,
);
let expected =
encryption::encrypt_ggsw(val, &sk_1, &level_1_params, &cbs_radix, PlaintextBits(1));
for (a, e) in actual
.rows(&level_1_params, &cbs_radix)
.zip(expected.rows(&level_1_params, &cbs_radix))
{
for (i, (a, e)) in a
.glwe_ciphertexts(&level_1_params)
.zip(e.glwe_ciphertexts(&level_1_params))
.enumerate()
{
let plaintext_bits = (i + 1) * cbs_radix.radix_log.0;
let plaintext_bits = PlaintextBits(plaintext_bits as u32);
let a = encryption::decrypt_glwe(a, &sk_1, &level_1_params, plaintext_bits);
let e = encryption::decrypt_glwe(e, &sk_1, &level_1_params, plaintext_bits);
assert_eq!(a, e);
}
}
}
}
}

View File

@@ -0,0 +1,8 @@
mod blind_rotation;
pub use blind_rotation::*;
mod circuit_bootstrapping;
pub use circuit_bootstrapping::*;
mod programmable_bootstrapping;
pub use programmable_bootstrapping::*;

View File

@@ -0,0 +1,824 @@
use num::Complex;
use crate::{
dst::FromMutSlice,
entities::{
BivariateLookupTableRef, BootstrapKeyFftRef, BootstrapKeyRef, GlweCiphertextRef,
GlweSecretKeyRef, LweCiphertextRef, LweSecretKeyRef, Polynomial, PolynomialRef,
UnivariateLookupTableRef,
},
ops::{
bootstrapping::rotate_glwe_positive_monomial_negacyclic,
ciphertext::{add_lwe_inplace, modulus_switch, sample_extract, scalar_mul_ciphertext_mad},
encryption::encrypt_ggsw_ciphertext_scalar,
fft_ops::cmux,
},
scratch::allocate_scratch_ref,
CarryBits, GlweDef, LweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps,
};
use super::rotate_glwe_negative_monomial_negacyclic;
/// Generate a bootstrap key from a LWE secret key to a GLWE secret key.
///
/// Mathematically, this key is a list of GGSW ciphertexts, one for each bit of
/// the secret key being encrypted.
///
/// See
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap)
/// for an example of how to use this key.
pub fn generate_bootstrap_key<S>(
bootstrap_key: &mut BootstrapKeyRef<S>,
sk_to_encrypt: &LweSecretKeyRef<S>,
sk: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
sk.assert_valid(params);
radix.assert_valid::<S>();
for (s_i, ggsw) in sk_to_encrypt
.s()
.iter()
.zip(bootstrap_key.rows_mut(params, radix))
{
encrypt_ggsw_ciphertext_scalar(ggsw, *s_i, sk, params, radix, PlaintextBits(1));
}
}
/// Generate a negacyclic LUT for bootstrapping. Another name for this structure
/// is a test polynomial.
///
/// The map function passed in must have the following negacyclic property,
/// where N is the size of the polynomial:
///
/// ```text
/// map(N + i) = -map(i)
/// ```
#[allow(dead_code)]
fn generate_negacyclic_lut<S, F>(
output: &mut Polynomial<Torus<S>>,
map: F,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
F: Fn(u64) -> u64,
{
let p = (1 << plaintext_bits.0) as u64;
let n = params.dim.polynomial_degree.0 as u64;
let stride = 2 * n / p;
let delta = S::BITS - plaintext_bits.0;
let c = output.coeffs_mut();
// Written out this way because when we get to programmable boot strapping,
// this will involve replacing p_i with f(p_i)
for (j, p_i_unmapped) in (0..=p / 2).enumerate() {
let j = j as u64;
let p_i = map(p_i_unmapped);
assert!(p_i < p, "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i);
let p_i = p_i << delta;
if j == 0 {
for k in 0..(stride / 2) {
c[k as usize] = Torus::from(S::from_u64(p_i));
}
} else if j == p / 2 {
for k in (n - (stride / 2))..n {
c[k as usize] = Torus::from(S::from_u64(p_i));
}
} else {
for k in (stride / 2 + (j - 1) * stride)..(stride / 2 + j * stride) {
c[k as usize] = Torus::from(S::from_u64(p_i));
}
}
}
}
/// Generates a lookup table (LUT) to be used with bootstrapping. This LUT is
/// not negacyclic, and hence must be used with LWE inputs that have at least
/// one padding bit.
///
/// The input `map` is used for generating programmable bootstrapping LUTs. This
/// function takes an element in the plaintext space and must produce another
/// element in the plaintext space.
pub(crate) fn generate_lut<S, F>(
output: &mut PolynomialRef<Torus<S>>,
map: F,
params: &GlweDef,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
F: Fn(u64) -> u64,
{
let p = (1 << plaintext_bits.0) as usize;
let n = params.dim.polynomial_degree.0;
assert!(n >= p);
let stride = n / p;
let delta = S::BITS - plaintext_bits.0;
let c = output.coeffs_mut();
for (j, p_i_unmapped) in (0..=p - 1).enumerate() {
let p_i = map(p_i_unmapped as u64);
assert!(p_i < (p as u64), "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i);
let p_i = p_i << delta;
// Insert a stride amount into the LUT
c[j * stride..(j + 1) * stride].iter_mut().for_each(|c| {
*c = Torus::from(S::from_u64(p_i));
});
}
// Negate the first half of p_0 in the LUT in preparation for it to be
// rotated.
c[0..stride / 2].iter_mut().for_each(|c| {
*c = num::traits::WrappingNeg::wrapping_neg(c);
});
c.rotate_left(stride / 2);
}
/// Programmable bootstrapping with a univariate function.
///
/// The LUT this is a table that maps two inputs into a single output. For
/// example, say we want to encode the negation function `f(x) = (x + 1) % 2`
/// into a lookup table. We would create a
/// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable) that
/// implements this function and then execute it on the input ciphertexts.
///
/// Important note: This function does not perform key switching. The output
/// ciphertext will be encrypted under the LWE key extracted from the GLWE
/// secret key used for the bootstrapping key. To perform a keyswitch, use
/// [`keyswitch_lwe_to_lwe`](crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe)
/// after the bootstrapping operation.
///
/// # Example
///
/// ```
/// use sunscreen_tfhe::{
/// high_level::{keygen, encryption, fft},
/// entities::{UnivariateLookupTable, LweCiphertext},
/// ops::bootstrapping::programmable_bootstrap,
/// params::{
/// GLWE_1_1024_80,
/// LWE_512_80,
/// CarryBits,
/// PlaintextBits,
/// RadixDecomposition,
/// RadixCount,
/// RadixLog
/// },
/// };
///
/// // Parameters defining the scheme we are using
/// let lwe_params = LWE_512_80;
/// let glwe_params = GLWE_1_1024_80;
/// let radix = RadixDecomposition {
/// count: RadixCount(3),
/// radix_log: RadixLog(4),
/// };
///
/// // We will be showing a binary univariate function. Note that for
/// // programmable bootstrapping to work in general, you will need to include at
/// // least one padding bit to the input.
/// let plaintext_bits = PlaintextBits(1);
/// let carry_bits = CarryBits(1);
/// let plaintext_bits_carry = PlaintextBits(2);
///
/// // The univariate function we want to evaluate, encoded as a lookup table.
/// let negate = |x| (x + 1) % (1 << plaintext_bits.0);
/// let lut = UnivariateLookupTable::trivial_from_fn(
/// &negate,
/// &glwe_params,
/// plaintext_bits,
/// );
///
/// // Generate the secret keys and the bootstrapping key
/// let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
/// let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
///
/// let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, &lwe_params, &glwe_params, &radix);
/// let bsk =
/// fft::fft_bootstrap_key(&bsk, &lwe_params, &glwe_params, &radix);
///
/// // Specify the inputs
/// let input_plain = 0;
///
/// // Encrypt the inputs. Note we are adding carry bits to the inputs.
/// let input = encryption::encrypt_lwe_secret(
/// input_plain,
/// &lwe_sk,
/// &lwe_params,
/// plaintext_bits_carry
/// );
///
/// // Perform the programmable bootstrapping
/// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def());
/// programmable_bootstrap(
/// &mut result,
/// &input,
/// &lut,
/// &bsk,
/// &lwe_params,
/// &glwe_params,
/// &radix,
/// );
///
/// // Check the result matches our plaintext function.
/// let decrypted = encryption::decrypt_lwe(
/// &result,
/// &glwe_sk.to_lwe_secret_key(),
/// &glwe_params.as_lwe_def(),
/// plaintext_bits,
/// );
///
/// let expected = negate(input_plain);
/// assert_eq!(expected, decrypted);
/// ```
///
/// # See also
///
/// For the bivariate version of programmable bootstrapping, see
/// [`programmable_bootstrap_bivariate`](programmable_bootstrap_bivariate) and
/// its associated LUT
/// [`BivariateLookupTable`](crate::entities::BivariateLookupTable).
pub fn programmable_bootstrap<S>(
output: &mut LweCiphertextRef<S>,
input: &LweCiphertextRef<S>,
lut: &UnivariateLookupTableRef<S>,
bootstrap_key: &BootstrapKeyFftRef<Complex<f64>>,
lwe_params: &LweDef,
glwe_params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
// Steps:
// 1. Modulus switch the ciphertext to 2N.
// 2. Use a cmux tree to blind rotate V using the elements of the bootstrap key (the input LWE secret key bits).
// 3. Sample extract.
// 4. (Optional, done outside of this method) Key switch to the output LWE
// secret key (should be the one extracted from the GLWE key).
let degree = glwe_params.dim.polynomial_degree.0;
let two_n = degree.ilog2() + 1;
// 1. Modulus switch the ciphertext to 2N.
let mut ct = input.to_owned();
modulus_switch(&mut ct, S::BITS, two_n, lwe_params);
let (ct_a, ct_b) = ct.a_b(lwe_params);
// 2. Use a cmux tree to blind rotate V using the elements of the bootstrap
// key (the input LWE secret key bits).
// Perform V_0 ^ X^{-b}
allocate_scratch_ref!(cmux_output, GlweCiphertextRef<S>, (glwe_params.dim));
cmux_output.clear();
rotate_glwe_negative_monomial_negacyclic(
cmux_output,
lut.glwe(),
ct_b.inner().to_u64() as usize,
glwe_params,
);
allocate_scratch_ref!(rotated_ct, GlweCiphertextRef<S>, (glwe_params.dim));
// Perform the cmux tree from the bootstrap key with the relation
// V_n = V_{n-1} ^ X^{a_{n-1} s_{n-1}}
for (a_i, index_select) in ct_a.iter().zip(bootstrap_key.rows(glwe_params, radix)) {
let tmp = cmux_output.to_owned();
// This operation performs a copy so the rotated_ct doesn't need to be
// cleared.
rotate_glwe_positive_monomial_negacyclic(
rotated_ct,
cmux_output,
a_i.inner().to_u64() as usize,
glwe_params,
);
cmux(
cmux_output,
&tmp,
rotated_ct,
index_select,
glwe_params,
radix,
);
}
// 3. Sample extract.
sample_extract(output, cmux_output, 0, glwe_params);
}
/// Evaluate a bivariate function on a packed input.
fn bivariate_function<F>(map: F, input: u64, plaintext_bits: PlaintextBits) -> u64
where
F: Fn(u64, u64) -> u64,
{
let modulus = 1 << plaintext_bits.0;
let lhs = (input / modulus) % modulus;
let rhs = input % modulus;
let result = map(lhs, rhs);
assert!(
result < modulus,
"The result of the bivariate function must be less than the plaintext modulus"
);
result
}
/// Generate a lookup table that takes two inputs and produces a single output.
pub(crate) fn generate_bivariate_lut<S, F>(
output: &mut PolynomialRef<Torus<S>>,
map: F,
params: &GlweDef,
plaintext_bits: PlaintextBits,
carry_bits: CarryBits,
) where
S: TorusOps,
F: Fn(u64, u64) -> u64,
{
assert!(
plaintext_bits.0 <= carry_bits.0,
"The number of plaintext bits must be less than or equal to the number of carry bits"
);
let wrapped_func = |input: u64| bivariate_function(&map, input, plaintext_bits);
generate_lut(
output,
wrapped_func,
params,
PlaintextBits(plaintext_bits.0 + carry_bits.0),
);
}
/// Programmable bootstrapping with a bivariate function.
///
/// The LUT this is a table that maps two inputs into a single output.
/// For example, say we want to encode the xor function `f(x, y) = (x + y) % 2`
/// into a lookup table. We would create a
/// [`BivariateLookupTable`](crate::entities::BivariateLookupTable) that
/// implements this function and then execute it on the input ciphertexts.
///
/// Important note: This function does not perform key switching. The output
/// ciphertext will be encrypted under the LWE key extracted from the GLWE
/// secret key used for the bootstrapping key. To perform a keyswitch, use
/// [`keyswitch_lwe_to_lwe`](crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe)
/// after the bootstrapping operation.
///
/// # Example
///
/// ```
/// use sunscreen_tfhe::{
/// high_level::{keygen, encryption, fft},
/// entities::{BivariateLookupTable, LweCiphertext},
/// ops::bootstrapping::programmable_bootstrap_bivariate,
/// params::{
/// GLWE_1_1024_80,
/// LWE_512_80,
/// CarryBits,
/// PlaintextBits,
/// RadixDecomposition,
/// RadixCount,
/// RadixLog
/// },
/// };
///
/// // Parameters defining the scheme we are using
/// let lwe_params = LWE_512_80;
/// let glwe_params = GLWE_1_1024_80;
/// let radix = RadixDecomposition {
/// count: RadixCount(3),
/// radix_log: RadixLog(4),
/// };
///
/// // We will be showing a binary bivariate function, but bivariate
/// // bootstrapping can be done on more plaintext bits. Note that the effective
/// // number of plaintext bits used is twice the number of plaintext bits
/// // specified because the inputs are packed into one ciphertext inside
/// // `programmable_bootstrap_bivariate`. The number of carry bits must always
/// // be greater than or equal to the number of plaintext bits.
/// let plaintext_bits = PlaintextBits(1);
/// let plaintext_bits_carry = PlaintextBits(2);
/// let carry_bits = CarryBits(1);
///
/// // The bivariate function we want to evaluate, encoded as a lookup table.
/// let xor = |x, y| (x + y) % (1 << plaintext_bits.0);
/// let lut = BivariateLookupTable::trivial_from_fn(
/// &xor,
/// &glwe_params,
/// plaintext_bits,
/// carry_bits
/// );
///
/// // Generate the secret keys and the bootstrapping key
/// let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
/// let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
///
/// let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, &lwe_params, &glwe_params, &radix);
/// let bsk =
/// fft::fft_bootstrap_key(&bsk, &lwe_params, &glwe_params, &radix);
///
/// // Specify the inputs
/// let left_input_plain = 0;
/// let right_input_plain = 1;
///
/// // Encrypt the inputs. Note we are adding carry bits to the inputs.
/// let left_input = encryption::encrypt_lwe_secret(
/// left_input_plain,
/// &lwe_sk,
/// &lwe_params,
/// plaintext_bits_carry
/// );
/// let right_input = encryption::encrypt_lwe_secret(
/// right_input_plain,
/// &lwe_sk,
/// &lwe_params,
/// plaintext_bits_carry
/// );
///
/// // Perform the programmable bootstrapping
/// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def());
/// programmable_bootstrap_bivariate(
/// &mut result,
/// &left_input,
/// &right_input,
/// &lut,
/// &bsk,
/// &lwe_params,
/// &glwe_params,
/// plaintext_bits,
/// &radix,
/// );
///
/// // Check the result matches our plaintext function.
/// let decrypted = encryption::decrypt_lwe_with_carry(
/// &result,
/// &glwe_sk.to_lwe_secret_key(),
/// &glwe_params.as_lwe_def(),
/// plaintext_bits,
/// carry_bits
/// );
///
/// let expected = xor(left_input_plain, right_input_plain);
/// assert_eq!(expected, decrypted);
/// ```
///
/// # See also
///
/// For the univariate version of programmable bootstrapping, see
/// [`programmable_bootstrap`](programmable_bootstrap) and its associated LUT
/// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable).
#[allow(clippy::too_many_arguments)]
pub fn programmable_bootstrap_bivariate<S>(
output: &mut LweCiphertextRef<S>,
left_input: &LweCiphertextRef<S>,
right_input: &LweCiphertextRef<S>,
lut: &BivariateLookupTableRef<S>,
bootstrap_key: &BootstrapKeyFftRef<Complex<f64>>,
lwe_params: &LweDef,
glwe_params: &GlweDef,
plaintext_bits: PlaintextBits,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
// The general operation for a bivariate PBS is
//
// 1. Ensure that the number of carry bits is equal to the size of the
// message or greater.
// 2. Define a LUT where the function takes in one input and decomposes that
// input into the higher n plaintext and lower n plaintext bits. The
// higher n bits are the left input to the bivariate function, while the
// lower n bits are the right input to the bivariate function.
// 3. Encrypt the two input ciphertexts using the number of carry bits and
// the plaintext bits, with padding.
// 4. On the left encrypted input, shift it up by the number of plaintext
// bits by multiplying the ciphertext by the plaintext modulus.
// 5. Add the left and right encrypted inputs together.
// 6. Perform the programmable bootstrapping with this combined input.
let shift = (1 << plaintext_bits.0) as u64;
allocate_scratch_ref!(pbs_input, LweCiphertextRef<S>, (lwe_params.dim));
pbs_input.clear();
// (left * modulus) + right to pack the two inputs into a single LWE
scalar_mul_ciphertext_mad(pbs_input, &S::from_u64(shift), left_input, lwe_params);
add_lwe_inplace(pbs_input, right_input, lwe_params);
programmable_bootstrap(
output,
pbs_input,
lut.as_univariate(),
bootstrap_key,
lwe_params,
glwe_params,
radix,
)
}
#[cfg(test)]
mod tests {
use crate::{
entities::{
BivariateLookupTable, BootstrapKey, BootstrapKeyFft, LweCiphertext, LweKeyswitchKey,
UnivariateLookupTable,
},
high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX},
ops::{
encryption::{decrypt_ggsw_ciphertext, encrypt_lwe_ciphertext},
keyswitch::lwe_keyswitch_key::generate_keyswitch_key_lwe,
},
RoundedDiv, GLWE_1_1024_80,
};
use super::*;
fn generate_negacyclic_lut_from_formula<S>(
params: &GlweDef,
plaintext_bits: PlaintextBits,
) -> Polynomial<Torus<S>>
where
S: TorusOps,
{
let mut output = Polynomial::<Torus<S>>::zero(params.dim.polynomial_degree.0);
let p = (1 << plaintext_bits.0) as u64;
let n = params.dim.polynomial_degree.0 as u64;
let divisor = 2 * n;
for (j, c) in output.coeffs_mut().iter_mut().enumerate() {
let v_i = ((p * (j as u64)).div_rounded(divisor)) % p;
let v_i = v_i << (S::BITS - plaintext_bits.0);
*c = Torus::from(S::from_u64(v_i));
}
output
}
#[test]
fn can_generate_negacyclic_lut() {
let p = PlaintextBits(4);
let params = TEST_GLWE_DEF_1;
let mut poly = Polynomial::<Torus<u64>>::zero(params.dim.polynomial_degree.0);
generate_negacyclic_lut(&mut poly, |x| x, &params, p);
let expected = generate_negacyclic_lut_from_formula(&params, p);
assert_eq!(expected, poly);
}
#[test]
fn can_generate_bootstrap_key() {
let lwe_params = TEST_LWE_DEF_1;
let glwe_params = TEST_GLWE_DEF_1;
let radix = TEST_RADIX;
let sk = keygen::generate_binary_lwe_sk(&lwe_params);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
let mut bootstrap_key = BootstrapKey::new(&lwe_params, &glwe_params, &radix);
generate_bootstrap_key(&mut bootstrap_key, &sk, &glwe_sk, &glwe_params, &radix);
let mut count = 0;
for (s_i, ct) in sk.s().iter().zip(bootstrap_key.rows(&glwe_params, &radix)) {
let mut msg = Polynomial::<Torus<u64>>::zero(glwe_params.dim.polynomial_degree.0);
decrypt_ggsw_ciphertext(&mut msg, ct, &glwe_sk, &glwe_params, &radix);
assert_eq!(msg.coeffs()[0].inner(), *s_i);
count += 1
}
assert_eq!(count, sk.s().len());
}
fn bootstrap_helper(map: impl Fn(u64) -> u64) {
let bits = PlaintextBits(3);
let lwe = TEST_LWE_DEF_1;
let glwe = GLWE_1_1024_80;
let radix = TEST_RADIX;
let original_sk = keygen::generate_binary_lwe_sk(&lwe);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe);
// We want to switch from the sample extracted key to the new key.
let mut ksk = LweKeyswitchKey::<u64>::new(&glwe.as_lwe_def(), &lwe, &radix);
generate_keyswitch_key_lwe(
&mut ksk,
glwe_sk.to_lwe_secret_key(),
&original_sk,
&lwe,
&radix,
);
let mut bsk_nonfft = BootstrapKey::new(&lwe, &glwe, &radix);
generate_bootstrap_key(&mut bsk_nonfft, &original_sk, &glwe_sk, &glwe, &radix);
let mut bsk = BootstrapKeyFft::new(&lwe, &glwe, &radix);
bsk_nonfft.fft(&mut bsk, &glwe, &radix);
// Generate the LUT
let lut = UnivariateLookupTable::trivial_from_fn(&map, &glwe, bits);
let mut failed = Vec::new();
for msg in 0..(1 << bits.0) {
let mut original_ct = LweCiphertext::new(&lwe);
// Adding a padding bit
let encoded_msg = msg << (64 - bits.0 - 1);
encrypt_lwe_ciphertext(
&mut original_ct,
&original_sk,
Torus::from(encoded_msg),
&lwe,
);
let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def());
programmable_bootstrap(&mut new_ct, &original_ct, &lut, &bsk, &lwe, &glwe, &radix);
let decoded = glwe_sk
.to_lwe_secret_key()
.decrypt(&new_ct, &glwe.as_lwe_def(), bits);
let result = map(msg);
if result != decoded {
failed.push((result, decoded));
}
}
if !failed.is_empty() {
panic!(
"Failed to decrypt the following messages and decrypted values: {:?}",
failed
);
}
}
#[test]
fn can_bootstrap() {
bootstrap_helper(|x| x);
}
#[test]
fn can_bootstrap_with_map() {
bootstrap_helper(|x| (x + 3) % 8);
}
fn bivariate_bootstrap_helper(map: impl Fn(u64, u64) -> u64) {
let lwe = TEST_LWE_DEF_1;
let glwe = TEST_GLWE_DEF_1;
let _radix = TEST_RADIX;
let bits = PlaintextBits(1);
let carry_bits = CarryBits(1);
let radix = TEST_RADIX;
let original_sk = keygen::generate_binary_lwe_sk(&lwe);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe);
// We want to switch from the sample extracted key to the new key.
let mut ksk = LweKeyswitchKey::<u64>::new(&glwe.as_lwe_def(), &lwe, &radix);
generate_keyswitch_key_lwe(
&mut ksk,
glwe_sk.to_lwe_secret_key(),
&original_sk,
&lwe,
&radix,
);
let mut bsk_nonfft = BootstrapKey::new(&lwe, &glwe, &radix);
generate_bootstrap_key(&mut bsk_nonfft, &original_sk, &glwe_sk, &glwe, &radix);
let mut bsk = BootstrapKeyFft::new(&lwe, &glwe, &radix);
bsk_nonfft.fft(&mut bsk, &glwe, &radix);
// Generate the LUT
let lut = BivariateLookupTable::trivial_from_fn(&map, &glwe, bits, carry_bits);
let mut failed = Vec::new();
let mut succeeded = Vec::new();
for left_msg in 0..(1 << bits.0) {
for right_msg in 0..(1 << bits.0) {
let mut left_ct = LweCiphertext::new(&lwe);
let mut right_ct = LweCiphertext::new(&lwe);
// Adding a padding bit, hence the - 1
let encoded_left_msg = left_msg << (64 - bits.0 - carry_bits.0 - 1);
let encoded_right_msg = right_msg << (64 - bits.0 - carry_bits.0 - 1);
encrypt_lwe_ciphertext(
&mut left_ct,
&original_sk,
Torus::from(encoded_left_msg),
&lwe,
);
encrypt_lwe_ciphertext(
&mut right_ct,
&original_sk,
Torus::from(encoded_right_msg),
&lwe,
);
let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def());
programmable_bootstrap_bivariate(
&mut new_ct,
&left_ct,
&right_ct,
&lut,
&bsk,
&lwe,
&glwe,
bits,
&radix,
);
let decrypted = glwe_sk
.to_lwe_secret_key()
.decrypt_without_decode(&new_ct, &glwe.as_lwe_def());
// We manually decode here because the
let plain_bits = bits;
let round_bit = decrypted
.inner()
.wrapping_shr(64 - plain_bits.0 - carry_bits.0 - 1)
& 0x1;
let mask = (0x1 << plain_bits.0) - 1;
let decoded = (decrypted
.inner()
.wrapping_shr(64 - plain_bits.0 - carry_bits.0)
+ round_bit)
& mask;
let result = map(left_msg, right_msg);
if result != decoded {
failed.push(((left_msg, right_msg), result, decoded));
} else {
succeeded.push(((left_msg, right_msg), result, decoded));
}
}
}
if !failed.is_empty() {
panic!(
"Failed to decrypt the following messages and decrypted values (as ((left input, right_input), expected, decrypted)): {:?}. However, the following messages and decrypted values succeeded: {:?}",
failed, succeeded
);
}
}
fn bivariate_test_function(left: u64, right: u64) -> u64 {
(left + right) % 2
}
#[test]
fn can_bootstrap_with_bivariate_map() {
bivariate_bootstrap_helper(bivariate_test_function);
}
#[test]
fn can_decompose_bivariate_map() {
let plaintext_bits = PlaintextBits(2);
let modulus = 1 << plaintext_bits.0;
let map = &bivariate_test_function;
for left in 0u64..(plaintext_bits.0 as u64) {
for right in 0u64..(plaintext_bits.0 as u64) {
let left_shifted = left * modulus;
let input = left_shifted + right;
let result = bivariate_function(map, input, plaintext_bits);
assert_eq!(result, map(left, right));
}
}
}
}

View File

@@ -0,0 +1,58 @@
use crate::{
entities::{GlevCiphertextRef, GlweCiphertextRef, Polynomial},
radix::{PolynomialRadixIterator, ScalarRadixIterator},
GlweDef, TorusOps,
};
use super::{glwe_polynomial_mad, glwe_scalar_mad};
/// Compute `c += (G^-1 * a) \[*\] b`, where
/// * `G^-1 * a`` is the radix decomposition of `a`
/// * `b` is a GLEV ciphertext.
/// * `c` is a GLWE ciphertext.
/// * \[*\] is the external product between a GLEV ciphertext and `l` polynomials
///
/// # Remarks
/// This functions takes a PolynomialRadixIterator to perform the decomposition.
pub fn decomposed_polynomial_glev_mad<S>(
c: &mut GlweCiphertextRef<S>,
mut a: PolynomialRadixIterator<S>,
b: &GlevCiphertextRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
// a = decomp(a_i)
// b = r
let b_glwe = b.glwe_ciphertexts(params);
let mut cur_radix: Polynomial<S> = Polynomial::zero(params.dim.polynomial_degree.0);
// The decomposition of
// <Decomp^{beta, l}(gamma), GLEV>
// can be performed using
// sum_{j = 1}^l gamma_j * C_j
// where gamma_j is the polynomial to decompose multiplied by q/B^{j+1}
// Note the reverse of the GLWE ciphertexts here! The decomposition iterator
// returns the decomposed values in the opposite order.
for b in b_glwe.rev() {
a.write_next(&mut cur_radix);
glwe_polynomial_mad(c, b, &cur_radix, params);
}
}
/// Compute `c += (G^-1 * a) \[*\] b`, where
/// * `G^-1 * a`` is the radix decomposition of `a`
/// * `b` is a GLEV ciphertext.
pub fn decomposed_scalar_glev_mad<S>(
c: &mut GlweCiphertextRef<S>,
a: ScalarRadixIterator<S>,
b: &GlevCiphertextRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
for (b, a) in b.glwe_ciphertexts(params).rev().zip(a) {
glwe_scalar_mad(c, b, a, params);
}
}

View File

@@ -0,0 +1,549 @@
use crate::{
dst::FromMutSlice,
entities::{
GgswCiphertextRef, GlweCiphertext, GlweCiphertextRef, LweCiphertextRef, PolynomialRef,
},
ops::ciphertext::decomposed_polynomial_glev_mad,
polynomial::{
polynomial_add, polynomial_external_mad, polynomial_negate, polynomial_scalar_mad,
polynomial_sub,
},
radix::PolynomialRadixIterator,
scratch::allocate_scratch_ref,
GlweDef, RadixDecomposition, TorusOps,
};
/**
* Extract a specific coefficient in a message M in a GLWE ciphertext as a LWE
* ciphertext under the LWE extracted secret key (extracted from the GLWE secret
* key).
*
* # Arguments
*
* * `output` - The output LWE ciphertext
* * `glwe` - The input GLWE ciphertext
* * `h` - The index of the coefficient to extract
*
* # Remarks
* For a GLWE ciphertext of size k and dimension N, the output LWE ciphertext
* passed in must have size k*N.
*/
pub fn sample_extract<S>(
output: &mut LweCiphertextRef<S>,
glwe: &GlweCiphertextRef<S>,
h: usize,
params: &GlweDef,
) where
S: TorusOps,
{
// We are copying parts of the GLWE ciphertext out according to the following rule:
// a_{N*i + j} = a_{i, h - j} for 0 <= i < k, 0 <= j <= h
// a_{N*i + j} = -a_{i, h - j + n} for 0 <= i < k, n + 1 <= j < N
// b = b_n
#[allow(non_snake_case)]
let N = params.dim.polynomial_degree.0;
let k = params.dim.size.0;
let lwe_size = k * N;
let (a_lwe, b_lwe) = output.a_b_mut(&params.as_lwe_def());
// Make sure that the correctly sized LWE was passed in.
assert_eq!(lwe_size, a_lwe.len());
let (a_glwe, b_glwe) = glwe.a_b(params);
for (i, a_gwe_i) in a_glwe.enumerate() {
#[allow(non_snake_case)]
let Ni = N * i;
let a_glwe_i_coeffs = a_gwe_i.coeffs();
for j in 0..=h {
a_lwe[Ni + j] = a_glwe_i_coeffs[h - j];
}
for j in (h + 1)..N {
// Note we add N to h first, otherwise h - j might underflow.
a_lwe[Ni + j] = num::traits::WrappingNeg::wrapping_neg(&a_glwe_i_coeffs[h + N - j]);
}
}
*b_lwe = b_glwe.coeffs()[h];
}
/// Add two GLWE ciphertexts together, storing the result in `c`.
pub fn add_glwe_ciphertexts<S>(
c: &mut GlweCiphertextRef<S>,
a: &GlweCiphertextRef<S>,
b: &GlweCiphertextRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
let (b_a, b_b) = b.a_b(params);
assert_eq!(c_a.len(), a_a.len());
assert_eq!(c_a.len(), b_a.len());
for (c, (a, b)) in c_a.zip(a_a.zip(b_a)) {
polynomial_add(c, a, b);
}
polynomial_add(c_b, a_b, b_b);
}
/// Subtract two GLWE ciphertexts together, storing the result in `c`.
pub fn sub_glwe_ciphertexts<S>(
c: &mut GlweCiphertextRef<S>,
a: &GlweCiphertextRef<S>,
b: &GlweCiphertextRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
let (b_a, b_b) = b.a_b(params);
assert_eq!(c_a.len(), a_a.len());
assert_eq!(c_a.len(), b_a.len());
for (c, (a, b)) in c_a.zip(a_a.zip(b_a)) {
polynomial_sub(c, a, b);
}
polynomial_sub(c_b, a_b, b_b);
}
/// Homomorphically compute -ct.
///
/// # Remarks
/// This operation is noiseless.
pub fn glwe_negate_inplace<S>(ct: &mut GlweCiphertextRef<S>, params: &GlweDef)
where
S: TorusOps,
{
for a in ct.a_mut(params) {
polynomial_negate(a);
}
polynomial_negate(ct.b_mut(params));
}
/// Compute c += a \[*\] b where \[*\] is the external product between a GLWE
/// ciphertext and an polynomial in Z\[X\]/(X^N + 1).
///
/// # Remarks
/// For this to produce the correct result, degree(b) must be "small" (ideally
/// 0) and the coefficient must be small (i.e. less than the message size).
pub fn glwe_polynomial_mad<S>(
c: &mut GlweCiphertextRef<S>,
a: &GlweCiphertextRef<S>,
b: &PolynomialRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
assert_eq!(c_a.len(), params.dim.size.0);
assert_eq!(a_a.len(), params.dim.size.0);
for (c, a) in c_a.zip(a_a) {
polynomial_external_mad(c, a, b);
}
polynomial_external_mad(c_b, a_b, b);
}
/// Compute `c += a \[*\] b`` where
/// * `a` is a GLWE ciphertext
/// * `b` is a scalar
/// * `\[*\]` is the external product operator GLWE \[*\] Z -> GLWE
pub fn glwe_scalar_mad<S>(
c: &mut GlweCiphertextRef<S>,
a: &GlweCiphertextRef<S>,
b: S,
params: &GlweDef,
) where
S: TorusOps,
{
for (c, a) in c.a_mut(params).zip(a.a(params)) {
polynomial_scalar_mad(c, a, b);
}
polynomial_scalar_mad(c.b_mut(params), a.b(params), b);
}
/// Compute `c += a \[*\] b`` where
/// * `a` is a GLWE ciphertext
/// * `b` is a GGSW cipheetext
/// * `\[*\]` is the external product operator GGSW \[*\] GLWE -> GLWE`
pub fn glwe_ggsw_mad<S>(
c: &mut GlweCiphertextRef<S>,
a: &GlweCiphertextRef<S>,
b: &GgswCiphertextRef<S>,
glwe_def: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
let (a_a, a_b) = a.a_b(glwe_def);
let rows = b.rows(glwe_def, radix);
// Generate an iterator that includes a_a and a_b
let a_then_b_glwe_polynomials = a_a.chain(std::iter::once(a_b));
allocate_scratch_ref!(scratch, PolynomialRef<S>, (glwe_def.dim.polynomial_degree));
// Performs the external operation
//
// GGSW ⊡ GLWE = sum_i=0^k <Decomp^{beta, l}(AB_i), C_i>
//
// where
// * `beta` is the decomposition base
// * `l` is the decomposition level
// * `AB_i` is the i-th polynomial of the GLWE ciphertext with B at the end {A, B}
// * `C_i` is the i-th row of the GGSW ciphertext
for (a_i, r) in a_then_b_glwe_polynomials.zip(rows) {
// For each a polynomial, compute the external product with the GLEV ciphertext
// at index i in the GGSW ciphertext.
let decomp = PolynomialRadixIterator::new(a_i, scratch, radix);
decomposed_polynomial_glev_mad(c, decomp, r, glwe_def);
}
}
/// Compute external product of a GLWE ciphertext and a GGSW ciphertext.
/// GGSW ⊡ GLWE -> GLWE
pub fn external_product_ggsw_glwe<S>(
ggsw: &GgswCiphertextRef<S>,
glwe: &GlweCiphertextRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) -> GlweCiphertext<S>
where
S: TorusOps,
{
// Zero GLWE to sum over.
let mut result = GlweCiphertext::new(params);
glwe_ggsw_mad(&mut result, glwe, ggsw, params, radix);
result
}
#[cfg(test)]
mod tests {
use crate::{
entities::{GgswCiphertext, LweCiphertext, Polynomial},
high_level::*,
high_level::{keygen, TEST_GLWE_DEF_1},
ops::encryption::{
decrypt_ggsw_ciphertext, encrypt_ggsw_ciphertext, encrypt_glwe_ciphertext_secret,
trivially_encrypt_glwe_ciphertext,
},
polynomial::polynomial_mad,
PlaintextBits, Torus,
};
use super::*;
use rand::{thread_rng, RngCore};
#[test]
fn polynomial_iteration_mut() {
let glwe = TEST_GLWE_DEF_1;
let mut sk = keygen::generate_binary_glwe_sk(&glwe);
assert_eq!(sk.s(&glwe).count(), glwe.dim.size.0);
for s_i in sk.s_mut(&glwe) {
assert_eq!(s_i.len(), glwe.dim.polynomial_degree.0);
for s in s_i.coeffs() {
assert!(*s == 0 || *s == 1);
}
}
}
#[test]
fn can_add_glwe_ciphertexts() {
let bits = PlaintextBits(4);
let glwe = TEST_GLWE_DEF_1;
let sk = keygen::generate_binary_glwe_sk(&glwe);
let plaintext = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % 4)
.collect::<Vec<_>>(),
);
let a = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
let b = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
let c = a + b;
let dec = sk.decrypt_decode_glwe(&c, &glwe, bits);
for (i, c) in dec.coeffs().iter().enumerate() {
assert_eq!(*c, 2 * (i as u64 % 4));
}
}
#[test]
fn can_sub_glwe_ciphertexts() {
let glwe = TEST_GLWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_glwe_sk(&glwe);
let plaintext = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % 4)
.collect::<Vec<_>>(),
);
let a = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
let b = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
let c = a - b;
let dec = sk.decrypt_decode_glwe(&c, &glwe, bits);
for c in dec.coeffs() {
assert_eq!(*c, 0);
}
}
#[test]
fn can_internal_product_glwe_polynomial() {
let bits = PlaintextBits(4);
let glwe = TEST_GLWE_DEF_1;
let sk = keygen::generate_binary_glwe_sk(&glwe);
let large_poly = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % 4)
.collect::<Vec<_>>(),
);
let small_poly = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| if x < 1 { 3 } else { 0 })
.collect::<Vec<_>>(),
);
// Do the external product with an encryption of the large polynomial.
let a = sk.encode_encrypt_glwe(&large_poly, &glwe, bits);
let mut c = GlweCiphertext::new(&glwe);
glwe_polynomial_mad(&mut c, &a, &small_poly, &glwe);
let actual = sk.decrypt_decode_glwe(&c, &glwe, bits);
let mut expected = Polynomial::<u64>::zero(glwe.dim.polynomial_degree.0);
polynomial_mad(
expected.as_wrapping_mut(),
large_poly.as_wrapping(),
small_poly.as_wrapping(),
);
assert_eq!(expected, actual);
// Now do reverse the large and small polynomials' roles.
let a = sk.encode_encrypt_glwe(&small_poly, &glwe, bits);
let mut c = GlweCiphertext::new(&glwe);
glwe_polynomial_mad(&mut c, &a, &large_poly, &glwe);
let actual = sk.decrypt_decode_glwe(&c, &glwe, bits);
assert_eq!(expected, actual);
}
#[test]
fn can_external_product_ggsw_glwe() {
let glwe = TEST_GLWE_DEF_1;
let bits = PlaintextBits(4);
let radix = TEST_RADIX;
let sk = keygen::generate_binary_glwe_sk(&glwe);
// Constant polynomial: [1, 0, 0, ...]
let ggsw_plaintext_polynomial = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| if x < 1 { 1 } else { 0 })
.collect::<Vec<_>>(),
);
let ggsw_plaintext_polynomial_torus = ggsw_plaintext_polynomial.map(|x| Torus::from(*x));
let mut ggsw_ct = GgswCiphertext::new(&glwe, &radix);
encrypt_ggsw_ciphertext(
&mut ggsw_ct,
&ggsw_plaintext_polynomial,
&sk,
&glwe,
&radix,
bits,
);
let mut ggsw_decrypt = Polynomial::zero(glwe.dim.polynomial_degree.0);
decrypt_ggsw_ciphertext(&mut ggsw_decrypt, &ggsw_ct, &sk, &glwe, &radix);
assert_eq!(
ggsw_decrypt, ggsw_plaintext_polynomial_torus,
"GGSW decrypt does not match plaintext"
);
// [1, 2, ...]
let glwe_plaintext_polynomial = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % (1 << bits.0))
.collect::<Vec<_>>(),
);
let glwe_plaintext_polynomial_encoded =
glwe_plaintext_polynomial.map(|x| Torus::encode(*x, bits));
// Do the external product with an encryption of the large polynomial.
let mut glwe_ct = GlweCiphertext::new(&glwe);
encrypt_glwe_ciphertext_secret(
&mut glwe_ct,
&glwe_plaintext_polynomial_encoded,
&sk,
&glwe,
);
let glwe_decrypt = sk.decrypt_decode_glwe(&glwe_ct, &glwe, bits);
assert_eq!(
glwe_decrypt, glwe_plaintext_polynomial,
"GLWE decrypt does not match plaintext"
);
let encrypted_result = external_product_ggsw_glwe(&ggsw_ct, &glwe_ct, &glwe, &radix);
let result = sk.decrypt_decode_glwe(&encrypted_result, &glwe, bits);
// expected is the polynomial multiplication of
// ggsw_plaintext_polynomial and glwe_plaintext_polynomial
let mut expected = Polynomial::<u64>::zero(glwe.dim.polynomial_degree.0);
polynomial_mad(
expected.as_wrapping_mut(),
ggsw_plaintext_polynomial.as_wrapping(),
glwe_plaintext_polynomial.as_wrapping(),
);
// Reduce modulo 2^BITS
let expected = expected.map(|x| x % (1 << bits.0));
// Verify that the ciphertext multiplication is correct
assert_eq!(
result, expected,
"External product does not match polynomial multiplication"
);
}
#[test]
fn can_add_glwe_with_trivial_glwe() {
let glwe = TEST_GLWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_glwe_sk(&glwe);
let delta = 1u64 << (64 - bits.0);
let large_poly = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % 4)
.collect::<Vec<_>>(),
);
let small_poly = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| if x < 1 { 3 } else { 0 })
.collect::<Vec<_>>(),
);
let small_poly_scaled = small_poly.map(|x| Torus::from(x * delta));
let a = sk.encode_encrypt_glwe(&large_poly, &glwe, bits);
let mut b = GlweCiphertext::new(&glwe);
trivially_encrypt_glwe_ciphertext(&mut b, &small_poly_scaled, &glwe);
let c = a.as_ref() + b.as_ref();
let actual = sk.decrypt_decode_glwe(&c, &glwe, bits);
let expected = small_poly + large_poly;
assert_eq!(expected, actual);
let c2 = b.as_ref() + a.as_ref();
let actual = sk.decrypt_decode_glwe(&c2, &glwe, bits);
assert_eq!(expected, actual);
}
#[test]
fn test_sample_extract() {
let bits = PlaintextBits(2);
let glwe_params = TEST_GLWE_DEF_1;
let lwe_params = glwe_params.as_lwe_def();
let sk = keygen::generate_binary_glwe_sk(&glwe_params);
let extracted_lwe_sk = sk.to_lwe_secret_key();
let large_poly = Polynomial::new(
&(0..glwe_params.dim.polynomial_degree.0 as u64)
.map(|x| x % 4)
.collect::<Vec<_>>(),
);
let glwe = sk.encode_encrypt_glwe(&large_poly, &glwe_params, bits);
for h in 0..glwe_params.dim.polynomial_degree.0 {
let mut lwe = LweCiphertext::new(&lwe_params);
sample_extract(&mut lwe, &glwe, h, &glwe_params);
let lwe_msg = extracted_lwe_sk.decrypt(&lwe, &lwe_params, bits);
let expected = large_poly.coeffs()[h];
assert_eq!(expected, lwe_msg);
}
}
#[test]
fn can_glwe_scalar_mad() {
for _ in 0..20 {
let sk = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
let plaintext_bits = (thread_rng().next_u64()) % 8 + 1;
let plaintext_bits = PlaintextBits(plaintext_bits as u32);
let scalar = thread_rng().next_u64() % 64;
let pt = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0)
.map(|_| thread_rng().next_u64() % plaintext_bits.0 as u64)
.collect::<Vec<_>>();
let pt = Polynomial::new(&pt);
let ct = sk.encode_encrypt_glwe(&pt, &TEST_GLWE_DEF_1, plaintext_bits);
let mut result = GlweCiphertext::new(&TEST_GLWE_DEF_1);
glwe_scalar_mad(&mut result, &ct, scalar, &TEST_GLWE_DEF_1);
let actual = sk.decrypt_decode_glwe(&result, &TEST_GLWE_DEF_1, plaintext_bits);
for (pt, actual) in pt.coeffs().iter().zip(actual.coeffs()) {
let expected = pt.wrapping_mul(scalar) % (0x1u64 << plaintext_bits.0);
assert_eq!(expected, *actual);
}
}
}
}

View File

@@ -0,0 +1,42 @@
use crate::{
entities::{LevCiphertextRef, LweCiphertextRef, Polynomial},
radix::PolynomialRadixIterator,
LweDef, TorusOps,
};
use super::scalar_mul_ciphertext_mad;
/// Compute `c += (G^-1 * a) \[*\] b`, where
/// * `G^-1 * a`` is the radix decomposition of `a`
/// * `b` is a LEV ciphertext.
/// * `c` is a LWE ciphertext.
/// * \[*\] is the external product between a LEV ciphertext and the decomposed
/// LWE ciphertext.
///
/// # Remarks
/// This functions takes a PolynomialRadixIterator to perform the decomposition.
pub fn decomposed_scalar_lev_mad<S>(
c: &mut LweCiphertextRef<S>,
mut a: PolynomialRadixIterator<S>,
b: &LevCiphertextRef<S>,
params: &LweDef,
) where
S: TorusOps,
{
let b_lwe = b.lwe_ciphertexts(params);
let mut cur_radix: Polynomial<S> = Polynomial::zero(1);
// The decomposition of
// <Decomp^{beta, l}(gamma), GLEV>
// can be performed using
// sum_{j = 1}^l gamma_j * C_j
// where gamma_j is the polynomial to decompose multiplied by q/B^{j+1}
// Note the reverse of the GLWE ciphertexts here! The decomposition iterator
// returns the decomposed values in the opposite order.
for b in b_lwe.rev() {
a.write_next(&mut cur_radix);
let radix = cur_radix.coeffs()[0];
scalar_mul_ciphertext_mad(c, &radix, b, params);
}
}

View File

@@ -0,0 +1,93 @@
use crate::{entities::LweCiphertextRef, LweDef, RoundedDiv, Torus, TorusOps};
/// Add the coefficients of a to the coefficients of c in place.
pub fn add_lwe_inplace<S>(c: &mut LweCiphertextRef<S>, a: &LweCiphertextRef<S>, params: &LweDef)
where
S: TorusOps,
{
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
assert_eq!(c_a.len(), a_a.len());
for (c, a) in c_a.iter_mut().zip(a_a.iter()) {
*c = num::traits::WrappingAdd::wrapping_add(c, a);
}
*c_b = num::traits::WrappingAdd::wrapping_add(c_b, a_b);
}
/// Subtract one LWE ciphertext from another, storing the result in the provided
/// output variable. Mostly meant to be used reduce the number of allocations
/// and with functions like [allocate_scratch_ref].
pub(crate) fn sub_lwe_ciphertexts<S>(
c: &mut LweCiphertextRef<S>,
a: &LweCiphertextRef<S>,
b: &LweCiphertextRef<S>,
params: &LweDef,
) where
S: TorusOps,
{
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
let (b_a, b_b) = b.a_b(params);
assert_eq!(c_a.len(), a_a.len());
assert_eq!(c_a.len(), b_a.len());
for (c, (a, b)) in c_a.iter_mut().zip(a_a.iter().zip(b_a.iter())) {
*c = num::traits::WrappingSub::wrapping_sub(a, b);
}
*c_b = num::traits::WrappingSub::wrapping_sub(a_b, b_b);
}
/// Multiplies an LWE ciphertext by a scalar, storing the result in the provided
/// output variable. Mostly meant to be used reduce the number of allocations
/// and with functions like [allocate_scratch_ref].
pub(crate) fn scalar_mul_ciphertext_mad<S>(
c: &mut LweCiphertextRef<S>,
scalar: &S,
a: &LweCiphertextRef<S>,
params: &LweDef,
) where
S: TorusOps,
{
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
assert_eq!(c_a.len(), a_a.len());
for (c, a) in c_a.iter_mut().zip(a_a.iter()) {
*c += a * scalar;
}
*c_b += a_b * scalar;
}
/// Perform modulus switching on a ciphertext. We are assuming that moduli are
/// both powers of two, and that the original number of bits is greater than the
/// new number of bits.
pub fn modulus_switch<S>(
ct: &mut LweCiphertextRef<S>,
original_bits: u32,
new_bits: u32,
params: &LweDef,
) where
S: TorusOps,
{
let (c_a, c_b) = ct.a_b_mut(params);
// We specifically want to zero out the MSBs instead of shifting them back
// around.
for a in c_a {
let c = a.inner().to_u64() as u128;
let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128);
*a = Torus::from(S::from_u64(res as u64));
}
let c = c_b.inner().to_u64() as u128;
let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128);
*c_b = Torus::from(S::from_u64(res as u64));
}

View File

@@ -0,0 +1,11 @@
mod lwe_ciphertext_ops;
pub use lwe_ciphertext_ops::*;
mod lev_ciphertext_ops;
pub use lev_ciphertext_ops::*;
mod glev_ciphertext_ops;
pub use glev_ciphertext_ops::*;
mod glwe_ciphertext_ops;
pub use glwe_ciphertext_ops::*;

View File

@@ -0,0 +1,403 @@
use num::Zero;
use crate::{
entities::{GgswCiphertextRef, GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef},
polynomial::{polynomial_external_mad, polynomial_scalar_mad, polynomial_scalar_mul},
GlweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps,
};
use super::{
decrypt_glwe_ciphertext, encrypt_glwe_ciphertext_secret,
trivially_encrypt_glwe_with_sk_argument,
};
/// Perform a ggsw encryption. This is generic in case a trivial GGSW encryption
/// is wanted (for example, for testing purposes).
pub(crate) fn encrypt_ggsw_ciphertext_generic<S>(
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
msg: &PolynomialRef<S>,
glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
encrypt: impl Fn(
&mut GlweCiphertextRef<S>,
&PolynomialRef<Torus<S>>,
&GlweSecretKeyRef<S>,
&GlweDef,
),
) where
S: TorusOps,
{
let max_val = S::from_u64(0x1 << plaintext_bits.0);
assert!(msg.coeffs().iter().all(|x| *x < max_val));
let decomposition_radix_log = radix.radix_log.0;
let polynomial_degree = params.dim.polynomial_degree.0;
let glwe_size = params.dim.size.0;
// k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption
// of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of
// M/B^{j+1}
for (i, row) in ggsw_ciphertext.rows_mut(params, radix).enumerate() {
let mut m_times_s = Polynomial::<Torus<S>>::zero(polynomial_degree);
let m_times_s = if i < glwe_size {
// The message is composed of the negated secret key and the message
// for all but the last row.
let s = glwe_secret_key.s(params).nth(i).unwrap();
polynomial_external_mad(&mut m_times_s, msg.as_torus(), s);
// Negate the product.
for c in m_times_s.coeffs_mut().iter_mut() {
// Have to call the trait directly because deref is implemented on Torus
*c = num::traits::WrappingNeg::wrapping_neg(c);
}
&m_times_s
} else {
// Last row isn't multiplied by secret key.
msg.as_torus()
};
for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() {
let mut scaled_msg = Polynomial::zero(polynomial_degree);
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
let decomp_factor =
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor);
encrypt(col, &scaled_msg, glwe_secret_key, params);
}
}
}
/// Encrypt a GGSW ciphertext with a given message polynomial and secret key.
pub fn encrypt_ggsw_ciphertext<S>(
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
msg: &PolynomialRef<S>,
glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
{
encrypt_ggsw_ciphertext_generic(
ggsw_ciphertext,
msg,
glwe_secret_key,
params,
radix,
plaintext_bits,
encrypt_glwe_ciphertext_secret,
);
}
/// Encrypt a GGSW ciphertext with a given message polynomial and secret key.
/// This is a trivial encryption that doesn't use the secret key and is not
/// secure.
pub fn trivially_encrypt_ggsw_ciphertext<S>(
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
msg: &PolynomialRef<S>,
glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
{
encrypt_ggsw_ciphertext_generic(
ggsw_ciphertext,
msg,
glwe_secret_key,
params,
radix,
plaintext_bits,
trivially_encrypt_glwe_with_sk_argument,
);
}
/// Encrypt scalar (i.e. degree 0 polynomial) msg as a GGSW ciphertext.
pub fn encrypt_ggsw_ciphertext_scalar<S>(
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
msg: S,
glwe_secret_key: &GlweSecretKeyRef<S>,
glwe_def: &GlweDef,
radix: &RadixDecomposition,
plaintext_bits: PlaintextBits,
) where
S: TorusOps,
{
let max_val = S::from_u64(0x1 << plaintext_bits.0);
assert!(msg < max_val);
let decomposition_radix_log = radix.radix_log.0;
let polynomial_degree = glwe_def.dim.polynomial_degree.0;
let glwe_size = glwe_def.dim.size.0;
// k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption
// of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of
// M/B^{j+1}
for (i, row) in ggsw_ciphertext.rows_mut(glwe_def, radix).enumerate() {
let mut m_times_s = Polynomial::<Torus<S>>::zero(polynomial_degree);
let m_times_s = if i < glwe_size {
let s = glwe_secret_key.s(glwe_def).nth(i).unwrap();
polynomial_scalar_mad(&mut m_times_s, s.as_torus(), msg);
&m_times_s
} else {
// Last row isn't multiplied by secret key.
m_times_s.clear();
m_times_s.coeffs_mut()[0] = Torus::from(msg);
&m_times_s
};
for (j, col) in row.glwe_ciphertexts_mut(glwe_def).enumerate() {
let mut scaled_msg = Polynomial::zero(polynomial_degree);
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
let decomp_factor =
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
if i < glwe_size {
let decomp_factor = decomp_factor.wrapping_neg();
polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor);
} else {
scaled_msg.coeffs_mut()[0] = Torus::from(msg.wrapping_mul(&decomp_factor));
for c in scaled_msg.coeffs_mut().iter_mut().skip(1) {
*c = Torus::zero();
}
}
encrypt_glwe_ciphertext_secret(col, &scaled_msg, glwe_secret_key, glwe_def);
}
}
}
fn decrypt_glwe_in_ggsw<S>(
msg: &mut PolynomialRef<Torus<S>>,
ggsw_ciphertext: &GgswCiphertextRef<S>,
glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
row: usize,
column: usize,
) -> Option<()>
where
S: TorusOps,
{
let decomposition_radix_log = radix.radix_log.0;
// To decrypt a GGSW ciphertext, it suffices to decrypt the first GLWE ciphertext in
// the last row and divide by its decomposition factor.
let glev = ggsw_ciphertext.rows(params, radix).nth(row)?;
let glwe = glev.glwe_ciphertexts(params).nth(column)?;
// Decrypt that specific GLWE ciphertext, which should have a message of
// q / beta ^ {column + 1} * SM, where SM is the message times the secret
// every row but the last (-SM) and M for the last row.
decrypt_glwe_ciphertext(msg, glwe, glwe_secret_key, params);
let mask = (0x1 << decomposition_radix_log) - 1;
for c in msg.coeffs_mut() {
let val = c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1));
let r = (c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1) - 1))
& S::from_u64(0x1);
*c = Torus::from((val + r) & S::from_u64(mask));
}
Some(())
}
/// Decrypt a GGSW ciphertext with a given secret key.
pub fn decrypt_ggsw_ciphertext<S>(
msg: &mut PolynomialRef<Torus<S>>,
ggsw_ciphertext: &GgswCiphertextRef<S>,
glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
let row = params.dim.size.0;
decrypt_glwe_in_ggsw(msg, ggsw_ciphertext, glwe_secret_key, params, radix, row, 0).unwrap();
}
#[cfg(test)]
mod tests {
use crate::{entities::GgswCiphertext, high_level::TEST_GLWE_DEF_1, high_level::*};
use super::*;
#[test]
fn can_encrypt_decrypt_gsw_const_coeff() {
let params = TEST_GLWE_DEF_1;
let radix = &TEST_RADIX;
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_glwe_sk(&params);
let msg = 1;
let ct = encryption::encrypt_ggsw(msg, &sk, &params, radix, bits);
let pt = encryption::decrypt_ggsw(&ct, &sk, &params, radix, bits);
assert_eq!(pt.coeffs()[0], msg);
for c in pt.coeffs().iter().skip(1) {
assert_eq!(*c, 0);
}
}
/// Test that each of the rows in the GGSW ciphertext is a GLWE ciphertext that encodes the
/// appropriate message (usually the decomposed message times the secret key)
#[test]
fn can_decrypt_all_elements_ggsw() {
let params = TEST_GLWE_DEF_1;
let radix = TEST_RADIX;
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_glwe_sk(&params);
let coeffs = (0..params.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>();
let msg = Polynomial::new(&coeffs);
let mut ct = GgswCiphertext::new(&params, &radix);
encrypt_ggsw_ciphertext(&mut ct, &msg, &sk, &params, &radix, bits);
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
decrypt_ggsw_ciphertext(&mut pt, &ct, &sk, &params, &radix);
let pt = pt.map(|x| x.inner());
// Ensure that the basic decryption works.
assert_eq!(pt, msg);
let n_rows = ct.rows(&params, &radix).len();
let n_cols = ct
.rows(&params, &radix)
.next()
.unwrap()
.glwe_ciphertexts(&params)
.len();
// Beta
let decomposition_radix_log = radix.radix_log.0;
for i in 0..n_rows {
let mut m_times_s = Polynomial::zero(params.dim.polynomial_degree.0);
let m_times_s = if i < params.dim.size.0 {
// The message is composed of the negated secret key and the message
// for all but the last row.
let s = sk.s(&params).nth(i).unwrap();
polynomial_external_mad(&mut m_times_s, msg.as_torus(), s);
// Negate the product.
for c in m_times_s.coeffs_mut().iter_mut() {
// Have to call the trait directly because deref is implemented on Torus
*c = num::traits::WrappingNeg::wrapping_neg(c);
}
&m_times_s
} else {
// Last row isn't multiplied by secret key.
msg.as_torus()
};
for j in 0..n_cols {
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
let mut msg = m_times_s.to_owned();
let mask = (0x1 << decomposition_radix_log) - 1;
for c in msg.coeffs_mut() {
*c = Torus::from(c.inner() & mask);
}
decrypt_glwe_in_ggsw(&mut pt, &ct, &sk, &params, &radix, i, j).unwrap();
assert_eq!(pt, msg);
}
}
}
#[test]
fn can_trivially_decrypy_ggsw() {
let params = TEST_GLWE_DEF_1;
let radix = TEST_RADIX;
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_glwe_sk(&params);
let coeffs = (0..params.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>();
let msg = Polynomial::new(&coeffs);
let mut ct = GgswCiphertext::new(&params, &radix);
trivially_encrypt_ggsw_ciphertext(&mut ct, &msg, &sk, &params, &radix, bits);
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
decrypt_ggsw_ciphertext(&mut pt, &ct, &sk, &params, &radix);
let pt = pt.map(|x| x.inner());
// Ensure that the basic decryption works.
assert_eq!(pt, msg);
let n_rows = ct.rows(&params, &radix).len();
let n_cols = ct
.rows(&params, &radix)
.next()
.unwrap()
.glwe_ciphertexts(&params)
.len();
// Beta
let decomposition_radix_log = radix.radix_log.0;
for i in 0..n_rows {
let mut m_times_s = Polynomial::zero(params.dim.polynomial_degree.0);
let m_times_s = if i < params.dim.size.0 {
// The message is composed of the negated secret key and the message
// for all but the last row.
let s = sk.s(&params).nth(i).unwrap();
polynomial_external_mad(&mut m_times_s, msg.as_torus(), s);
// Negate the product.
for c in m_times_s.coeffs_mut().iter_mut() {
// Have to call the trait directly because deref is implemented on Torus
*c = num::traits::WrappingNeg::wrapping_neg(c);
}
&m_times_s
} else {
// Last row isn't multiplied by secret key.
msg.as_torus()
};
for j in 0..n_cols {
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
let mut msg = m_times_s.to_owned();
let mask = (0x1 << decomposition_radix_log) - 1;
for c in msg.coeffs_mut() {
*c = Torus::from(c.inner() & mask);
}
decrypt_glwe_in_ggsw(&mut pt, &ct, &sk, &params, &radix, i, j).unwrap();
assert_eq!(pt, msg);
}
}
}
}

View File

@@ -0,0 +1,184 @@
use num::Zero;
use crate::{
entities::{GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef},
polynomial::{polynomial_add_assign, polynomial_external_mad, polynomial_sub_assign},
rand::{normal_torus, uniform_torus},
GlweDef, Torus, TorusOps,
};
pub(crate) fn trivially_encrypt_glwe_with_sk_argument<S>(
glwe_ciphertext: &mut GlweCiphertextRef<S>,
msg: &PolynomialRef<Torus<S>>,
_glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
trivially_encrypt_glwe_ciphertext(glwe_ciphertext, msg, params);
}
/// Encrypt `msg` into a into the given GLWE ciphertext `c` using the secret key `sk.`
pub fn encrypt_glwe_ciphertext_secret_generic<S>(
c: &mut GlweCiphertextRef<S>,
msg: &PolynomialRef<Torus<S>>,
sk: &GlweSecretKeyRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
let mut tmp = Polynomial::zero(params.dim.polynomial_degree.0);
let (a, b) = c.a_b_mut(params);
// tmp = A_i * S_i
for (a_i, s_i) in a.zip(sk.s(params)) {
// Fill a_i with uniform data
for c in a_i.coeffs_mut() {
*c = uniform_torus();
}
polynomial_external_mad(&mut tmp, a_i, s_i);
}
// b = A * S
polynomial_add_assign(b, &tmp);
// b = A * S + m
polynomial_add_assign(b, msg);
let e = Polynomial::new(
&(0..msg.len())
.map(|_| normal_torus::<S>(params.std))
.collect::<Vec<_>>(),
);
// b = A * S + m + e
polynomial_add_assign(b, &e);
}
/// Encrypt `msg` into a into the given GLWE ciphertext `c` using the secret key `sk.`
pub fn encrypt_glwe_ciphertext_secret<S>(
c: &mut GlweCiphertextRef<S>,
msg: &PolynomialRef<Torus<S>>,
sk: &GlweSecretKeyRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
encrypt_glwe_ciphertext_secret_generic(c, msg, sk, params)
}
/// Generate a trivial GLWE encryption. Note that the caller will need to scale
/// the message appropriately; a factor like delta is not automatically applied.
pub fn trivially_encrypt_glwe_ciphertext<S>(
c: &mut GlweCiphertextRef<S>,
msg: &PolynomialRef<Torus<S>>,
params: &GlweDef,
) where
S: TorusOps,
{
let (a, b) = c.a_b_mut(params);
// tmp = A_i * S_i
for a_i in a {
// Fill a_i with zero data
for c in a_i.coeffs_mut() {
*c = Torus::zero();
}
}
// b = m
b.clone_from_ref(msg);
}
/// Decrypt GLWE ciphertext `ct` into `msg` using secret key `sk`.
pub fn decrypt_glwe_ciphertext<S>(
msg: &mut PolynomialRef<Torus<S>>,
ct: &GlweCiphertextRef<S>,
sk: &GlweSecretKeyRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
let (a, b) = ct.a_b(params);
let mut tmp = Polynomial::zero(b.len());
// msg = b
msg.clone_from_ref(b);
// tmp = A_i * S_i
for (a_i, s_i) in a.zip(sk.s(params)) {
polynomial_external_mad(&mut tmp, a_i, s_i);
}
// msg = b - A * S = m + e
polynomial_sub_assign(msg, &tmp);
}
#[cfg(test)]
mod tests {
use crate::{high_level::*, PlaintextBits};
use super::*;
// Encryption
#[test]
fn can_encrypt_decrypt() {
let params = TEST_GLWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_glwe_sk(&params);
let plaintext = Polynomial::new(
&(0..params.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>(),
);
let ct = encryption::encrypt_glwe(&plaintext, &sk, &params, bits);
let dec = encryption::decrypt_glwe(&ct, &sk, &params, bits);
assert_eq!(dec, plaintext);
}
#[test]
fn can_encrypt_decrypt_uniform() {
let params = TEST_GLWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_uniform_glwe_sk(&params);
let plaintext = Polynomial::new(
&(0..params.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>(),
);
let ct = encryption::encrypt_glwe(&plaintext, &sk, &params, bits);
let dec = encryption::decrypt_glwe(&ct, &sk, &params, bits);
assert_eq!(dec, plaintext);
}
#[test]
fn trivial_glwe_decrypts() {
let params = TEST_GLWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_glwe_sk(&params);
let plaintext = Polynomial::new(
&(0..params.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>(),
);
let ct = encryption::trivial_glwe(&plaintext, &params, bits);
let dec = encryption::decrypt_glwe(&ct, &sk, &params, bits);
assert_eq!(dec, plaintext);
}
}

View File

@@ -0,0 +1,114 @@
use sunscreen_math::Zero;
use crate::{
entities::{LweCiphertextRef, LweSecretKeyRef},
math::{Torus, TorusOps},
rand::{normal_torus, uniform_torus},
LweDef, PlaintextBits,
};
/// Generate a trivial GLWE encryption. Note that the caller will need to scale
/// the message appropriately; a factor like delta is not automatically applied.
pub fn trivially_encrypt_lwe_ciphertext<S>(
c: &mut LweCiphertextRef<S>,
msg: &Torus<S>,
params: &LweDef,
) where
S: TorusOps,
{
let (a, b) = c.a_b_mut(params);
// tmp = A_i * S_i
for a_i in a {
*a_i = Torus::zero();
}
// b = m
*b = *msg;
}
/// Encrypts the given message under sk, writing the ciphertext to ct. Returns the
/// randomness used to generate the ciphertext.
pub fn encrypt_lwe_ciphertext<S>(
ct: &mut LweCiphertextRef<S>,
sk: &LweSecretKeyRef<S>,
msg: Torus<S>,
params: &LweDef,
) -> Torus<S>
where
S: TorusOps,
{
let (a, b) = ct.a_b_mut(params);
for (a_i, d_i) in a.iter_mut().zip(sk.as_slice().iter()) {
*a_i = uniform_torus::<S>();
*b += *a_i * d_i;
}
let e = normal_torus(params.std);
*b += msg + e;
e
}
/// Encrypts the given message under sk, writing the ciphertext to ct. Returns the
/// randomness used to generate the ciphertext.
pub fn encode_and_encrypt_lwe_ciphertext<S>(
ct: &mut LweCiphertextRef<S>,
sk: &LweSecretKeyRef<S>,
msg: S,
params: &LweDef,
plaintext_bits: PlaintextBits,
) -> Torus<S>
where
S: TorusOps,
{
let msg = Torus::<S>::encode(msg, plaintext_bits);
encrypt_lwe_ciphertext(ct, sk, msg, params)
}
#[cfg(test)]
mod tests {
use crate::{high_level::*, PlaintextBits};
#[test]
fn can_encrypt_decrypt() {
let params = TEST_LWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_lwe_sk(&params);
let ct = encryption::encrypt_lwe_secret(4, &sk, &params, bits);
let pt = encryption::decrypt_lwe(&ct, &sk, &params, bits);
assert_eq!(pt, 4);
}
#[test]
fn can_encrypt_decrypt_uniform() {
let params = TEST_LWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_uniform_lwe_sk(&params);
let ct = encryption::encrypt_lwe_secret(4, &sk, &params, bits);
let pt = encryption::decrypt_lwe(&ct, &sk, &params, bits);
assert_eq!(pt, 4);
}
#[test]
fn can_trivially_decrypt() {
let params = TEST_LWE_DEF_1;
let bits = PlaintextBits(4);
let sk = keygen::generate_binary_lwe_sk(&params);
let ct = encryption::trivial_lwe(4, &params, bits);
let pt = encryption::decrypt_lwe(&ct, &sk, &params, bits);
assert_eq!(pt, 4);
}
}

View File

@@ -0,0 +1,8 @@
mod lwe_encryption;
pub use lwe_encryption::*;
mod glwe_encryption;
pub use glwe_encryption::*;
mod ggsw_encryption;
pub use ggsw_encryption::*;

View File

@@ -0,0 +1,285 @@
use num::Complex;
use crate::{
dst::{FromMutSlice, OverlaySize},
entities::{
GgswCiphertextFftRef, GlevCiphertextFftRef, GlweCiphertextFftRef, GlweCiphertextRef,
PolynomialFftRef, PolynomialRef,
},
ops::ciphertext::{add_glwe_ciphertexts, sub_glwe_ciphertexts},
radix::PolynomialRadixIterator,
scratch::{allocate_scratch, allocate_scratch_ref},
GlweDef, RadixDecomposition, TorusOps,
};
/// Compute `c += a \[*] b`` where
/// * `a` is a GLWE ciphertext
/// * `b` is a GGSW cipheetext
/// * `\[*\]` is the external product operator GGSW \[*\] GLWE -> GLWE`
pub fn glwe_ggsw_mad<S>(
c_fft: &mut GlweCiphertextFftRef<Complex<f64>>,
a: &GlweCiphertextRef<S>,
b_fft: &GgswCiphertextFftRef<Complex<f64>>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
let (a_a, a_b) = a.a_b(params);
let rows = b_fft.rows(params, radix);
// Generate an iterator that includes a_a and a_b
let a_then_b_glwe_polynomials = a_a.chain(std::iter::once(a_b));
allocate_scratch_ref!(scratch, PolynomialRef<S>, (params.dim.polynomial_degree));
// Performs the external operation
//
// GGSW ⊡ GLWE = sum_i=0^k <Decomp^{beta, l}(AB_i), C_i>
//
// where
// * `beta` is the decomposition base
// * `l` is the decomposition level
// * `AB_i` is the i-th polynomial of the GLWE ciphertext with B at the end {A, B}
// * `C_i` is the i-th row of the GGSW ciphertext
for (a_i, r) in a_then_b_glwe_polynomials.zip(rows) {
// For each a polynomial, compute the external product with the GLEV ciphertext
// at index i in the GGSW ciphertext.
let decomp = PolynomialRadixIterator::new(a_i, scratch, radix);
decomposed_polynomial_glev_mad(c_fft, decomp, r, params);
}
}
/// Compute `c += (G^-1 * a) \[*\] b`, where
/// * `G^-1 * a`` is the radix decomposition of `a`
/// * `b` is a GLEV ciphertext.
/// * `c` is a GLWE ciphertext.
/// * \[*\] is the external product between a GLEV ciphertext and `l` polynomials
///
/// # Remarks
/// This functions takes a PolynomialRadixIterator to perform the decomposition.
pub fn decomposed_polynomial_glev_mad<S>(
c: &mut GlweCiphertextFftRef<Complex<f64>>,
mut a: PolynomialRadixIterator<S>,
b: &GlevCiphertextFftRef<Complex<f64>>,
params: &GlweDef,
) where
S: TorusOps,
{
let b_glwe = b.glwe_ciphertexts(params);
let mut cur_radix = allocate_scratch::<S>(params.dim.polynomial_degree.0);
let cur_radix = PolynomialRef::from_mut_slice(cur_radix.as_mut_slice());
let mut decomp_fft = allocate_scratch(PolynomialFftRef::<Complex<f64>>::size(
params.dim.polynomial_degree,
));
let decomp_fft = PolynomialFftRef::from_mut_slice(decomp_fft.as_mut_slice());
// The decomposition of
// <Decomp^{beta, l}(gamma), GLEV>
// can be performed using
// sum_{j = 1}^l gamma_j * C_j
// where gamma_j is the polynomial to decompose multiplied by q/B^{j+1}
// Note the reverse of the GLWE ciphertexts here! The decomposition iterator
// returns the decomposed values in the opposite order.
for b in b_glwe.rev() {
a.write_next(cur_radix);
cur_radix.fft(decomp_fft);
glwe_polynomial_mad(c, b, decomp_fft, params);
}
}
/// Compute c += a \[*\] b where \[*\] is the external product
/// between a GLWE ciphertext and an polynomial in Z\[X\]/(X^N + 1).
///
/// # Remarks
/// For this to produce the correct result, degree(b) must be "small"
/// (ideally 0) and the coefficient must be small (i.e. less than the
/// message size).
pub fn glwe_polynomial_mad(
c: &mut GlweCiphertextFftRef<Complex<f64>>,
a: &GlweCiphertextFftRef<Complex<f64>>,
b: &PolynomialFftRef<Complex<f64>>,
params: &GlweDef,
) {
let (c_a, c_b) = c.a_b_mut(params);
let (a_a, a_b) = a.a_b(params);
assert_eq!(c_a.len(), params.dim.size.0);
assert_eq!(a_a.len(), params.dim.size.0);
for (c, a) in c_a.zip(a_a) {
c.multiply_add(a, b);
}
c_b.multiply_add(a_b, b);
}
/// Performs a CMUX operation, which enables one of two GLWE ciphertexts
/// to be selected from an encrypted boolean GGSW ciphertext. The result
/// is stored in `c`.
///
/// Conceptually, this can be seen as the following operation in Rust:
///
/// ```text
/// let c = if b_fft { d_1 } else { d_0 }
/// ```
///
/// where the output `c` is a different encryption than either of the initial
/// inputs. Note that this will result in higher noise than in the original
/// ciphertexts.
pub fn cmux<S>(
c: &mut GlweCiphertextRef<S>,
d_0: &GlweCiphertextRef<S>,
d_1: &GlweCiphertextRef<S>,
b_fft: &GgswCiphertextFftRef<Complex<f64>>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
allocate_scratch_ref!(diff, GlweCiphertextRef<S>, (params.dim));
sub_glwe_ciphertexts(diff, d_1, d_0, params);
allocate_scratch_ref!(prod_fft, GlweCiphertextFftRef<Complex<f64>>, (params.dim));
prod_fft.clear();
glwe_ggsw_mad(prod_fft, diff, b_fft, params, radix);
allocate_scratch_ref!(prod, GlweCiphertextRef<S>, (params.dim));
prod_fft.ifft(prod, params);
add_glwe_ciphertexts(c, prod, d_0, params);
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{
entities::{GgswCiphertextFft, GlweCiphertext, GlweCiphertextFft, Polynomial},
high_level::*,
PlaintextBits, Torus,
};
use super::*;
#[test]
fn can_fft_external_product_glwe_ggsw() {
let glwe_params = TEST_GLWE_DEF_1;
let radix = TEST_RADIX;
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_glwe_sk(&glwe_params);
for _ in 0..100 {
let sel = thread_rng().next_u64() % 2;
let ggsw = encryption::encrypt_ggsw(sel, &sk, &glwe_params, &radix, bits);
let glwe_pt = (0..glwe_params.dim.polynomial_degree.0)
.map(|_| thread_rng().next_u64() % 2)
.collect::<Vec<_>>();
let glwe_pt = Polynomial::new(&glwe_pt);
let glwe = encryption::encrypt_glwe(&glwe_pt, &sk, &glwe_params, bits);
let mut ggsw_fft = GgswCiphertextFft::new(&glwe_params, &radix);
let mut res_fft = GlweCiphertextFft::new(&glwe_params);
let mut res = GlweCiphertext::<u64>::new(&glwe_params);
ggsw.fft(&mut ggsw_fft, &glwe_params, &radix);
glwe_ggsw_mad(&mut res_fft, &glwe, &ggsw_fft, &glwe_params, &radix);
res_fft.ifft(&mut res, &glwe_params);
let actual = encryption::decrypt_glwe(&res, &sk, &glwe_params, bits);
if sel == 1 {
assert_eq!(actual, glwe_pt);
} else {
assert_eq!(
actual,
Polynomial::zero(glwe_params.dim.polynomial_degree.0)
);
}
}
}
#[test]
fn can_cmux_fft() {
let glwe = TEST_GLWE_DEF_1;
let sk = keygen::generate_binary_glwe_sk(&glwe);
let radix = TEST_RADIX;
let bits = PlaintextBits(1);
for _ in 0..100 {
let sel = thread_rng().next_u64() % 2;
let sel_ct = encryption::encrypt_ggsw(sel, &sk, &glwe, &radix, bits);
let a = (0..glwe.dim.polynomial_degree.0)
.map(|_| thread_rng().next_u64() % 2)
.collect::<Vec<_>>();
let a = Polynomial::new(&a);
let a_ct = encryption::encrypt_glwe(&a, &sk, &glwe, bits);
let b = (0..glwe.dim.polynomial_degree.0)
.map(|_| thread_rng().next_u64() % 2)
.collect::<Vec<_>>();
let b = Polynomial::new(&b);
let b_ct = encryption::encrypt_glwe(&b, &sk, &glwe, bits);
let sel_fft = fft::fft_ggsw(&sel_ct, &glwe, &radix);
let mut res_ct = GlweCiphertext::new(&glwe);
cmux(&mut res_ct, &a_ct, &b_ct, &sel_fft, &glwe, &radix);
let actual = encryption::decrypt_glwe(&res_ct, &sk, &glwe, bits);
if sel == 1 {
assert_eq!(actual, b);
} else {
assert_eq!(actual, a);
}
}
}
#[test]
fn cmux_trivial_ciphertexts_yields_nontrivial() {
let sk = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
let plaintext_bits = crate::PlaintextBits(1);
let a = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Polynomial<_>>();
let a = encryption::trivial_glwe(&a, &TEST_GLWE_DEF_1, plaintext_bits);
let b = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0 as u64)
.map(|x| (x + 1) % 2)
.collect::<Polynomial<_>>();
let b = encryption::trivial_glwe(&b, &TEST_GLWE_DEF_1, plaintext_bits);
let sel = encryption::encrypt_ggsw(1, &sk, &TEST_GLWE_DEF_1, &TEST_RADIX, plaintext_bits);
let sel = fft::fft_ggsw(&sel, &TEST_GLWE_DEF_1, &TEST_RADIX);
let res = evaluation::cmux(&sel, &a, &b, &TEST_GLWE_DEF_1, &TEST_RADIX);
for a in res.a(&TEST_GLWE_DEF_1) {
let zero = Polynomial::<Torus<u64>>::zero(TEST_GLWE_DEF_1.dim.polynomial_degree.0);
assert_ne!(a.to_owned(), zero);
}
}
}

View File

@@ -0,0 +1,67 @@
use crate::{entities::LweCiphertextRef, LweDef, Torus, TorusOps};
/// Add `amount` to each torus element (mod q) in the ciphertext.
/// This shifts where messages lie on the torus and adds no noise.
///
/// # Remark
/// Suppose we have plaintexts 0 and 1 that lie centered at 0 and q/2 respectively.
/// If we rotate by q/4, then the 0 lies centered at q/4 and 1 lies at 3q/4 == -q/4.
pub fn rotate<S: TorusOps>(
output: &mut LweCiphertextRef<S>,
input: &LweCiphertextRef<S>,
amount: Torus<S>,
lwe: &LweDef,
) {
output.assert_valid(lwe);
input.assert_valid(lwe);
output.a_mut(lwe).clone_from_slice(input.a(lwe));
*output.b_mut(lwe) = input.b(lwe) + amount;
}
#[cfg(test)]
mod tests {
use crate::{
entities::LweCiphertext,
high_level::{keygen, TEST_LWE_DEF_1},
PlaintextBits, Torus,
};
use super::rotate;
#[test]
fn can_rotate() {
let lwe_params = TEST_LWE_DEF_1;
for _ in 0..100 {
let sk = keygen::generate_binary_lwe_sk(&lwe_params);
let val = sk.encrypt(0, &lwe_params, PlaintextBits(1)).0;
let mut res = LweCiphertext::new(&lwe_params);
rotate(
&mut res,
&val,
Torus::encode(1, PlaintextBits(2)),
&lwe_params,
);
let t = sk.decrypt_without_decode(&res, &lwe_params);
assert!(t.inner() < 0x1u64 << 63);
let val = sk.encrypt(1, &lwe_params, PlaintextBits(1)).0;
rotate(
&mut res,
&val,
Torus::encode(1, PlaintextBits(2)),
&lwe_params,
);
let t = sk.decrypt_without_decode(&res, &lwe_params);
assert!(t.inner() > 0x1u64 << 63);
}
}
}

View File

@@ -0,0 +1,2 @@
mod lwe;
pub use lwe::*;

View File

@@ -0,0 +1,107 @@
use crate::{
dst::FromMutSlice,
entities::{GlweCiphertext, GlweCiphertextRef, GlweKeyswitchKeyRef, PolynomialRef},
ops::{
ciphertext::{decomposed_polynomial_glev_mad, sub_glwe_ciphertexts},
encryption::trivially_encrypt_glwe_ciphertext,
},
radix::PolynomialRadixIterator,
scratch::allocate_scratch_ref,
GlweDef, RadixDecomposition, TorusOps,
};
/// Switches a ciphertext under the original key to a ciphertext under the new
/// key using a keyswitch key.
///
/// # Remark
///
/// This performs the following operation:
///
/// ```text
/// switched_ciphertext = trivial_encrypt(ciphertext_b) - sum_i(<decomp(ciphertext_a_i), glev_i>)
/// ```
///
/// where `trivial_encrypt` is the encryption of the body of the original
/// ciphertext.
pub fn keyswitch_glwe_to_glwe<S>(
output: &mut GlweCiphertextRef<S>,
ciphertext_under_original_key: &GlweCiphertextRef<S>,
keyswitch_key: &GlweKeyswitchKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(params);
let keyswitch_glevs = keyswitch_key.rows(params, radix);
let mut a_i_decomp_sum = GlweCiphertext::new(params);
allocate_scratch_ref!(scratch, PolynomialRef<S>, (params.dim.polynomial_degree));
// sum_i(<decomp(ciphertext_a_i), glev_i>)
for (a_i, glev_i) in ciphertext_a.zip(keyswitch_glevs) {
let decomp = PolynomialRadixIterator::new(a_i, scratch, radix);
decomposed_polynomial_glev_mad(&mut a_i_decomp_sum, decomp, glev_i, params);
}
// trivial_encrypt(ciphertext_b)
let mut trivial_b = GlweCiphertext::new(params);
trivially_encrypt_glwe_ciphertext(&mut trivial_b, ciphertext_b, params);
// output = trivial_encrypt(ciphertext_b) - sum_i(<decomp(ciphertext_a_i), glev_i>)
sub_glwe_ciphertexts(output, &trivial_b, &a_i_decomp_sum, params);
}
#[cfg(test)]
mod tests {
use crate::{
entities::{GlweCiphertext, GlweKeyswitchKey, Polynomial},
high_level::*,
ops::keyswitch::{
glwe_keyswitch::keyswitch_glwe_to_glwe, glwe_keyswitch_key::generate_keyswitch_key_glwe,
},
PlaintextBits,
};
#[test]
fn keyswitch_glwe() {
let glwe = TEST_GLWE_DEF_1;
let bits = PlaintextBits(1);
let original_sk = keygen::generate_binary_glwe_sk(&glwe);
let new_sk = keygen::generate_binary_glwe_sk(&glwe);
let mut ksk = GlweKeyswitchKey::<u64>::new(&TEST_GLWE_DEF_1, &TEST_RADIX);
generate_keyswitch_key_glwe(
&mut ksk,
&original_sk,
&new_sk,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
);
let msg = Polynomial::new(
&(0..glwe.dim.polynomial_degree.0 as u64)
.map(|x| x % 2)
.collect::<Vec<_>>(),
);
let original_ct = original_sk.encode_encrypt_glwe(&msg, &glwe, bits);
let mut new_ct = GlweCiphertext::new(&glwe);
keyswitch_glwe_to_glwe(
&mut new_ct,
&original_ct,
&ksk,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
);
let new_decrypted = new_sk.decrypt_decode_glwe(&new_ct, &glwe, bits);
assert_eq!(new_decrypted.coeffs(), msg.coeffs());
}
}

View File

@@ -0,0 +1,150 @@
use crate::{
entities::{
GlweCiphertextRef, GlweKeyswitchKeyRef, GlweSecretKeyRef, Polynomial, PolynomialRef,
},
ops::encryption::encrypt_glwe_ciphertext_secret_generic,
polynomial::polynomial_scalar_mul,
GlweDef, RadixDecomposition, Torus, TorusOps,
};
fn encrypt_glwe_ciphertext_secret_with_keyswitch_noise<S>(
c: &mut GlweCiphertextRef<S>,
msg: &PolynomialRef<Torus<S>>,
sk: &GlweSecretKeyRef<S>,
params: &GlweDef,
) where
S: TorusOps,
{
encrypt_glwe_ciphertext_secret_generic(c, msg, sk, params);
}
/**
* Generates a keyswitch key from the original key to the new key. The resulting
* keyswitch key is encrypted under the new key. This function is generic over
* the encrypt function should there be a need.
*
* The specific operation on each GLev ciphertext row inside a keyswitch key is
*
* ```text
* KSK_i = (GLWE_{s', })
* ```
*/
fn encrypt_keyswitch_key_generic<S>(
keyswitch_key: &mut GlweKeyswitchKeyRef<S>,
original_glwe_secret_key: &GlweSecretKeyRef<S>,
new_glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
encrypt: impl Fn(
&mut GlweCiphertextRef<S>,
&PolynomialRef<Torus<S>>,
&GlweSecretKeyRef<S>,
&GlweDef,
),
) where
S: TorusOps,
{
let decomposition_radix_log = radix.radix_log.0;
let polynomial_degree = params.dim.polynomial_degree.0;
for (i, row) in keyswitch_key.rows_mut(params, radix).enumerate() {
let s = original_glwe_secret_key
.s(params)
.nth(i)
.unwrap()
.map(|x| Torus::from(*x));
for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() {
let mut scaled_original_key = Polynomial::zero(polynomial_degree);
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
let decomp_factor =
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
polynomial_scalar_mul(&mut scaled_original_key, &s, decomp_factor);
encrypt(col, &scaled_original_key, new_glwe_secret_key, params);
}
}
}
/// Generate a keyswitch key from the original key to the new key.
/// For use with
/// [`keyswitch_glwe_to_glwe`](crate::ops::keyswitch::glwe_keyswitch::keyswitch_glwe_to_glwe).
pub fn generate_keyswitch_key_glwe<S>(
keyswitch_key: &mut GlweKeyswitchKeyRef<S>,
original_glwe_secret_key: &GlweSecretKeyRef<S>,
new_glwe_secret_key: &GlweSecretKeyRef<S>,
params: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
encrypt_keyswitch_key_generic(
keyswitch_key,
original_glwe_secret_key,
new_glwe_secret_key,
params,
radix,
encrypt_glwe_ciphertext_secret_with_keyswitch_noise,
)
}
#[cfg(test)]
mod tests {
use crate::{
dst::FromSlice,
entities::{GlweKeyswitchKey, GlweKeyswitchKeyRef},
high_level::{TEST_GLWE_DEF_1, TEST_RADIX},
Torus,
};
#[test]
fn test_generate_keyswitch_key_glwe() {
// Size of the arrary should be (k + 1) * l * k * poly_degree as we are
// GLev encrypting all S_i with a given radix count and base.
let ksk = GlweKeyswitchKey::<u64>::new(&TEST_GLWE_DEF_1, &TEST_RADIX);
let k = TEST_GLWE_DEF_1.dim.size.0;
let l = TEST_RADIX.count.0;
let poly_degree = TEST_GLWE_DEF_1.dim.polynomial_degree.0;
assert_eq!(ksk.as_slice().len(), (k + 1) * l * k * poly_degree);
// Generate fake data to iterate through.
let ksk_data = (0..ksk.as_slice().len())
.map(|x| Torus::from(x as u64))
.collect::<Vec<Torus<u64>>>();
let ksk = GlweKeyswitchKeyRef::<u64>::from_slice(&ksk_data);
// Check that the data is correct.
let glwe_size = TEST_GLWE_DEF_1.dim.polynomial_degree.0 * (TEST_GLWE_DEF_1.dim.size.0 + 1);
let mut count = 0;
for row in ksk.rows(&TEST_GLWE_DEF_1, &TEST_RADIX) {
for glwe in row.glwe_ciphertexts(&TEST_GLWE_DEF_1) {
let (a, b) = glwe.a_b(&TEST_GLWE_DEF_1);
let expected_vector = (count..(glwe_size + count))
.map(|x| Torus::from(x as u64))
.collect::<Vec<_>>();
let (a_expected, b_expected) = expected_vector
.split_at(TEST_GLWE_DEF_1.dim.size.0 * TEST_GLWE_DEF_1.dim.polynomial_degree.0);
let a_i_expected = a_expected
.chunks(TEST_GLWE_DEF_1.dim.polynomial_degree.0)
.map(|x| x.to_vec())
.collect::<Vec<Vec<Torus<u64>>>>();
for (a_j, a_j_expected) in a.zip(a_i_expected) {
assert_eq!(a_j.coeffs(), a_j_expected);
}
assert_eq!(b.coeffs(), b_expected);
count += glwe_size;
}
}
}
}

View File

@@ -0,0 +1,91 @@
use crate::{
dst::{FromMutSlice, FromSlice},
entities::{LweCiphertext, LweCiphertextRef, LweKeyswitchKeyRef, PolynomialRef},
ops::{
ciphertext::{decomposed_scalar_lev_mad, sub_lwe_ciphertexts},
encryption::trivially_encrypt_lwe_ciphertext,
},
radix::PolynomialRadixIterator,
scratch::allocate_scratch_ref,
LweDef, PolynomialDegree, RadixDecomposition, TorusOps,
};
/// Switches a ciphertext under the original key to a ciphertext under the new
/// key using a keyswitch key.
///
/// Arguments:
///
/// * output: the output ciphertext
/// * ciphertext_under_original_key: the input ciphertext
/// * keyswitch_key: the keyswitch key
/// * old_params: the parameters of the original ciphertext
/// * new_params: the parameters of the output ciphertext
pub fn keyswitch_lwe_to_lwe<S>(
output: &mut LweCiphertextRef<S>,
ciphertext_under_original_key: &LweCiphertextRef<S>,
keyswitch_key: &LweKeyswitchKeyRef<S>,
old_params: &LweDef,
new_params: &LweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
keyswitch_key.assert_valid(old_params, new_params, radix);
let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(old_params);
let keyswitch_levs = keyswitch_key.rows(new_params, radix);
let mut a_i_decomp_sum = LweCiphertext::new(new_params);
allocate_scratch_ref!(scratch, PolynomialRef<S>, (PolynomialDegree(1)));
// sum_i(<decomp(ciphertext_a_i), lev_i>)
for (a_i, lev_i) in ciphertext_a.iter().zip(keyswitch_levs) {
let decomp =
PolynomialRadixIterator::new(PolynomialRef::from_slice(&[*a_i]), scratch, radix);
decomposed_scalar_lev_mad(&mut a_i_decomp_sum, decomp, lev_i, new_params);
}
// trivial_encrypt(ciphertext_b)
let mut trivial_b = LweCiphertext::new(new_params);
trivially_encrypt_lwe_ciphertext(&mut trivial_b, ciphertext_b, new_params);
// output = trivial_encrypt(ciphertext_b) - sum_i(<decomp(ciphertext_a_i), glev_i>)
sub_lwe_ciphertexts(output, &trivial_b, &a_i_decomp_sum, new_params);
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{high_level::*, PlaintextBits};
#[test]
fn keyswitch_lwe() {
let bits = PlaintextBits(4);
let from_lwe = TEST_LWE_DEF_1;
let to_lwe = TEST_LWE_DEF_2;
let radix = TEST_RADIX;
for _ in 0..50 {
let original_sk = keygen::generate_binary_lwe_sk(&from_lwe);
let new_sk = keygen::generate_binary_lwe_sk(&to_lwe);
let ksk = keygen::generate_ksk(&original_sk, &new_sk, &from_lwe, &to_lwe, &radix);
let msg = thread_rng().next_u64() % (1 << bits.0);
let original_ct = original_sk.encrypt(msg, &from_lwe, bits).0;
let new_ct =
evaluation::keyswitch_lwe_to_lwe(&original_ct, &ksk, &from_lwe, &to_lwe, &radix);
let new_decrypted = new_sk.decrypt(&new_ct, &to_lwe, bits);
assert_eq!(new_decrypted, msg);
}
}
}

View File

@@ -0,0 +1,41 @@
use crate::{
entities::{LweKeyswitchKeyRef, LweSecretKeyRef},
ops::encryption::encrypt_lwe_ciphertext,
LweDef, RadixDecomposition, Torus, TorusOps,
};
/// Generates a keyswitch key from an original LWE key to a new LWE key. The
/// resulting keyswitch key is encrypted under the new key.
///
/// Arguments:
///
/// * keyswitch_key: the resulting keyswitch key
/// * original_lwe_secret_key: the original LWE secret key
/// * new_lwe_secret_key: the new LWE secret key
/// * new_params: the parameters of the new LWE secret key
pub fn generate_keyswitch_key_lwe<S>(
keyswitch_key: &mut LweKeyswitchKeyRef<S>,
original_lwe_secret_key: &LweSecretKeyRef<S>,
new_lwe_secret_key: &LweSecretKeyRef<S>,
new_params: &LweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
{
let decomposition_radix_log = radix.radix_log.0;
for (i, row) in keyswitch_key.rows_mut(new_params, radix).enumerate() {
let s_i = original_lwe_secret_key.s()[i];
for (j, col) in row.lwe_ciphertexts_mut(new_params).enumerate() {
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
let decomp_factor =
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
let msg = decomp_factor * s_i;
encrypt_lwe_ciphertext(col, new_lwe_secret_key, Torus::from(msg), new_params);
}
}
}

View File

@@ -0,0 +1,17 @@
/// Methods for performing a private functional keyswitch (PFKS)
pub mod private_functional_keyswitch;
/// Methods for performing a public functional keyswitch (PuFKS)
pub mod public_functional_keyswitch;
/// Generate LWE keyswitch keys.
pub mod lwe_keyswitch_key;
/// Generate GLWE keyswitch keys.
pub mod glwe_keyswitch_key;
/// Methods for performing a LWE keyswitch.
pub mod lwe_keyswitch;
/// Methods for performing a GLWE keyswitch.
pub mod glwe_keyswitch;

View File

@@ -0,0 +1,350 @@
use sunscreen_math::Zero;
use crate::{
dst::FromMutSlice,
entities::{
CircuitBootstrappingKeyswitchKeysRef, GlweCiphertextRef, GlweSecretKeyRef,
LweCiphertextRef, LweSecretKeyRef, PolynomialRef, PrivateFunctionalKeyswitchKeyRef,
},
ops::{
ciphertext::{decomposed_scalar_glev_mad, glwe_negate_inplace},
encryption::encrypt_glwe_ciphertext_secret,
},
radix::{scale_by_decomposition_factor, ScalarRadixIterator},
scratch::allocate_scratch_ref,
GlweDef, LweDef, PrivateFunctionalKeyswitchLweCount, RadixDecomposition, Torus, TorusOps,
};
/// Initialize `output`, a
/// [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey),
/// under the given scheme parameters for the given secret mapping `map`.
/// Conceptually, this map transforms a list of torus plaintexts into a
/// polynomial plaintext.
///
/// # Remarks
/// `map` must be an R-Lipschitzian morphism `T_q^p -> T_q[X]` where `p = lwe_count`.
///
/// The first parameter in `map` is the output
/// [`Polynomial`](crate::entities::Polynomial) of the morphism. This parameter
/// is initialized to 0.
///
/// The second argument is a [`&[Torus]`](crate::math::Torus) of length `lwe_count`.
///
/// # Security
/// To prevent side channels, `map` must run in constant time.
///
/// # Panics
/// * If `output` is not valid for the given `from_lwe`, `to_glwe`, `radix`, `lwe_count`.
/// * If any of `from_lwe`, `to_glwe`, `radix`, `lwe_count` are invalid.
#[allow(clippy::too_many_arguments)]
pub fn generate_private_functional_keyswitch_key<S, F>(
output: &mut PrivateFunctionalKeyswitchKeyRef<S>,
from_key: &LweSecretKeyRef<S>,
to_key: &GlweSecretKeyRef<S>,
map: F,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
lwe_count: &PrivateFunctionalKeyswitchLweCount,
) where
S: TorusOps,
F: Fn(&mut PolynomialRef<Torus<S>>, &[Torus<S>]),
{
output.assert_valid(from_lwe, to_glwe, radix, lwe_count);
radix.assert_valid::<S>();
from_key.assert_valid(from_lwe);
to_key.assert_valid(to_glwe);
to_glwe.assert_valid();
from_lwe.assert_valid();
lwe_count.assert_valid();
allocate_scratch_ref!(
pt_poly,
PolynomialRef<Torus<S>>,
(to_glwe.dim.polynomial_degree)
);
allocate_scratch_ref!(pt_touri, [Torus<S>], lwe_count.0);
let mut glevs = output.glevs_mut(to_glwe, radix);
let minus_one = <S as Zero>::zero().wrapping_sub(&<S as num::One>::one());
for z in 0..lwe_count.0 {
for s_i in from_key.s().iter().chain([minus_one].iter()) {
let glev = glevs.next().unwrap();
for (j, glwe) in glev.glwe_ciphertexts_mut(to_glwe).enumerate() {
let scaled_s_i = scale_by_decomposition_factor(*s_i, j, radix);
pt_poly.clear();
pt_touri.iter_mut().for_each(|x| *x = Torus::zero());
pt_touri[z] = Torus::from(scaled_s_i);
map(pt_poly, pt_touri);
encrypt_glwe_ciphertext_secret(glwe, pt_poly, to_key, to_glwe);
}
}
}
}
/// Perform a private functional keyswitch. See
/// [`module`](crate::ops::keyswitch::private_functional_keyswitch) documentation for more
/// details.
pub fn private_functional_keyswitch<S: TorusOps>(
output: &mut GlweCiphertextRef<S>,
inputs: &[&LweCiphertextRef<S>],
pfksk: &PrivateFunctionalKeyswitchKeyRef<S>,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
lwe_count: &PrivateFunctionalKeyswitchLweCount,
) {
output.assert_valid(to_glwe);
pfksk.assert_valid(from_lwe, to_glwe, radix, lwe_count);
from_lwe.assert_valid();
to_glwe.assert_valid();
radix.assert_valid::<S>();
lwe_count.assert_valid();
assert_eq!(lwe_count.0, inputs.len());
let mut ksk_glevs = pfksk.glevs(to_glwe, radix);
for input in inputs.iter() {
for i in 0..from_lwe.dim.0 + 1 {
// Treating the z'th ciphertext as slice of length n + 1 allows us to iterate
// over a || b.
let ab = input.as_slice();
let glev = ksk_glevs.next().unwrap();
let decomp = ScalarRadixIterator::new(ab[i], radix);
decomposed_scalar_glev_mad(output, decomp, glev, to_glwe);
}
}
// Return minus output.
glwe_negate_inplace(output, to_glwe);
}
/// Generate the keys for a private functional keyswitch.
pub fn generate_circuit_bootstrapping_pfks_keys<S: TorusOps>(
output: &mut CircuitBootstrappingKeyswitchKeysRef<S>,
from_key: &LweSecretKeyRef<S>,
to_key: &GlweSecretKeyRef<S>,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) {
output.assert_valid(from_lwe, to_glwe, radix);
from_key.assert_valid(from_lwe);
to_glwe.assert_valid();
to_key.assert_valid(to_glwe);
radix.assert_valid::<S>();
from_lwe.assert_valid();
// Fill in k pfks keys that multiply each of the "a" GLEVs by the corresponding
// polynomial in the GLWE secret key.
for (pfksk, s) in output
.keys_mut(from_lwe, to_glwe, radix)
.zip(to_key.s(to_glwe))
.take(to_glwe.dim.size.0)
{
let map = |poly: &mut PolynomialRef<Torus<S>>, x: &[Torus<S>]| {
for (c, a) in poly.coeffs_mut().iter_mut().zip(s.coeffs().iter()) {
*c = -x[0] * a;
}
};
generate_private_functional_keyswitch_key(
pfksk,
from_key,
to_key,
map,
from_lwe,
to_glwe,
radix,
&PrivateFunctionalKeyswitchLweCount(1),
);
}
// Now fill in the "b" GLEV.
// TODO: We could compute this row with public key switching. Is it worth it?
let b = output
.keys_mut(from_lwe, to_glwe, radix)
.nth(to_glwe.dim.size.0)
.unwrap();
let map = |poly: &mut PolynomialRef<Torus<S>>, x: &[Torus<S>]| {
poly.clear();
poly.coeffs_mut()[0] = x[0];
};
generate_private_functional_keyswitch_key(
b,
from_key,
to_key,
map,
from_lwe,
to_glwe,
radix,
&PrivateFunctionalKeyswitchLweCount(1),
)
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{
entities::{GlweCiphertext, PrivateFunctionalKeyswitchKey},
high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX},
PlaintextBits, PrivateFunctionalKeyswitchLweCount,
};
use super::*;
#[test]
fn can_create_private_functional_keyswitch_key() {
for _ in 0..5 {
let lwe_count =
PrivateFunctionalKeyswitchLweCount((thread_rng().next_u64() as usize % 8) + 1);
let lwe_key = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
let glwe_key = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
let mut pfks_key = PrivateFunctionalKeyswitchKey::<u64>::new(
&TEST_LWE_DEF_1,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
&lwe_count,
);
fn map<S: TorusOps>(poly: &mut PolynomialRef<Torus<S>>, inputs: &[Torus<S>]) {
for (i, input) in inputs.iter().enumerate() {
poly.coeffs_mut()[i] = *input;
}
}
generate_private_functional_keyswitch_key(
&mut pfks_key,
&lwe_key,
&glwe_key,
map,
&TEST_LWE_DEF_1,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
&lwe_count,
);
let mut glevs = pfks_key.glevs(&TEST_GLWE_DEF_1, &TEST_RADIX);
let minus_one = u64::MAX;
for z in 0..lwe_count.0 {
for s_i in lwe_key.s().iter().chain([minus_one].iter()) {
let glev = glevs.next().unwrap();
for (j, glwe) in glev.glwe_ciphertexts(&TEST_GLWE_DEF_1).enumerate() {
let plaintext_bits = (j + 1) * TEST_RADIX.radix_log.0;
let plaintext_bits = PlaintextBits(plaintext_bits as u32);
let pt =
glwe_key.decrypt_decode_glwe(glwe, &TEST_GLWE_DEF_1, plaintext_bits);
for i in 0..pt.coeffs().len() {
if *s_i == minus_one && i == z {
let expected = (0x1u64 << ((j + 1) * TEST_RADIX.radix_log.0)) - 1;
let actual = pt.coeffs()[i];
assert_eq!(actual, expected);
} else if i == z {
assert_eq!(pt.coeffs()[i], *s_i);
} else {
assert_eq!(pt.coeffs()[i], 0);
}
}
}
}
}
}
}
#[test]
fn can_private_functional_keyswitch() {
for _ in 0..5 {
let lwe_count =
PrivateFunctionalKeyswitchLweCount((thread_rng().next_u64() as usize % 8) + 1);
let lwe_key = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
let glwe_key = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
let mut pfks_key = PrivateFunctionalKeyswitchKey::<u64>::new(
&TEST_LWE_DEF_1,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
&lwe_count,
);
fn map<S: TorusOps>(poly: &mut PolynomialRef<Torus<S>>, inputs: &[Torus<S>]) {
for (i, input) in inputs.iter().enumerate() {
poly.coeffs_mut()[i] = *input;
}
}
generate_private_functional_keyswitch_key(
&mut pfks_key,
&lwe_key,
&glwe_key,
map,
&TEST_LWE_DEF_1,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
&lwe_count,
);
let plaintext_bits = PlaintextBits(4);
let pts = (0..lwe_count.0)
.map(|_x| thread_rng().next_u64() % (0x1u64 << plaintext_bits.0))
.collect::<Vec<u64>>();
let lwe_cts = pts
.iter()
.map(|x| lwe_key.encrypt(*x, &TEST_LWE_DEF_1, plaintext_bits))
.collect::<Vec<_>>();
let mut lwe_ct_refs: Vec<&LweCiphertextRef<_>> = vec![];
for ct in lwe_cts.iter() {
lwe_ct_refs.push(&ct.0);
}
let mut result = GlweCiphertext::new(&TEST_GLWE_DEF_1);
private_functional_keyswitch(
&mut result,
&lwe_ct_refs,
&pfks_key,
&TEST_LWE_DEF_1,
&TEST_GLWE_DEF_1,
&TEST_RADIX,
&lwe_count,
);
let actual = glwe_key.decrypt_decode_glwe(&result, &TEST_GLWE_DEF_1, plaintext_bits);
for (i, (c, pt)) in actual
.coeffs()
.iter()
.zip(pts.iter().cycle().take(actual.coeffs().len()))
.enumerate()
{
if i < lwe_count.0 {
assert_eq!(*c, *pt);
} else {
assert_eq!(*c, 0);
}
}
}
}
}

View File

@@ -0,0 +1,264 @@
use num::Complex;
use crate::dst::FromMutSlice;
use crate::entities::{
GlevCiphertextFftRef, GlweCiphertextRef, LweCiphertextRef, LweSecretKeyRef, PolynomialRef,
};
use crate::ops::ciphertext::glwe_negate_inplace;
use crate::ops::encryption::encrypt_glwe_ciphertext_secret;
use crate::ops::fft_ops::decomposed_polynomial_glev_mad;
use crate::polynomial::polynomial_add_assign;
use crate::radix::PolynomialRadixIterator;
use crate::scratch::allocate_scratch;
use crate::Torus;
use crate::{
entities::{GlweCiphertextFftRef, GlweSecretKeyRef, PublicFunctionalKeyswitchKeyRef},
radix::scale_by_decomposition_factor,
scratch::allocate_scratch_ref,
GlweDef, LweDef, RadixDecomposition, TorusOps,
};
/// Generate a public functional keyswitch key, which is used to transform a
/// list of LWE ciphertexts into a GLWE ciphertext while applying a provided
/// function that converts the scalars in the LWE ciphertexts to the polynomial
/// message space in the GLWE ciphertext.
///
/// See
/// [`public_functional_keyswitch`](crate::ops::keyswitch::public_functional_keyswitch)
/// for more details.
pub fn generate_public_functional_keyswitch_key<S: TorusOps>(
output: &mut PublicFunctionalKeyswitchKeyRef<S>,
from_sk: &LweSecretKeyRef<S>,
to_sk: &GlweSecretKeyRef<S>,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) {
from_sk.assert_valid(from_lwe);
to_sk.assert_valid(to_glwe);
output.assert_valid(from_lwe, to_glwe, radix);
allocate_scratch_ref!(pt, PolynomialRef<Torus<S>>, (to_glwe.dim.polynomial_degree));
pt.clear();
for (s_i, glev) in from_sk.s().iter().zip(output.glevs_mut(to_glwe, radix)) {
for (j, glwe_ct) in (0..radix.count.0).zip(glev.glwe_ciphertexts_mut(to_glwe)) {
let x = scale_by_decomposition_factor(*s_i, j, radix);
pt.coeffs_mut()[0] = Torus::from(x);
encrypt_glwe_ciphertext_secret(glwe_ct, pt, to_sk, to_glwe);
}
}
}
/// Perform a public functional keyswitch, where a list of LWE ciphertexts are
/// transformed into a GLWE ciphertext while applying a provided function.
/// Conceptually, this map transforms a list of torus plaintexts into a
/// polynomial plaintext.
///
/// This operation is called "public" because the function F is public; a
/// variant of this operation called
/// [`private_functional_keyswitch`](crate::ops::keyswitch::private_functional_keyswitch)
/// is also available, where the function F is secret encoded in the keyswitch
/// key.
///
/// # Remarks
/// `map` must be an R-Lipschitzian morphism `T_q^p -> T_q[X]` where `p = lwe_count`.
///
/// The first parameter in `map` is the output
/// [`Polynomial`](crate::entities::Polynomial) of the morphism. This parameter
/// is initialized to 0.
///
/// The second argument is a [`&[Torus]`](crate::math::Torus) of length `lwe_count`.
pub fn public_functional_keyswitch<S, F>(
output: &mut GlweCiphertextRef<S>,
inputs: &[&LweCiphertextRef<S>],
pufksk: &PublicFunctionalKeyswitchKeyRef<S>,
f: F,
from_lwe: &LweDef,
to_glwe: &GlweDef,
radix: &RadixDecomposition,
) where
S: TorusOps,
F: Fn(&mut PolynomialRef<Torus<S>>, &[Torus<S>]),
{
pufksk.assert_valid(from_lwe, to_glwe, radix);
output.assert_valid(to_glwe);
for i in inputs {
i.assert_valid(from_lwe);
}
assert!(inputs.len() <= to_glwe.dim.polynomial_degree.0);
allocate_scratch_ref!(
poly,
PolynomialRef<Torus<S>>,
(to_glwe.dim.polynomial_degree)
);
output.clear();
allocate_scratch_ref!(
decomp_scratch,
PolynomialRef<S>,
(to_glwe.dim.polynomial_degree)
);
let mut a_buf = allocate_scratch::<Torus<S>>(inputs.len());
let lwe_vals = a_buf.as_mut_slice();
allocate_scratch_ref!(
glev_fft,
GlevCiphertextFftRef<Complex<f64>>,
(to_glwe.dim, radix.count)
);
allocate_scratch_ref!(
output_fft,
GlweCiphertextFftRef<Complex<f64>>,
(to_glwe.dim)
);
output_fft.clear();
// Compute all the a terms
for (i, row) in pufksk.glevs(to_glwe, radix).enumerate() {
for (j, a_i) in inputs.iter().map(|x| x.a(from_lwe)[i]).enumerate() {
lwe_vals[j] = a_i;
}
row.fft(glev_fft, to_glwe);
f(poly, lwe_vals);
let decomp = PolynomialRadixIterator::new(poly, decomp_scratch, radix);
decomposed_polynomial_glev_mad(output_fft, decomp, glev_fft, to_glwe);
}
// Compute the b term
for (j, b) in inputs.iter().map(|x| x.b(from_lwe)).enumerate() {
lwe_vals[j] = *b;
}
f(poly, lwe_vals);
output_fft.ifft(output, to_glwe);
// Compute (0, b) - output
glwe_negate_inplace(output, to_glwe);
polynomial_add_assign(output.b_mut(to_glwe), poly);
}
#[cfg(test)]
mod tests {
use rand::{thread_rng, RngCore};
use crate::{
entities::{GlweCiphertext, PublicFunctionalKeyswitchKey},
high_level::{encryption, keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX},
PlaintextBits,
};
use super::*;
#[test]
fn can_generate_public_functional_keyswitch_key() {
let lwe_params = TEST_LWE_DEF_1;
let glwe_params = TEST_GLWE_DEF_1;
let radix = TEST_RADIX;
let mut pufksk = PublicFunctionalKeyswitchKey::new(&lwe_params, &glwe_params, &TEST_RADIX);
let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
generate_public_functional_keyswitch_key(
&mut pufksk,
&lwe_sk,
&glwe_sk,
&lwe_params,
&glwe_params,
&radix,
);
for (s_i, glev) in lwe_sk.s().iter().zip(pufksk.glevs(&glwe_params, &radix)) {
for (j, glwe_ct) in (0..radix.count.0).zip(glev.glwe_ciphertexts(&glwe_params)) {
let plaintext_bits = ((j + 1) * radix.radix_log.0) as u32;
let plaintext_bits = PlaintextBits(plaintext_bits);
let pt = encryption::decrypt_glwe(glwe_ct, &glwe_sk, &glwe_params, plaintext_bits);
assert_eq!(pt.coeffs()[0], *s_i);
for x in pt.coeffs().iter().skip(1) {
assert_eq!(*x, 0);
}
}
}
}
#[test]
fn can_public_functional_keyswitch() {
let lwe_params = TEST_LWE_DEF_1;
let glwe_params = TEST_GLWE_DEF_1;
let radix = TEST_RADIX;
let mut pufksk = PublicFunctionalKeyswitchKey::new(&lwe_params, &glwe_params, &TEST_RADIX);
let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
let plaintext_bits = PlaintextBits(4);
generate_public_functional_keyswitch_key(
&mut pufksk,
&lwe_sk,
&glwe_sk,
&lwe_params,
&glwe_params,
&radix,
);
for _ in 0..10 {
let lwe_count = thread_rng().next_u64() as usize % glwe_params.dim.polynomial_degree.0;
let pts = (0..lwe_count)
.map(|_| thread_rng().next_u64() % (0x1 << plaintext_bits.0))
.collect::<Vec<_>>();
let lwes = pts
.iter()
.map(|x| encryption::encrypt_lwe_secret(*x, &lwe_sk, &lwe_params, plaintext_bits))
.collect::<Vec<_>>();
let mut lwe_refs: Vec<&LweCiphertextRef<u64>> = vec![];
for x in lwes.iter() {
lwe_refs.push(x);
}
let mut output = GlweCiphertext::new(&glwe_params);
fn map<S: TorusOps>(poly: &mut PolynomialRef<Torus<S>>, tori: &[Torus<S>]) {
for (c, t) in poly.coeffs_mut().iter_mut().zip(tori.iter()) {
*c = *t;
}
}
public_functional_keyswitch(
&mut output,
&lwe_refs,
&pufksk,
map,
&lwe_params,
&glwe_params,
&radix,
);
let actual = encryption::decrypt_glwe(&output, &glwe_sk, &glwe_params, plaintext_bits);
for (a, e) in actual.coeffs().iter().zip(pts.iter()) {
assert_eq!(a, e);
}
}
}
}

View File

@@ -0,0 +1,19 @@
/// Ciphertext operations where one of the operands is in FFT form.
pub mod fft_ops;
/// Methods for key switching a ciphertext from one key to another, potentially
/// switching the parameters at the same time.
pub mod keyswitch;
/// Methods for bootstrapping an LWE ciphertext from one key to another, while
/// refreshing the noise in the ciphertext.
pub mod bootstrapping;
/// Methods for homomorphic operations on ciphertexts.
pub mod homomorphisms;
/// Methods for operating on different ciphertext types.
pub mod ciphertext;
/// Methods for encrypting and decrypting to various ciphertext types.
pub mod encryption;

View File

@@ -0,0 +1,310 @@
use serde::{Deserialize, Serialize};
use crate::rand::Stddev;
use crate::TorusOps;
use sunscreen_math::security::{lwe_std_to_security_level, SecurityLevelResults};
trait SecurityLevel {
fn security_level(&self) -> f64;
fn assert_security_level(&self, specified_security_level: usize) {
// Note that the underlying approximation is accurate up to 4 bits, so
// that is what we use here.
let tolerance = 4.0;
let security_level = self.security_level();
let security_difference = (security_level - specified_security_level as f64).abs();
assert!(
security_difference <= tolerance,
"Security level mismatch: expected {}, got {}",
specified_security_level,
security_level
)
}
}
fn generic_security_level(dimension: usize, std: f64) -> f64 {
match lwe_std_to_security_level(dimension, std) {
SecurityLevelResults::Level(level) => level,
SecurityLevelResults::BelowStandardDeviationBound => {
panic!("Standard deviation is too small for the given dimension")
}
SecurityLevelResults::AboveStandardDeviationBound => {
panic!("Standard deviation is too large for the given dimension")
}
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of torus elements in the LWE lattice.
pub struct LweDimension(pub usize);
impl LweDimension {
/// Asserts this LWE problem is well-formed.
pub fn assert_valid(&self) {
assert_ne!(self.0, 0);
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The degree of the modulus polynomial `(x^N+1)` in a GLWE instance.
///
/// # Remarks
/// GLWE encryption uses polynomials in `Z_q\[X\]/(X^N + 1)` where N is a power
/// of two. I.e. negacyclic polynomials modulo (X^N + 1) where the coefficients
/// are integers mod `q`.
pub struct PolynomialDegree(pub usize);
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of polynomials in a GLWE instance.
pub struct GlweSize(pub usize);
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of plaintext bits to encode into a message.
///
/// # Remarks
/// Packing too much data into messages can result in incorrect decryptions due
/// to noise.
///
/// For binary, set this to one.
pub struct PlaintextBits(pub u32);
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of padding bits to include in an LWE ciphertext.
pub struct CarryBits(pub u32);
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of digits to decompose a value into.
pub struct RadixCount(pub usize);
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of bits in a digit output during base decomposition.
pub struct RadixLog(pub usize);
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The number of [`LweCiphertext`](crate::entities::LweCiphertext)s that get
/// mapped into a [`GlweCiphertext`](crate::entities::GlweCiphertext) during
/// private functional keyswitching.
pub struct PrivateFunctionalKeyswitchLweCount(pub usize);
impl PrivateFunctionalKeyswitchLweCount {
#[inline(always)]
/// Assert this [`PrivateFunctionalKeyswitchLweCount`] is valid.
pub fn assert_valid(&self) {
assert_ne!(self.0, 0);
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
/// The parameters defining how to do approximately perform base decomposition. I.e.
/// decompose values into digits.
///
/// # Validity
/// For [`RadixDecomposition`] parameters to be valid:
/// * `count` must be greater than zero.
/// * `radix_log` must be greater than zero.
/// * `count * radix_log` must be less than equal to the number of bits in the
/// [`Torus`](crate::Torus) element used in the operation(s) performing radix
/// decomposition.
///
/// Calling [`assert_valid`](Self::assert_valid) will panic if the parameters are invalid.
pub struct RadixDecomposition {
/// The number of digits to decompose a value into.
pub count: RadixCount,
/// The number of bits in a digit output during base decomposition.
pub radix_log: RadixLog,
}
impl RadixDecomposition {
#[inline(always)]
/// Panics if these [`RadixDecomposition`] parameters are invalid.
pub fn assert_valid<S: TorusOps>(&self) {
assert!(self.count.0 > 0);
assert!(self.radix_log.0 > 0);
assert!(self.count.0 * self.radix_log.0 <= S::BITS as usize);
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
/// A [`PolynomialDegree`] and [`GlweSize`] in a GLWE instance.
pub struct GlweDimension {
/// The degree of the polynomial in a GLWE instance.
pub polynomial_degree: PolynomialDegree,
/// The number of polynomials in a GLWE instance.
pub size: GlweSize,
}
impl GlweDimension {
/// Reinterpret this GLWE problem instance as an LWE problem instance, returning
/// the [`LweDimension`].
pub fn as_lwe_dimension(&self) -> LweDimension {
LweDimension(self.polynomial_degree.0 * self.size.0)
}
#[inline(always)]
/// Assert these GLWE parameters are valid.
pub fn assert_valid(&self) {
assert!(self.polynomial_degree.0.is_power_of_two());
assert!(self.polynomial_degree.0 > 0);
assert!(self.size.0 > 0);
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
/// Parameters that define an LWE problem instance.
///
/// # Security
/// `dim` and `std` must be properly set to attain the desired security level. Improper
/// values will result in an insecure scheme.
pub struct LweDef {
/// The dimension of the LWE lattice.
pub dim: LweDimension,
/// The standard deviation of the noise in the LWE lattice.
pub std: Stddev,
}
impl LweDef {
#[inline(always)]
/// Asserts this LWE problem is well-formed.
pub fn assert_valid(&self) {
self.dim.assert_valid();
}
}
impl SecurityLevel for LweDef {
fn security_level(&self) -> f64 {
generic_security_level(self.dim.0, self.std.0)
}
}
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
/// Parameters that define a GLWE problem instance.
///
/// # Security
/// `dim` and `std` must be properly set to attain the desired security level. Improper
/// values will result in an insecure scheme.
pub struct GlweDef {
/// The dimension of the GLWE instance.
pub dim: GlweDimension,
/// The standard deviation of the noise in the GLWE instance.
pub std: Stddev,
}
impl GlweDef {
/// Reinterpret this GLWE instance as an LWE instance of the same lattice dimension.
pub fn as_lwe_def(&self) -> LweDef {
LweDef {
dim: self.dim.as_lwe_dimension(),
std: self.std,
}
}
#[inline(always)]
/// Assert the GLWE instance is valid.
pub fn assert_valid(&self) {
self.dim.assert_valid();
}
}
impl SecurityLevel for GlweDef {
fn security_level(&self) -> f64 {
self.as_lwe_def().security_level()
}
}
/// 128-bit secure parameters for an LWE instance with a dimension of 512.
pub const LWE_512_128: LweDef = LweDef {
dim: LweDimension(512),
std: Stddev(0.0004899836456140595),
};
/// 128-bit secure parameters for a GLWE instance with 1 polynomial of degree 1024.
pub const GLWE_1_1024_128: GlweDef = GlweDef {
dim: GlweDimension {
size: GlweSize(1),
polynomial_degree: PolynomialDegree(1024),
},
std: Stddev(0.0000000444778278004718),
};
/// 128-bit secure parameters for a GLWE instance with 1 polynomial of degree 2048.
pub const GLWE_1_2048_128: GlweDef = GlweDef {
dim: GlweDimension {
size: GlweSize(1),
polynomial_degree: PolynomialDegree(2048),
},
std: Stddev(0.00000000000000034667670193445625),
};
/// 80-bit secure parameters for an LWE instance with a dimension of 512.
pub const LWE_512_80: LweDef = LweDef {
dim: LweDimension(512),
std: Stddev(0.000001842343446823844),
};
/// 80-bit secure parameters for a GLWE instance with 5 polynomials of degree 256.
pub const GLWE_5_256_80: GlweDef = GlweDef {
dim: GlweDimension {
size: GlweSize(5),
polynomial_degree: PolynomialDegree(256),
},
std: Stddev(0.000000000000002106764669572764),
};
/// 80-bit secure parameters for a GLWE instance with 1 polynomial of degree 1024.
pub const GLWE_1_1024_80: GlweDef = GlweDef {
dim: GlweDimension {
size: GlweSize(1),
polynomial_degree: PolynomialDegree(1024),
},
std: Stddev(0.0000000000010900242107812643),
};
#[cfg(test)]
mod tests {
use sunscreen_math::security::lwe_security_level_to_std;
use super::*;
#[test]
fn check_security_levels() {
let actual_lwe_std = lwe_security_level_to_std(512, 128.0);
println!("LWE 512 128: {}", actual_lwe_std);
LWE_512_128.assert_security_level(128);
let actual_glwe_std = lwe_security_level_to_std(1024, 128.0);
println!("GLWE 1 1024 128: {}", actual_glwe_std);
GLWE_1_1024_128.assert_security_level(128);
let actual_glwe_std = lwe_security_level_to_std(2048, 128.0);
println!("GLWE 1 2048 128: {}", actual_glwe_std);
GLWE_1_2048_128.assert_security_level(128);
let actual_lwe_std = lwe_security_level_to_std(512, 80.0);
println!("LWE 512 80: {}", actual_lwe_std);
LWE_512_80.assert_security_level(80);
let actual_glwe_std = lwe_security_level_to_std(256 * 5, 80.0);
println!("GLWE 5 256 80: {}", actual_glwe_std);
GLWE_5_256_80.assert_security_level(80);
let actual_glwe_std = lwe_security_level_to_std(1024, 80.0);
println!("GLWE 1 1024 80: {}", actual_glwe_std);
GLWE_1_1024_80.assert_security_level(80);
}
}

View File

@@ -0,0 +1,97 @@
use std::fmt::Debug;
use rand::{thread_rng, Rng, RngCore};
use rand_distr::Normal;
use serde::{Deserialize, Serialize};
use crate::math::{Torus, TorusOps};
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
#[serde(transparent)]
/// The standard deviation of a Gaussian distribution normalized over the torus
/// `T_q`.
pub struct Stddev(pub f64);
/// Sample a random torus element from the a normal distribution
/// with a mean of 0 and the given stddev
pub fn normal_torus<S: TorusOps>(std: Stddev) -> Torus<S> {
let dist = Normal::new(0., std.0).unwrap();
let e_0 = thread_rng().sample(dist);
let q = (S::BITS as f64).exp2();
let e = f64::round(e_0 * q) as i64;
let e: u64 = unsafe { std::mem::transmute(e) };
Torus::from(S::from_u64(e))
}
/// Generate a random torus element uniformly
pub fn uniform_torus<S: TorusOps>() -> Torus<S> {
Torus::from(S::from_u64(thread_rng().next_u64()))
}
/// Generate a random binary torus element
pub fn binary<S: TorusOps>() -> S {
S::from_u64(thread_rng().next_u64() % 2)
}
#[cfg(test)]
mod tests {
use std::mem::transmute_copy;
use crate::math::ToF64;
use super::*;
#[test]
fn can_produce_random_torus() {
pub fn case<S, I>()
where
S: TorusOps,
I: ToF64 + Copy + Debug,
<S as TryFrom<u64>>::Error: Debug,
{
let q = (S::BITS as f64).exp2();
let n: i32 = 100_000;
let dev = Stddev(0.000_448_516_698_238_696_5);
let data = (0..n)
.map(|_| {
let t = normal_torus::<S>(dev).inner();
unsafe { transmute_copy::<S, I>(&t) }
})
.collect::<Vec<_>>();
// Reinterpreting the torus points as i64 values should give a mean of approximately zero.
let mean = data
.iter()
.copied()
.map(|x| x.to_f64())
.fold(0., |s, x| s + x)
/ (q * n as f64);
assert!(mean < 1e-5);
// Scale the integer values back to the range [0, 1) and compute the stddev
let measured_std = data
.iter()
.copied()
.map(|x| {
let val = (x.to_f64() / q) - mean;
val * val
})
.fold(0f64, |s, x| s + x)
/ n as f64;
let measured_std = measured_std.sqrt();
assert!((measured_std - dev.0).abs() < 0.00001);
}
case::<u32, i32>();
case::<u64, i64>();
}
}

View File

@@ -0,0 +1,427 @@
use linked_list::{Cursor, LinkedList};
use num::{Complex, Float};
use rustfft::FftNum;
use std::{
alloc::Layout,
cell::RefCell,
marker::PhantomData,
mem::{size_of, transmute},
};
use crate::{Torus, TorusOps};
thread_local! {
static SCRATCH: RefCell<Option<Scratch>> = RefCell::new(None);
}
macro_rules! allocate_scratch_ref {
($out_ident:ident,$ref_t:ident<$t_arg:ty>, ($($args:expr),*)) => {
let mut tmp = crate::scratch::allocate_scratch(<$ref_t<$t_arg> as crate::dst::OverlaySize>::size(($($args),*)));
let $out_ident = <$ref_t<$t_arg>>::from_mut_slice(tmp.as_mut_slice());
};
($out_ident:ident,[$ref_t:ident<$t_arg:ty>], $len:expr) => {
let mut tmp = crate::scratch::allocate_scratch::<$ref_t<$t_arg>>(<[$ref_t<$t_arg>] as crate::dst::OverlaySize>::size($len));
let $out_ident = tmp.as_mut_slice();
}
}
pub(crate) use allocate_scratch_ref;
/// Indicates this is a "Plain Old Data" type. For `T` qualify as such,
/// all bit patterns must be considered a properly initialized instance of
/// `T`.
///
/// # Safety
/// Implementing this trait on types that don't meet the above requirements
/// may result in undefined behavior.
pub unsafe trait Pod {}
unsafe impl Pod for u8 {}
unsafe impl Pod for u16 {}
unsafe impl Pod for u32 {}
unsafe impl Pod for u64 {}
unsafe impl Pod for u128 {}
unsafe impl Pod for i8 {}
unsafe impl Pod for i16 {}
unsafe impl Pod for i32 {}
unsafe impl Pod for i64 {}
unsafe impl Pod for i128 {}
unsafe impl Pod for f32 {}
unsafe impl Pod for f64 {}
unsafe impl<T> Pod for Complex<T> where T: Float + FftNum {}
unsafe impl<S> Pod for Torus<S> where S: TorusOps {}
/// Allocate a scratch buffer in a cache efficient manner. Freed scratch
/// buffers are reused in subsequent allocations.
///
/// # Remarks
/// The returned [`ScratchBuffer`] will be aligned to `align_of::<T>()` and
/// have a length equal to count.
///
/// # Panics
/// If `T` is a zero-sized type (e.g. `()`).
pub fn allocate_scratch<T>(count: usize) -> ScratchBuffer<'static, T>
where
T: Pod,
{
SCRATCH.with(|s| {
if s.borrow().is_none() {
let new_scratch = Scratch::new();
*s.borrow_mut() = Some(new_scratch);
}
(*s.borrow_mut()).as_mut().unwrap().allocate::<T>(count)
})
}
/// An "allocator" designed for allocating scratch memory.
///
/// # Remarks
/// Internally, this data structure is a [`LinkedList`] of
/// [`Vec`]s treated as a stack.
///
/// [`Scratch`] is designed to provide a cache locality for
/// temporary buffers by reusing allocations.
///
/// Upon allocation, we update the "top" of the stack to be the
/// furthest consecutive free buffer from the actual top.
///
/// The only way to use this type is through the `thread_local`
/// singleton, as this is the only way to guarantee soundness.
/// The references doled out by allocate have `'static``
/// lifetimes. This is needed so you can have mutable references
/// to different allocations at the same time.
struct Scratch {
// Only accessed through the cursor, so compiler thinks it's unused.
#[allow(unused)]
stack: Box<LinkedList<Allocation>>,
top: *mut Cursor<'static, Allocation>,
}
impl Drop for Scratch {
fn drop(&mut self) {
let top = unsafe { Box::from_raw(self.top) };
std::mem::drop(top);
}
}
impl Scratch {
/// We require this to be private. [`allocate_scratch`] should be
/// the only way to use scratch memory, which will allocate memory
/// using a thread_local allocator.
fn new() -> Self {
let mut list = Box::new(LinkedList::new());
let cursor = Box::new(list.cursor());
let top = unsafe { transmute(Box::into_raw(cursor)) };
Self { stack: list, top }
}
/// Allocate a buffer matching the given specification.
fn allocate<T>(&mut self, count: usize) -> ScratchBuffer<'static, T>
where
T: Pod,
{
assert_ne!(size_of::<T>(), 0);
let top = unsafe { &mut *self.top };
// Push the top as far down until we hit the bottom or an allocation
// currently in use.
loop {
let prev = top.peek_prev();
if let Some(x) = prev {
if x.is_free {
top.prev().unwrap();
continue;
}
}
break;
}
let layout = Layout::array::<T>(count).unwrap();
let req_len = layout.size() + layout.align();
let allocation = match top.peek_next() {
Some(d) => {
assert!(d.is_free);
// Resize the allocation if needed.
if d.data.len() < req_len {
d.data.resize(req_len, 0u8);
}
d.requested_len = count;
top.next().unwrap()
}
None => {
let data = vec![0u8; req_len];
let allocation = Allocation {
requested_len: count,
is_free: false,
data,
};
top.insert(allocation);
top.next().unwrap()
}
};
allocation.is_free = false;
ScratchBuffer {
allocation: allocation as *mut Allocation,
_phantom: PhantomData,
}
}
}
struct Allocation {
requested_len: usize,
data: Vec<u8>,
is_free: bool,
}
pub struct ScratchBuffer<'a, T> {
allocation: *mut Allocation,
_phantom: PhantomData<&'a T>,
}
impl<'a, T> ScratchBuffer<'a, T> {
#[allow(unused)]
/// Get a slice to the underlying data.
///
/// # Remarks
/// While not extremely expensive, this operation does require capturing
/// an aligned slice of data in an underlying allocation. As such,
/// you should avoid repeated calls.
pub fn as_slice(&self) -> &[T] {
let count = unsafe { (*self.allocation).requested_len };
let (_, slice, _) = unsafe { (*self.allocation).data.align_to::<T>() };
unsafe { transmute(&slice[0..count]) }
}
/// Get a mutable slice to the underlying data.
///
/// # Remarks
/// While not extremely expensive, this operation does require capturing
/// an aligned slice of data in an underlying allocation. As such,
/// you should avoid repeated calls.
pub fn as_mut_slice(&mut self) -> &mut [T] {
let count = unsafe { (*self.allocation).requested_len };
let (_pre, slice, _post) = unsafe { (*self.allocation).data.align_to_mut::<T>() };
unsafe { transmute(&mut slice[0..count]) }
}
}
impl<'a, T> Drop for ScratchBuffer<'a, T> {
fn drop(&mut self) {
unsafe { (*self.allocation).is_free = true };
}
}
#[cfg(test)]
mod tests {
use std::mem::align_of;
use super::*;
#[test]
fn can_allocate() {
let mut scratch = Scratch::new();
let mut d = scratch.allocate::<u64>(64);
let d = d.as_mut_slice();
assert_eq!(d.len(), 64);
for (i, d_i) in d.iter_mut().enumerate() {
*d_i = i as u64;
}
assert_eq!(scratch.stack.len(), 1);
}
#[test]
fn buffers_get_reused() {
let mut scratch = Scratch::new();
let b = scratch.allocate::<u64>(64);
assert_eq!(scratch.stack.len(), 1);
let b_slice = b.as_slice();
assert_eq!(b_slice.len(), 64);
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
let first_ptr = b_slice.as_ptr();
std::mem::drop(b);
let mut b = scratch.allocate::<u64>(64);
let b_slice = b.as_mut_slice();
assert_eq!(first_ptr, b_slice.as_ptr());
assert_eq!(scratch.stack.len(), 1);
assert_eq!(b_slice.len(), 64);
for (i, b_i) in b_slice.iter_mut().enumerate() {
*b_i = i as u64;
}
}
#[test]
#[ignore]
fn reallocate_on_bigger_request() {
let mut scratch = Scratch::new();
let mut b = scratch.allocate::<u64>(64);
assert_eq!(scratch.stack.len(), 1);
let b_slice = b.as_mut_slice();
assert_eq!(b_slice.len(), 64);
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
let first_ptr = b_slice.as_ptr();
for (i, b_i) in b_slice.iter_mut().enumerate() {
*b_i = i as u64;
}
std::mem::drop(b);
let mut b = scratch.allocate::<u64>(16384);
let b = b.as_mut_slice();
assert_ne!(first_ptr, b.as_ptr());
assert_eq!(scratch.stack.len(), 1);
assert_eq!(b.len(), 16384);
for (i, b_i) in b.iter_mut().enumerate() {
*b_i = i as u64;
}
}
#[test]
fn allocate_two_buffers() {
let mut scratch = Scratch::new();
let mut a = scratch.allocate::<u64>(12);
let mut b = scratch.allocate::<u64>(12);
let a = a.as_mut_slice();
let b = b.as_mut_slice();
assert_eq!(a.len(), 12);
assert_eq!(b.len(), 12);
assert_eq!(scratch.stack.len(), 2);
assert_ne!(a.as_mut_ptr(), b.as_mut_ptr());
for i in 0..a.len() {
a[i] = i as u64;
b[i] = i as u64;
}
}
#[test]
fn align_16() {
let mut scratch = Scratch::new();
let mut buffers = (0..10)
.map(|_| scratch.allocate::<u128>(10))
.collect::<Vec<_>>();
assert_eq!(scratch.stack.len(), 10);
for b in buffers.iter_mut() {
let b = b.as_mut_slice();
assert_eq!(b.len(), 10);
for (i, b_i) in b.iter_mut().enumerate() {
*b_i = i as u128;
}
}
}
#[test]
fn align_65536() {
// Chose an alignment larger than any reasonable OS's page size
// to try to force the alignment algorithm into play.
#[repr(C, align(65536))]
#[derive(Copy, Clone)]
struct Foo {
x: u32,
}
unsafe impl Pod for Foo {}
let mut scratch = Scratch::new();
let mut b = scratch.allocate::<Foo>(15);
let b_slice = b.as_mut_slice();
assert_eq!(b_slice.len(), 15);
for b in b_slice {
b.x = 22;
let ptr = b as *mut Foo;
// Check that each item is memory aligned to the proper
// location.
assert_eq!(ptr.align_offset(align_of::<Foo>()), 0);
}
}
#[test]
fn stack_coalesces_correctly() {
let mut scratch = Scratch::new();
let a = scratch.allocate::<u64>(16);
let mut b: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
let b_ptr = b.as_mut_slice().as_mut_ptr();
let c: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
let d: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
std::mem::drop(b);
assert_eq!(scratch.stack.len(), 4);
// We can't reuse b's buffer until c, d, e get dropped.
let mut e: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
assert_ne!(b_ptr, e.as_mut_slice().as_mut_ptr());
assert_eq!(scratch.stack.len(), 5);
std::mem::drop(c);
std::mem::drop(d);
std::mem::drop(e);
// Now we can reuse b's buffer.
let mut f = scratch.allocate::<u64>(16);
assert_eq!(f.as_mut_slice().as_mut_ptr(), b_ptr);
assert_eq!(scratch.stack.len(), 5);
std::mem::drop(a);
}
#[test]
fn zero_size_allocations() {
let mut scratch = Scratch::new();
let a = scratch.allocate::<u64>(2);
let b = scratch.allocate::<u64>(0);
let a_slice = a.as_slice();
let b_slice = b.as_slice();
assert_eq!(scratch.stack.len(), 2);
assert_eq!(a_slice.len(), 2);
assert_eq!(b_slice.len(), 0);
}
#[test]
#[should_panic]
fn zst_allocations_should_panic() {
struct Foo {}
unsafe impl Pod for Foo {}
let mut scratch = Scratch::new();
let _ = scratch.allocate::<Foo>(0x1 << 48);
}
}

986
sunscreen_tfhe/src/zkp.rs Normal file
View File

@@ -0,0 +1,986 @@
use logproof::{
crypto::CryptoHash,
linear_algebra::{Matrix, PolynomialMatrix},
math::ModSwitch,
rings::ZqRistretto,
Bounds, LogProofProverKnowledge, LogProofVerifierKnowledge as VerifierKnowledge,
};
use sunscreen_math::{
poly::Polynomial,
ring::{BarrettBackend, Ring, RingModulus, Zq},
BarrettConfig, One, Zero,
};
use crate::{
entities::{LweCiphertext, LwePublicKey, LweSecretKey, TlwePublicEncRandomness},
math::{Torus, TorusOps},
LweDef, PlaintextBits,
};
/// Proof statements for the SDLP proof system when applied to TFHE.
#[derive(Debug)]
pub enum ProofStatement<'a, 'b, S: TorusOps + TorusZq> {
/// A private key encryption statement.
PrivateKeyEncryption {
/// The message ID being encrypted.
message_id: usize,
/// The ciphertext being encrypted.
ciphertext: &'a LweCiphertext<S>,
},
/// A public key encryption statement.
PublicKeyEncryption {
/// The message ID being encrypted.
message_id: usize,
/// The encrypted data under the provided public key.
ciphertext: &'a LweCiphertext<S>,
/// The public key being used to encrypt the message.
public_key: &'b LwePublicKey<S>,
},
}
/// Witness information for the SDLP proof system when applied to TFHE.
/// This is the private information used when generating a proof.
#[derive(Debug)]
pub enum Witness<'a, 'b, S: TorusOps + TorusZq> {
/// A private key encryption witness.
PrivateKeyEncryption {
/// The randomness used in the encryption.
randomness: Torus<S>,
/// The private key used in the encryption.
private_key: &'a LweSecretKey<S>,
},
/// A public key encryption witness.
PublicKeyEncryption {
/// The randomness used in the encryption.
randomness: &'b TlwePublicEncRandomness<S>,
},
}
/// Generate LogProofProverKnowledge for the SDLP proof system.
pub fn generate_tfhe_sdlp_prover_knowledge<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
messages: &[Torus<S>],
witness: &[Witness<S>],
lwe: &LweDef,
plaintext_bits: PlaintextBits,
) -> LogProofProverKnowledge<S::Zq> {
let vk = generate_tfhe_sdlp_verifier_knowledge(statements, lwe, plaintext_bits);
let s = compute_s(statements, witness, messages, lwe);
LogProofProverKnowledge { vk, s }
}
/// Modulus for u32.
#[derive(BarrettConfig)]
#[barrett_config(modulus = "4294967296", num_limbs = 1)]
pub struct U32Config;
/// Field for u32.
pub type Zq32 = Zq<1, BarrettBackend<1, U32Config>>;
/// Modulus for u64.
#[derive(BarrettConfig)]
#[barrett_config(modulus = "18446744073709551616", num_limbs = 2)]
pub struct U64Config;
/// Field for u64.
pub type Zq64 = Zq<2, BarrettBackend<2, U64Config>>;
/// Torus properties on a discrete ring `Z_q`.
pub trait TorusZq
where
Self: Sized,
{
/// The discrete ring `Z_q`.
type Zq: Ring
+ From<Self>
+ CryptoHash
+ ModSwitch<ZqRistretto>
+ RingModulus<4>
+ Ord
+ From<u32>;
}
impl TorusZq for u32 {
type Zq = Zq32;
}
impl TorusZq for u64 {
type Zq = Zq64;
}
/// Computes the public information needed to prove and verify public and private key
/// TLWE encryptions.
///
/// # Remarks
/// Using only private key encryption results in significantly faster runtime since we
/// can characterize `Z_q[X]/f` with `f = X + 1`.
///
/// # Details
/// Let `N` be the TLWE lattice dimension.
/// let `M` be the number of messages encrypted in in the ciphertexts
/// Note that multiple ciphertexts can encrypt the same message and SDLP can prove they contain
/// the same message.
///
/// ## A matrix's structure
/// SDLP proves a linear relation `AS=T` where
/// * `A in Z_q[X]^{m x k}/f`
/// * `S in Z_q[X]^{k x n}/f`
/// * `T in Z_q[X]^{m x n}/f`
/// * `f = X^D + 1`
///
/// For proving encryptions of TLWE ciphertexts, we have:
/// * `m` is the number of proof statements.
/// * `n = 1`.
/// * `k = num_messages + num_public_keys * (N + 1) + num_private_keys * N + num_private_encs + num_public_encs * (N + 1)`.
/// * See [quotient ring modulus `f`](#quotient-ring-modulus-f).
///
/// The matrix `A` is arranged into rows, one per proof statement whose column structure depends
/// on whether the statement is for public or private key encryption.
///
/// ### Private key statements
/// Each private key encryption statement row has the following column arrangement:
/// ```ignore
/// 1 at msg_idx pub_key (N/A) a at pvt_key_id * N 1 at pvt_stmt_idx e_idx (N/A)
/// V V V V V
/// [ 0 .. 1 .. 0 0 .. 0 0 .. a_0, a_1, .. a_N .. 0 0 .. 1 .. 0 0 .. 0]
/// ```
/// ### Public key statements
/// Recall that a TLWE ciphertext consists of (a_1, ... a_N, b), where `a_i, b` in the discrete
/// torus `T`.
///
/// Furthermore, the public key `P = (p_1, ... p_N)` where `p_i` is a secret key encryption of
/// zero.
///
/// Each public key encryption statement row has the following column arrangement:
/// ```ignore
/// 1 at msg_idx p at pub_key_id * N a (N/A) b (N/A) e_idx
/// V V V V V
/// [ 0 .. 1 .. 0 0 .. p_0, p_1, .. p_N .. 0 0 .. 0 0 .. 0 0 .. 1 .. 0 ]
/// ```
/// For these entries, we reinterpret each `p_i` as a polynomial `mod (X^N + 1)`.
///
/// ## T's structure
/// While SDLP support T as a matrix, we only need a vector of `m` rows.
///
/// The structure of each row in T depends on whether said the corresponding statement describes
/// a public or private key encryption.
///
/// ### Private key statements
/// The row's polynomial is simply `b * X^0`.
///
/// ### Public key statements
/// The rows's polynomial is `a_0 * X^0, a_1 * X^1, ... a_N * X^{N - 1} b * X^N`
///
/// ## Quotient ring modulus `f`
/// When `statements` contains only secret key encryption statements, `f = X + 1`. Otherwise,
/// `f = X^{N + 1} + 1`.
pub fn generate_tfhe_sdlp_verifier_knowledge<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
lwe: &LweDef,
plaintext_bits: PlaintextBits,
) -> VerifierKnowledge<S::Zq> {
// If we need to prove any public encryption statements, then `f = X^{lwe_dimension + 1} + 1`.
// If not, we can use the more efficient X + 1.
let mut f_coeffs = vec![S::Zq::from(<S as Zero>::zero()); f_degree(statements, lwe) + 1];
f_coeffs[0] = S::Zq::from(S::one());
let last_coeff = f_coeffs.len() - 1;
f_coeffs[last_coeff] = S::Zq::from(S::one());
let f = Polynomial::new(&f_coeffs);
let a = compute_a(statements, lwe, plaintext_bits);
let t = compute_t(statements, lwe);
let bounds = compute_bounds(statements, lwe, plaintext_bits);
VerifierKnowledge::new(a, t, f, bounds)
}
fn compute_bounds<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
lwe: &LweDef,
plaintext_bits: PlaintextBits,
) -> Matrix<Bounds> {
let (_, cols) = proof_matrix_dim(statements, lwe.dim.0);
let offsets = compute_a_column_offsets(statements, lwe);
let mut bounds = Matrix::<Bounds>::new(cols, 1);
let num_messages = num_messages(statements);
let num_coeffs = f_degree(statements, lwe);
let lwe_dimension = lwe.dim.0;
// Bounds for messages
for i in 0..num_messages {
let mut b = Bounds(vec![0; num_coeffs]);
b.0[0] = 0x1u64 << plaintext_bits.0;
debug_assert_eq!(bounds[(i, 0)].0, &[]);
bounds[(i, 0)] = b;
}
// Public r and e
for i in 0..num_public(statements) {
for j in 0..lwe_dimension {
let mut b = Bounds(vec![0; num_coeffs]);
// Values of r are binary
b.0[0] = 0x1u64 << plaintext_bits.0;
debug_assert_eq!(
bounds[(offsets.public_keys + i * lwe_dimension + j, 0)].0,
&[]
);
bounds[(offsets.public_keys + i * lwe_dimension + j, 0)] = b;
}
// e is normal distributed over the torus.
// TODO: This bound is too high. Get a tighter bound.
let b = Bounds(vec![0x1u64 << (60 - plaintext_bits.0); num_coeffs]);
debug_assert_eq!(bounds[(offsets.public_e + i, 0)].0, &[]);
bounds[(offsets.public_e + i, 0)] = b;
}
// Private s and e
for i in 0..num_private(statements) {
for j in 0..lwe_dimension {
let mut b = Bounds(vec![0; num_coeffs]);
// Values of s are binary
b.0[0] = 0x1u64 << plaintext_bits.0;
debug_assert_eq!(
bounds[(offsets.private_a + j + i * lwe_dimension, 0)].0,
&[]
);
bounds[(offsets.private_a + j + i * lwe_dimension, 0)] = b;
}
let mut b = Bounds(vec![0; num_coeffs]);
// e is normal distributed over the torus.
// TODO: This bound is too high. Get a tighter bound.
b.0[0] = 0x1u64 << (62 - plaintext_bits.0);
debug_assert_eq!(bounds[(offsets.private_e + i, 0)].0, &[]);
bounds[(offsets.private_e + i, 0)] = b;
}
bounds
}
fn f_degree<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>], lwe: &LweDef) -> usize {
let lwe_dimension = lwe.dim.0;
if num_public(statements) > 0 {
lwe_dimension + 1
} else {
1
}
}
fn encoding_factor<S: TorusZq + TorusOps>(plaintext_bits: PlaintextBits) -> S::Zq {
let x = S::from_u64(0x1u64 << (S::BITS - plaintext_bits.0));
S::Zq::from(x)
}
struct IdxOffsets {
public_keys: usize,
public_e: usize,
private_a: usize,
private_e: usize,
}
#[inline]
fn compute_a_column_offsets<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
lwe: &LweDef,
) -> IdxOffsets {
let lwe_dimension = lwe.dim.0;
let public_keys = num_messages(statements);
let public_e = public_keys + num_public(statements) * lwe_dimension;
let private_a = public_e + num_public(statements);
let private_e = private_a + num_private(statements) * lwe_dimension;
IdxOffsets {
public_keys,
public_e,
private_a,
private_e,
}
}
fn compute_a<S: TorusZq + TorusOps>(
statements: &[ProofStatement<S>],
lwe: &LweDef,
plaintext_bits: PlaintextBits,
) -> Matrix<Polynomial<S::Zq>> {
let lwe_dimension = lwe.dim.0;
let offsets = compute_a_column_offsets(statements, lwe);
let (rows, cols) = proof_matrix_dim(statements, lwe_dimension);
let mut a = Matrix::<Polynomial<S::Zq>>::new(rows, cols);
assert!(plaintext_bits.0 > 0);
let msg_encode = encoding_factor::<S>(plaintext_bits);
let mut cur_public = 0;
let mut cur_private = 0;
for (i, s) in statements.iter().enumerate() {
let (is_public, public_key, message_id, ciphertext) = match s {
ProofStatement::PrivateKeyEncryption {
ciphertext,
message_id,
} => (false, None, *message_id, *ciphertext),
ProofStatement::PublicKeyEncryption {
public_key,
ciphertext,
message_id,
} => (true, Some(public_key), *message_id, *ciphertext),
};
// Insert the message coefficient. For private encryptions, this goes in
// the constant coefficient. For public, it goes in the d-1 coefficient.
let coeffs = if !is_public {
vec![msg_encode.clone()]
} else {
let mut coeffs = vec![S::Zq::from(<S as Zero>::zero()); lwe_dimension + 1];
coeffs[lwe_dimension] = msg_encode.clone();
coeffs
};
debug_assert_eq!(a[(i, message_id)], Polynomial::zero());
a[(i, message_id)] = Polynomial::new(&coeffs);
if is_public {
// Push the public key
let pk = public_key.unwrap();
for (j, z) in pk.enc_zeros(lwe).enumerate() {
let (z_a, z_b) = z.a_b(lwe);
let mut coeffs = z_a
.iter()
.map(|x| S::Zq::from(x.inner()))
.collect::<Vec<_>>();
coeffs.push(S::Zq::from(z_b.inner()));
let public_key_idx = offsets.public_keys + cur_public * lwe_dimension + j;
debug_assert_eq!(a[(i, public_key_idx)], Polynomial::zero());
a[(i, public_key_idx)] = Polynomial::new(&coeffs);
}
// Push the randomness
debug_assert_eq!(a[(i, offsets.public_e + cur_public)], Polynomial::zero());
a[(i, offsets.public_e + cur_public)] = Polynomial::one();
cur_public += 1;
} else {
let (c_a, _c_b) = ciphertext.a_b(lwe);
// Place the a values of the cipertext in the matrix.
for (j, a_j) in c_a.iter().enumerate() {
let private_key_idx = offsets.private_a + j + cur_private * lwe_dimension;
debug_assert_eq!(a[(i, private_key_idx)], Polynomial::zero());
a[(i, private_key_idx)] = Polynomial::new(&[S::Zq::from(a_j.inner())]);
}
debug_assert_eq!(a[(i, offsets.private_e)], Polynomial::zero());
a[(i, offsets.private_e + i)] = Polynomial::one();
cur_private += 1;
}
}
a
}
fn compute_t<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
lwe: &LweDef,
) -> PolynomialMatrix<S::Zq> {
let mut t = PolynomialMatrix::new(statements.len(), 1);
for (i, s) in statements.iter().enumerate() {
match s {
ProofStatement::PrivateKeyEncryption {
ciphertext: c,
message_id: _,
} => {
let (_, c_b) = c.a_b(lwe);
debug_assert_eq!(t[(i, 0)], Polynomial::zero());
t[(i, 0)] = Polynomial::new(&[S::Zq::from(c_b.inner())]);
}
ProofStatement::PublicKeyEncryption {
public_key: _,
message_id: _,
ciphertext: c,
} => {
let (c_a, c_b) = c.a_b(lwe);
let mut coeffs = c_a
.iter()
.map(|x| S::Zq::from(x.inner()))
.collect::<Vec<_>>();
coeffs.push(S::Zq::from(c_b.inner()));
debug_assert_eq!(t[(i, 0)], Polynomial::zero());
t[(i, 0)] = Polynomial::new(&coeffs);
}
}
}
t
}
#[inline]
fn compute_s<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
witness: &[Witness<S>],
messages: &[Torus<S>],
lwe: &LweDef,
) -> PolynomialMatrix<S::Zq> {
assert_eq!(statements.len(), witness.len());
let lwe_dimension = lwe.dim.0;
// If the a matrix is rows x cols, the S witness must have 'cols' rows.
let (_, cols) = proof_matrix_dim(statements, lwe_dimension);
// Column offsets in the A matrix become row offsets in the S witness.
let offsets = compute_a_column_offsets(statements, lwe);
let mut s = PolynomialMatrix::new(cols, 1);
// Put the messages into the witness
for (i, m) in messages.iter().take(num_messages(statements)).enumerate() {
debug_assert_eq!(s[(i, 0)], Polynomial::zero());
s[(i, 0)] = Polynomial::new(&[S::Zq::from(m.inner())]);
}
let public_entries = witness.iter().filter_map(|x| match x {
Witness::PublicKeyEncryption { randomness } => Some(randomness),
_ => None,
});
// Put public 'r' and 'e' randomness into the witness
for (i, rnd) in public_entries.enumerate() {
for (j, r) in rnd.r.iter().enumerate() {
let public_key_index = offsets.public_keys + j + i * lwe_dimension;
debug_assert_eq!(s[(public_key_index, 0)], Polynomial::zero());
s[(public_key_index, 0)] = Polynomial::new(&[S::Zq::from(*r)]);
}
let mut coeffs = rnd
.e
.a_b(lwe)
.0
.iter()
.map(|x| S::Zq::from(x.inner()))
.collect::<Vec<_>>();
coeffs.push(S::Zq::from(rnd.e.a_b(lwe).1.inner()));
let public_randomness_index = offsets.public_e + i;
debug_assert_eq!(s[(public_randomness_index, 0)], Polynomial::zero());
s[(public_randomness_index, 0)] = Polynomial::new(&coeffs);
}
let private_keys = witness.iter().filter_map(|x| match x {
Witness::PrivateKeyEncryption {
randomness,
private_key,
} => Some((private_key, randomness)),
_ => None,
});
// Put the secret keys into the witness.
for (i, (sk, e)) in private_keys.enumerate() {
for (j, sk) in sk.s().iter().enumerate() {
let private_key_index = offsets.private_a + j + i * lwe_dimension;
debug_assert_eq!(s[(private_key_index, 0)], Polynomial::zero());
s[(private_key_index, 0)] = Polynomial::new(&[S::Zq::from(*sk)]);
}
let private_randomness_index = offsets.private_e + i;
debug_assert_eq!(s[(private_randomness_index, 0)], Polynomial::zero());
s[(private_randomness_index, 0)] = Polynomial::new(&[S::Zq::from(e.inner())]);
}
s
}
#[inline(always)]
fn num_private<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>]) -> usize {
// Public key statements require 2 rows while private key statements require only one.
statements
.iter()
.filter(|x| {
matches!(
x,
ProofStatement::PrivateKeyEncryption {
ciphertext: _,
message_id: _
}
)
})
.count()
}
#[inline(always)]
fn num_public<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>]) -> usize {
statements.len() - num_private(statements)
}
/// Returns a tuple of the `(rows, cols)` in SDLP's `A` matrix.
#[inline]
fn proof_matrix_dim<S: TorusOps + TorusZq>(
statements: &[ProofStatement<S>],
lwe_dimension: usize,
) -> (usize, usize) {
let num_rows = statements.len();
let num_cols = num_messages(statements) // message terms
+ num_public(statements) * lwe_dimension // public key terms
+ num_private(statements) * lwe_dimension // a terms
+ num_rows; // Public + private e term '1' coeffs
(num_rows, num_cols)
}
#[inline]
fn num_messages<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>]) -> usize {
statements.iter().fold(0usize, |max, x| {
let message_id = match x {
ProofStatement::PublicKeyEncryption {
public_key: _,
ciphertext: _,
message_id,
} => message_id,
ProofStatement::PrivateKeyEncryption {
ciphertext: _,
message_id,
} => message_id,
};
usize::max(max, *message_id)
}) + 1
}
#[cfg(test)]
mod tests {
use logproof::{InnerProductVerifierKnowledge, LogProof, LogProofGenerators};
use merlin::Transcript;
use rand::{thread_rng, RngCore};
use crate::{
high_level::*,
zkp::{num_private, num_public},
LweDef, LweDimension, LWE_512_80,
};
use super::*;
#[test]
fn can_compute_a_column_offsets() {
let lwe = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let lwe_dimension = lwe.dim.0;
let ct = LweCiphertext::<u64>::zero(&lwe);
let sk = keygen::generate_binary_lwe_sk(&lwe);
let pk = keygen::generate_lwe_pk(&sk, &lwe);
let statements = [
ProofStatement::PrivateKeyEncryption {
ciphertext: &ct,
message_id: 1,
},
ProofStatement::PrivateKeyEncryption {
ciphertext: &ct,
message_id: 0,
},
ProofStatement::PublicKeyEncryption {
ciphertext: &ct,
public_key: &pk,
message_id: 0,
},
ProofStatement::PrivateKeyEncryption {
ciphertext: &ct,
message_id: 1,
},
ProofStatement::PrivateKeyEncryption {
ciphertext: &ct,
message_id: 0,
},
ProofStatement::PublicKeyEncryption {
ciphertext: &ct,
public_key: &pk,
message_id: 1,
},
ProofStatement::PublicKeyEncryption {
ciphertext: &ct,
public_key: &pk,
message_id: 1,
},
];
let idx = compute_a_column_offsets(&statements, &lwe);
let num_messages = num_messages(&statements);
let num_private = num_private(&statements);
let num_public = num_public(&statements);
assert_eq!(num_messages, 2);
assert_eq!(num_private, 4);
assert_eq!(num_public, 3);
// We have 2 messages that precede this.
assert_eq!(idx.public_keys, num_messages);
// We have 3 different public encryptions, so there are 3 public key
// column entries. That 2 use the same key is immaterial, as they
// must match different randomness in the witness.
assert_eq!(idx.public_e, idx.public_keys + num_public * lwe_dimension);
// We have 3 public key encryptions that contribute to this offset.
// Each 'e' is a polynomial of degree `lwe_dimension`.
assert_eq!(idx.private_a, idx.public_e + num_public);
// Finally, each secret_a entry takes `lwe_dimension` columns.
assert_eq!(idx.private_e, idx.private_a + num_private * lwe_dimension);
}
#[test]
fn num_messages_works() {
let lwe = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let sk = keygen::generate_binary_lwe_sk(&lwe);
let pk = keygen::generate_lwe_pk(&sk, &lwe);
let ct = encryption::trivial_lwe(0, &lwe, PlaintextBits(1));
let num_messages = num_messages(&[
ProofStatement::PrivateKeyEncryption {
ciphertext: &ct,
message_id: 0,
},
ProofStatement::PrivateKeyEncryption {
ciphertext: &ct,
message_id: 1,
},
ProofStatement::PublicKeyEncryption {
public_key: &pk,
ciphertext: &ct,
message_id: 0,
},
]);
assert_eq!(num_messages, 2);
}
fn prove_and_verify<S: TorusOps + TorusZq>(pk: &LogProofProverKnowledge<S::Zq>) {
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
let u = InnerProductVerifierKnowledge::get_u();
let mut p_t = Transcript::new(b"test");
let proof = LogProof::create(&mut p_t, pk, &gen.g, &gen.h, &u);
let mut v_t = Transcript::new(b"test");
proof.verify(&mut v_t, &pk.vk, &gen.g, &gen.h, &u).unwrap();
}
#[test]
fn one_secret_key() {
let params = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let sk = keygen::generate_binary_lwe_sk(&params);
let (ct, rng) =
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, &params, PlaintextBits(1));
let pk = generate_tfhe_sdlp_prover_knowledge(
&[ProofStatement::PrivateKeyEncryption {
message_id: 0,
ciphertext: &ct,
}],
&[Torus::from(1)],
&[Witness::PrivateKeyEncryption {
randomness: rng,
private_key: &sk,
}],
&params,
PlaintextBits(1),
);
prove_and_verify::<u64>(&pk);
}
#[test]
fn two_secret_key() {
let params = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_lwe_sk(&params);
let (ct0, rng0) =
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, &params, bits);
let (ct1, rng1) =
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, &params, bits);
let pk = generate_tfhe_sdlp_prover_knowledge(
&[
ProofStatement::PrivateKeyEncryption {
message_id: 0,
ciphertext: &ct0,
},
ProofStatement::PrivateKeyEncryption {
message_id: 1,
ciphertext: &ct1,
},
],
&[Torus::from(1), Torus::from(1)],
&[
Witness::PrivateKeyEncryption {
randomness: rng0,
private_key: &sk,
},
Witness::PrivateKeyEncryption {
randomness: rng1,
private_key: &sk,
},
],
&params,
bits,
);
prove_and_verify::<u64>(&pk);
}
#[test]
fn one_public_key() {
let params = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_lwe_sk(&params);
let pk = keygen::generate_lwe_pk(&sk, &params);
let (ct, rng) = encryption::encrypt_lwe_and_return_randomness(1, &pk, &params, bits);
let pk = generate_tfhe_sdlp_prover_knowledge(
&[ProofStatement::PublicKeyEncryption {
message_id: 0,
public_key: &pk,
ciphertext: &ct,
}],
&[Torus::from(1)],
&[Witness::PublicKeyEncryption { randomness: &rng }],
&params,
bits,
);
prove_and_verify::<u64>(&pk);
}
#[test]
fn two_public_key() {
let params = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_lwe_sk(&params);
let pk = keygen::generate_lwe_pk(&sk, &params);
let (ct0, rng0) = encryption::encrypt_lwe_and_return_randomness(1, &pk, &params, bits);
let (ct1, rng1) = encryption::encrypt_lwe_and_return_randomness(1, &pk, &params, bits);
let pk = generate_tfhe_sdlp_prover_knowledge(
&[
ProofStatement::PublicKeyEncryption {
message_id: 0,
public_key: &pk,
ciphertext: &ct0,
},
ProofStatement::PublicKeyEncryption {
message_id: 1,
public_key: &pk,
ciphertext: &ct1,
},
],
&[Torus::from(1), Torus::from(1)],
&[
Witness::PublicKeyEncryption { randomness: &rng0 },
Witness::PublicKeyEncryption { randomness: &rng1 },
],
&params,
bits,
);
prove_and_verify::<u64>(&pk);
}
#[test]
fn one_public_one_private() {
let params = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let bits = PlaintextBits(1);
let sk = keygen::generate_binary_lwe_sk(&params);
let pk = keygen::generate_lwe_pk(&sk, &params);
let (ct_priv, rng_priv) =
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, &params, bits);
let (ct_pub, rng_pub) =
encryption::encrypt_lwe_and_return_randomness(1, &pk, &params, bits);
let pk = generate_tfhe_sdlp_prover_knowledge(
&[
ProofStatement::PrivateKeyEncryption {
message_id: 0,
ciphertext: &ct_priv,
},
ProofStatement::PublicKeyEncryption {
message_id: 0,
public_key: &pk,
ciphertext: &ct_pub,
},
],
&[Torus::from(1)],
&[
Witness::PrivateKeyEncryption {
randomness: rng_priv,
private_key: &sk,
},
Witness::PublicKeyEncryption {
randomness: &rng_pub,
},
],
&params,
bits,
);
prove_and_verify::<u64>(&pk);
}
#[ignore]
#[test]
fn complex_examples() {
let params = LweDef {
dim: LweDimension(4),
..LWE_512_80
};
let bits = PlaintextBits(1);
let case = || {
let sk = keygen::generate_binary_lwe_sk(&params);
let pk = keygen::generate_lwe_pk(&sk, &params);
let num_messages = thread_rng().next_u64() as usize % 7 + 1;
let num_secret_encryptions = thread_rng().next_u64() as usize % 8;
let num_public_encryptions = thread_rng().next_u64() as usize % 8;
let messages = (0..num_messages)
.map(|_| thread_rng().next_u64() % 2)
.collect::<Vec<_>>();
// Skip trivial cases. I don't think SDLP allows it and it's boring..
if num_public_encryptions == 0 && num_secret_encryptions == 0 {
return;
}
let mut statements = vec![];
let mut witnesses = vec![];
let mut private_info = vec![];
let mut public_info = vec![];
for _ in 0..num_secret_encryptions {
let msg_id = thread_rng().next_u64() as usize % num_messages;
let (ct, noise) = encryption::encrypt_lwe_secret_and_return_randomness(
messages[msg_id],
&sk,
&params,
bits,
);
private_info.push((ct, noise, msg_id));
}
for (ct, noise, msg_id) in private_info.iter() {
statements.push(ProofStatement::PrivateKeyEncryption {
message_id: *msg_id,
ciphertext: ct,
});
witnesses.push(Witness::PrivateKeyEncryption {
randomness: *noise,
private_key: &sk,
});
}
for _ in 0..num_public_encryptions {
let msg_id = thread_rng().next_u64() as usize % num_messages;
let (ct, noise) = encryption::encrypt_lwe_and_return_randomness(
messages[msg_id],
&pk,
&params,
bits,
);
public_info.push((ct, noise, msg_id));
}
for (ct, noise, msg_id) in public_info.iter() {
statements.push(ProofStatement::PublicKeyEncryption {
message_id: *msg_id,
ciphertext: ct,
public_key: &pk,
});
witnesses.push(Witness::PublicKeyEncryption { randomness: noise });
}
let messages = messages.iter().map(|x| Torus::from(*x)).collect::<Vec<_>>();
let pk = generate_tfhe_sdlp_prover_knowledge(
&statements,
&messages,
&witnesses,
&params,
bits,
);
prove_and_verify::<u64>(&pk);
};
for _ in 0..5 {
case();
}
}
}