diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 8c61effc6..0f97a2511 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -60,14 +60,14 @@ jobs: # Build package in prep for user docs - name: Build sunscreen and bincode - run: cargo build --release --features bulletproofs,linkedproofs --package sunscreen --package bincode + run: cargo build --profile mdbook --features bulletproofs,linkedproofs --package sunscreen --package bincode # Build mdbook - name: Build mdBook run: cargo build --release working-directory: ./mdBook # Build user documentation - name: Test docs - run: ../mdBook/target/release/mdbook test -L dependency=../target/release/deps --extern sunscreen=../target/release/libsunscreen.rlib --extern bincode=../target/release/libbincode.rlib + run: ../mdBook/target/release/mdbook test -L dependency=../target/mdbook/deps --extern sunscreen=../target/mdbook/libsunscreen.rlib --extern bincode=../target/mdbook/libbincode.rlib working-directory: ./sunscreen_docs lint: diff --git a/.gitignore b/.gitignore index 03ce603a0..9c748ccc9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,189 @@ sunscreen_docs/book -/target -.DS_Store __bindgen.ii __bindgen.cpp + +# Ignore output from the proptest crate +proptest-regressions/ + +# Created by https://www.toptal.com/developers/gitignore/api/emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust +# Edit at https://www.toptal.com/developers/gitignore?templates=emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust + +### direnv ### +.direnv +.envrc + +### Emacs ### +# -*- mode: gitignore; -*- +*~ +\#*\# +/.emacs.desktop +/.emacs.desktop.lock +*.elc +auto-save-list +tramp +.\#* + +# Org-mode +.org-id-locations +*_archive + +# flymake-mode +*_flymake.* + +# eshell files +/eshell/history +/eshell/lastdir + +# elpa packages +/elpa/ + +# reftex files +*.rel + +# AUCTeX auto folder +/auto/ + +# cask packages +.cask/ +dist/ + +# Flycheck +flycheck_*.el + +# server auth directory +/server/ + +# projectiles files +.projectile + +# directory configuration +.dir-locals.el + +# network security +/network-security.data + + +### Linux ### + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### Rust ### +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# End of https://www.toptal.com/developers/gitignore/api/emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust diff --git a/Cargo.lock b/Cargo.lock index 98586d802..b2098c3d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -643,6 +643,8 @@ dependencies = [ "num-traits 0.2.17", "once_cell", "oorandom", + "plotters", + "rayon", "regex", "serde", "serde_derive", @@ -1593,6 +1595,12 @@ dependencies = [ "cc", ] +[[package]] +name = "linked-list" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4dacf969043dc69f1f731b5042eb05e030d264bcf34f2242889fcbdc7a65f06" + [[package]] name = "linux-raw-sys" version = "0.4.10" @@ -2196,6 +2204,15 @@ dependencies = [ "syn 2.0.38", ] +[[package]] +name = "primal-check" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9df7f93fd637f083201473dab4fee2db4c429d32e55e3299980ab3957ab916a0" +dependencies = [ + "num-integer", +] + [[package]] name = "private_tx_linkedproof" version = "0.1.0" @@ -2220,9 +2237,9 @@ checksum = "f89dff0959d98c9758c88826cc002e2c3d0b9dfac4139711d1f30de442f1139b" [[package]] name = "proptest" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e" +checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf" dependencies = [ "bit-set", "bit-vec", @@ -2232,7 +2249,7 @@ dependencies = [ "rand", "rand_chacha", "rand_xorshift", - "regex-syntax 0.7.5", + "regex-syntax 0.8.2", "rusty-fork", "tempfile", "unarray", @@ -2356,6 +2373,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "realfft" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953d9f7e5cdd80963547b456251296efc2626ed4e3cbf36c869d9564e0220571" +dependencies = [ + "rustfft", +] + [[package]] name = "redox_syscall" version = "0.2.11" @@ -2393,9 +2419,9 @@ checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" [[package]] name = "regex-syntax" -version = "0.7.5" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "remove_dir_all" @@ -2485,6 +2511,21 @@ dependencies = [ "semver", ] +[[package]] +name = "rustfft" +version = "6.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43806561bc506d0c5d160643ad742e3161049ac01027b5e6d7524091fd401d86" +dependencies = [ + "num-complex", + "num-integer", + "num-traits 0.2.17", + "primal-check", + "strength_reduce", + "transpose", + "version_check", +] + [[package]] name = "rustix" version = "0.38.20" @@ -2773,6 +2814,12 @@ dependencies = [ "rand", ] +[[package]] +name = "strength_reduce" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" + [[package]] name = "strsim" version = "0.10.0" @@ -2994,6 +3041,27 @@ dependencies = [ "thiserror", ] +[[package]] +name = "sunscreen_tfhe" +version = "0.1.0" +dependencies = [ + "bytemuck", + "criterion 0.5.1", + "linked-list", + "logproof", + "merlin", + "num", + "paste", + "proptest", + "rand", + "rand_distr", + "realfft", + "rustfft", + "serde", + "sunscreen_math", + "thiserror", +] + [[package]] name = "sunscreen_zkp_backend" version = "0.8.1" @@ -3195,6 +3263,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "transpose" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6522d49d03727ffb138ae4cbc1283d3774f0d10aa7f9bf52e6784c45daf9b23" +dependencies = [ + "num-integer", + "strength_reduce", +] + [[package]] name = "try-lock" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index bea76af39..5fa7a8d45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ members = [ "sunscreen_math", "sunscreen_math_macros", "sunscreen_runtime", - "sunscreen_compiler_common", + "sunscreen_tfhe", "sunscreen_zkp_backend", ] exclude = ["mdBook", "rust-playground"] @@ -25,6 +25,17 @@ exclude = ["mdBook", "rust-playground"] [profile.release] split-debuginfo = "packed" debug = true +lto = "fat" +codegen-units = 1 + +[profile.bench] +lto = "fat" +codegen-units = 1 + +[profile.mdbook] +inherits = "release" +lto = false +codegen-units = 16 [workspace.dependencies] bytemuck = "1.13.0" diff --git a/sunscreen_tfhe/.vscode/launch.json b/sunscreen_tfhe/.vscode/launch.json new file mode 100644 index 000000000..da4143492 --- /dev/null +++ b/sunscreen_tfhe/.vscode/launch.json @@ -0,0 +1,81 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'sunscreen_tfhe'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=sunscreen_tfhe" + ], + "filter": { + "name": "sunscreen_tfhe", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'tfhe_proof'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=tfhe_proof", + "--package=sunscreen_tfhe" + ], + "filter": { + "name": "tfhe_proof", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'fft'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=fft", + "--package=sunscreen_tfhe" + ], + "filter": { + "name": "fft", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'ops'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=ops", + "--package=sunscreen_tfhe" + ], + "filter": { + "name": "ops", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/sunscreen_tfhe/.vscode/settings.json b/sunscreen_tfhe/.vscode/settings.json new file mode 100644 index 000000000..4d9636b55 --- /dev/null +++ b/sunscreen_tfhe/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "rust-analyzer.showUnlinkedFileNotification": false +} \ No newline at end of file diff --git a/sunscreen_tfhe/Cargo.toml b/sunscreen_tfhe/Cargo.toml new file mode 100644 index 000000000..2ec640c8a --- /dev/null +++ b/sunscreen_tfhe/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "sunscreen_tfhe" +version = "0.1.0" +edition = "2021" + +authors = ["Sunscreen"] +rust-version = "1.56.0" +license = "AGPL-3.0-only" +description = "This crate contains the Sunscreen Torus FHE (TFHE) implementation" +homepage = "https://sunscreen.tech" +repository = "https://github.com/Sunscreen-tech/Sunscreen" +documentation = "https://docs.sunscreen.tech" +keywords = ["FHE", "TFHE", "lattice", "cryptography"] +categories = ["cryptography"] +readme = "crates-io.md" + + +[dependencies] +bytemuck = { workspace = true } +# TODO: Remove when Rust stabilizes Cursor API +linked-list = "0.0.3" +logproof = { workspace = true, optional = true } +num = { workspace = true } +paste = { workspace = true } +rand = { workspace = true } +rand_distr = { workspace = true } +realfft = "3.3.0" +rustfft = "6.1.0" +serde = { workspace = true } +sunscreen_math = { workspace = true } +thiserror = { workspace = true } + +[dev-dependencies] +criterion = "0.5.1" +merlin = "3.0.0" +proptest = "1.4.0" + +[features] +logproof = ["dep:logproof"] + +[[bench]] +name = "tfhe_proof" +harness = false +required-features= ["logproof"] + +[[bench]] +name = "fft" +harness = false + +[[bench]] +name = "ops" +harness = false diff --git a/sunscreen_tfhe/barrett.py b/sunscreen_tfhe/barrett.py new file mode 100644 index 000000000..168c6c3ea --- /dev/null +++ b/sunscreen_tfhe/barrett.py @@ -0,0 +1,35 @@ +import math +import sys + +radix = 10 + +if len(sys.argv) < 3: + print("Usage: barrett []") + +n = int(sys.argv[1]) + +if len(sys.argv) == 4: + radix = int(sys.argv[3]) + +p = int(sys.argv[2], radix) + +def compute_vals(n, p): + r = math.floor(2**(64 * n) // p) + s = math.floor(2**(128 * n) // p) - 2**(64 * n) * r + t = 2**(64*n) - r * p + + return (n, r, s, t) + +(_, r, s, t) = compute_vals(n, p) + +def print_value(n, name, x): + print(name + " = [") + + for i in range(n): + print(" " + str((x >> (64 * i)) & 0xFFFFFFFFFFFFFFFF) + ",") + + print("]") + +print_value(n, "r", r) +print_value(n, "s", s) +print_value(n, "t", t) diff --git a/sunscreen_tfhe/benches/fft.rs b/sunscreen_tfhe/benches/fft.rs new file mode 100644 index 000000000..96b7e57b5 --- /dev/null +++ b/sunscreen_tfhe/benches/fft.rs @@ -0,0 +1,47 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use num::Complex; +use sunscreen_tfhe::{math::fft::negacyclic::TwistedFft, FrequencyTransform}; + +fn negacyclic_fft(c: &mut Criterion) { + let n = 2048; + + let plan = TwistedFft::::new(n); + + let x = (0..n).map(|x| x as f64).collect::>(); + let mut y = vec![Complex::from(0.0); x.len() / 2]; + + c.bench_function("FFT 2048", |s| { + s.iter(|| { + plan.forward(&x, &mut y); + }); + }); + + let n = 1024; + + let plan = TwistedFft::::new(n); + + let x = (0..n).map(|x| x as f64).collect::>(); + let mut y = vec![Complex::from(0.0); x.len() / 2]; + + c.bench_function("FFT 1024", |s| { + s.iter(|| { + plan.forward(&x, &mut y); + }); + }); + + let n: usize = 256; + + let plan = TwistedFft::::new(n); + + let x = (0..n).map(|x| x as f64).collect::>(); + let mut y = vec![Complex::from(0.0); x.len() / 2]; + + c.bench_function("FFT 256", |s| { + s.iter(|| { + plan.forward(&x, &mut y); + }); + }); +} + +criterion_group!(benches, negacyclic_fft); +criterion_main!(benches); diff --git a/sunscreen_tfhe/benches/ops.rs b/sunscreen_tfhe/benches/ops.rs new file mode 100644 index 000000000..350e54be4 --- /dev/null +++ b/sunscreen_tfhe/benches/ops.rs @@ -0,0 +1,235 @@ +use criterion::{ + criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion, +}; + +use sunscreen_tfhe::{ + entities::{ + GgswCiphertext, GgswCiphertextFft, GlweCiphertext, Polynomial, UnivariateLookupTable, + }, + high_level::*, + ops::bootstrapping::circuit_bootstrap, + rand::Stddev, + GlweDef, GlweDimension, GlweSize, LweDef, LweDimension, PlaintextBits, PolynomialDegree, + RadixCount, RadixDecomposition, RadixLog, GLWE_1_1024_80, GLWE_5_256_80, LWE_512_80, +}; + +fn cmux(c: &mut Criterion) { + struct CmuxParams { + gsw_radix: RadixDecomposition, + glwe: GlweDef, + } + + fn cmux_params(params: &CmuxParams, c: &mut Criterion) { + let sk = keygen::generate_binary_glwe_sk(¶ms.glwe); + let bits = PlaintextBits(1); + + let msg = (0..params.glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(); + let msg = Polynomial::new(&msg); + + let a = encryption::encrypt_glwe(&msg, &sk, ¶ms.glwe, bits); + let b = a.clone(); + let sel = encryption::encrypt_ggsw(1, &sk, ¶ms.glwe, ¶ms.gsw_radix, bits); + let mut sel_fft = GgswCiphertextFft::new(¶ms.glwe, ¶ms.gsw_radix); + + sel.fft(&mut sel_fft, ¶ms.glwe, ¶ms.gsw_radix); + + let name = format!( + "cmux N={} k={} l={}", + params.glwe.dim.polynomial_degree.0, params.glwe.dim.size.0, params.gsw_radix.count.0 + ); + + let mut result = GlweCiphertext::new(¶ms.glwe); + + c.bench_function(&name, |bench| { + bench.iter(|| { + sunscreen_tfhe::ops::fft_ops::cmux( + &mut result, + &a, + &b, + &sel_fft, + ¶ms.glwe, + ¶ms.gsw_radix, + ); + }); + }); + } + + let params = CmuxParams { + gsw_radix: RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(10), + }, + glwe: GLWE_5_256_80, + }; + + cmux_params(¶ms, c); + + let params = CmuxParams { + gsw_radix: RadixDecomposition { + count: RadixCount(1), + radix_log: RadixLog(11), + }, + glwe: GLWE_1_1024_80, + }; + + cmux_params(¶ms, c); +} + +fn programmable_bootstrapping(c: &mut Criterion) { + fn run_bench( + name: &str, + g: &mut BenchmarkGroup, + lwe: &LweDef, + glwe: &GlweDef, + bs_radix: &RadixDecomposition, + ) { + let lwe_sk = keygen::generate_binary_lwe_sk(lwe); + let glwe_sk = keygen::generate_binary_glwe_sk(glwe); + let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, lwe, glwe, bs_radix); + let bsk = fft::fft_bootstrap_key(&bsk, lwe, glwe, bs_radix); + + let ct = lwe_sk.encrypt(1, &glwe.as_lwe_def(), PlaintextBits(1)).0; + let lut = UnivariateLookupTable::trivial_from_fn(|x| x, glwe, PlaintextBits(1)); + + g.bench_function(name, |b| { + b.iter(|| { + evaluation::univariate_programmable_bootstrap(&ct, &lut, &bsk, lwe, glwe, bs_radix); + }); + }); + } + + let mut g = c.benchmark_group("Bootstrapping"); + + // CBS parameters + let radix = RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(16), + }; + + run_bench( + "CBS parameters", + &mut g, + &LWE_512_80, + &GLWE_5_256_80, + &radix, + ); + + // Binary PBS parameters + let bs_radix = RadixDecomposition { + count: RadixCount(3), + radix_log: RadixLog(6), + }; + + run_bench( + "boolean PBS parameters", + &mut g, + &LweDef { + dim: LweDimension(722), + std: Stddev(0.000013071021089943935), + }, + &GlweDef { + dim: GlweDimension { + size: GlweSize(2), + polynomial_degree: PolynomialDegree(512), + }, + std: Stddev(0.00000004990272175010415), + }, + &bs_radix, + ); + + // 3-bit message 1-bit carry PBS parameters + let bs_radix = RadixDecomposition { + count: RadixCount(1), + radix_log: RadixLog(23), + }; + + run_bench( + "3+1 message PBS parameters", + &mut g, + &LweDef { + dim: LweDimension(742), + std: Stddev(0.000007069849454709433), + }, + &GlweDef { + dim: GlweDimension { + size: GlweSize(1), + polynomial_degree: PolynomialDegree(2048), + }, + std: Stddev(0.00000000000000029403601535432533), + }, + &bs_radix, + ); +} + +fn circuit_bootstrapping(c: &mut Criterion) { + let pbs_radix = RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(16), + }; + let cbs_radix = RadixDecomposition { + count: RadixCount(1), + radix_log: RadixLog(11), + }; + let pfks_radix = RadixDecomposition { + count: RadixCount(3), + radix_log: RadixLog(11), + }; + + let level_2_params = GLWE_5_256_80; + let level_1_params = GLWE_1_1024_80; + let level_0_params = LWE_512_80; + + let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params); + let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params); + let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params); + + let bsk = keygen::generate_bootstrapping_key( + &sk_0, + &sk_2, + &level_0_params, + &level_2_params, + &pbs_radix, + ); + let bsk = fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix); + + let cbsksk = keygen::generate_cbs_ksk( + sk_2.to_lwe_secret_key(), + &sk_1, + &level_2_params.as_lwe_def(), + &level_1_params, + &pfks_radix, + ); + + let val = 0; + + let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1)); + + let mut actual = GgswCiphertext::new(&level_1_params, &cbs_radix); + + c.bench_function("Circuit bootstrap", |b| { + b.iter(|| { + circuit_bootstrap( + &mut actual, + &ct, + &bsk, + &cbsksk, + &level_0_params, + &level_1_params, + &level_2_params, + &pbs_radix, + &cbs_radix, + &pfks_radix, + ); + }); + }); +} + +criterion_group!( + benches, + cmux, + programmable_bootstrapping, + circuit_bootstrapping +); +criterion_main!(benches); diff --git a/sunscreen_tfhe/benches/tfhe_proof.rs b/sunscreen_tfhe/benches/tfhe_proof.rs new file mode 100644 index 000000000..b5f58d6e8 --- /dev/null +++ b/sunscreen_tfhe/benches/tfhe_proof.rs @@ -0,0 +1,128 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use logproof::{ + InnerProductVerifierKnowledge, LogProof, LogProofGenerators, LogProofProverKnowledge, +}; +use merlin::Transcript; +use sunscreen_tfhe::{ + high_level::*, + zkp::{generate_tfhe_sdlp_prover_knowledge, ProofStatement, TorusZq, Witness}, + PlaintextBits, Torus, TorusOps, LWE_512_80, +}; + +fn make_proof(pk: &LogProofProverKnowledge) -> LogProof { + let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize); + let u = InnerProductVerifierKnowledge::get_u(); + let mut p_t = Transcript::new(b"test"); + + LogProof::create(&mut p_t, pk, &gen.g, &gen.h, &u) +} + +fn tfhe_secret_proof(c: &mut Criterion) { + let params = LWE_512_80; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let enc_data = (0..32) + .map(|_| encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits)) + .collect::>(); + let msg = vec![Torus::from(1u64); 32]; + + let statement = enc_data + .iter() + .enumerate() + .map(|(i, d)| ProofStatement::PrivateKeyEncryption { + message_id: i, + ciphertext: &d.0, + }) + .collect::>(); + + let witness = enc_data + .iter() + .enumerate() + .map(|(_i, d)| Witness::PrivateKeyEncryption { + randomness: d.1, + private_key: &sk, + }) + .collect::>(); + + let pk = generate_tfhe_sdlp_prover_knowledge(&statement, &msg, &witness, ¶ms, bits); + + let p = make_proof::(&pk); + + let mut g = c.benchmark_group("Secret key encryption"); + g.sample_size(10); + + g.bench_function("Prove 32-bit secret encryption", |b| { + b.iter(|| { + let _ = make_proof::(&pk); + }); + }); + + g.bench_function("Verify 32-bit secret encryption", |b| { + let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize); + let u = InnerProductVerifierKnowledge::get_u(); + + b.iter(|| { + let mut t = Transcript::new(b"test"); + + p.verify(&mut t, &pk.vk, &gen.g, &gen.h, &u).unwrap(); + }); + }); +} + +fn tfhe_public_proof(c: &mut Criterion) { + let params = LWE_512_80; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let public = keygen::generate_lwe_pk(&sk, ¶ms); + + let enc_data = (0..32) + .map(|_| encryption::encrypt_lwe_and_return_randomness(1, &public, ¶ms, bits)) + .collect::>(); + let msg = vec![Torus::from(1u64); 32]; + + let statement = enc_data + .iter() + .enumerate() + .map(|(i, d)| ProofStatement::PublicKeyEncryption { + public_key: &public, + message_id: i, + ciphertext: &d.0, + }) + .collect::>(); + + let witness = enc_data + .iter() + .enumerate() + .map(|(_i, d)| Witness::PublicKeyEncryption { randomness: &d.1 }) + .collect::>(); + + let pk = generate_tfhe_sdlp_prover_knowledge(&statement, &msg, &witness, ¶ms, bits); + + let p = make_proof::(&pk); + + let mut g = c.benchmark_group("Public key encryption"); + g.sample_size(10); + + g.bench_function("Prove 32-bit public encryption", |b| { + b.iter(|| { + let _ = make_proof::(&pk); + }); + }); + + g.bench_function("Verify 32-bit public encryption", |b| { + let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize); + let u = InnerProductVerifierKnowledge::get_u(); + + b.iter(|| { + let mut t = Transcript::new(b"test"); + + p.verify(&mut t, &pk.vk, &gen.g, &gen.h, &u).unwrap(); + }); + }); +} + +criterion_group!(benches, tfhe_secret_proof, tfhe_public_proof,); +criterion_main!(benches); diff --git a/sunscreen_tfhe/images/circuit_bootstrapping.graffle b/sunscreen_tfhe/images/circuit_bootstrapping.graffle new file mode 100644 index 000000000..d229b3339 Binary files /dev/null and b/sunscreen_tfhe/images/circuit_bootstrapping.graffle differ diff --git a/sunscreen_tfhe/images/circuit_bootstrapping.png b/sunscreen_tfhe/images/circuit_bootstrapping.png new file mode 100644 index 000000000..4b3d4ba68 Binary files /dev/null and b/sunscreen_tfhe/images/circuit_bootstrapping.png differ diff --git a/sunscreen_tfhe/mont.py b/sunscreen_tfhe/mont.py new file mode 100644 index 000000000..f5d6b5e7c --- /dev/null +++ b/sunscreen_tfhe/mont.py @@ -0,0 +1,34 @@ +p = 13 +r = 16 +rinv = (r**(p-2)) % p + +p_inv = (p**(p-2)) % r +p_prime = (p_inv * (r - 1)) % r +print((p_inv * p) % r) + +print((r * rinv) % p) + +def to_mont(x): + return (r * x) % p + +def from_mont(x): + return (x * rinv) % p + +def mont_add(x, y): + return (x + y) % p + +def mont_mul(x, y): + return (x * y * rinv) % p + +def redc(): + pass + +a = to_mont(5) +b = to_mont(6) + + +c = mont_mul(a, b) +d = mont_add(a, b) + +print(from_mont(c), from_mont(d)) + diff --git a/sunscreen_tfhe/src/dst.rs b/sunscreen_tfhe/src/dst.rs new file mode 100644 index 000000000..72228d5ee --- /dev/null +++ b/sunscreen_tfhe/src/dst.rs @@ -0,0 +1,293 @@ +use crate::scratch::Pod; + +macro_rules! dst { + ($(#[$meta:meta])* $t:ty, $ref_t:ty, $wrapper:ty, ($($derive:ident),* $(,)? ), ($($t_bounds:ty),* $(,)? )) => { + paste::paste! { + + $(#[$meta])* + #[derive($($derive,)*)] + pub struct $t where T: Clone $(+ $t_bounds)* { + data: Vec<$wrapper> + } + + /// A reference to the data structure. + #[repr(transparent)] + pub struct $ref_t where T: Clone $(+ $t_bounds)* { + data: [$wrapper], + } + + impl $ref_t where T: Clone $(+ $t_bounds)* { + /// Clones the contents of rhs into self + pub fn clone_from_ref(&mut self, rhs: &$ref_t) { + for (l, r) in self.data.iter_mut().zip(rhs.data.iter()) { + *l = r.clone(); + } + } + + /// Returns a slice view of the data representing a $t. + pub fn as_slice(&self) -> &[$wrapper] { + &self.data + } + + /// Returns a mutable slice view of the data representing a $t. + pub fn as_mut_slice(&mut self) -> &mut [$wrapper] { + &mut self.data + } + + /// Move the contents of rhs into self. + pub fn move_from(&mut self, rhs: $t) { + for (l, r) in self.data.iter_mut().zip(rhs.data.into_iter()) { + *l = r; + } + } + } + + impl crate::dst::FromSlice<$wrapper> for $ref_t where T: Clone $(+ $t_bounds)* { + fn from_slice(s: &[$wrapper]) -> &$ref_t { + unsafe { &*(s as *const [$wrapper] as *const $ref_t) } + } + } + + impl crate::dst::FromMutSlice<$wrapper> for $ref_t where T: Clone $(+ $t_bounds)* { + fn from_mut_slice(s: &mut [$wrapper]) -> &mut $ref_t { + unsafe { &mut *(s as *mut [$wrapper] as *mut $ref_t) } + } + } + + impl $ref_t where T: Clone $(+ $t_bounds)*, $wrapper: num::Zero { + /// Clears the contents of self to contain zero + pub fn clear(&mut self) { + + for x in self.as_mut_slice() { + *x = <$wrapper as num::Zero>::zero(); + } + } + } + + impl std::borrow::Borrow< $ref_t > for $t where T: Clone $(+ $t_bounds)* { + fn borrow(&self) -> &$ref_t { + let ptr = self.data.as_slice() as *const [$wrapper] as *const $ref_t; + + unsafe { &*ptr } + + } + } + + impl std::convert::AsRef< $ref_t > for $t where T: Clone $(+ $t_bounds)* + { + fn as_ref(&self) -> &$ref_t { + >>::borrow(self) + } + } + + impl std::borrow::BorrowMut< $ref_t> for $t where T: Clone $(+ $t_bounds)* { + fn borrow_mut(&mut self) -> &mut $ref_t { + let ptr = self.data.as_mut_slice() as *mut [$wrapper] as *mut $ref_t; + + unsafe { &mut *ptr } + + } + } + + impl std::borrow::ToOwned for $ref_t where T: Clone $(+ $t_bounds)* { + type Owned = $t; + + fn to_owned(&self) -> Self::Owned { + $t { data: self.data.to_owned() } + } + } + + impl std::ops::Deref for $t where T: Clone $(+ $t_bounds)* { + type Target = $ref_t; + + fn deref(&self) -> &Self::Target { + >>::borrow(&self) + } + } + + impl std::ops::DerefMut for $t where T: Clone $(+ $t_bounds)* { + fn deref_mut(&mut self) -> &mut Self::Target { + >>::borrow_mut(self) + } + } + } + }; +} + +macro_rules! dst_iter { + ($t:ty, $t_mut:ty, $wrapper_type: ty, $item_ref:ty, ($($t_bounds:ty,)*)) => { + paste::paste!{ + /// An iterator to access references to an underlying type. + pub struct $t<'a, T> where T: Clone $(+ $t_bounds)* { + data: &'a [$wrapper_type], + stride: usize, + front_idx: usize, + back_idx: i64 + } + + impl<'a, T> $t<'a, T> where T: Clone $(+ $t_bounds)* { + /// Create a new iterator that emits references to the contained type + /// by striding over the underlying data. + pub fn new(data: &'a [$wrapper_type], stride: usize) -> Self { + assert_eq!(data.len() % stride, 0); + + Self { + data, + stride, + front_idx: 0, + back_idx: (data.len() as i64) - (stride as i64) + } + } + + #[inline] + /// The total number of items this iterator will emit and has emitted. + /// + /// # Remarks + /// This method returns the same value regardless of how many times + /// `next` has been called. + /// + /// This operation does not consume the iterator. + pub fn len(&self) -> usize { + self.data.len() / self.stride + } + + /// Returns true if the iterator is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + } + + impl<'a, T> std::iter::Iterator for $t<'a, T> where T: Clone $(+ $t_bounds)* { + type Item = &'a $item_ref; + + fn next(&mut self) -> Option { + if self.front_idx >= self.data.len() { + return None; + } + + if (self.front_idx as i64) == self.back_idx + self.stride as i64 { + return None; + } + + let data = <$item_ref as crate::dst::FromSlice<$wrapper_type>>::from_slice( + &self.data[self.front_idx..self.front_idx + self.stride] + ); + + self.front_idx += self.stride; + + Some(data) + } + } + + impl<'a, T> std::iter::DoubleEndedIterator for $t<'a, T> where T: Clone $(+ $t_bounds)* { + fn next_back(&mut self) -> Option { + if self.back_idx < 0 { + return None; + } + + if (self.front_idx as i64) == self.back_idx + self.stride as i64 { + return None; + } + + let start = self.back_idx as usize; + + let data = <$item_ref as crate::dst::FromSlice<$wrapper_type>>::from_slice( + &self.data[start..start + self.stride] + ); + + self.back_idx -= (self.stride as i64); + + Some(data) + } + } + + /// A mutable iterator to access references to an underlying type. + pub struct $t_mut<'a, T> where T: Clone $(+ $t_bounds)* { + data: *mut $wrapper_type, + len: usize, + stride: usize, + idx: usize, + _phantom: std::marker::PhantomData<&'a T>, + } + + impl<'a, T> $t_mut<'a, T> where T: Clone $(+ $t_bounds)* { + /// Create a new iterator that emits references to the contained type + /// by striding over the underlying data, mutably. + pub fn new(data: &'a mut [$wrapper_type], stride: usize) -> Self { + assert_eq!(data.len() % stride, 0); + + Self { + idx: 0, + stride, + data: data.as_mut_ptr(), + len: data.len(), + _phantom: std::marker::PhantomData + } + } + + #[inline] + /// The total number of items this iterator will emit and has emitted. + /// + /// # Remarks + /// This method returns the same value regardless of how many times + /// `next` has been called. + /// + /// This operation does not consume the iterator. + pub fn len(&self) -> usize { + self.len / self.stride + } + + /// Returns true if the iterator is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + } + + impl<'a, T> std::iter::Iterator for $t_mut<'a, T> where T: Clone $(+ $t_bounds)* { + type Item = &'a mut $item_ref; + + fn next(&mut self) -> Option { + if self.idx == self.len { + return None; + } + + // Since the slices emitted by this iterator don't overlap, this is sound. + let data = unsafe { + let slice = self.data.add(self.idx); + std::slice::from_raw_parts_mut(slice, self.stride) + }; + + self.idx += self.stride; + + Some(<$item_ref as crate::dst::FromMutSlice<$wrapper_type>>::from_mut_slice(data)) + } + } + } + }; +} + +pub type NoWrapper = T; + +pub trait OverlaySize { + type Inputs: Copy + Clone; + + fn size(t: Self::Inputs) -> usize; +} + +impl OverlaySize for [S] { + type Inputs = usize; + + fn size(t: Self::Inputs) -> usize { + t + } +} + +pub trait FromSlice { + fn from_slice(data: &[T]) -> &Self; +} + +pub trait FromMutSlice { + fn from_mut_slice(data: &mut [T]) -> &mut Self; +} diff --git a/sunscreen_tfhe/src/entities/bivariate_lookup_table.rs b/sunscreen_tfhe/src/entities/bivariate_lookup_table.rs new file mode 100644 index 000000000..93ccefb79 --- /dev/null +++ b/sunscreen_tfhe/src/entities/bivariate_lookup_table.rs @@ -0,0 +1,98 @@ +use serde::{Deserialize, Serialize}; +use sunscreen_math::Zero; + +use crate::{ + dst::{FromMutSlice, FromSlice, OverlaySize}, + entities::PolynomialRef, + ops::{bootstrapping::generate_bivariate_lut, encryption::trivially_encrypt_glwe_ciphertext}, + scratch::allocate_scratch_ref, + CarryBits, GlweDef, GlweDimension, PlaintextBits, Torus, TorusOps, +}; + +use super::{GlweCiphertextRef, UnivariateLookupTableRef}; + +dst! { + /// Lookup table for a bivariate function used during + /// [`programmable_bootstrap_bivariate`](crate::ops::bootstrapping::programmable_bootstrap_bivariate) + /// See that function for more details. + BivariateLookupTable, + BivariateLookupTableRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for BivariateLookupTableRef { + type Inputs = GlweDimension; + + fn size(t: Self::Inputs) -> usize { + GlweCiphertextRef::::size(t) + } +} + +impl BivariateLookupTable { + /// Creates a [BivariateLookupTable] filled with the result of + /// a function applied to every possible pair of plaintext inputs. + pub fn trivial_from_fn( + map: F, + glwe: &GlweDef, + plaintext_bits: PlaintextBits, + carry_bits: CarryBits, + ) -> Self + where + F: Fn(u64, u64) -> u64, + { + let mut lut = BivariateLookupTable { + data: vec![Torus::zero(); BivariateLookupTableRef::::size(glwe.dim)], + }; + + lut.fill_trivial_from_fn(map, glwe, plaintext_bits, carry_bits); + + lut + } +} + +impl BivariateLookupTableRef { + /// Convert a [BivariateLookupTableRef] to a [UnivariateLookupTableRef]. + pub fn as_univariate(&self) -> &UnivariateLookupTableRef { + // This works because a bivariate lookup table is just a univariate + // lookup table. + UnivariateLookupTableRef::from_slice(&self.data) + } + + /// Gets a copy of the underlying [GlweCiphertextRef] from the + /// [BivariateLookupTableRef]. + pub fn glwe(&self) -> &GlweCiphertextRef { + GlweCiphertextRef::from_slice(&self.data) + } + + /// Gets a mutable copy of the underlying [GlweCiphertextRef] from the + /// [BivariateLookupTableRef]. + pub fn glwe_mut(&mut self) -> &mut GlweCiphertextRef { + GlweCiphertextRef::from_mut_slice(&mut self.data) + } + + /// Fills the [BivariateLookupTableRef] with the result of a bivariate + /// function. + pub fn fill_trivial_from_fn u64>( + &mut self, + map: F, + glwe: &GlweDef, + plaintext_bits: PlaintextBits, + carry_bits: CarryBits, + ) { + allocate_scratch_ref!(poly, PolynomialRef>, (glwe.dim.polynomial_degree)); + + generate_bivariate_lut(poly, map, glwe, plaintext_bits, carry_bits); + + trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe); + } + + /// Creates a lookup table filled with the same value at every entry. + pub fn fill_with_constant(&mut self, val: S, glwe: &GlweDef, plaintext_bits: PlaintextBits) { + self.clear(); + for o in self.glwe_mut().b_mut(glwe).coeffs_mut() { + *o = Torus::encode(val, plaintext_bits); + } + } +} diff --git a/sunscreen_tfhe/src/entities/blind_rotation_shift.rs b/sunscreen_tfhe/src/entities/blind_rotation_shift.rs new file mode 100644 index 000000000..0e605227f --- /dev/null +++ b/sunscreen_tfhe/src/entities/blind_rotation_shift.rs @@ -0,0 +1,137 @@ +use num::{Complex, Zero}; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + entities::{ + GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, GgswCiphertextFftRef, + GgswCiphertextIterator, GgswCiphertextIteratorMut, GgswCiphertextRef, + }, + GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +dst! { + /// An encrypted amount to rotate the polynomials in a GLWE ciphertext by. + /// The [BlindRotationShiftFft] variant of this type is used by the + /// [`blind_rotate`](crate::ops::bootstrapping::blind_rotation) function. + BlindRotationShift, + BlindRotationShiftRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for BlindRotationShiftRef { + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + let n_bits = (t.0.polynomial_degree.0 as u64).ilog2() as usize; + + GgswCiphertextRef::::size(t) * n_bits + } +} + +impl BlindRotationShift { + /// Create a new zero [BlindRotationShift] with the given parameters. + /// + /// A blind rotation shift is a collection of GGSW ciphertexts, each of + /// which encrypts a single bit in position `i` representing how much to + /// shift the input by as `2^i`. + pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self { + let len = BlindRotationShiftRef::::size((params.dim, radix.count)); + + Self { + data: vec![Torus::zero(); len], + } + } +} + +impl BlindRotationShiftRef { + /// Iterate over the rows of the [BlindRotationShift]. + pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextIterator { + let stride = GgswCiphertextRef::::size((params.dim, radix.count)); + + GgswCiphertextIterator::new(self.as_slice(), stride) + } + + /// Iterate over the rows of the [BlindRotationShift] mutably. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextIteratorMut { + let stride = GgswCiphertextRef::::size((params.dim, radix.count)); + + GgswCiphertextIteratorMut::new(self.as_mut_slice(), stride) + } +} + +dst! { + /// An encrypted amount to rotate the polynomials in a GLWE ciphertext by. + /// Used by the + /// [`blind_rotate`](crate::ops::bootstrapping::blind_rotation) function. + /// The non-FFT version of this type is [BlindRotationShift]. + BlindRotationShiftFft, + BlindRotationShiftFftRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + () +} + +impl OverlaySize for BlindRotationShiftFftRef> { + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + let n_bits = (t.0.polynomial_degree.0 as u64).ilog2() as usize; + + GgswCiphertextFftRef::>::size(t) * n_bits + } +} + +impl BlindRotationShiftFft> { + /// Create a new zero [BlindRotationShiftFft] with the given parameters. + pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self { + let len = BlindRotationShiftFftRef::size((params.dim, radix.count)); + + Self { + data: vec![Complex::zero(); len], + } + } +} + +impl BlindRotationShiftFftRef> { + /// Iterate over the rows of the [BlindRotationShiftFft]. + pub fn rows( + &self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextFftIterator> { + let stride = GgswCiphertextFftRef::>::size((params.dim, radix.count)); + + GgswCiphertextFftIterator::new(self.as_slice(), stride) + } + + /// Iterate over the rows of the [BlindRotationShiftFft] mutably. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextFftIteratorMut> { + let stride = GgswCiphertextFftRef::>::size((params.dim, radix.count)); + + GgswCiphertextFftIteratorMut::new(self.as_mut_slice(), stride) + } + + /// Compute the inverse FFT of the [BlindRotationShiftFft] and store the + /// result in the result [BlindRotationShift]. + pub fn ifft( + &self, + result: &mut BlindRotationShiftRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) { + for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) { + s.ifft(r, params, radix); + } + } +} diff --git a/sunscreen_tfhe/src/entities/bootstrap_key.rs b/sunscreen_tfhe/src/entities/bootstrap_key.rs new file mode 100644 index 000000000..a147a2e3a --- /dev/null +++ b/sunscreen_tfhe/src/entities/bootstrap_key.rs @@ -0,0 +1,160 @@ +use num::{Complex, Zero}; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + entities::{ + GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, GgswCiphertextFftRef, + GgswCiphertextIterator, GgswCiphertextIteratorMut, GgswCiphertextRef, + }, + GlweDef, GlweDimension, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +dst! { + /// Keys used for bootstrapping. The [BootstrapKeyFft] variant of this type + /// is used by the bootstrapping functions such as + /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap). + BootstrapKey, + BootstrapKeyRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for BootstrapKeyRef { + type Inputs = (LweDimension, GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GgswCiphertextRef::::size((t.1, t.2)) * t.0 .0 + } +} + +impl BootstrapKey { + /// Create a new zero [BootstrapKey] with the given parameters. + /// + /// A bootstrapping key is a collection of GGSW ciphertexts, each of which + /// encrypts a single bit of an LWE secret key. This representation cannot + /// be directly with the bootstrapping functions, but the FFT version of the + /// bootstrapping key that can be used with the bootstrapping functions can + /// be created by calling the [BootstrapKeyRef::fft] method + pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self { + let len = BootstrapKeyRef::::size((lwe_params.dim, glwe_params.dim, radix.count)); + + Self { + data: vec![Torus::zero(); len], + } + } +} + +impl BootstrapKeyRef { + /// Iterate over the rows of the [BootstrapKey]. + pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextIterator { + let stride = GgswCiphertextRef::::size((params.dim, radix.count)); + + GgswCiphertextIterator::new(self.as_slice(), stride) + } + + /// Iterate over the rows of the [BootstrapKey] mutably. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextIteratorMut { + let stride = GgswCiphertextRef::::size((params.dim, radix.count)); + + GgswCiphertextIteratorMut::new(self.as_mut_slice(), stride) + } + + /// Perform an FFT on the [BootstrapKey] to obtain a [BootstrapKeyFft]. + pub fn fft( + &self, + result: &mut BootstrapKeyFftRef>, + params: &GlweDef, + radix: &RadixDecomposition, + ) { + for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) { + s.fft(r, params, radix); + } + } +} + +dst! { + /// Keys used for bootstrapping. Used by the bootstrapping functions such as + /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap). + /// The non-FFT variant of this type is [BootstrapKey]. + BootstrapKeyFft, + BootstrapKeyFftRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + () +} + +impl OverlaySize for BootstrapKeyFftRef> { + type Inputs = (LweDimension, GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GgswCiphertextFftRef::>::size((t.1, t.2)) * t.0 .0 + } +} + +impl BootstrapKeyFft> { + /// Create a new zero [BootstrapKeyFft] with the given parameters. + /// + /// A bootstrapping key is a collection of GGSW ciphertexts, each of which + /// encrypts a single bit of an LWE secret key. In this representation, the + /// GGSW ciphertexts are in the frequency domain and can be used directly by + /// the bootstrapping functions such as + /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap). + pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self { + let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count)); + + Self { + data: vec![Complex::zero(); len], + } + } +} + +impl BootstrapKeyFftRef> { + /// Iterate over the rows of the [BootstrapKeyFft]. + pub fn rows( + &self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextFftIterator> { + let stride = GgswCiphertextFftRef::>::size((params.dim, radix.count)); + + GgswCiphertextFftIterator::new(self.as_slice(), stride) + } + + /// Iterate over the rows of the [BootstrapKeyFft] mutably. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextFftIteratorMut> { + let stride = GgswCiphertextFftRef::>::size((params.dim, radix.count)); + + GgswCiphertextFftIteratorMut::new(self.as_mut_slice(), stride) + } + + /// Perform an IFFT on the [BootstrapKeyFft] to obtain a [BootstrapKey]. + pub fn ifft( + &self, + result: &mut BootstrapKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) { + for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) { + s.ifft(r, params, radix); + } + } + + /// Asserts that the [BootstrapKeyFft] is valid for the given parameters. + #[inline(always)] + pub(crate) fn assert_valid(&self, lwe: &LweDef, glwe: &GlweDef, radix: &RadixDecomposition) { + assert_eq!( + self.as_slice().len(), + BootstrapKeyFftRef::size((lwe.dim, glwe.dim, radix.count)) + ); + } +} diff --git a/sunscreen_tfhe/src/entities/circuit_bootstrapping_private_keyswitch_keys.rs b/sunscreen_tfhe/src/entities/circuit_bootstrapping_private_keyswitch_keys.rs new file mode 100644 index 000000000..00145991b --- /dev/null +++ b/sunscreen_tfhe/src/entities/circuit_bootstrapping_private_keyswitch_keys.rs @@ -0,0 +1,98 @@ +use serde::{Deserialize, Serialize}; +use sunscreen_math::Zero; + +use crate::{ + dst::OverlaySize, GlweDef, GlweDimension, LweDef, LweDimension, + PrivateFunctionalKeyswitchLweCount, RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +use super::{ + PrivateFunctionalKeyswitchKeyIter, PrivateFunctionalKeyswitchKeyIterMut, + PrivateFunctionalKeyswitchKeyRef, +}; + +dst! { + /// Key for Circuit Bootstrapping Key Switching. + CircuitBootstrappingKeyswitchKeys, + CircuitBootstrappingKeyswitchKeysRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for CircuitBootstrappingKeyswitchKeysRef { + type Inputs = (LweDimension, GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + PrivateFunctionalKeyswitchKeyRef::::size(( + t.0, + t.1, + t.2, + PrivateFunctionalKeyswitchLweCount(1), + )) * (t.1.size.0 + 1) + } +} + +impl CircuitBootstrappingKeyswitchKeys { + /// Allocate a new [`CircuitBootstrappingKeyswitchKeys`] for the given parameters. + pub fn new(from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) -> Self { + let len = CircuitBootstrappingKeyswitchKeysRef::::size(( + from_lwe.dim, + to_glwe.dim, + radix.count, + )); + + Self { + data: vec![Torus::zero(); len], + } + } +} + +impl CircuitBootstrappingKeyswitchKeysRef { + /// Get an iterator over the contained [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s. + pub fn keys( + &self, + lwe: &LweDef, + glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> PrivateFunctionalKeyswitchKeyIter { + let stride = PrivateFunctionalKeyswitchKeyRef::::size(( + lwe.dim, + glwe.dim, + radix.count, + PrivateFunctionalKeyswitchLweCount(1), + )); + + PrivateFunctionalKeyswitchKeyIter::new(self.as_slice(), stride) + } + + /// Get a mutable iterator over the contained [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s. + pub fn keys_mut( + &mut self, + lwe: &LweDef, + glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> PrivateFunctionalKeyswitchKeyIterMut { + let stride = PrivateFunctionalKeyswitchKeyRef::::size(( + lwe.dim, + glwe.dim, + radix.count, + PrivateFunctionalKeyswitchLweCount(1), + )); + + PrivateFunctionalKeyswitchKeyIterMut::new(self.as_mut_slice(), stride) + } + + #[inline(always)] + /// Assert these keys are valid under the given parameters. + pub fn assert_valid(&self, from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) { + assert_eq!( + self.as_slice().len(), + CircuitBootstrappingKeyswitchKeysRef::::size(( + from_lwe.dim, + to_glwe.dim, + radix.count, + )) + ); + } +} diff --git a/sunscreen_tfhe/src/entities/ggsw_ciphertext.rs b/sunscreen_tfhe/src/entities/ggsw_ciphertext.rs new file mode 100644 index 000000000..3bec24874 --- /dev/null +++ b/sunscreen_tfhe/src/entities/ggsw_ciphertext.rs @@ -0,0 +1,114 @@ +use num::{Complex, Zero}; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::OverlaySize, ops::ciphertext::external_product_ggsw_glwe, GlweDef, GlweDimension, + RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +use super::{ + GgswCiphertextFftRef, GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef, + GlweCiphertext, GlweCiphertextRef, +}; + +dst! { + /// A GGSW ciphertext. For the FFT variant, see + /// [`GgswCiphertextFft`](crate::entities::GgswCiphertextFft). + GgswCiphertext, + GgswCiphertextRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} +dst_iter! { GgswCiphertextIterator, GgswCiphertextIteratorMut, Torus, GgswCiphertextRef, (TorusOps,)} + +impl OverlaySize for GgswCiphertextRef +where + S: TorusOps, +{ + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GlevCiphertextRef::::size(t) * (t.0.size.0 + 1) + } +} + +impl GgswCiphertext +where + S: TorusOps, +{ + /// Create a new zero GGSW ciphertext with the given parameters. + pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self { + let elems = GgswCiphertextRef::::size((params.dim, radix.count)); + + Self { + data: vec![Torus::zero(); elems], + } + } + + /// Create a new GGSW ciphertext from a slice of Torus elements. + pub fn from_slice(data: &[Torus], params: &GlweDef, radix: &RadixDecomposition) -> Self { + let elems = GgswCiphertextRef::::size((params.dim, radix.count)); + + assert_eq!(data.len(), elems); + + Self { + data: data.to_vec(), + } + } + + /// Computes the external product between a GGSW ciphertext and a GLWE ciphertext. + /// GGSW ⊡ GLWE -> GLWE + pub fn external_product( + &self, + glwe: &GlweCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlweCiphertext { + external_product_ggsw_glwe(self, glwe, params, radix) + } +} + +impl GgswCiphertextRef +where + S: TorusOps, +{ + /// Returns an iterator over the rows of the GGSW ciphertext, which are + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. + pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GlevCiphertextIterator { + let stride = GlevCiphertextRef::::size((params.dim, radix.count)); + + GlevCiphertextIterator::new(&self.data, stride) + } + + /// Returns a mutable iterator over the rows of the GGSW ciphertext, which are + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextIteratorMut { + let stride = GlevCiphertextRef::::size((params.dim, radix.count)); + + GlevCiphertextIteratorMut::new(&mut self.data, stride) + } + + /// Compute the FFT of each of the GLWE ciphertexts in the GGSW ciphertext. + /// The result is stored in `result`. + pub fn fft( + &self, + result: &mut GgswCiphertextFftRef>, + params: &GlweDef, + radix: &RadixDecomposition, + ) { + for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) { + s.fft(r, params); + } + } + + /// Assert that the GGSW ciphertext is valid for the given parameters. + #[inline(always)] + pub(crate) fn assert_valid(&self, glwe: &GlweDef, radix: &RadixDecomposition) { + assert_eq!(self.as_slice().len(), Self::size((glwe.dim, radix.count))); + } +} diff --git a/sunscreen_tfhe/src/entities/ggsw_ciphertext_fft.rs b/sunscreen_tfhe/src/entities/ggsw_ciphertext_fft.rs new file mode 100644 index 000000000..59b7fe526 --- /dev/null +++ b/sunscreen_tfhe/src/entities/ggsw_ciphertext_fft.rs @@ -0,0 +1,79 @@ +use num::{Complex, Zero}; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + entities::GgswCiphertextRef, + GlweDef, GlweDimension, RadixCount, RadixDecomposition, TorusOps, +}; + +use super::{GlevCiphertextFftIterator, GlevCiphertextFftIteratorMut, GlevCiphertextFftRef}; + +dst! { + /// The FFT variant of a GGSW ciphertext. See + /// [`GgswCiphertext`](crate::entities::GgswCiphertext) for more details. + GgswCiphertextFft, + GgswCiphertextFftRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + () +} +dst_iter! { GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, NoWrapper, GgswCiphertextFftRef, ()} + +impl OverlaySize for GgswCiphertextFftRef> { + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GlevCiphertextFftRef::>::size(t) * (t.0.size.0 + 1) + } +} + +impl GgswCiphertextFft> { + /// Creates a new GGSW ciphertext with FFT representation. + pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextFft> { + let len = GgswCiphertextFftRef::size((params.dim, radix.count)); + + GgswCiphertextFft { + data: vec![Complex::zero(); len], + } + } +} + +impl GgswCiphertextFftRef> { + /// Returns an iterator over the rows of the GGSW ciphertext, which are + /// [GlevCiphertextFft](crate::entities::GlevCiphertextFft)s. + pub fn rows( + &self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextFftIterator> { + let stride = GlevCiphertextFftRef::>::size((params.dim, radix.count)); + + GlevCiphertextFftIterator::new(self.as_slice(), stride) + } + + /// Returns a mutable iterator over the rows of the GGSW ciphertext, which are + /// [GlevCiphertextFft](crate::entities::GlevCiphertextFft)s. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextFftIteratorMut> { + let stride = GlevCiphertextFftRef::>::size((params.dim, radix.count)); + + GlevCiphertextFftIteratorMut::new(self.as_mut_slice(), stride) + } + + /// Computes the inverse FFT of the GGSW ciphertexts and stores computation + /// in `result`. + pub fn ifft( + &self, + result: &mut GgswCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) { + for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) { + s.ifft(r, params); + } + } +} diff --git a/sunscreen_tfhe/src/entities/glev_ciphertext.rs b/sunscreen_tfhe/src/entities/glev_ciphertext.rs new file mode 100644 index 000000000..bfd7a9c42 --- /dev/null +++ b/sunscreen_tfhe/src/entities/glev_ciphertext.rs @@ -0,0 +1,58 @@ +use num::Complex; +use serde::{Deserialize, Serialize}; + +use crate::{dst::OverlaySize, GlweDef, GlweDimension, RadixCount, Torus, TorusOps}; + +use super::{ + GlevCiphertextFftRef, GlweCiphertextIterator, GlweCiphertextIteratorMut, GlweCiphertextRef, +}; + +dst! { + /// A GLEV ciphertext. For the FFT variant, see + /// [`GlevCiphertextFft`](crate::entities::GlevCiphertextFft). + GlevCiphertext, + GlevCiphertextRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps,) +} +dst_iter! { GlevCiphertextIterator, GlevCiphertextIteratorMut, Torus, GlevCiphertextRef, (TorusOps,)} + +impl OverlaySize for GlevCiphertextRef +where + S: TorusOps, +{ + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GlweCiphertextRef::::size(t.0) * t.1 .0 + } +} + +impl GlevCiphertextRef +where + S: TorusOps, +{ + /// Returns an iterator over the rows of the GLEV ciphertext, which are + /// [`GlweCiphertext`](crate::entities::GlweCiphertext)s. + pub fn glwe_ciphertexts(&self, params: &GlweDef) -> GlweCiphertextIterator { + GlweCiphertextIterator::new(&self.data, GlweCiphertextRef::::size(params.dim)) + } + + /// Returns a mutable iterator over the rows of the GLEV ciphertext, which are + /// [`GlweCiphertext](crate::entities::GlweCiphertext)s. + pub fn glwe_ciphertexts_mut(&mut self, params: &GlweDef) -> GlweCiphertextIteratorMut { + GlweCiphertextIteratorMut::new(&mut self.data, GlweCiphertextRef::::size(params.dim)) + } + + /// Compute the FFT of each of the GLWE ciphertexts in the GLEV ciphertext. + /// The result is stored in `result`. + pub fn fft(&self, result: &mut GlevCiphertextFftRef>, params: &GlweDef) { + for (i, fft) in self + .glwe_ciphertexts(params) + .zip(result.glwe_ciphertexts_mut(params)) + { + i.fft(fft, params); + } + } +} diff --git a/sunscreen_tfhe/src/entities/glev_ciphertext_fft.rs b/sunscreen_tfhe/src/entities/glev_ciphertext_fft.rs new file mode 100644 index 000000000..6a5d4aae0 --- /dev/null +++ b/sunscreen_tfhe/src/entities/glev_ciphertext_fft.rs @@ -0,0 +1,65 @@ +use num::Complex; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + GlweDef, GlweDimension, RadixCount, TorusOps, +}; + +use super::{ + GlevCiphertextRef, GlweCiphertextFftIterator, GlweCiphertextFftIteratorMut, + GlweCiphertextFftRef, +}; + +dst! { + /// The FFT variant of a GLEV ciphertext. See + /// [GlevCiphertext](crate::entities::GlevCiphertext) for more details. + GlevCiphertextFft, + GlevCiphertextFftRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + () +} +dst_iter! { GlevCiphertextFftIterator, GlevCiphertextFftIteratorMut, NoWrapper, GlevCiphertextFftRef, ()} + +impl OverlaySize for GlevCiphertextFftRef> { + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GlweCiphertextFftRef::>::size(t.0) * t.1 .0 + } +} + +impl GlevCiphertextFftRef> { + /// Returns an iterator over the rows of the GLEV ciphertext, which are + /// [`GlweCiphertextFft`](crate::entities::GlweCiphertextFft)s. + pub fn glwe_ciphertexts(&self, params: &GlweDef) -> GlweCiphertextFftIterator> { + GlweCiphertextFftIterator::new( + &self.data, + GlweCiphertextFftRef::>::size(params.dim), + ) + } + + /// Returns a mutable iterator over the rows of the GLEV ciphertext, which are + /// [`GlweCiphertextFft`](crate::entities::GlweCiphertextFft)s. + pub fn glwe_ciphertexts_mut( + &mut self, + params: &GlweDef, + ) -> GlweCiphertextFftIteratorMut> { + GlweCiphertextFftIteratorMut::new( + &mut self.data, + GlweCiphertextFftRef::>::size(params.dim), + ) + } + + /// Computes the inverse FFT of the GLEV ciphertexts and stores computation + /// in `result`. + pub fn ifft(&self, result: &mut GlevCiphertextRef, params: &GlweDef) { + for (i, ifft) in self + .glwe_ciphertexts(params) + .zip(result.glwe_ciphertexts_mut(params)) + { + i.ifft(ifft, params); + } + } +} diff --git a/sunscreen_tfhe/src/entities/glwe_ciphertext.rs b/sunscreen_tfhe/src/entities/glwe_ciphertext.rs new file mode 100644 index 000000000..5c87486b0 --- /dev/null +++ b/sunscreen_tfhe/src/entities/glwe_ciphertext.rs @@ -0,0 +1,159 @@ +use num::{Complex, Zero}; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{FromMutSlice, FromSlice, OverlaySize}, + entities::GgswCiphertextRef, + macros::{impl_binary_op, impl_unary_op}, + ops::ciphertext::external_product_ggsw_glwe, + GlweDef, GlweDimension, RadixDecomposition, Torus, TorusOps, +}; + +use super::{ + GlweCiphertextFftRef, GlweSecretKeyRef, PolynomialIterator, PolynomialIteratorMut, + PolynomialRef, +}; + +dst! { + /// A GLWE ciphertext. + GlweCiphertext, + GlweCiphertextRef, + Torus, + (Debug, Clone, Serialize, Deserialize), + (TorusOps) +} +dst_iter! { GlweCiphertextIterator, GlweCiphertextIteratorMut, Torus, GlweCiphertextRef, (TorusOps,) } + +// Also implements the assign operators. +impl_binary_op!(Add, GlweCiphertext, (TorusOps,)); +impl_binary_op!(Sub, GlweCiphertext, (TorusOps,)); +impl_unary_op!(Neg, GlweCiphertext); + +impl OverlaySize for GlweCiphertextRef +where + S: TorusOps, +{ + type Inputs = GlweDimension; + + fn size(t: Self::Inputs) -> usize { + // We have `n` a polynomials plus 1 b polynomial each of degree d. + GlweSecretKeyRef::::size(t) + t.polynomial_degree.0 + } +} + +impl GlweCiphertext +where + S: TorusOps, +{ + /// Initialize an empty (zero) GLWE ciphertext + pub fn new(params: &GlweDef) -> GlweCiphertext { + params.dim.assert_valid(); + + let len = GlweCiphertextRef::::size(params.dim); + + let data = (0..len).map(|_| Torus::::zero()).collect::>(); + + GlweCiphertext { data } + } + + /// Computes the external product of a GLWE ciphertext and a GGSW ciphertext. + /// GGSW ⊡ GLWE -> GLWE + pub fn external_product( + &self, + ggsw: &GgswCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlweCiphertext { + external_product_ggsw_glwe(ggsw, self, params, radix) + } + + /// Generate a GLWE ciphertext from a slice of [crate::TorusOps] elements. + pub fn from_slice(data: &[S], params: &GlweDef) -> GlweCiphertext { + assert_eq!(data.len(), GlweCiphertextRef::::size(params.dim)); + + GlweCiphertext { + data: data + .iter() + .map(|x| Torus::from(*x)) + .collect::>>(), + } + } +} + +impl GlweCiphertextRef +where + S: TorusOps, +{ + /// Returns an iterator over the `a` polynomials and the `b` polynomial. + pub fn a_b( + &self, + params: &GlweDef, + ) -> (PolynomialIterator>, &PolynomialRef>) { + let (a, b) = self.data.as_ref().split_at(self.split_idx(params)); + + ( + PolynomialIterator::new(a, params.dim.polynomial_degree.0), + PolynomialRef::from_slice(b), + ) + } + + /// Returns an interator over the a polynomials in a GLWE ciphertext. + pub fn a(&self, params: &GlweDef) -> PolynomialIterator> { + self.a_b(params).0 + } + + /// Returns a reference to the b polynomial in a GLWE ciphertext. + pub fn b(&self, params: &GlweDef) -> &PolynomialRef> { + self.a_b(params).1 + } + + /// Returns an iterator over the `a` polynomials and the `b` polynomial. + pub fn a_b_mut( + &mut self, + params: &GlweDef, + ) -> ( + PolynomialIteratorMut>, + &mut PolynomialRef>, + ) { + let polynomial_degree = params.dim.polynomial_degree; + let split_idx = self.split_idx(params); + + let (a, b) = self.data.as_mut().split_at_mut(split_idx); + + ( + PolynomialIteratorMut::new(a, polynomial_degree.0), + PolynomialRef::from_mut_slice(b), + ) + } + + /// Returns a mutable iterator over the a polynomials in a GLWE ciphertext. + pub fn a_mut(&mut self, params: &GlweDef) -> PolynomialIteratorMut> { + self.a_b_mut(params).0 + } + + /// Returns a mutable reference to the b polynomial in a GLWE ciphertext. + pub fn b_mut(&mut self, params: &GlweDef) -> &mut PolynomialRef> { + self.a_b_mut(params).1 + } + + fn split_idx(&self, params: &GlweDef) -> usize { + params.dim.size.0 * params.dim.polynomial_degree.0 + } + + /// Create an FFT transformed version of `self` stored to result. + pub fn fft(&self, result: &mut GlweCiphertextFftRef>, params: &GlweDef) { + for (a, fft) in self.a(params).zip(result.a_mut(params)) { + a.fft(fft); + } + + self.b(params).fft(result.b_mut(params)); + } + + #[inline(always)] + pub(crate) fn assert_valid(&self, params: &GlweDef) { + assert_eq!( + self.as_slice().len(), + GlweCiphertextRef::::size(params.dim) + ) + } +} diff --git a/sunscreen_tfhe/src/entities/glwe_ciphertext_fft.rs b/sunscreen_tfhe/src/entities/glwe_ciphertext_fft.rs new file mode 100644 index 000000000..2eb226081 --- /dev/null +++ b/sunscreen_tfhe/src/entities/glwe_ciphertext_fft.rs @@ -0,0 +1,139 @@ +use num::{complex::Complex64, Complex, Zero}; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{FromMutSlice, FromSlice, NoWrapper, OverlaySize}, + GlweDef, GlweDimension, TorusOps, +}; + +use super::{GlweCiphertextRef, PolynomialFftIterator, PolynomialFftIteratorMut, PolynomialFftRef}; + +dst! { + /// The FFT variant of a GLWE ciphertext. See + /// [`GlweCiphertext`](crate::entities::GlweCiphertext) for more details. + GlweCiphertextFft, + GlweCiphertextFftRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + () +} +dst_iter! { GlweCiphertextFftIterator, GlweCiphertextFftIteratorMut, NoWrapper, GlweCiphertextFftRef, ()} + +impl OverlaySize for GlweCiphertextFftRef> { + type Inputs = GlweDimension; + + fn size(t: Self::Inputs) -> usize { + // FFT polynomials are half the length of their standard counterparts. + PolynomialFftRef::>::size(t.polynomial_degree) * (t.size.0 + 1) + } +} + +impl GlweCiphertextFft> { + /// Creates a new zero GLWE ciphertext in the frequency domain. + pub fn new(params: &GlweDef) -> Self { + let len = GlweCiphertextFftRef::size(params.dim); + + Self { + data: vec![Complex::zero(); len], + } + } +} + +impl GlweCiphertextFftRef> { + /// Returns an iterator over the `a` polynomials and the `b` polynomial. + pub fn a_b( + &self, + params: &GlweDef, + ) -> ( + PolynomialFftIterator>, + &PolynomialFftRef>, + ) { + let (a, b) = self.as_slice().split_at(self.split_idx(params)); + + ( + PolynomialFftIterator::new(a, params.dim.polynomial_degree.0 / 2), + PolynomialFftRef::from_slice(b), + ) + } + + /// Returns an interator over the a polynomials in a GLWE ciphertext. + pub fn a(&self, params: &GlweDef) -> PolynomialFftIterator> { + self.a_b(params).0 + } + + /// Returns a reference to the b polynomial in a GLWE ciphertext. + pub fn b(&self, params: &GlweDef) -> &PolynomialFftRef { + self.a_b(params).1 + } + + /// Returns an iterator over the `a` polynomials and the `b` polynomial. + pub fn a_b_mut( + &mut self, + params: &GlweDef, + ) -> ( + PolynomialFftIteratorMut>, + &mut PolynomialFftRef>, + ) { + let polynomial_degree = params.dim.polynomial_degree; + let split_idx = self.split_idx(params); + + let (a, b) = self.as_mut_slice().split_at_mut(split_idx); + + ( + PolynomialFftIteratorMut::new(a, polynomial_degree.0 / 2), + PolynomialFftRef::from_mut_slice(b), + ) + } + + /// Returns a mutable iterator over the a polynomials in a GLWE ciphertext. + pub fn a_mut(&mut self, params: &GlweDef) -> PolynomialFftIteratorMut> { + self.a_b_mut(params).0 + } + + /// Returns a mutable reference to the b polynomial in a GLWE ciphertext. + pub fn b_mut(&mut self, params: &GlweDef) -> &mut PolynomialFftRef> { + self.a_b_mut(params).1 + } + + #[inline(always)] + fn split_idx(&self, params: &GlweDef) -> usize { + params.dim.size.0 * params.dim.polynomial_degree.0 / 2 + } + + /// Computes the inverse FFT of the GLWE ciphertext and stores the + /// computation in `result`. + pub fn ifft(&self, result: &mut GlweCiphertextRef, params: &GlweDef) { + for (a, fft) in self.a(params).zip(result.a_mut(params)) { + a.ifft(fft); + } + + self.b(params).ifft(result.b_mut(params)); + } +} + +#[cfg(test)] +mod tests { + use crate::{entities::Polynomial, high_level::*, PlaintextBits, GLWE_1_1024_80}; + + #[test] + fn can_decrypt_glwe_after_fft_roundtrip() { + let params = GLWE_1_1024_80; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let pt = (0..params.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(); + let pt = Polynomial::new(&pt); + + let mut ct = encryption::encrypt_glwe(&pt, &sk, ¶ms, bits); + let fft = fft::fft_glwe(&ct, ¶ms); + + fft.ifft(&mut ct, ¶ms); + + let actual = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits); + + assert_eq!(actual, pt); + } +} diff --git a/sunscreen_tfhe/src/entities/glwe_keyswitch_key.rs b/sunscreen_tfhe/src/entities/glwe_keyswitch_key.rs new file mode 100644 index 000000000..a596c185b --- /dev/null +++ b/sunscreen_tfhe/src/entities/glwe_keyswitch_key.rs @@ -0,0 +1,75 @@ +use num::Zero; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::OverlaySize, GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +use super::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef}; + +// TODO: This GLWE keyswitch only works for switching to a new key with the same +// parameter. Copy what is above but changed for polynomials to enable +// converting to a different key parameter set. +dst! { + /// A GLWE keyswitch key used to switch a ciphertext from one key to another. + /// See [`module`](crate::ops::keyswitch) documentation for more details. + GlweKeyswitchKey, + GlweKeyswitchKeyRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps,) +} + +impl OverlaySize for GlweKeyswitchKeyRef +where + S: TorusOps, +{ + type Inputs = (GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GlevCiphertextRef::::size(t) * (t.0.size.0) + } +} + +impl GlweKeyswitchKey +where + S: TorusOps, +{ + /// Creates a new GLWE keyswitch key. This enables switching to a new key as + /// well as switching from the `original_params` that define the first key + /// to the `new_params` that define the second key. + pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self { + // TODO: Shouldn't this function take 2 GlweDefs? + // Ryan: to whoever wrote this, yes, see the above todo next to the dst. + let elems = GlweKeyswitchKeyRef::::size((params.dim, radix.count)); + + Self { + data: vec![Torus::zero(); elems], + } + } +} + +impl GlweKeyswitchKeyRef +where + S: TorusOps, +{ + /// Returns an iterator over the rows of the GLWE keyswitch key, which are + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. + pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GlevCiphertextIterator { + let stride = GlevCiphertextRef::::size((params.dim, radix.count)); + + GlevCiphertextIterator::new(&self.data, stride) + } + + /// Returns a mutable iterator over the rows of the GLWE keyswitch key, which are + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. + pub fn rows_mut( + &mut self, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextIteratorMut { + let stride = GlevCiphertextRef::::size((params.dim, radix.count)); + + GlevCiphertextIteratorMut::new(&mut self.data, stride) + } +} diff --git a/sunscreen_tfhe/src/entities/glwe_secret_key.rs b/sunscreen_tfhe/src/entities/glwe_secret_key.rs new file mode 100644 index 000000000..9fc48eb98 --- /dev/null +++ b/sunscreen_tfhe/src/entities/glwe_secret_key.rs @@ -0,0 +1,385 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{FromSlice, NoWrapper, OverlaySize}, + entities::GgswCiphertext, + macros::{impl_binary_op, impl_unary_op}, + ops::encryption::{ + decrypt_glwe_ciphertext, encrypt_ggsw_ciphertext, encrypt_glwe_ciphertext_secret, + }, + rand::{binary, uniform_torus}, + GlweDef, GlweDimension, PlaintextBits, RadixDecomposition, Torus, TorusOps, +}; + +use super::{ + GlweCiphertext, GlweCiphertextRef, LweSecretKeyRef, Polynomial, PolynomialIterator, + PolynomialIteratorMut, PolynomialRef, +}; + +dst! { + /// A GLWE secret key. This is a list of `s` polynomials of degree `n` where + /// `n` is the polynomial degree. The length of the list is `k` where `k` is + /// the size of the GLWE secret key. + GlweSecretKey, + GlweSecretKeyRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + (TorusOps,) +} + +// We can use these macros (which does operations on the underlying data vector) +// because addition, subtraction, and negation of polynomials is just element +// wise addition, subtraction, and negation of the underlying data vector. +impl_binary_op!(Add, GlweSecretKey, (TorusOps,)); +impl_binary_op!(Sub, GlweSecretKey, (TorusOps,)); +impl_unary_op!(Neg, GlweSecretKey); + +impl OverlaySize for GlweSecretKeyRef +where + S: TorusOps, +{ + type Inputs = GlweDimension; + + fn size(t: Self::Inputs) -> usize { + PolynomialRef::::size(t.polynomial_degree) * t.size.0 + } +} + +impl GlweSecretKey +where + S: TorusOps, +{ + fn generate(params: &GlweDef, torus_element_generator: impl Fn() -> S) -> GlweSecretKey { + params.dim.assert_valid(); + + let len = GlweSecretKeyRef::::size(params.dim); + + GlweSecretKey { + data: (0..len) + .map(|_| torus_element_generator()) + .collect::>(), + } + } + + /// Generate a random binary GLWE secret key. + pub fn generate_binary(params: &GlweDef) -> GlweSecretKey { + Self::generate(params, binary) + } + + /// Generate a secret key with uniformly random coefficients. This can be + /// used when performing threshold decryption, which needs random secret + /// keys that are uniform over the entire ciphertext modulus. Uniform + /// secret keys are also valid keys for encryption/decryption but are not + /// widely used. + pub fn generate_uniform(params: &GlweDef) -> GlweSecretKey { + Self::generate(params, || uniform_torus::().inner()) + } +} + +impl GlweSecretKeyRef +where + S: TorusOps, +{ + /// Returns an iterator over the `s` polynomials in a GLWE secret key. + pub fn s(&self, params: &GlweDef) -> PolynomialIterator { + PolynomialIterator::new(&self.data, params.dim.polynomial_degree.0) + } + + /// Decrypts and decodes a GLWE ciphertext into a polynomial. + pub fn decrypt_decode_glwe( + &self, + ct: &GlweCiphertextRef, + params: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> Polynomial + where + S: TorusOps, + { + let mut result = Polynomial::zero(ct.a_b(params).1.len()); + + decrypt_glwe_ciphertext(&mut result, ct, self, params); + + result.map(|x| x.decode(plaintext_bits)) + } + + /// Encodes and encrypts a message as a GLWE ciphertext using a secret key. + pub fn encode_encrypt_glwe( + &self, + plaintext: &PolynomialRef, + params: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> GlweCiphertext + where + S: TorusOps, + { + let plaintext = plaintext.map(|x| Torus::encode(*x, plaintext_bits)); + + let mut ct = GlweCiphertext::new(params); + + encrypt_glwe_ciphertext_secret(&mut ct, &plaintext, self, params); + + ct + } + + /// Encodes and encrypts a message as a GGSW ciphertext using a secret key. + pub fn encode_encrypt_ggsw( + &self, + msg: &PolynomialRef, + params: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, + ) -> GgswCiphertext + where + S: TorusOps, + { + let mut ggsw = GgswCiphertext::new(params, radix); + + encrypt_ggsw_ciphertext(&mut ggsw, msg, self, params, radix, plaintext_bits); + + ggsw + } + + /// Returns a representation of a GLWE secret key as an LWE secret key. + /// This is a LWE secret key with dimension `N * k` where `N` is the + /// polynomial degree and `k` is the size of the GLWE secret key. + pub fn to_lwe_secret_key(&self) -> &LweSecretKeyRef { + LweSecretKeyRef::from_slice(&self.data) + } + + #[inline(always)] + pub(crate) fn assert_valid(&self, params: &GlweDef) { + assert_eq!( + self.as_slice().len(), + GlweSecretKeyRef::::size(params.dim) + ); + } +} + +impl GlweSecretKeyRef +where + S: TorusOps, +{ + /// Returns an mutable iterator over the `s` polynomials in a GLWE secret + /// key. + pub fn s_mut(&mut self, params: &GlweDef) -> PolynomialIteratorMut { + PolynomialIteratorMut::new(&mut self.data, params.dim.polynomial_degree.0) + } +} + +#[cfg(test)] +mod tests { + use crate::{high_level::*, GLWE_1_1024_80}; + + use num::traits::{WrappingAdd, WrappingNeg, WrappingSub}; + + #[test] + fn secret_key_dimensions() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + assert_eq!(sk.s(¶ms).count(), params.dim.size.0); + + for s_i in sk.s(¶ms) { + assert_eq!(s_i.len(), params.dim.polynomial_degree.0); + + for s in s_i.coeffs() { + assert!(*s == 0 || *s == 1); + } + } + } + + // Addition + + #[test] + fn add_secret_keys() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk3_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + let sk3 = sk + sk2; + + assert_eq!(sk3_expected, sk3.data) + } + + #[test] + fn add_assign_secret_keys() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let mut sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk2_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + sk2 += sk; + + assert_eq!(sk2_expected, sk2.data) + } + + #[test] + fn add_secret_key_refs() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk3_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + let sk3 = sk.as_ref() + sk2.as_ref(); + + assert_eq!(sk3_expected, sk3.data) + } + + #[test] + fn wrapping_add_secret_keys() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk3_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + let sk3 = sk.wrapping_add(&sk2); + + assert_eq!(sk3_expected, sk3.data) + } + + // Subtraction + + #[test] + fn sub_secret_keys() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk3_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + let sk3 = sk - sk2; + + assert_eq!(sk3_expected, sk3.data) + } + + #[test] + fn sub_assign_secret_keys() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let mut sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk2_expected = sk2 + .data + .iter() + .zip(sk.data.iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + sk2 -= sk; + + assert_eq!(sk2_expected, sk2.data) + } + + #[test] + fn sub_secret_key_refs() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk3_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + let sk3 = sk.as_ref() - sk2.as_ref(); + + assert_eq!(sk3_expected, sk3.data) + } + + #[test] + fn wrapping_sub_secret_keys() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + let sk2 = keygen::generate_uniform_glwe_sk(¶ms); + + let sk3_expected = sk + .data + .iter() + .zip(sk2.data.iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + let sk3 = sk.wrapping_sub(&sk2); + + assert_eq!(sk3_expected, sk3.data) + } + + // Negation + + #[test] + fn neg_secret_key() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + + let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::>(); + let sk2 = -sk; + + assert_eq!(sk2_expected, sk2.data) + } + + #[test] + fn neg_secret_key_ref() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + + let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::>(); + let sk2 = -sk.as_ref(); + + assert_eq!(sk2_expected, sk2.data) + } + + #[test] + fn wrapping_neg_secret_key() { + let params = GLWE_1_1024_80; + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::>(); + let sk2 = sk.wrapping_neg(); + + assert_eq!(sk2_expected, sk2.data) + } +} diff --git a/sunscreen_tfhe/src/entities/lev_ciphertext.rs b/sunscreen_tfhe/src/entities/lev_ciphertext.rs new file mode 100644 index 000000000..b1286b276 --- /dev/null +++ b/sunscreen_tfhe/src/entities/lev_ciphertext.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +use crate::{dst::OverlaySize, LweDef, LweDimension, RadixCount, Torus, TorusOps}; + +use super::{LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef}; + +// Iteration over LWE ciphertexts +dst! { + /// A Lev Ciphertext is a collection of LWE ciphertexts. + LevCiphertext, + LevCiphertextRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps,) +} +dst_iter! { LevCiphertextIterator, LevCiphertextIteratorMut, Torus, LevCiphertextRef, (TorusOps,)} + +impl OverlaySize for LevCiphertextRef +where + S: TorusOps, +{ + type Inputs = (LweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + LweCiphertextRef::::size(t.0) * t.1 .0 + } +} + +impl LevCiphertextRef +where + S: TorusOps, +{ + /// Returns an iterator over the rows of the Lev ciphertext, which are + /// [`LweCiphertext`](crate::entities::LweCiphertext)s. + pub fn lwe_ciphertexts(&self, params: &LweDef) -> LweCiphertextIterator { + LweCiphertextIterator::new(&self.data, LweCiphertextRef::::size(params.dim)) + } + + /// Returns a mutable iterator over the rows of the Lev ciphertext, which are + /// [`LweCiphertext`](crate::entities::LweCiphertext)s. + pub fn lwe_ciphertexts_mut(&mut self, params: &LweDef) -> LweCiphertextIteratorMut { + LweCiphertextIteratorMut::new(&mut self.data, LweCiphertextRef::::size(params.dim)) + } +} diff --git a/sunscreen_tfhe/src/entities/lwe_ciphertext.rs b/sunscreen_tfhe/src/entities/lwe_ciphertext.rs new file mode 100644 index 000000000..bdfaf5543 --- /dev/null +++ b/sunscreen_tfhe/src/entities/lwe_ciphertext.rs @@ -0,0 +1,192 @@ +use num::Zero; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::OverlaySize, + macros::{impl_binary_op, impl_unary_op}, + LweDef, LweDimension, Torus, TorusOps, +}; + +dst! { + /// An LWE ciphertext. + LweCiphertext, + LweCiphertextRef, + Torus, + (Clone, Debug,Serialize, Deserialize), + (TorusOps) +} +dst_iter! { LweCiphertextIterator, LweCiphertextIteratorMut, Torus, LweCiphertextRef, (TorusOps,) } + +impl_binary_op!(Add, LweCiphertext, (TorusOps,)); +impl_binary_op!(Sub, LweCiphertext, (TorusOps,)); +impl_unary_op!(Neg, LweCiphertext); + +impl OverlaySize for LweCiphertextRef +where + S: TorusOps, +{ + type Inputs = LweDimension; + + fn size(t: Self::Inputs) -> usize { + t.0 + 1 + } +} + +impl LweCiphertext { + /// Create a new LWE ciphertext with all coefficients set to zero. + pub fn new(params: &LweDef) -> Self { + Self::zero(params) + } + + /// Create a new LWE ciphertext with all coefficients set to zero. + pub fn zero(params: &LweDef) -> Self { + let data = vec![Torus::zero(); LweCiphertextRef::::size(params.dim)]; + + Self { data } + } +} + +impl LweCiphertextRef { + fn split_at_idx(&self, params: &LweDef) -> usize { + params.dim.0 + } + + /// Returns a reference to the mask A and body B in an LWE ciphertext. + pub fn a_b(&self, params: &LweDef) -> (&[Torus], &Torus) { + let (a, b) = self.data.split_at(self.split_at_idx(params)); + + (a, &b[0]) + } + + /// Returns a reference to the mask A in an LWE ciphertext. + pub fn a(&self, params: &LweDef) -> &[Torus] { + let (a, _) = self.a_b(params); + + a + } + + /// Returns a reference to the body B in an LWE ciphertext. + pub fn b(&self, params: &LweDef) -> &Torus { + let (_, b) = self.a_b(params); + + b + } + + /// Returns a mutable reference to the mask A and body B in an LWE ciphertext. + pub fn a_b_mut(&mut self, params: &LweDef) -> (&mut [Torus], &mut Torus) { + let (a, b) = self.data.split_at_mut(self.split_at_idx(params)); + + (a, &mut b[0]) + } + + /// Returns a mutable reference to the mask A in an LWE ciphertext. + pub fn a_mut(&mut self, params: &LweDef) -> &mut [Torus] { + let (a, _) = self.a_b_mut(params); + + a + } + + /// Returns a mutable reference to the body B in an LWE ciphertext. + pub fn b_mut(&mut self, params: &LweDef) -> &mut Torus { + let (_, b) = self.a_b_mut(params); + + b + } + + /// Asserts that the LWE ciphertext is valid for a given LWE dimension. + #[inline(always)] + pub(crate) fn assert_valid(&self, params: &LweDef) { + assert_eq!( + self.as_slice().len(), + LweCiphertextRef::::size(params.dim) + ); + } +} + +#[cfg(test)] +mod tests { + use crate::{high_level::*, PlaintextBits, LWE_512_80}; + use proptest::prelude::*; + + // Test that the negation of a ciphertext is the same as the negation of the + // plaintext. + proptest! { + #[test] + fn negation_homomorphism(a in any::()) { + let bits = PlaintextBits(4); + + let params = LWE_512_80; + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits); + + let a_enc_neg = -a_enc; + + prop_assert_eq!(encryption::decrypt_lwe(&a_enc_neg, &sk, ¶ms, bits), a.wrapping_neg() % (0x1 << bits.0 as u64)); + } + } + + // Test that the addition of ciphertexts is the same as the addition of the + // plaintexts. + proptest! { + #[test] + fn additive_homomorphism(a in any::(), b in any::()) { + let params = LWE_512_80; + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let bits = PlaintextBits(4); + + let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits); + let b_enc = encryption::encrypt_lwe_secret(b, &sk, ¶ms, bits); + + let c_enc = a_enc + b_enc; + + prop_assert_eq!(encryption::decrypt_lwe(&c_enc, &sk, ¶ms, bits), a.wrapping_add(b) % (0x1 << bits.0 as u64)); + } + } + + // Test that the subtraction of ciphertexts is the same as the subtraction + // of the plaintexts. + proptest! { + #[test] + fn subtraction_homomorphism(a in any::(), b in any::()) { + let params = LWE_512_80; + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let bits = PlaintextBits(4); + + let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits); + let b_enc = encryption::encrypt_lwe_secret(b, &sk, ¶ms, bits); + + let c_enc = a_enc - b_enc; + + prop_assert_eq!(encryption::decrypt_lwe(&c_enc, &sk, ¶ms, bits), a.wrapping_sub(b) % (0x1 << bits.0 as u64)); + } + } + + // Testing that the addition of a ciphertext and a negated ciphertext is the + // same as the subtraction of the ciphertexts. + proptest! { + #[test] + fn add_negative_is_subtraction(a in any::(), b in any::()) { + let params = LWE_512_80; + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let bits = PlaintextBits(4); + + let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits); + let b_enc = encryption::encrypt_lwe_secret(b, &sk, ¶ms, bits); + + let c_enc_by_add_neg = a_enc.as_ref() + (-(b_enc.as_ref())).as_ref(); + let c_enc_by_sub = a_enc.as_ref() - b_enc.as_ref(); + + // Test that the a values are the same + for (a_enc_by_add_neg_i, a_enc_by_sub_i) in c_enc_by_add_neg.a(¶ms).iter().zip(c_enc_by_sub.a(¶ms).iter()) { + assert_eq!(a_enc_by_add_neg_i, a_enc_by_sub_i); + } + + // Test that the b values are the same + assert_eq!(c_enc_by_add_neg.b(¶ms), c_enc_by_sub.b(¶ms)); + } + } +} diff --git a/sunscreen_tfhe/src/entities/lwe_ciphertext_list.rs b/sunscreen_tfhe/src/entities/lwe_ciphertext_list.rs new file mode 100644 index 000000000..b7e71aa1b --- /dev/null +++ b/sunscreen_tfhe/src/entities/lwe_ciphertext_list.rs @@ -0,0 +1,49 @@ +use serde::{Deserialize, Serialize}; +use sunscreen_math::Zero; + +use crate::{dst::OverlaySize, LweDef, LweDimension, Torus, TorusOps}; + +use super::{LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef}; + +dst! { + /// A list of LWE ciphertexts. Used during + /// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap). + LweCiphertextList, + LweCiphertextListRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for LweCiphertextListRef { + type Inputs = (LweDimension, usize); + + #[inline(always)] + fn size(t: Self::Inputs) -> usize { + LweCiphertextRef::::size(t.0) * t.1 + } +} + +impl LweCiphertextList { + /// Create a new zero [LweCiphertextList] with the given parameters. + /// + /// This data structure represents is a list of LWE ciphertexts, used for + /// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap). + pub fn new(lwe: &LweDef, count: usize) -> Self { + Self { + data: vec![Torus::zero(); LweCiphertextListRef::::size((lwe.dim, count))], + } + } +} + +impl LweCiphertextListRef { + /// Iterate over the LWE ciphertexts in the list. + pub fn ciphertexts(&self, lwe: &LweDef) -> LweCiphertextIterator { + LweCiphertextIterator::new(self.as_slice(), LweCiphertextRef::::size(lwe.dim)) + } + + /// Iterate over the LWE ciphertexts in the list mutably. + pub fn ciphertexts_mut(&mut self, lwe: &LweDef) -> LweCiphertextIteratorMut { + LweCiphertextIteratorMut::new(self.as_mut_slice(), LweCiphertextRef::::size(lwe.dim)) + } +} diff --git a/sunscreen_tfhe/src/entities/lwe_keyswitch_key.rs b/sunscreen_tfhe/src/entities/lwe_keyswitch_key.rs new file mode 100644 index 000000000..df22ac3eb --- /dev/null +++ b/sunscreen_tfhe/src/entities/lwe_keyswitch_key.rs @@ -0,0 +1,170 @@ +use num::Zero; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::OverlaySize, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +use super::{LevCiphertextIterator, LevCiphertextIteratorMut, LevCiphertextRef}; + +dst! { + /// A LWE keyswitch key used to switch a ciphertext from one key to another. + /// See [`module`](crate::ops::keyswitch) documentation for more details. + LweKeyswitchKey, + LweKeyswitchKeyRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps,) +} + +impl OverlaySize for LweKeyswitchKeyRef +where + S: TorusOps, +{ + // Old LWE dimension, new LWE dimension, radix count + type Inputs = (LweDimension, LweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + // Number of rows should be equal to the number of elements in the original key + let num_rows = t.0 .0; + + // Each row is made up of encryptions under the new key + let len_row = LevCiphertextRef::::size((t.1, t.2)); + + // Encrypt the secret key s_i in each row + len_row * (num_rows) + } +} + +impl LweKeyswitchKey +where + S: TorusOps, +{ + /// Creates a new LWE keyswitch key. This enables switching to a new key as + /// well as switching from the `original_params` that define the first key + /// to the `new_params` that define the second key. + pub fn new(original_params: &LweDef, new_params: &LweDef, radix: &RadixDecomposition) -> Self { + let elems = + LweKeyswitchKeyRef::::size((original_params.dim, new_params.dim, radix.count)); + + Self { + data: vec![Torus::zero(); elems], + } + } +} + +impl LweKeyswitchKeyRef +where + S: TorusOps, +{ + /// Returns an iterator over the rows of the LWE keyswitch key, which are + /// [`LevCiphertext`](crate::entities::LevCiphertext)s. + pub fn rows( + &self, + new_params: &LweDef, + radix: &RadixDecomposition, + ) -> LevCiphertextIterator { + let stride = LevCiphertextRef::::size((new_params.dim, radix.count)); + + LevCiphertextIterator::new(&self.data, stride) + } + + /// Returns a mutable iterator over the rows of the LWE keyswitch key, which are + /// [`LevCiphertext`](crate::entities::LevCiphertext)s. + pub fn rows_mut( + &mut self, + new_params: &LweDef, + radix: &RadixDecomposition, + ) -> LevCiphertextIteratorMut { + let stride = LevCiphertextRef::::size((new_params.dim, radix.count)); + + LevCiphertextIteratorMut::new(&mut self.data, stride) + } + + /// Asserts that the keyswitch key is valid for the given parameters. + #[inline(always)] + pub(crate) fn assert_valid( + &self, + original_params: &LweDef, + new_params: &LweDef, + radix: &RadixDecomposition, + ) { + assert_eq!( + self.as_slice().len(), + LweKeyswitchKeyRef::::size((original_params.dim, new_params.dim, radix.count)) + ); + } +} + +#[cfg(test)] +mod tests { + + use rand::{thread_rng, RngCore}; + + use crate::{ + entities::{LweCiphertext, LweKeyswitchKey}, + high_level::*, + high_level::{TEST_LWE_DEF_1, TEST_LWE_DEF_2, TEST_RADIX}, + ops::keyswitch::{ + lwe_keyswitch::keyswitch_lwe_to_lwe, lwe_keyswitch_key::generate_keyswitch_key_lwe, + }, + PlaintextBits, + }; + + #[test] + fn keyswitch_lwe() { + let bits = PlaintextBits(4); + let from_lwe = TEST_LWE_DEF_1; + let to_lwe = TEST_LWE_DEF_2; + + for _ in 0..50 { + let original_sk = keygen::generate_binary_lwe_sk(&from_lwe); + let new_sk = keygen::generate_binary_lwe_sk(&to_lwe); + + let mut ksk = LweKeyswitchKey::::new(&from_lwe, &to_lwe, &TEST_RADIX); + generate_keyswitch_key_lwe(&mut ksk, &original_sk, &new_sk, &to_lwe, &TEST_RADIX); + + let msg = thread_rng().next_u64() % (1 << bits.0); + + let original_ct = original_sk.encrypt(msg, &from_lwe, bits).0; + + let mut new_ct = LweCiphertext::new(&to_lwe); + keyswitch_lwe_to_lwe( + &mut new_ct, + &original_ct, + &ksk, + &from_lwe, + &to_lwe, + &TEST_RADIX, + ); + + let new_decrypted = new_sk.decrypt(&new_ct, &to_lwe, bits); + + assert_eq!(new_decrypted, msg); + } + } + + #[test] + fn lwe_keyswitch_keygen() { + let from_lwe = TEST_LWE_DEF_1; + let to_lwe = TEST_LWE_DEF_2; + + for _ in 0..10 { + let sk_1 = keygen::generate_binary_lwe_sk(&from_lwe); + let sk_2 = keygen::generate_binary_lwe_sk(&to_lwe); + + let mut ksk = LweKeyswitchKey::::new(&from_lwe, &to_lwe, &TEST_RADIX); + generate_keyswitch_key_lwe(&mut ksk, &sk_1, &sk_2, &to_lwe, &TEST_RADIX); + + for (i, r) in ksk.rows(&to_lwe, &TEST_RADIX).enumerate() { + for (j, l) in r.lwe_ciphertexts(&to_lwe).enumerate() { + let decomp = (j + 1) * TEST_RADIX.radix_log.0; + + let res = sk_2.decrypt(l, &to_lwe, PlaintextBits(decomp as u32)); + + assert_eq!(res, sk_1.s()[i]); + } + } + } + } +} diff --git a/sunscreen_tfhe/src/entities/lwe_public_key.rs b/sunscreen_tfhe/src/entities/lwe_public_key.rs new file mode 100644 index 000000000..c53856b6e --- /dev/null +++ b/sunscreen_tfhe/src/entities/lwe_public_key.rs @@ -0,0 +1,158 @@ +use num::Zero; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::OverlaySize, + ops::encryption::encode_and_encrypt_lwe_ciphertext, + rand::{binary, normal_torus}, + LweDef, LweDimension, PlaintextBits, Torus, TorusOps, +}; + +use super::{ + LweCiphertext, LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef, + LweSecretKeyRef, +}; + +/// Randomness used to encrypt a message with a public key. +#[derive(Debug)] +pub struct TlwePublicEncRandomness { + /// The binary selectors of the encryptions of zero in the public key. + pub r: Vec, + + /// The gaussian noise added to make the LWE problem. + pub e: LweCiphertext, +} + +dst! { + /// An LWE public key. + LwePublicKey, + LwePublicKeyRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for LwePublicKeyRef +where + S: TorusOps, +{ + type Inputs = LweDimension; + + fn size(t: Self::Inputs) -> usize { + LweCiphertextRef::::size(t) * t.0 + } +} + +impl LwePublicKey +where + S: TorusOps, +{ + /// Generate an LWE public key from a given secret key. This is done by + /// encrypting the LWE dimension number of zeros under the secret key, and + /// then using the resulting ciphertext as the public key. + pub fn generate(sk: &LweSecretKeyRef, params: &LweDef) -> Self { + let mut pk = LwePublicKey { + data: vec![Torus::zero(); LwePublicKeyRef::::size(params.dim)], + }; + let enc_zeros = pk.enc_zeros_mut(params); + + for z in enc_zeros { + encode_and_encrypt_lwe_ciphertext(z, sk, ::zero(), params, PlaintextBits(1)); + } + + pk + } +} + +impl LwePublicKeyRef +where + S: TorusOps, +{ + /// Get the public key data as an iterator. + pub fn enc_zeros(&self, params: &LweDef) -> LweCiphertextIterator { + LweCiphertextIterator::new(&self.data, LweCiphertextRef::::size(params.dim)) + } + + /// Get the public key data as a mutable iterator. + pub fn enc_zeros_mut(&mut self, params: &LweDef) -> LweCiphertextIteratorMut { + LweCiphertextIteratorMut::new(&mut self.data, LweCiphertextRef::::size(params.dim)) + } + + /// Encrypt a message as an LWE ciphertext using a public key, returning the + /// encrypted message and the randomness used. + pub fn encrypt( + &self, + msg: S, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> (LweCiphertext, TlwePublicEncRandomness) { + let msg = Torus::::encode(msg, plaintext_bits); + let lwe_dimension = params.dim.0; + + let mut acc = LweCiphertext::zero(params); + let (acc_a, acc_b) = acc.a_b_mut(params); + + let mut r_noise = vec![]; + let mut e = LweCiphertext::zero(params); + let (e_a, e_b) = e.a_b_mut(params); + + for z in self.enc_zeros(params) { + let (a, b) = z.a_b(params); + let r = binary::(); + r_noise.push(r); + + for i in 0..lwe_dimension { + acc_a[i] += a[i] * r; + } + + *acc_b += *b * r; + } + + for i in 0..lwe_dimension { + let a_noise = normal_torus(params.std); + e_a[i] = a_noise; + acc_a[i] += a_noise; + } + + *acc_b += msg; + *e_b = normal_torus(params.std); + *acc_b += *e_b; + + let noise = TlwePublicEncRandomness { r: r_noise, e }; + + (acc, noise) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + high_level::{encryption, keygen, TEST_LWE_DEF_1}, + PlaintextBits, + }; + + #[test] + fn public_key_is_zeros() { + let params = TEST_LWE_DEF_1; + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let pk = keygen::generate_lwe_pk(&sk, ¶ms); + + for ct in pk.enc_zeros(¶ms) { + let pt = encryption::decrypt_lwe(ct, &sk, ¶ms, PlaintextBits(1)); + assert_eq!(pt, 0); + } + } + + #[test] + fn can_public_key_encrypt() { + let params = TEST_LWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let pk = keygen::generate_lwe_pk(&sk, ¶ms); + + let ct = encryption::encrypt_lwe(5, &pk, ¶ms, bits); + assert_eq!(encryption::decrypt_lwe(&ct, &sk, ¶ms, bits), 5); + } +} diff --git a/sunscreen_tfhe/src/entities/lwe_secret_key.rs b/sunscreen_tfhe/src/entities/lwe_secret_key.rs new file mode 100644 index 000000000..4a2b782f0 --- /dev/null +++ b/sunscreen_tfhe/src/entities/lwe_secret_key.rs @@ -0,0 +1,337 @@ +use num::Zero; +use serde::{Deserialize, Serialize}; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + macros::{impl_binary_op, impl_unary_op}, + ops::encryption::encode_and_encrypt_lwe_ciphertext, + rand::{binary, uniform_torus}, + LweDef, LweDimension, PlaintextBits, Torus, TorusOps, +}; + +use super::{LweCiphertext, LweCiphertextRef}; + +dst! { + /// An LWE secret key. + LweSecretKey, + LweSecretKeyRef, + NoWrapper, + (Clone, Debug, Serialize, Deserialize), + () +} + +impl_binary_op!(Add, LweSecretKey, (TorusOps,)); +impl_binary_op!(Sub, LweSecretKey, (TorusOps,)); +impl_unary_op!(Neg, LweSecretKey); + +impl OverlaySize for LweSecretKeyRef +where + S: TorusOps, +{ + type Inputs = LweDimension; + + fn size(t: Self::Inputs) -> usize { + t.0 + } +} + +impl LweSecretKey +where + S: TorusOps, +{ + fn generate(params: &LweDef, torus_element_generator: fn() -> S) -> Self { + let len = LweSecretKeyRef::::size(params.dim); + + LweSecretKey { + data: (0..len) + .map(|_| torus_element_generator()) + .collect::>(), + } + } + + /// Generate a random binary LWE secret key + pub fn generate_binary(params: &LweDef) -> Self { + Self::generate(params, binary) + } + + /// Generate a secret key with uniformly random coefficients. This can be + /// used when performing threshold decryption, which needs random secret + /// keys that are uniform over the entire ciphertext modulus. Uniform + /// secret keys are also valid keys for encryption/decryption but are not + /// widely used. + pub fn generate_uniform(params: &LweDef) -> Self { + Self::generate(params, || uniform_torus::().inner()) + } +} + +impl LweSecretKeyRef +where + S: TorusOps, +{ + /// Create an LWE ciphertext from a given message with a private key. The + /// message should be in the plaintext space, and will be encoded onto the + /// Torus automatically. + pub fn encrypt( + &self, + msg: S, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> (LweCiphertext, Torus) { + let mut ct = LweCiphertext::::zero(params); + + let e = encode_and_encrypt_lwe_ciphertext(&mut ct, self, msg, params, plaintext_bits); + + (ct, e) + } + + /// Decrypts the given ciphertext, returning the message. The message will + /// not be decoded into the plaintext space; the caller is responsible for + /// performing operations like shifting by delta and rounding. See + /// [Self::decrypt] for a function that performs the decoding automatically. + pub fn decrypt_without_decode(&self, ct: &LweCiphertextRef, params: &LweDef) -> Torus { + ct.assert_valid(params); + + let (a, b) = ct.a_b(params); + + let mut dot = Torus::::zero(); + + for (a_i, d_i) in a.iter().zip(self.data.iter()) { + dot += a_i * d_i + } + + b - dot + } + + /// Decrypts and decodes a ciphertext, returning the message. The message + /// will be decoded into the plaintext space. See + /// [Self::decrypt_without_decode] for a function that does not perform the + /// decoding. + pub fn decrypt( + &self, + ct: &LweCiphertextRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> S { + let msg = self.decrypt_without_decode(ct, params); + + msg.decode(plaintext_bits) + } + + /// Asserts that a given secret key is valid for a given LWE dimension. + pub fn assert_valid(&self, params: &LweDef) { + assert_eq!( + self.as_slice().len(), + LweSecretKeyRef::::size(params.dim) + ); + } +} + +impl LweSecretKeyRef +where + S: TorusOps, +{ + /// Returns the secret key data as a slice. + pub fn s(&self) -> &[S] { + &self.data + } +} + +#[cfg(test)] +mod tests { + use crate::high_level::*; + use num::traits::{WrappingAdd, WrappingNeg, WrappingSub}; + + // Addition + + #[test] + fn add_secret_keys() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk3_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + let sk3 = sk + sk2; + + assert_eq!(sk3_expected, sk3.s()) + } + + #[test] + fn add_assign_secret_keys() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let mut sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk2_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + sk2 += sk; + + assert_eq!(sk2_expected, sk2.s()) + } + + #[test] + fn add_secret_key_refs() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk3_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + let sk3 = sk.as_ref() + sk2.as_ref(); + + assert_eq!(sk3_expected, sk3.s()) + } + + #[test] + fn wrapping_add_secret_keys() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk3_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_add(b)) + .collect::>(); + + let sk3 = sk.wrapping_add(&sk2); + + assert_eq!(sk3_expected, sk3.s()) + } + + // Subtraction + + #[test] + fn sub_secret_keys() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk3_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + let sk3 = sk - sk2; + + assert_eq!(sk3_expected, sk3.s()) + } + + #[test] + fn sub_assign_secret_keys() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let mut sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk2_expected = sk2 + .s() + .iter() + .zip(sk.s().iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + sk2 -= sk; + + assert_eq!(sk2_expected, sk2.s()) + } + + #[test] + fn sub_secret_key_refs() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk3_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + let sk3 = sk.as_ref() - sk2.as_ref(); + + assert_eq!(sk3_expected, sk3.s()) + } + + #[test] + fn wrapping_sub_secret_keys() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + let sk2 = keygen::generate_uniform_lwe_sk(params); + + let sk3_expected = sk + .s() + .iter() + .zip(sk2.s().iter()) + .map(|(a, b)| a.wrapping_sub(b)) + .collect::>(); + + let sk3 = sk.wrapping_sub(&sk2); + + assert_eq!(sk3_expected, sk3.s()) + } + + // Negation + + #[test] + fn neg_secret_key() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + + let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::>(); + let sk2 = -sk; + + assert_eq!(sk2_expected, sk2.s()) + } + + #[test] + fn neg_secret_key_ref() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + + let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::>(); + let sk2 = -sk.as_ref(); + + assert_eq!(sk2_expected, sk2.s()) + } + + #[test] + fn wrapping_neg_secret_key() { + let params = &TEST_LWE_DEF_1; + + let sk = keygen::generate_uniform_lwe_sk(params); + + let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::>(); + let sk2 = sk.wrapping_neg(); + + assert_eq!(sk2_expected, sk2.s()) + } +} diff --git a/sunscreen_tfhe/src/entities/mod.rs b/sunscreen_tfhe/src/entities/mod.rs new file mode 100644 index 000000000..7e89c7e83 --- /dev/null +++ b/sunscreen_tfhe/src/entities/mod.rs @@ -0,0 +1,71 @@ +mod public_functional_keyswitch_key; +pub use public_functional_keyswitch_key::*; + +mod blind_rotation_shift; +pub use blind_rotation_shift::*; + +mod lwe_ciphertext_list; +pub use lwe_ciphertext_list::*; + +mod private_functional_keyswitch_key; +pub use private_functional_keyswitch_key::*; + +mod circuit_bootstrapping_private_keyswitch_keys; +pub use circuit_bootstrapping_private_keyswitch_keys::*; + +mod bootstrap_key; +pub use bootstrap_key::*; + +mod univariate_lookup_table; +pub use univariate_lookup_table::*; + +mod bivariate_lookup_table; +pub use bivariate_lookup_table::*; + +mod glwe_secret_key; +pub use glwe_secret_key::*; + +mod glwe_ciphertext; +pub use glwe_ciphertext::*; + +mod glwe_ciphertext_fft; +pub use glwe_ciphertext_fft::*; + +mod lwe_ciphertext; +pub use lwe_ciphertext::*; + +mod lwe_secret_key; +pub use lwe_secret_key::*; + +mod lwe_public_key; +pub use lwe_public_key::*; + +mod glev_ciphertext; +pub use glev_ciphertext::*; + +mod glev_ciphertext_fft; +pub use glev_ciphertext_fft::*; + +mod ggsw_ciphertext; +pub use ggsw_ciphertext::*; + +mod ggsw_ciphertext_fft; +pub use ggsw_ciphertext_fft::*; + +mod lev_ciphertext; +pub use lev_ciphertext::*; + +mod lwe_keyswitch_key; +pub use lwe_keyswitch_key::*; + +mod glwe_keyswitch_key; +pub use glwe_keyswitch_key::*; + +mod polynomial; +pub use polynomial::*; + +mod polynomial_fft; +pub use polynomial_fft::*; + +mod polynomial_list; +pub use polynomial_list::*; diff --git a/sunscreen_tfhe/src/entities/polynomial.rs b/sunscreen_tfhe/src/entities/polynomial.rs new file mode 100644 index 000000000..67e378eba --- /dev/null +++ b/sunscreen_tfhe/src/entities/polynomial.rs @@ -0,0 +1,588 @@ +use std::{ + num::Wrapping, + ops::{Add, AddAssign, Mul, Sub, SubAssign}, +}; + +use num::{Complex, Zero}; + +use crate::{ + dst::{FromMutSlice, FromSlice, NoWrapper, OverlaySize}, + fft::negacyclic::get_fft, + polynomial::{polynomial_add_assign, polynomial_external_mad, polynomial_sub_assign}, + scratch::allocate_scratch, + FrequencyTransform, PolynomialDegree, ReinterpretAsSigned, ToF64, Torus, TorusOps, +}; + +use super::PolynomialFftRef; + +dst! { + /// A type representing a polynomial. + Polynomial, + PolynomialRef, + NoWrapper, + (Debug, Clone, PartialEq, Eq,), + () +} +dst_iter! { PolynomialIterator, PolynomialIteratorMut, NoWrapper, PolynomialRef, () } + +impl OverlaySize for PolynomialRef +where + T: Clone, +{ + type Inputs = PolynomialDegree; + + fn size(t: Self::Inputs) -> usize { + t.0 + } +} + +impl Polynomial +where + T: Clone, +{ + /// Create a new polynomial from a slice of coefficients. + pub fn new(data: &[T]) -> Polynomial { + Polynomial { + data: data.to_owned(), + } + } + + /// Create a new polynomial filled with zeros of a specified length. + pub fn zero(len: usize) -> Polynomial + where + T: Zero, + { + Polynomial { + data: vec![T::zero(); len], + } + } +} + +impl FromIterator for Polynomial +where + T: Clone, +{ + fn from_iter>(iter: I) -> Self { + Self { + data: iter.into_iter().collect::>(), + } + } +} + +impl PolynomialRef +where + T: Clone, +{ + /// Returns the coefficients of the polynomial in ascending order. + pub fn coeffs(&self) -> &[T] { + &self.data + } + + /// Returns the mutable coefficients of the polynomial in ascending order. + pub fn coeffs_mut(&mut self) -> &mut [T] { + &mut self.data + } + + /// Apply a function to each coefficient of the polynomial and return a new + /// polynomial. + pub fn map(&self, f: F) -> Polynomial + where + F: Fn(&T) -> U, + U: Clone, + { + Polynomial { + data: self.data.iter().map(f).collect::>(), + } + } + + /// Maps this polynomial using f into the dst [`PolynomialRef`]. + /// + /// # Panics + /// If `dst.len() != self.len()` + pub fn map_into(&self, dst: &mut PolynomialRef, f: F) + where + F: Fn(&T) -> U, + U: Clone, + { + assert_eq!(dst.len(), self.len()); + + dst.coeffs_mut() + .iter_mut() + .zip(self.coeffs().iter()) + .for_each(|(d, s)| *d = f(s)); + } + + /// Returns the number of coefficients in the polynomial. + #[inline] + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns true if the polynomial has no coefficients. + #[inline] + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } +} + +impl PolynomialRef +where + T: TorusOps, +{ + /// Reinterpret the this polynomial as a polynomial of torus elements. + pub fn as_torus(&self) -> &PolynomialRef> { + let as_torus = bytemuck::cast_slice(&self.data); + + PolynomialRef::from_slice(as_torus) + } + + /// Reinterpret this polynomial of integers as having wrapping semantics. + pub fn as_wrapping(&self) -> &PolynomialRef> { + let as_wrapping = bytemuck::cast_slice(&self.data); + + PolynomialRef::from_slice(as_wrapping) + } + + /// Reinterpret this polynomial of integers as having wrapping semantics. + pub fn as_wrapping_mut(&mut self) -> &mut PolynomialRef> { + let as_wrapping = bytemuck::cast_slice_mut(&mut self.data); + + PolynomialRef::from_mut_slice(as_wrapping) + } +} + +impl PolynomialRef> +where + T: TorusOps, +{ + /** + * Multiply by a monomial X^-degree, returning a new Polynomial. + */ + pub fn mul_by_negative_monomial_negacyclic(&mut self, degree: usize) { + let len = self.len(); + + // The behavior of the rotation is the same for every degree + k*2N for + // k >= 0. + let degree = degree % (2 * len); + + // If the degree is 0 (or a multiple of 2N), the polynomial is unchanged. + if degree == 0 { + return; + } + + // If the degree is N (or of the form N + k*2N), the polynomial is negated. + if degree == len { + self.data + .iter_mut() + .for_each(|x| *x = num::traits::WrappingNeg::wrapping_neg(x)); + return; + } + + let shift = degree % len; + self.data.rotate_left(shift); + + let negate_segment = if degree < len { + (len - shift)..len + } else { + 0..(len - shift) + }; + + for i in negate_segment { + self.data[i] = num::traits::WrappingNeg::wrapping_neg(&self.data[i]); + } + } + + /** + * Multiply by a monomial X^degree, returning a new Polynomial. + */ + pub fn mul_by_positive_monomial_negacyclic(&mut self, degree: usize) { + let len = self.len(); + + // The behavior of the rotation is the same for every degree + k*2N for + // k >= 0. + let degree = degree % (2 * len); + + // If the degree is 0 (or a multiple of 2N), the polynomial is unchanged. + if degree == 0 { + return; + } + + // If the degree is N (or of the form N + k*2N), the polynomial is negated. + if degree == len { + self.data + .iter_mut() + .for_each(|x| *x = num::traits::WrappingNeg::wrapping_neg(x)); + return; + } + + let shift = degree % len; + self.data.rotate_right(shift); + + let negate_segment = if degree < len { 0..degree } else { shift..len }; + + for i in negate_segment { + self.data[i] = num::traits::WrappingNeg::wrapping_neg(&self.data[i]); + } + } + + /** + * Multiply by a monomial X^degree, returning a new Polynomial. The degree + * can be either positive or negative. + */ + pub fn mul_by_monomial_negacyclic(&mut self, degree: isize) { + if degree < 0 { + self.mul_by_negative_monomial_negacyclic(degree.unsigned_abs()); + } else { + self.mul_by_positive_monomial_negacyclic(degree as usize); + } + } +} + +impl PolynomialRef +where + U: ToF64, + T: Clone + Copy + ReinterpretAsSigned, +{ + /// Compute the FFT of the polynomial. + pub fn fft(&self, out: &mut PolynomialFftRef>) { + assert!(self.len().is_power_of_two()); + assert_eq!(self.len(), out.len() * 2); + + let mut self_f64 = allocate_scratch::(self.len()); + let self_f64 = self_f64.as_mut_slice(); + + for (o, i) in self_f64.iter_mut().zip(self.coeffs().iter()) { + // Reinterperet [0, 1) to [-q/2, q/2) to slightly increase + // precision. + *o = (*i).reinterpret_as_signed().to_f64(); + } + + let log_n = self.len().ilog2() as usize; + + let fft = get_fft(log_n); + fft.forward(self_f64, out.as_mut_slice()); + } +} + +impl Add> for Polynomial +where + S: Add + Copy, +{ + type Output = Polynomial; + + fn add(self, rhs: Polynomial) -> Self::Output { + self.as_ref().add(rhs.as_ref()) + } +} + +impl Add<&PolynomialRef> for &PolynomialRef +where + S: Add + Copy, +{ + type Output = Polynomial; + + fn add(self, rhs: &PolynomialRef) -> Self::Output { + assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len()); + + let coeffs = self + .coeffs() + .as_ref() + .iter() + .zip(rhs.coeffs().as_ref().iter()) + .map(|(a, b)| *a + *b) + .collect::>(); + + Polynomial { data: coeffs } + } +} + +impl AddAssign<&PolynomialRef> for PolynomialRef +where + S: AddAssign + Copy, +{ + fn add_assign(&mut self, rhs: &PolynomialRef) { + polynomial_add_assign(self, rhs) + } +} + +impl Sub> for Polynomial +where + S: Sub + Copy, +{ + type Output = Polynomial; + + fn sub(self, rhs: Polynomial) -> Self::Output { + self.as_ref().sub(rhs.as_ref()) + } +} + +impl Sub<&PolynomialRef> for &PolynomialRef +where + S: Sub + Copy, +{ + type Output = Polynomial; + + fn sub(self, rhs: &PolynomialRef) -> Self::Output { + assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len()); + + let coeffs = self + .coeffs() + .as_ref() + .iter() + .zip(rhs.coeffs().as_ref().iter()) + .map(|(a, b)| *a - *b) + .collect::>(); + + Polynomial { data: coeffs } + } +} + +impl SubAssign<&PolynomialRef> for PolynomialRef +where + S: SubAssign + Copy, +{ + fn sub_assign(&mut self, rhs: &PolynomialRef) { + polynomial_sub_assign(self, rhs) + } +} + +impl Mul<&PolynomialRef> for &PolynomialRef> +where + S: TorusOps, +{ + type Output = Polynomial>; + + /// External product of T\[X\]/f * Z\[X\]/f + /// TODO: use NTT to do in nlog(n) time. + fn mul(self, rhs: &PolynomialRef) -> Self::Output { + assert_eq!(rhs.len(), self.len()); + + let mut c = Polynomial { + data: vec![Torus::zero(); rhs.len()], + }; + + polynomial_external_mad(&mut c, self, rhs); + + c + } +} + +#[cfg(test)] +mod tests { + use std::ops::Deref; + + use crate::{entities::Polynomial, Torus}; + + #[test] + fn can_add_polynomials() { + let a = Polynomial::new(&[1, 2, 3]); + + let b = Polynomial::new(&[4, 5, 6]); + + let expected = Polynomial::new(&[5, 7, 9]); + + let c = a.deref() + b.deref(); + assert_eq!(c, expected); + } + + // A golden test but easier than recoding the logic again. + #[test] + fn can_multiply_by_positive_monomial_negacyclic() { + let original = Polynomial::new(&[1, 2, 3, 4].map(Torus::::from)); + + let mut shift_0 = original.clone(); + shift_0.mul_by_positive_monomial_negacyclic(0); + let expected_0 = original.clone(); + + assert_eq!(shift_0, expected_0); + + let mut shift_1 = original.clone(); + shift_1.mul_by_positive_monomial_negacyclic(1); + let expected_1 = Polynomial::new(&[ + Torus::from(4u64.wrapping_neg()), + Torus::from(1), + Torus::from(2), + Torus::from(3), + ]); + + assert_eq!(shift_1, expected_1); + + let mut shift_2 = original.clone(); + shift_2.mul_by_positive_monomial_negacyclic(2); + let expected_2 = Polynomial::new(&[ + Torus::from(3u64.wrapping_neg()), + Torus::from(4u64.wrapping_neg()), + Torus::from(1), + Torus::from(2), + ]); + + assert_eq!(shift_2, expected_2); + + let mut shift_3 = original.clone(); + shift_3.mul_by_positive_monomial_negacyclic(3); + let expected_3 = Polynomial::new(&[ + Torus::from(2u64.wrapping_neg()), + Torus::from(3u64.wrapping_neg()), + Torus::from(4u64.wrapping_neg()), + Torus::from(1), + ]); + + assert_eq!(shift_3, expected_3); + + let mut shift_4 = original.clone(); + shift_4.mul_by_positive_monomial_negacyclic(4); + let expected_4 = Polynomial::new(&[ + Torus::from(1u64.wrapping_neg()), + Torus::from(2u64.wrapping_neg()), + Torus::from(3u64.wrapping_neg()), + Torus::from(4u64.wrapping_neg()), + ]); + + assert_eq!(shift_4, expected_4); + + let mut shift_5 = original.clone(); + shift_5.mul_by_positive_monomial_negacyclic(5); + let expected_5 = Polynomial::new(&[ + Torus::from(4u64), + Torus::from(1u64.wrapping_neg()), + Torus::from(2u64.wrapping_neg()), + Torus::from(3u64.wrapping_neg()), + ]); + + assert_eq!(shift_5, expected_5); + + let mut shift_6 = original.clone(); + shift_6.mul_by_positive_monomial_negacyclic(6); + let expected_6 = Polynomial::new(&[ + Torus::from(3u64), + Torus::from(4u64), + Torus::from(1u64.wrapping_neg()), + Torus::from(2u64.wrapping_neg()), + ]); + + assert_eq!(shift_6, expected_6); + + let mut shift_7 = original.clone(); + shift_7.mul_by_positive_monomial_negacyclic(7); + let expected_7 = Polynomial::new(&[ + Torus::from(2u64), + Torus::from(3u64), + Torus::from(4u64), + Torus::from(1u64.wrapping_neg()), + ]); + + assert_eq!(shift_7, expected_7); + + let mut shift_8 = original.clone(); + shift_8.mul_by_positive_monomial_negacyclic(8); + let expected_8 = original.clone(); + + assert_eq!(shift_8, expected_8); + + let mut shift_9 = original.clone(); + shift_9.mul_by_positive_monomial_negacyclic(9); + let expected_9 = shift_1.clone(); + + assert_eq!(shift_9, expected_9); + } + + #[test] + fn can_multiply_by_negative_monomial_negacyclic() { + let original = Polynomial::new(&[1, 2, 3, 4].map(Torus::::from)); + + let mut shift_0 = original.clone(); + shift_0.mul_by_negative_monomial_negacyclic(0); + let expected_0 = original.clone(); + + assert_eq!(shift_0, expected_0); + + let mut shift_1 = original.clone(); + shift_1.mul_by_negative_monomial_negacyclic(1); + let expected_1 = Polynomial::new(&[ + Torus::from(2), + Torus::from(3), + Torus::from(4), + Torus::from(1u64.wrapping_neg()), + ]); + + assert_eq!(shift_1, expected_1); + + let mut shift_2 = original.clone(); + shift_2.mul_by_negative_monomial_negacyclic(2); + let expected_2 = Polynomial::new(&[ + Torus::from(3), + Torus::from(4), + Torus::from(1u64.wrapping_neg()), + Torus::from(2u64.wrapping_neg()), + ]); + + assert_eq!(shift_2, expected_2); + + let mut shift_3 = original.clone(); + shift_3.mul_by_negative_monomial_negacyclic(3); + let expected_3 = Polynomial::new(&[ + Torus::from(4), + Torus::from(1u64.wrapping_neg()), + Torus::from(2u64.wrapping_neg()), + Torus::from(3u64.wrapping_neg()), + ]); + + assert_eq!(shift_3, expected_3); + + let mut shift_4 = original.clone(); + shift_4.mul_by_negative_monomial_negacyclic(4); + let expected_4 = Polynomial::new(&[ + Torus::from(1u64.wrapping_neg()), + Torus::from(2u64.wrapping_neg()), + Torus::from(3u64.wrapping_neg()), + Torus::from(4u64.wrapping_neg()), + ]); + + assert_eq!(shift_4, expected_4); + + let mut shift_5 = original.clone(); + shift_5.mul_by_negative_monomial_negacyclic(5); + let expected_5 = Polynomial::new(&[ + Torus::from(2u64.wrapping_neg()), + Torus::from(3u64.wrapping_neg()), + Torus::from(4u64.wrapping_neg()), + Torus::from(1), + ]); + + assert_eq!(shift_5, expected_5); + + let mut shift_6 = original.clone(); + shift_6.mul_by_negative_monomial_negacyclic(6); + let expected_6 = Polynomial::new(&[ + Torus::from(3u64.wrapping_neg()), + Torus::from(4u64.wrapping_neg()), + Torus::from(1), + Torus::from(2), + ]); + + assert_eq!(shift_6, expected_6); + + let mut shift_7 = original.clone(); + shift_7.mul_by_negative_monomial_negacyclic(7); + let expected_7 = Polynomial::new(&[ + Torus::from(4u64.wrapping_neg()), + Torus::from(1), + Torus::from(2), + Torus::from(3), + ]); + + assert_eq!(shift_7, expected_7); + + let mut shift_8 = original.clone(); + shift_8.mul_by_negative_monomial_negacyclic(8); + let expected_8 = original.clone(); + + assert_eq!(shift_8, expected_8); + + let mut shift_9 = original.clone(); + shift_9.mul_by_negative_monomial_negacyclic(9); + let expected_9 = shift_1.clone(); + + assert_eq!(shift_9, expected_9); + } +} diff --git a/sunscreen_tfhe/src/entities/polynomial_fft.rs b/sunscreen_tfhe/src/entities/polynomial_fft.rs new file mode 100644 index 000000000..5208d7b78 --- /dev/null +++ b/sunscreen_tfhe/src/entities/polynomial_fft.rs @@ -0,0 +1,158 @@ +use num::{traits::MulAdd, Complex}; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + fft::negacyclic::get_fft, + scratch::allocate_scratch, + FrequencyTransform, FromF64, NumBits, PolynomialDegree, +}; + +use super::PolynomialRef; + +dst! { + /// The FFT of a polynomial. See [`Polynomial`](crate::entities::Polynomial) + /// for the non-FFT variant. + PolynomialFft, + PolynomialFftRef, + NoWrapper, + (Debug, Clone, PartialEq, Eq,), + () +} +dst_iter!( + PolynomialFftIterator, + PolynomialFftIteratorMut, + NoWrapper, + PolynomialFftRef, + () +); + +impl OverlaySize for PolynomialFftRef +where + T: Clone, +{ + type Inputs = PolynomialDegree; + + fn size(t: Self::Inputs) -> usize { + t.0 / 2 + } +} + +impl PolynomialFft +where + T: Clone, +{ + /// Create a new polynomial with the given length in the fourier domain. + pub fn new(data: &[T]) -> Self { + Self { + data: data.to_owned(), + } + } +} + +impl PolynomialFftRef +where + T: Clone, +{ + /// Returns the coefficients of the polynomial in the fourier domain. + pub fn coeffs(&self) -> &[T] { + &self.data + } + + /// Returns the mutable coefficients of the polynomial in the fourier domain. + pub fn coeffs_mut(&mut self) -> &mut [T] { + &mut self.data + } + + /// Returns the number of coefficients in the polynomial. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns true if the polynomial has no coefficients. + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } +} + +impl PolynomialFftRef> { + /// Compute the inverse FFT of the polynomial. + pub fn ifft(&self, poly: &mut PolynomialRef) + where + T: Clone + FromF64 + NumBits, + { + assert!(self.len().is_power_of_two()); + assert_eq!(self.len() * 2, poly.len()); + + let log_n = poly.len().ilog2() as usize; + + let fft = get_fft(log_n); + + let mut ifft = allocate_scratch::(poly.len()); + let ifft = ifft.as_mut_slice(); + + fft.reverse(&self.data, ifft); + + // When the exponent != 0 && exponent != 1024, + // IEEE-754 doubles are represented as -1**s * 1.m * 2**(e - 1023). + // + // m is 52 bits, e is 11 bits, and s is 1 bit. + // + // Thus, to compute 2**x, we set e = 1023 + x, m=0, and s = 0. So, we just + // need to fill in EXP and shift it up 52 places. + // + // We first reduce modulo q + let exp: u64 = 1023 + T::BITS as u64; + let q: f64 = f64::from_bits(exp << 52); + + let exp_div_2 = exp - 1; + let q_div_2 = f64::from_bits(exp_div_2 << 52); + + // Exploit the fact that q is a power of 2 when performing the modulo + // reduction. Could possibly be even faster by masking and shifting + // the mantissa and tweaking the exponent. However, profiling on ARM + // indicates this is no longer a bottleneck with the code below. + // + // See https://stackoverflow.com/questions/49139283/are-there-any-numbers-that-enable-fast-modulo-calculation-on-floats + // + // Don't know why Rust decides not to inline this. Inlining allows + // the below loop to get unrolled, vectorized, and division gets + // replaced with multiplication since q is a known constant. + #[inline(always)] + fn mod_q(val: f64, q: f64) -> f64 { + f64::mul_add(-(val / q).trunc(), q, val) + } + + for (o, ifft) in poly.coeffs_mut().iter_mut().zip(ifft.iter()) { + let mut ifft = mod_q(*ifft, q); + + // Next, we need to adjust x outside [-q/2, q/2) to wrap to the correct torus + // point. + if ifft >= q_div_2 { + ifft -= q; + } else if ifft <= -q_div_2 { + ifft += q; + } + + *o = T::from_f64(ifft); + } + } + + /// Computes the multiplication of two polynomials as `c += a * b`. This is + /// more efficient than the naive method, and has a runtime of O(N). Note + /// that performing the FFT and IFFT to get in and out of the fourier domain + /// costs O(N log N). + pub fn multiply_add( + &mut self, + a: &PolynomialFftRef>, + b: &PolynomialFftRef>, + ) { + for ((c, a), b) in self + .coeffs_mut() + .iter_mut() + .zip(a.coeffs().iter()) + .zip(b.coeffs().iter()) + { + *c = a.mul_add(b, c); + } + } +} diff --git a/sunscreen_tfhe/src/entities/polynomial_list.rs b/sunscreen_tfhe/src/entities/polynomial_list.rs new file mode 100644 index 000000000..67b07c593 --- /dev/null +++ b/sunscreen_tfhe/src/entities/polynomial_list.rs @@ -0,0 +1,52 @@ +use num::Zero; + +use crate::{ + dst::{NoWrapper, OverlaySize}, + PolynomialDegree, +}; + +use super::{PolynomialIterator, PolynomialIteratorMut, PolynomialRef}; + +dst! { + /// A list of polynomials. + PolynomialList, + PolynomialListRef, + NoWrapper, + (Debug, Clone, PartialEq, Eq,), + () +} + +impl OverlaySize for PolynomialListRef { + type Inputs = (PolynomialDegree, usize); + + fn size(t: Self::Inputs) -> usize { + PolynomialRef::::size(t.0) * t.1 + } +} + +impl PolynomialList +where + S: Clone + Zero, +{ + /// Create a new polynomial list, where each polynomial has the same degree. + pub fn new(degree: PolynomialDegree, count: usize) -> Self { + Self { + data: vec![S::zero(); degree.0 * count], + } + } +} + +impl PolynomialListRef +where + S: Clone + Zero, +{ + /// Iterate over the polynomials in the list. + pub fn iter(&self, degree: PolynomialDegree) -> PolynomialIterator { + PolynomialIterator::new(&self.data, PolynomialRef::::size(degree)) + } + + /// Iterate over the polynomials in the list mutably. + pub fn iter_mut(&mut self, degree: PolynomialDegree) -> PolynomialIteratorMut { + PolynomialIteratorMut::new(&mut self.data, PolynomialRef::::size(degree)) + } +} diff --git a/sunscreen_tfhe/src/entities/private_functional_keyswitch_key.rs b/sunscreen_tfhe/src/entities/private_functional_keyswitch_key.rs new file mode 100644 index 000000000..4ecd5df76 --- /dev/null +++ b/sunscreen_tfhe/src/entities/private_functional_keyswitch_key.rs @@ -0,0 +1,134 @@ +use serde::{Deserialize, Serialize}; +use sunscreen_math::Zero; + +use crate::{ + dst::OverlaySize, + entities::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef}, + GlweDef, GlweDimension, LweDef, LweDimension, PrivateFunctionalKeyswitchLweCount, RadixCount, + RadixDecomposition, Torus, TorusOps, +}; + +use super::LweSecretKeyRef; + +dst! { + /// Key for Private Functional Key Switching. See + /// [`module`](crate::ops::keyswitch::private_functional_keyswitch) + /// documentation for more details. + PrivateFunctionalKeyswitchKey, + PrivateFunctionalKeyswitchKeyRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} +dst_iter!( + PrivateFunctionalKeyswitchKeyIter, + PrivateFunctionalKeyswitchKeyIterMut, + Torus, + PrivateFunctionalKeyswitchKeyRef, + (TorusOps,) +); + +impl OverlaySize for PrivateFunctionalKeyswitchKeyRef { + type Inputs = ( + LweDimension, + GlweDimension, + RadixCount, + PrivateFunctionalKeyswitchLweCount, + ); + + fn size(t: Self::Inputs) -> usize { + GlevCiphertextRef::::size((t.1, t.2)) * (LweSecretKeyRef::::size(t.0) + 1) * t.3 .0 + } +} + +impl PrivateFunctionalKeyswitchKey { + /// Construct a new uninitialized [`PrivateFunctionalKeyswitchKey`]. This key is used + /// compute a secret function mapping `lwe_count` + /// [`LweCiphertext`](crate::entities::LweCiphertext)s to a + /// [`GlweCiphertext`](crate::entities::GlweCiphertext). + /// + /// # Remarks + /// The key is composed of a `(from_lwe.dim + 1) * lwe_count.0` matrix of + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. The leading + /// dimension contains radix-scaled encryptions of the bits in a + /// [`LweSecretKey`](crate::entities::LweSecretKey) followed by a scaled + /// encryption of -1. + /// + /// Plaintexts in the GLEV are scaled by the usual `q/beta^(j+1)`. + /// + /// The trailing dimension iterates over `0..lwe_count.0`. + pub fn new( + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + lwe_count: &PrivateFunctionalKeyswitchLweCount, + ) -> Self { + Self { + data: vec![ + Torus::zero(); + PrivateFunctionalKeyswitchKeyRef::::size(( + from_lwe.dim, + to_glwe.dim, + radix.count, + *lwe_count + )) + ], + } + } +} + +impl PrivateFunctionalKeyswitchKeyRef { + /// Returns an iterator over the + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s that + /// compose this key. + /// + /// # See also + /// To make sense of the layout, see also [`PrivateFunctionalKeyswitchKey::new()`](./struct.PrivateFunctionalKeyswitchKey.html#remarks). + pub fn glevs( + &self, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextIterator { + GlevCiphertextIterator::new( + self.as_slice(), + GlevCiphertextRef::::size((to_glwe.dim, radix.count)), + ) + } + + /// Returns a muitable iterator over the + /// [`GlevCiphertext`](crate::entities::GlevCiphertext)s that compose this + /// key. + /// + /// # See also + /// To make sense of the layout, see also [`PrivateFunctionalKeyswitchKey::new()`](./struct.PrivateFunctionalKeyswitchKey.html#remarks). + pub fn glevs_mut( + &mut self, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextIteratorMut { + GlevCiphertextIteratorMut::new( + self.as_mut_slice(), + GlevCiphertextRef::::size((to_glwe.dim, radix.count)), + ) + } + + #[inline(always)] + /// Assert this value is correct for the given parameters. + pub fn assert_valid( + &self, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + lwe_count: &PrivateFunctionalKeyswitchLweCount, + ) { + assert_eq!( + self.as_slice().len(), + PrivateFunctionalKeyswitchKeyRef::::size(( + from_lwe.dim, + to_glwe.dim, + radix.count, + *lwe_count + )) + ) + } +} diff --git a/sunscreen_tfhe/src/entities/public_functional_keyswitch_key.rs b/sunscreen_tfhe/src/entities/public_functional_keyswitch_key.rs new file mode 100644 index 000000000..3204de486 --- /dev/null +++ b/sunscreen_tfhe/src/entities/public_functional_keyswitch_key.rs @@ -0,0 +1,79 @@ +use serde::{Deserialize, Serialize}; +use sunscreen_math::Zero; + +use crate::{ + dst::OverlaySize, + entities::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef}, + GlweDef, GlweDimension, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps, +}; + +use super::GlweCiphertextRef; + +dst! { + /// Public Functional Key Switching Key. See + /// [`module`](crate::ops::keyswitch::public_functional_keyswitch) + /// documentation for more details. + PublicFunctionalKeyswitchKey, + PublicFunctionalKeyswitchKeyRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for PublicFunctionalKeyswitchKeyRef { + type Inputs = (LweDimension, GlweDimension, RadixCount); + + fn size(t: Self::Inputs) -> usize { + GlweCiphertextRef::::size(t.1) * t.0 .0 * t.2 .0 + } +} + +impl PublicFunctionalKeyswitchKey { + /// Construct a new uninitialized [`PublicFunctionalKeyswitchKey`]. This key is used + /// when performing a [`public_functional_keyswitch`](crate::ops::keyswitch::public_functional_keyswitch). + pub fn new(from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) -> Self { + let len = + PublicFunctionalKeyswitchKeyRef::::size((from_lwe.dim, to_glwe.dim, radix.count)); + + Self { + data: vec![Torus::zero(); len], + } + } +} + +impl PublicFunctionalKeyswitchKeyRef { + /// Iterate over the rows of the [`PublicFunctionalKeyswitchKey`]. + pub fn glevs( + &self, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextIterator { + let stride = GlevCiphertextRef::::size((to_glwe.dim, radix.count)); + + GlevCiphertextIterator::new(self.as_slice(), stride) + } + + /// Iterate over the rows of the [`PublicFunctionalKeyswitchKey`] mutably. + pub fn glevs_mut( + &mut self, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> GlevCiphertextIteratorMut { + let stride = GlevCiphertextRef::::size((to_glwe.dim, radix.count)); + + GlevCiphertextIteratorMut::new(self.as_mut_slice(), stride) + } + + /// Asserts that the key is valid for the given parameters. + pub(crate) fn assert_valid( + &self, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + ) { + assert_eq!( + self.as_slice().len(), + PublicFunctionalKeyswitchKeyRef::::size((from_lwe.dim, to_glwe.dim, radix.count)) + ); + } +} diff --git a/sunscreen_tfhe/src/entities/univariate_lookup_table.rs b/sunscreen_tfhe/src/entities/univariate_lookup_table.rs new file mode 100644 index 000000000..643ca30d4 --- /dev/null +++ b/sunscreen_tfhe/src/entities/univariate_lookup_table.rs @@ -0,0 +1,83 @@ +use serde::{Deserialize, Serialize}; +use sunscreen_math::Zero; + +use crate::{ + dst::{FromMutSlice, FromSlice, OverlaySize}, + entities::PolynomialRef, + ops::{bootstrapping::generate_lut, encryption::trivially_encrypt_glwe_ciphertext}, + scratch::allocate_scratch_ref, + GlweDef, GlweDimension, PlaintextBits, Torus, TorusOps, +}; + +use super::GlweCiphertextRef; + +dst! { + /// Lookup table for a univariate function used during + /// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap) + /// and [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap). + UnivariateLookupTable, + UnivariateLookupTableRef, + Torus, + (Clone, Debug, Serialize, Deserialize), + (TorusOps) +} + +impl OverlaySize for UnivariateLookupTableRef { + type Inputs = GlweDimension; + + fn size(t: Self::Inputs) -> usize { + GlweCiphertextRef::::size(t) + } +} + +impl UnivariateLookupTable { + /// Creates a lookup table that is trivially encrypted. + pub fn trivial_from_fn(map: F, glwe: &GlweDef, plaintext_bits: PlaintextBits) -> Self + where + F: Fn(u64) -> u64, + { + let mut lut = UnivariateLookupTable { + data: vec![Torus::zero(); UnivariateLookupTableRef::::size(glwe.dim)], + }; + + lut.fill_trivial_from_fn(map, glwe, plaintext_bits); + + lut + } +} + +impl UnivariateLookupTableRef { + /// Return the underlying GLWE representation of a lookup table. + pub fn glwe(&self) -> &GlweCiphertextRef { + GlweCiphertextRef::from_slice(&self.data) + } + + /// Return a mutable representation of the underlying GLWE representation of + /// a lookup table. + pub fn glwe_mut(&mut self) -> &mut GlweCiphertextRef { + GlweCiphertextRef::from_mut_slice(&mut self.data) + } + + /// Generates a look up table filled with the values from the provided map, + /// and trivially encrypts the lookup table. + pub fn fill_trivial_from_fn u64>( + &mut self, + map: F, + glwe: &GlweDef, + plaintext_bits: PlaintextBits, + ) { + allocate_scratch_ref!(poly, PolynomialRef>, (glwe.dim.polynomial_degree)); + + generate_lut(poly, map, glwe, plaintext_bits); + + trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe); + } + + /// Creates a lookup table filled with the same value at every entry. + pub fn fill_with_constant(&mut self, val: S, glwe: &GlweDef, plaintext_bits: PlaintextBits) { + self.clear(); + for o in self.glwe_mut().b_mut(glwe).coeffs_mut() { + *o = Torus::encode(val, plaintext_bits); + } + } +} diff --git a/sunscreen_tfhe/src/error.rs b/sunscreen_tfhe/src/error.rs new file mode 100644 index 000000000..bc89924cb --- /dev/null +++ b/sunscreen_tfhe/src/error.rs @@ -0,0 +1,4 @@ +#[derive(thiserror::Error)] +pub enum Error { + OutOfRange +} \ No newline at end of file diff --git a/sunscreen_tfhe/src/high_level.rs b/sunscreen_tfhe/src/high_level.rs new file mode 100644 index 000000000..1d0abfc3e --- /dev/null +++ b/sunscreen_tfhe/src/high_level.rs @@ -0,0 +1,939 @@ +#![allow(dead_code)] + +use crate::{ + rand::Stddev, GlweDef, GlweDimension, GlweSize, LweDef, LweDimension, PolynomialDegree, + RadixCount, RadixDecomposition, RadixLog, +}; + +#[doc(hidden)] +pub const TEST_RADIX: RadixDecomposition = RadixDecomposition { + count: RadixCount(3), + radix_log: RadixLog(4), +}; + +#[doc(hidden)] +pub const TEST_GLWE_DEF_1: GlweDef = GlweDef { + dim: GlweDimension { + polynomial_degree: PolynomialDegree(128), + size: GlweSize(2), + }, + std: Stddev(1e-16), +}; + +#[doc(hidden)] +pub const TEST_GLWE_DEF_2: GlweDef = GlweDef { + dim: GlweDimension { + polynomial_degree: PolynomialDegree(256), + size: GlweSize(3), + }, + std: Stddev(1e-16), +}; + +#[doc(hidden)] +pub const TEST_LWE_DEF_1: LweDef = LweDef { + dim: LweDimension(128), + std: Stddev(1e-16), +}; + +#[doc(hidden)] +pub const TEST_LWE_DEF_2: LweDef = LweDef { + dim: LweDimension(256), + std: Stddev(1e-16), +}; + +#[doc(hidden)] +pub const TEST_LWE_DEF_3: LweDef = LweDef { + dim: LweDimension(128), + std: Stddev(0.0), +}; + +/// TFHE functionality related to key generation. +pub mod keygen { + use crate::{ + entities::{ + BootstrapKey, CircuitBootstrappingKeyswitchKeys, GlweSecretKey, GlweSecretKeyRef, + LweKeyswitchKey, LwePublicKey, LweSecretKey, LweSecretKeyRef, + }, + ops::{ + bootstrapping::generate_bootstrap_key, + keyswitch::{ + lwe_keyswitch_key::generate_keyswitch_key_lwe, + private_functional_keyswitch::generate_circuit_bootstrapping_pfks_keys, + }, + }, + GlweDef, LweDef, RadixDecomposition, + }; + + /// Generate a new binary [`LweSecretKey`] under the given LWE parameters. + /// + /// # Remarks + /// Any functions that use this key will need the same [`LweDef`]. + /// + /// These keys may be used to create bootstrapping keys. + /// + /// # Panics + /// If [`LweDef`] is invalid. + /// + /// # Security + /// These keys are *not* secure under some threshold cryptography settings. + /// Under those settings, you should use [`generate_uniform_lwe_sk`]. + /// + /// This key is secret and care should be taken as to which parties + /// possess it. Anyone who possesses the returned [`LweSecretKey`] + /// can decrypt any messages encrypted under it. + pub fn generate_binary_lwe_sk(params: &LweDef) -> LweSecretKey { + LweSecretKey::generate_binary(params) + } + + /// Generate a new binary [`LweSecretKey`] under the given LWE parameters. + /// + /// # Remarks + /// Any functions that use this key will need the same [`LweDef`]. + /// + /// These keys may *not* directly be used to create bootstrapping keys. + /// However, in threshold schemes that use them, usually you derive + /// a binary key from uniform key shares. + /// + /// # Panics + /// If [`LweDef`] is invalid. + /// + /// # Security + /// This key is secret and care should be taken as to which parties + /// possess it. Anyone who possesses the returned [`LweSecretKey`] + /// can decrypt any messages encrypted under it. + pub fn generate_uniform_lwe_sk(params: &LweDef) -> LweSecretKey { + LweSecretKey::generate_uniform(params) + } + + /// Generate a new [`LwePublicKey`] under the given parameters. This + /// public key is paired with `sk` - that is messages encrypted under + /// this public key can be decrypted with `sk`. + /// + /// # Remarks + /// Any functions that use this key will need to use the same `params`. + /// + /// # Panics + /// If `params` is invalid + /// If `params` doesn't correspond with `sk`. + /// + /// # Security + /// This key is public and sharing it does not compromise semantic + /// security. + pub fn generate_lwe_pk(sk: &LweSecretKeyRef, params: &LweDef) -> LwePublicKey { + LwePublicKey::generate(sk, params) + } + + /// Generate a new GLWE secret key under the given GLWE parameters. + /// The key will consist of {0,1}-valued coefficients. + /// + /// # Remarks + /// Any functions that use this key will need the same [`GlweDef`]. + /// + /// # Panics + /// If [`GlweDef`] is invalid. + /// + /// # Security + /// Binary [GlweSecretKey]s are insecure in some threshold cryptography + /// settings. Under those settings, you should use + /// [`generate_uniform_glwe_sk`]. + /// + /// This key is secret and care should be taken as to which parties + /// possess it. Anyone who possesses the returned [`GlweSecretKey`] + /// can decrypt any messages encrypted under it. + pub fn generate_binary_glwe_sk(params: &GlweDef) -> GlweSecretKey { + GlweSecretKey::generate_binary(params) + } + + /// Generate a new GLWE secret key under the given GLWE parameters. + /// The key will consist of uniform coefficients over the Torus's + /// isomorphic ring. + /// + /// # Remarks + /// Any functions that use this key will need the same [`GlweDef`]. + /// + /// This is used over binary in threshold cryptography settings. + /// + /// # Panics + /// If [`GlweDef`] is invalid. + /// + /// # Security + /// This key is secret and care should be taken as to which parties + /// possess it. Anyone who possesses the returned [`GlweSecretKey`] + /// can decrypt any messages encrypted under it. + pub fn generate_uniform_glwe_sk(params: &GlweDef) -> GlweSecretKey { + GlweSecretKey::generate_uniform(params) + } + + /// Generate a new bootstrapping key, which is used in bootstrapping operations. + /// + /// See also + /// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap) + /// and [`circuit_bootstrap`](super::evaluation::circuit_bootstrap). + /// + /// # Remarks + /// A bootstrapping key is an encryption of an LWE secret key under a different + /// GLWE secret key reinterpreted as an LWE key. + /// + /// `lwe` and `glwe` are the LWE and GLWE parameters used when you generated the + /// [`LweSecretKey`] and [`GlweSecretKey`], respectively. + /// + /// `radix` specifies the decomposition to use during bootstrapping. + /// + /// You should use the same `lwe`, `glwe`, `radix` values here as when you + /// call + /// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap). + /// + /// The returned bootstrapping key is not immediately useful outside of serialization. + /// You need to FFT transform is first (see [fft_bootstrap_key](super::fft::fft_bootstrap_key)). + /// + /// ## Circuit bootstrapping + /// When using [`circuit_bootstrap`](super::evaluation::circuit_bootstrap), `pbs_radix` + /// must match this `radix`, `lwe_0` must match `lwe`, and `glwe_2` must match `glwe`. + /// + /// # Panics + /// If `lwe`, `glwe`, or `radix` are invalid. + /// If `glwe_key` isn't valid under `glwe`. + /// If `sk` isn't valid under `lwe`. + /// + /// # Security + /// The returned key is public and does not compromise semantic security. + /// However, anyone who possesses `glwe_key` can easily use the returned + /// [`BootstrapKey`] to recover `sk`. + pub fn generate_bootstrapping_key( + sk: &LweSecretKey, + glwe_key: &GlweSecretKey, + lwe: &LweDef, + glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> BootstrapKey { + let mut bsk = BootstrapKey::new(lwe, glwe, radix); + + generate_bootstrap_key(&mut bsk, sk, glwe_key, glwe, radix); + + bsk + } + + /// Generate an LWE keyswitch key. LWE keyswitching allows you take an encryption of `m` + /// under [LWESecretKey](crate::entities::LweSecretKey) `from_sk` and turn it into an + /// encryption of `m` under `to_sk`. + /// + /// # Remarks + /// `from_lwe` and `to_lwe` are the parameters under which you generated `from_sk` and + /// `to_sk`, respectively. + /// + /// `radix` specifies the decomposition to use during LWE keyswitching. + /// + /// TODO: mention keyswitch operation. + /// + /// # Panics + /// If the `from_lwe` parameters aren't valid for `from_sk`. + /// If the `to_lwe` parameters aren't valid for `to_sk`. + /// If `from_lwe`, `to_lwe`, or `radix` parameters are invalid. + /// + /// # Security + /// The returned key is public and sharing it does not compromise semantic + /// security. However, anyone who possesses `to_lwe` will effectively + /// be able to decrypt any message encrypted under `from_lwe` with the + /// returned [`LweKeyswitchKey`]. + pub fn generate_ksk( + from_sk: &LweSecretKeyRef, + to_sk: &LweSecretKeyRef, + from_lwe: &LweDef, + to_lwe: &LweDef, + radix: &RadixDecomposition, + ) -> LweKeyswitchKey { + let mut ksk = LweKeyswitchKey::new(from_lwe, to_lwe, radix); + + generate_keyswitch_key_lwe(&mut ksk, from_sk, to_sk, to_lwe, radix); + + ksk + } + + /// Generate a set of [`CircuitBootstrappingKeyswitchKeys`] to use during + /// [circuit_bootstrap](super::evaluation::circuit_bootstrap) operations. + /// + /// # Remarks + /// Internally, [`CircuitBootstrappingKeyswitchKeys`] is a list of + /// [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s. + /// + /// During [circuit_bootstrap](super::evaluation::circuit_bootstrap) operations, + /// these keys are used to convert [`LweCiphertext`](crate::entities::LweCiphertext)s + /// encrypted under `from_sk` into [`GlweCiphertext`](crate::entities::GlweCiphertext)s + /// encrypted under `to_sk`. These [`GlweCiphertext`](crate::entities::GlweCiphertext)s + /// together form a [`GgswCiphertext`](crate::entities::GgswCiphertext). + /// + /// The `from_lwe` and `to_glwe` parameters correspond to those used when you generated + /// `from_sk` and `to_sk`, respectively. + /// + /// The `radix` parameter describes the [`RadixDecomposition`] to use during the private + /// functional keyswitch operation. When performing a + /// [circuit_bootstrap](super::evaluation::circuit_bootstrap), these same parameters should + /// be passed as `pfks_radix`. + /// + /// # Panics + /// If `from_lwe`, `to_glwe`, or `radix` are invalid. + /// If `from_lwe` or `to_glwe` parameters don't correspond with `to_sk` or `to_glwe`, respectively. + /// + /// # Security + /// The returned [`CircuitBootstrappingKeyswitchKeys`] are public and + /// don't in of themselves compromise semantic security. However, + /// anyone who possesses `to_sk` can easily recover `from_sk` using this + /// information. + pub fn generate_cbs_ksk( + from_sk: &LweSecretKeyRef, + to_sk: &GlweSecretKeyRef, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> CircuitBootstrappingKeyswitchKeys { + let mut cbs_ksk = CircuitBootstrappingKeyswitchKeys::new(from_lwe, to_glwe, radix); + + generate_circuit_bootstrapping_pfks_keys( + &mut cbs_ksk, + from_sk, + to_sk, + from_lwe, + to_glwe, + radix, + ); + + cbs_ksk + } +} + +/// TFHE functionality related to encryption. +pub mod encryption { + use crate::{ + entities::{ + GgswCiphertext, GgswCiphertextRef, GlweCiphertext, GlweCiphertextRef, GlweSecretKeyRef, + LweCiphertext, LweCiphertextRef, LwePublicKeyRef, LweSecretKeyRef, Polynomial, + PolynomialRef, TlwePublicEncRandomness, + }, + ops::encryption::{encrypt_ggsw_ciphertext_scalar, trivially_encrypt_lwe_ciphertext}, + CarryBits, GlweDef, LweDef, PlaintextBits, RadixDecomposition, Torus, + }; + + /// Create an [`LweCiphertext`] encryption of `val` under + /// [LweSecretKey](crate::entities::LweSecretKey) `sk`. + /// + /// # Remarks + /// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk) to + /// generate an [`LweSecretKey`](crate::entities::LweSecretKey). + /// + /// `params` should be the same parameters used when generating the secret key. + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// will contain the message. `val` should not exceed `2^plaintext_bits.0`. + /// + /// # Panics + /// If `params` is invalid. + /// If `params` don't correspond with `sk`. + pub fn encrypt_lwe_secret( + val: u64, + sk: &LweSecretKeyRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> LweCiphertext { + sk.encrypt(val, params, plaintext_bits).0 + } + + /// Create a tuple containing an [`LweCiphertext`] encryption of `val` + /// under [LweSecretKey](crate::entities::LweSecretKey) `sk` and the + /// randomness used to generate it. + /// + /// This randomness can be used to produce zero-knowledge proofs. + /// + /// # Remarks + /// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk) + /// to generate an [`LweSecretKey`](crate::entities::LweSecretKey). + /// + /// `params` should be the same parameters used when generating the secret key. + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// will contain the message. `val` should not exceed `2^plaintext_bits.0`. + /// + /// # Panics + /// If `params` is invalid. + /// If `params` don't correspond with `sk`. + /// + /// # Security + /// Revealing the returned randomness compromises the confidentiality + /// of the returned [`LweCiphertext`]. + pub fn encrypt_lwe_secret_and_return_randomness( + val: u64, + sk: &LweSecretKeyRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> (LweCiphertext, Torus) { + sk.encrypt(val, params, plaintext_bits) + } + + /// Create an [LweCiphertext] encryption of `val` under the secret + /// key that pairs with `pk`. + /// + /// This function uses the [`LwePublicKey`](crate::entities::LwePublicKey), + /// which may be freely distributed without compromising security. + /// + /// # Remarks + /// Use [`generate_lwe_pk`](super::keygen::generate_lwe_pk) to generate + /// a [`LwePublicKey`](crate::entities::LwePublicKey). + /// + /// # Panics + /// If `params` is invalid. + /// If `params` doesn't correspond with `pk`. + pub fn encrypt_lwe( + val: u64, + pk: &LwePublicKeyRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> LweCiphertext { + pk.encrypt(val, params, plaintext_bits).0 + } + + /// Create a tuple containing an [`LweCiphertext`] encryption of `val` + /// under the [LweSecretKey](crate::entities::LweSecretKey) paired with + /// `pk` and the randomness used to generate it. + /// + /// This randomness can be used to produce zero-knowledge proofs. + /// + /// # Remarks + /// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk) + /// to generate an [`LweSecretKey`](crate::entities::LweSecretKey). + /// + /// `params` should be the same parameters used when generating the secret key. + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// will contain the message. `val` should not exceed `2^plaintext_bits.0`. + /// + /// # Panics + /// If `params` is invalid. + /// If `params` don't correspond with `pk`. + /// + /// # Security + /// Revealing the returned randomness compromises the confidentiality + /// of the returned [`LweCiphertext`]. + pub fn encrypt_lwe_and_return_randomness( + val: u64, + pk: &LwePublicKeyRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> (LweCiphertext, TlwePublicEncRandomness) { + pk.encrypt(val, params, plaintext_bits) + } + + /// Create a [`GlweCiphertext`] encryption of `pt` under `sk`. + /// + /// # Remarks + /// `params` should be a same as those used when creating `sk`. + /// `pt.len()` should equal `sk.dim.polynomial_degree.0`. + /// + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// polynomial coefficients will contain the message. No coefficient in + /// `pt` should exceed `2^plaintext_bits.0`. + /// + /// # Panics + /// If `params` is invalid. + /// If `params` doesn't correspond with `sk` + /// If `pt` doesn't have the same number of coefficients as + /// `params.dim.polynomial_degree.0`. + pub fn encrypt_glwe( + pt: &PolynomialRef, + sk: &GlweSecretKeyRef, + params: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> GlweCiphertext { + sk.encode_encrypt_glwe(pt, params, plaintext_bits) + } + + /// Create a trivial LWE encryption. Trivial encryptions have no noise and are thus + /// insecure. However, they are useful for creating public constants in TFHE computations. + /// + /// Trivial encryptions are valid encryptions of `val` under *every* LWE secret key + /// using the same `params`. + /// + /// # Remarks + /// `params` are the parameters under which to create this encryption. Note that homomorphic + /// operations require every operand use the same [`LweDef`] parameters. + /// + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// will contain the message. `val` should not exceed `2^plaintext_bits.0`. Truncation + /// of `val`'s most-significant bits will occur otherwise. + /// + /// # Panics + /// If `plaintext_bits > 63`. + /// If `params` are invalid. + /// + /// # Security + /// Trivial encryptions provide no cryptographic security. + pub fn trivial_lwe( + val: u64, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> LweCiphertext { + let mut ct = LweCiphertext::new(params); + + trivially_encrypt_lwe_ciphertext(&mut ct, &Torus::encode(val, plaintext_bits), params); + + ct + } + + /// Decrypt [LweCiphertext] `ct` encrypted under [LweSecretKey](crate::entities::LweSecretKey) + /// `sk`. Decode this decrypted value and return it. + /// + /// # Remarks + /// `params` must correspond with `ct` and `sk`. + /// + /// If a different `sk` is used than the one that produced `ct`, the result will be + /// garbage. + /// + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// will contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally, + /// you should use the same `plaintext_bits` that were used during encryption. + /// + /// # Panics + /// If `params` doesn't correspond with either `ct` or `sk`. + /// If `params` is invalid. + pub fn decrypt_lwe( + ct: &LweCiphertextRef, + sk: &LweSecretKeyRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + ) -> u64 { + sk.decrypt(ct, params, plaintext_bits) + } + + /// Decrypt [LweCiphertext] `ct` encrypted under [LweSecretKey](crate::entities::LweSecretKey) + /// `sk` and decode it, handling carry bits. + /// + /// # Remarks + /// `params` must correspond with `ct` and `sk`. + /// + /// If a different `sk` is used than the one that produced `ct`, the result will be + /// garbage. + /// + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// will contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally, + /// you should use the same `plaintext_bits` that were used during encryption. + /// + /// `carry_bits` describes how many of the most-significant bits of the + /// [`Torus`] will contain the carry. + /// + /// # Panics + /// If `params` doesn't correspond with either `ct` or `sk`. + /// If `params` is invalid. + pub fn decrypt_lwe_with_carry( + ct: &LweCiphertextRef, + sk: &LweSecretKeyRef, + params: &LweDef, + plaintext_bits: PlaintextBits, + carry_bits: CarryBits, + ) -> u64 { + let decrypted = sk.decrypt_without_decode(ct, params); + + // We manually decode here because the padding bit. + let plain_bits = plaintext_bits; + + let round_bit = decrypted + .inner() + .wrapping_shr(64 - plain_bits.0 - carry_bits.0 - 1) + & 0x1; + let mask = (0x1 << plain_bits.0) - 1; + + (decrypted + .inner() + .wrapping_shr(64 - plain_bits.0 - carry_bits.0) + + round_bit) + & mask + } + + /// Decrypt [GlweCiphertext] `ct` encrypted under [GlweSecretKey](crate::entities::GlweSecretKey) + /// `sk`. Decode this decrypted value and return it. + /// + /// # Remarks + /// `params` must correspond with `ct` and `sk`. + /// + /// If a different `sk` is used than the one that produced `ct`, the result will be + /// garbage. + /// + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] polynomial's + /// coefficients contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally, + /// you should use the same `plaintext_bits` that were used during encryption. + /// + /// # Panics + /// If `params` doesn't correspond with either `ct` or `sk`. + /// If `params` is invalid. + pub fn decrypt_glwe( + ct: &GlweCiphertextRef, + sk: &GlweSecretKeyRef, + params: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> Polynomial { + sk.decrypt_decode_glwe(ct, params, plaintext_bits) + } + + /// Create a trivial encryption of `pt` as a [GlweCiphertext]. Trivial encryptions contain + /// no noise and are thus insecure. However, they are useful as public constants in + /// a TFHE computation. + /// + /// # Remarks + /// Trivial encryption are valid under *every* [GlweSecretKey](crate::entities::GlweSecretKey) + /// using `params`. Note that homomorphic operations require every operand to use the same + /// `params` and secret key. + /// + /// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] + /// polynomial coefficients will contain the message. `val` should not exceed + /// `2^plaintext_bits.0`. Truncation of `val`'s most-significant bits will occur otherwise. + /// + /// # Panics + /// If `params` are invalid. + /// If `plaintext_bits >= 64`. + /// + /// # Security + /// Trivial encryptions provide no cryptographic security. + pub fn trivial_glwe( + pt: &PolynomialRef, + params: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> GlweCiphertext { + let mut result = GlweCiphertext::new(params); + + for (b_out, b_in) in result + .b_mut(params) + .coeffs_mut() + .iter_mut() + .zip(pt.coeffs().iter()) + { + *b_out = Torus::encode(*b_in, plaintext_bits); + } + + result + } + + /// Create a [`GgswCiphertext`] encrypting `msg` under + /// [GlweSecretKey](crate::entities::GlweSecretKey) `sk`. + /// + /// # Remarks + /// This encrypts `msg` as a constant coefficient polynomial. While [`GgswCiphertext`]s + /// theoretically support encrypting arbitrary polynomial messages, such ciphertexts have + /// no known uses in TFHE. + /// + /// Typically, you'll want to set `plaintext_bits` to 1 and encrypt a binary + /// `msg`. This allows you to use the [`GgswCiphertext`] as the select input to + /// a [`cmux`](super::evaluation::cmux) operation. + /// + /// `params` should match those under which you generated `sk`. + /// + /// `radix` defines how many decompositions to include in the result. Subsequent + /// [`cmux`](super::evaluation::cmux) operations using this result must use the + /// same `radix`. + /// + /// [`GgswCiphertext`]s are not immediate useful outside of serialization. You must + /// first take its Fourier transform using [`fft_ggsw`](super::fft::fft_ggsw) before + /// using it in a [`cmux`](super::evaluation::cmux) operation. + /// + /// # Panics + /// If `params` or `radix` are invalid. + /// If `params` doesn't correspond to `sk`. + /// If `plaintext_bits >= 64`. + pub fn encrypt_ggsw( + msg: u64, + sk: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, + ) -> GgswCiphertext { + let mut result = GgswCiphertext::new(params, radix); + + encrypt_ggsw_ciphertext_scalar(&mut result, msg, sk, params, radix, plaintext_bits); + + result + } + + /// Decrypt a [`GgswCiphertext`] encrypted under `sk`. + /// Since GGSW ciphertexts generally contain binary, you should + /// usually set `plaintext_bits` to 1. + /// + /// # Remarks + /// `params` should correspond with `ct` and `sk`. + /// `radix` should correspond with `ct`. + /// + /// # Panics + /// If `params` or `radix` are invalid. + /// If `params` or `radix` don't correspond to `ct` + /// If `params` don't correspond to `sk`. + pub fn decrypt_ggsw( + ct: &GgswCiphertextRef, + sk: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + _plaintext_bits: PlaintextBits, + ) -> Polynomial { + let mut msg = Polynomial::zero(params.dim.polynomial_degree.0); + + crate::ops::encryption::decrypt_ggsw_ciphertext(&mut msg, ct, sk, params, radix); + + msg.map(|x| x.inner()) + } +} + +/// Operations for producing Fourier-transformed versions of entities. +pub mod fft { + use num::Complex; + + use crate::{ + entities::{ + BootstrapKeyFft, BootstrapKeyRef, GgswCiphertextFft, GgswCiphertextRef, + GlweCiphertextFft, GlweCiphertextRef, + }, + GlweDef, LweDef, RadixDecomposition, + }; + + /// Take the fourier transform of a [`GlweCiphertext`](crate::entities::GlweCiphertext). + /// + /// # Remarks + /// `params` must be the same parameters that produce `ct`. + /// + /// # Panics + /// `params` is invalid. + /// `params` doesn't correspond with `ct`. + pub fn fft_glwe( + ct: &GlweCiphertextRef, + params: &GlweDef, + ) -> GlweCiphertextFft> { + let mut fft = GlweCiphertextFft::new(params); + + ct.fft(&mut fft, params); + + fft + } + + /// Take the fourier transform of a [`GgswCiphertext`](crate::entities::GgswCiphertext). + /// + /// # Remarks + /// `glwe` and `radix` must be the same parameters that produced `ggsw`. + /// + /// For [`GgswCiphertext`](crate::entities::GgswCiphertext)s that result from a + /// [`circuit_bootstrap`](super::evaluation::circuit_bootstrap) operation, these + /// must match `glwe_1` and `cbs_radix` respectively. + /// + /// # Panics + /// If `glwe` and `radix` don't correspond with `ggsw`. + /// If `glwe` or `radix` are invalid. + pub fn fft_ggsw( + ggsw: &GgswCiphertextRef, + glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> GgswCiphertextFft> { + let mut fft = GgswCiphertextFft::new(glwe, radix); + + ggsw.fft(&mut fft, glwe, radix); + + fft + } + + /// Take the fourier transform of a [BootstrapKey](crate::entities::BootstrapKey). + /// The resulting [`BootstrapKeyFft`] may be used in + /// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap) + /// and [`circuit_bootstrap`](super::evaluation::circuit_bootstrap) + /// operations. + /// + /// # Remarks + /// `glwe` and `radix` must be the same parameters that produced `bsk`. + /// + /// # Panics + /// If `glwe` and `radix` don't correspond with `bsk`. + /// If `glwe` or `radix` are invalid. + pub fn fft_bootstrap_key( + bsk: &BootstrapKeyRef, + lwe: &LweDef, + glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> BootstrapKeyFft> { + let mut bsk_fft = BootstrapKeyFft::new(lwe, glwe, radix); + + bsk.fft(&mut bsk_fft, glwe, radix); + + bsk_fft + } +} + +/// TFHE operations for performing computation. +pub mod evaluation { + use num::Complex; + + use crate::{ + entities::{ + BootstrapKeyFft, BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, + GgswCiphertext, GgswCiphertextFftRef, GlweCiphertext, GlweCiphertextRef, LweCiphertext, + LweCiphertextRef, LweKeyswitchKeyRef, UnivariateLookupTableRef, + }, + GlweDef, LweDef, RadixDecomposition, + }; + + /// Perform a multiplexing operation. When `b_fft` encrypts a zero polynomial, + /// the resulting [`GlweCiphertext`] will the same message as `d_0`. When `b_fft` + /// encrypts the 1 polynomial, the result will contain the same message as `d_1`. + /// + /// # Remarks + /// `b_fft`, `d_0`, and `d_1` must all be encrypted under the same + /// [`GlweSecretKey`](crate::entities::GlweSecretKey). This implies `params` must + /// correspond with all three values. + /// + /// Additionally, `radix` must correspond to `b_fft`. + /// + /// For + /// [`GgswCiphertext`] resulting from [`circuit_bootstrap`] operations, + /// `radix` must be the same as `cbs_radix` and `params` must be the same as + /// `glwe_1`. + /// + /// # Panics + /// If `params` doesn't correspond with `b_fft`, `d_0`, `d_1`. + /// If `radix` doesn't correspond with `b_fft`. + /// If `radix` or `params` are invalid. + pub fn cmux( + b_fft: &GgswCiphertextFftRef>, + d_0: &GlweCiphertextRef, + d_1: &GlweCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, + ) -> GlweCiphertext { + let mut result = GlweCiphertext::new(params); + + crate::ops::fft_ops::cmux(&mut result, d_0, d_1, b_fft, params, radix); + + result + } + + #[allow(clippy::too_many_arguments)] + /// Perform a programmable bootstrapping operation. Bootstrapping takes + /// `input` and produces a new ciphertext with a fixed noise level, applying + /// a univariate function defined by `lut` in the process. + /// + /// This new ciphertext is encrypted under a (usually) different + /// [`LweSecretKey`](crate::entities::LweSecretKey) defined by `glwe` interpreted + /// as an [`LweDef`]. + /// + /// See also [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable). + /// + /// To switch the message back to the original key, you need to perform an LWE + /// keyswitch operation. TODO: hotlink + /// + /// # Remarks + /// `input` must be valid under the `lwe` parameters. + /// `lwe`, `glwe`, and `radix` parameters must be the same as those used when + /// first creating the `bsk`. + /// `lut` must be valid under `glwe` parameters. + /// + /// # Panics + /// If `lwe`, `glwe`, or `radix` parameters are invalid. + /// If `input` doesn't correspond to `lwe` parameters. + /// If `bsk` doesn't correspond to `lwe`, `glwe`, `radix` parameters. + /// If `lut` doesn't correspond to `glwe` parameters. + pub fn univariate_programmable_bootstrap( + input: &LweCiphertextRef, + lut: &UnivariateLookupTableRef, + bsk: &BootstrapKeyFft>, + lwe: &LweDef, + glwe: &GlweDef, + radix: &RadixDecomposition, + ) -> LweCiphertext { + let mut out = LweCiphertext::new(&glwe.as_lwe_def()); + + crate::ops::bootstrapping::programmable_bootstrap( + &mut out, input, lut, bsk, lwe, glwe, radix, + ); + + out + } + + #[allow(clippy::too_many_arguments)] + /// Perform a circuit bootstrapping operation. Circuit bootstrapping takes + /// `input` [LweCiphertext] encrypted under a [LweSecretKey](crate::entities::LweSecretKey) + /// constructed with `lwe_0` parameters and produces a [GgswCiphertext] encrypted + /// under a [GlweSecretKey](crate::entities::GlweSecretKey) constructed with `glwe_1` + /// parameters. + /// + /// See also [generate_bootstrapping_key](super::keygen::generate_bootstrapping_key) and + /// [generate_cbs_pfks](super::keygen::generate_cbs_ksk) for how to generate the required + /// keys. + /// + /// # Remarks + /// Internally, circuit bootstrapping occurs in 2 steps. First we perform programmable + /// bootstraps (PBS) from `lwe_0` parameters to `glwe_2` reinterpreted as LWE parameters. We + /// do this `cbs_radix.count` times using univariate functions that map the message in + /// input to its corresponding radix decomposition. + /// + /// For step 2, we use private functional keyswitching (PFKS) to transform the + /// `cbs_radix.count` [LweCiphertext]s encrypted under `glwe_2` into + /// `cbs_radix.count * glwe_2.size + 1` [GlweCiphertext]s. The PFKS operations multiply each + /// [GlevCiphertext](crate::entities::GlevCiphertext) by the corresponding polynomial in + /// the `glwe_1` [GlweSecretKey](crate::entities::GlweSecretKey) to create a valid + /// [GgswCiphertext]. + /// + /// To summarize, we use PBS to turn an `lwe_0` LWE ciphertext into `glwe_2` LWE ciphertexts. + /// We then use PFKS to turn the `glwe_2` LWE ciphertexts into `glwe_1` GLWE ciphertexts + /// arranged as a valid [GgswCiphertext] encrypting the same value as `input` as a constant + /// coefficient polynomial. + /// + /// See [crate::ops::bootstrapping::circuit_bootstrap] for more details. + /// + /// `pbs_radix` parameterizes the bootstrapping operation (step 1). + /// + /// `pfks_radix` parameterizes the PFKS operation (step 2). These should + /// + /// `cbs_radix` parameterizes the final decomposition of the resulting [`GgswCiphertext`]. This + /// should match the radix used when creating the + /// [CircuitBootstrappingKeyswitchKeys](crate::entities::CircuitBootstrappingKeyswitchKeys) + /// + /// + /// # Panics + /// If `pbs_radix`, `cbs_radix`, `pfksk_radix`, `lwe_0`, `glwe_2`, or `glwe_1` are invalid. + /// If `pbs_radix`, `lwe_0`, `glwe_2` don't match the `radix`, `lwe`, `glwe` (respectively) used to generate `bsk`. + /// If `pfks_radix`, `glwe_2.as_lwe_def()`, `glwe_1` parameters don't match the `radix`, `from_lwe`, `to_lwe` + /// (respectively) used to generate `cbsksk`. + pub fn circuit_bootstrap( + input: &LweCiphertextRef, + bsk: &BootstrapKeyFftRef>, + cbsksk: &CircuitBootstrappingKeyswitchKeysRef, + lwe_0: &LweDef, + glwe_1: &GlweDef, + glwe_2: &GlweDef, + pbs_radix: &RadixDecomposition, + cbs_radix: &RadixDecomposition, + pfks_radix: &RadixDecomposition, + ) -> GgswCiphertext { + let mut out = GgswCiphertext::new(glwe_1, cbs_radix); + + crate::ops::bootstrapping::circuit_bootstrap( + &mut out, input, bsk, cbsksk, lwe_0, glwe_1, glwe_2, pbs_radix, cbs_radix, pfks_radix, + ); + + out + } + + /// Perform LWE keyswitching to produce a new [`LweCiphertext`] encrypted + /// under a different [`LweSecretKey`](crate::entities::LweSecretKey). + /// + /// # Remarks + /// When creating `ksk` with [`generate_ksk`](super::keygen::generate_ksk), + /// you pass 2 different [`LweSecretKey`](crate::entities::LweSecretKey) + /// values: a `from_sk` and `to_sk`. `ct` should be encrypted under the + /// same `from_sk` and ciphertext this function returns will be encrypted + /// under `to_sk`. + pub fn keyswitch_lwe_to_lwe( + ct: &LweCiphertextRef, + ksk: &LweKeyswitchKeyRef, + from_lwe: &LweDef, + to_lwe: &LweDef, + radix: &RadixDecomposition, + ) -> LweCiphertext { + let mut new_ct = LweCiphertext::new(to_lwe); + crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe( + &mut new_ct, + ct, + ksk, + from_lwe, + to_lwe, + radix, + ); + + new_ct + } +} diff --git a/sunscreen_tfhe/src/lib.rs b/sunscreen_tfhe/src/lib.rs new file mode 100644 index 000000000..39951b62b --- /dev/null +++ b/sunscreen_tfhe/src/lib.rs @@ -0,0 +1,35 @@ +//! TFHE low-level library + +#![deny(missing_docs)] +#![deny(rustdoc::broken_intra_doc_links)] + +#[macro_use] +mod dst; + +/// The entities module contains the main data structures used in the library. +pub mod entities; + +/// Higher level operations in TFHE such as programmable bootstrapping, keyswitching, etc. +pub mod ops; + +/// Parameters that define a TFHE scheme. +pub mod params; +pub use params::*; + +/// Math operations on various math primitives such as polynomials. +pub mod math; +pub use math::*; + +mod macros; + +/// Random number generation. +pub mod rand; +mod scratch; + +/// A high-level API for interfacing with TFHE. Allocates, computes with and returns +/// objects as you would expect from a Rust API. +pub mod high_level; + +/// Zero Knowledge proofs for TFHE. +#[cfg(feature = "logproof")] +pub mod zkp; diff --git a/sunscreen_tfhe/src/macros.rs b/sunscreen_tfhe/src/macros.rs new file mode 100644 index 000000000..842d48e6b --- /dev/null +++ b/sunscreen_tfhe/src/macros.rs @@ -0,0 +1,108 @@ +macro_rules! impl_binary_op { + ($op:ident, $type:ty, ($($t_bounds:ty),* $(,)? )) => { + paste::paste! { + + // Ex: AddAssign for LweSecretKey + impl std::ops::[<$op Assign>] for $type + where + S: $($t_bounds)* + { + fn [<$op:lower _assign>](&mut self, rhs: Self) { + self.data.iter_mut().zip(rhs.data.iter()).for_each(|(a, b)| { + *a = num::traits::[]::[](a, b); + }); + } + } + + // Ex: Add for LweSecretKey + // Calls Add for &LweSecretKeyRef + impl std::ops::$op for $type + where + S: TorusOps, + { + type Output = Self; + + fn [<$op:lower >](self, rhs: Self) -> Self::Output { + std::ops::$op::[< $op:lower >](self.as_ref(), rhs.as_ref()) + } + } + + // Ex: WrappingAdd for LweSecretKey + // Calls Add for &LweSecretKeyRef + impl num::traits::[] for $type + where + S: TorusOps, + { + fn [](&self, rhs: &Self) -> Self { + std::ops::$op::[< $op:lower >](self.as_ref(), rhs.as_ref()) + } + } + + // Ex: Add for &LweSecretKeyRef + // Calls AddAssign for LweSecretKey + impl std::ops::$op for &[<$type Ref>] + where + S: TorusOps, + { + type Output = $type; + + fn [< $op:lower >](self, rhs: Self) -> Self::Output { + let mut a = self.to_owned(); + std::ops::[< $op Assign >]::[< $op:lower _assign>](&mut a, rhs.to_owned()); + a + } + } + } + }; +} + +macro_rules! impl_unary_op { + ($op:ident, $type:ty) => { + paste::paste! { + + // Ex: Neg for LweSecretKey + // Calls Neg for &LweSecretKeyRef + impl std::ops::$op for $type + where + S: TorusOps, + { + type Output = Self; + + fn [<$op:lower>](self) -> Self::Output { + std::ops::$op::[<$op:lower>](self.as_ref()) + } + } + + // Ex: Neg for &LweSecretKeyRef + impl std::ops::$op for &[<$type Ref>] + where + S: TorusOps, + { + type Output = $type; + + fn [<$op:lower>](self) -> Self::Output { + // We call the wrapping trait instead of using the dot + // syntax because the dot syntax can dereference the value + // and can cause problems with Deref. + let data = self.data.iter().map(|a| num::traits::[]::[](a)).collect(); + + $type { data } + } + } + + // Ex: WrappingNeg for LweSecretKey + // Calls Neg for &LweSecretKeyRef + impl num::traits::[] for $type + where + S: TorusOps, + { + fn [](&self) -> Self { + std::ops::$op::[<$op:lower>](self.as_ref()) + } + } + } + }; +} + +pub(crate) use impl_binary_op; +pub(crate) use impl_unary_op; diff --git a/sunscreen_tfhe/src/math/basic.rs b/sunscreen_tfhe/src/math/basic.rs new file mode 100644 index 000000000..0f53ac39c --- /dev/null +++ b/sunscreen_tfhe/src/math/basic.rs @@ -0,0 +1,64 @@ +/// An integer type that supports rounding division. +pub trait RoundedDiv { + /// Divides two numbers and rounds the result to the nearest integer. + fn div_rounded(&self, divisor: Self) -> Self; +} + +macro_rules! div_rounded { + ($t:ty) => { + impl RoundedDiv for $t { + #[inline(always)] + fn div_rounded(&self, divisor: $t) -> $t { + // There are a few ways to do this, but we chose the following + // because it allows the entire range of a type to be used. The + // other common method is to add half the divisor to the + // numerator, but that effectively cuts the size of the possible + // inputs in half before overflow. + let q = self / divisor; + let r = self % divisor; + if r >= divisor / 2 { + q + 1 + } else { + q + } + } + } + }; +} + +div_rounded!(u8); +div_rounded!(u16); +div_rounded!(u32); +div_rounded!(u64); +div_rounded!(u128); +div_rounded!(usize); +div_rounded!(i8); +div_rounded!(i16); +div_rounded!(i32); +div_rounded!(i64); +div_rounded!(i128); +div_rounded!(isize); + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use super::*; + + #[test] + fn test_div_rounded() { + for _ in 0..1_000 { + let a = thread_rng().next_u64(); + let mut b = thread_rng().next_u64(); + + if b == 0 { + b = 1; + } + + let expected = ((a as f64) / (b as f64)).round() as u64; + let actual = a.div_rounded(b); + + assert_eq!(expected, actual); + } + } +} diff --git a/sunscreen_tfhe/src/math/fft/cyclic/mod.rs b/sunscreen_tfhe/src/math/fft/cyclic/mod.rs new file mode 100644 index 000000000..39a1382d4 --- /dev/null +++ b/sunscreen_tfhe/src/math/fft/cyclic/mod.rs @@ -0,0 +1,184 @@ +use std::ops::{Add, Mul, Neg}; +use std::sync::Arc; + +use num::{complex::Complex, Float}; +use realfft::{ComplexToReal, FftNum, RealFftPlanner, RealToComplex}; +use sunscreen_math::{One, Zero as SunscreenZero}; + +use crate::{FrequencyTransform, Inverse, Pow, RootOfUnity}; + +/// A struct that can perform a real FFT. +pub struct RealFft +where + T: FftNum, +{ + pub(crate) fplan: Arc>, + pub(crate) rplan: Arc>, + pub(crate) scale: T, +} + +impl RealFft +where + T: FftNum + Float, +{ + /// Create a new [RealFft] with the given size. + pub fn new(n: usize) -> Self { + let mut plan = RealFftPlanner::::new(); + let fplan = plan.plan_fft_forward(n); + let rplan = plan.plan_fft_inverse(n); + + let scale = T::from(1.0).unwrap() / T::from(n).unwrap(); + + Self { + fplan, + rplan, + scale, + } + } +} + +impl FrequencyTransform for RealFft +where + T: FftNum + Float, +{ + type BaseRepr = T; + type FrequencyRepr = Complex; + + fn forward(&self, data: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]) { + assert_eq!(data.len() / 2 + 1, output.len()); + + self.fplan.process(&mut data.to_owned(), output).unwrap(); + } + + fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]) { + self.rplan.process(&mut data.to_owned(), output).unwrap(); + + output.iter_mut().for_each(|x| { + *x = *x * self.scale; + }); + } +} + +/// A struct that can perform a NTT using a naive algorithm. +#[allow(unused)] +pub struct NaiveNtt { + twiddle: Vec, + inv_twiddle: Vec, + n_inv: T, +} + +impl NaiveNtt +where + T: RootOfUnity + + Copy + + SunscreenZero + + One + + Mul + + Add + + Neg + + Pow + + From + + Inverse, +{ + /// Create a new [NaiveNtt] with the given size. + #[allow(unused)] + pub fn new(n: usize) -> Self { + let mut twiddle = vec![]; + let mut inv_twiddle = vec![]; + let root = T::nth_root_of_unity(n as u64); + let inv_root = root.inverse(); + + for i in 0..n as u64 { + twiddle.push(root.pow(i)); + inv_twiddle.push(inv_root.pow(i)); + } + + Self { + twiddle, + inv_twiddle, + n_inv: T::from(n as u64).inverse(), + } + } +} + +#[cfg(test)] +mod tests { + use num::complex::ComplexFloat; + + use crate::FrequencyTransform; + + use super::*; + + #[test] + fn can_roundtrip_real_fft() { + let n = 256; + + let fft = RealFft::::new(n); + let input = (0..n).map(|x| x as f64).collect::>(); + let mut points = vec![Complex::from(0.0); input.len() / 2 + 1]; + let mut result = vec![0.0; input.len()]; + + fft.forward(&input, &mut points); + fft.reverse(&points, &mut result); + + for (l, r) in input.iter().zip(result.iter()) { + assert!((l - r).abs() < 1e-12); + } + } + + #[test] + fn negacyclic_gives_odd_harmonics() { + let n = 512; + let nega_len = n / 2; + + let fft = RealFft::::new(n); + let input = (0..n) + .enumerate() + .map(|(i, x)| { + let val = x % nega_len; + + if i < nega_len { + val as f64 + } else { + -(val as f64) + } + }) + .collect::>(); + + let mut points = vec![Complex::from(0.0); input.len() / 2 + 1]; + + fft.forward(&input, &mut points); + + for (i, x) in points.iter().enumerate() { + if i % 2 == 0 { + assert!(x.re().abs() < 1e-12); + assert!(x.im().abs() < 1e-12); + } else { + assert!(x.re().abs() > 1.0); + assert!(x.im().abs() > 1.0); + } + } + } + + #[test] + fn can_cyclic_convolution() { + let n = 4; + + let fft = RealFft::::new(n); + let x = (0..n).map(|x| x as f64).collect::>(); + let mut actual = vec![0.0; n]; + let mut y = vec![Complex::from(0.0); n / 2 + 1]; + + fft.forward(&x, &mut y); + + let z = y.iter().map(|y| y * y).collect::>(); + + fft.reverse(&z, &mut actual); + + let expected = [10.0, 12.0, 10.0, 4.0]; + + for (e, a) in expected.iter().zip(actual.iter()) { + assert!((e - a).abs() < 1e-12); + } + } +} diff --git a/sunscreen_tfhe/src/math/fft/mod.rs b/sunscreen_tfhe/src/math/fft/mod.rs new file mode 100644 index 000000000..04a5ecf97 --- /dev/null +++ b/sunscreen_tfhe/src/math/fft/mod.rs @@ -0,0 +1,5 @@ +/// FFT based operations over real numbers. +pub mod cyclic; + +/// FFT based operations over twisted cyclotomics. +pub mod negacyclic; diff --git a/sunscreen_tfhe/src/math/fft/negacyclic/mod.rs b/sunscreen_tfhe/src/math/fft/negacyclic/mod.rs new file mode 100644 index 000000000..540c79a1f --- /dev/null +++ b/sunscreen_tfhe/src/math/fft/negacyclic/mod.rs @@ -0,0 +1,176 @@ +use std::{ + f64::consts::PI, + sync::{Arc, Once}, +}; + +use num::{Complex, Float, One}; +use realfft::FftNum; +use rustfft::{Fft, FftPlanner}; + +use crate::{scratch::allocate_scratch, FrequencyTransform}; + +static FFT_CACHE_INIT: Once = Once::new(); +static mut FFT_CACHE: Vec> = vec![]; + +/// Get a [TwistedFft] for a given log N. +pub fn get_fft(log_n: usize) -> &'static TwistedFft { + // Can FFT powers of 2 from N=1 up to 4096. + assert!(log_n < 13); + + FFT_CACHE_INIT.call_once(|| { + for i in 0..13 { + unsafe { FFT_CACHE.push(TwistedFft::new(0x1 << i)) }; + } + }); + + unsafe { &FFT_CACHE[log_n] } +} + +/// Perform FFT with a twist so points can be used for +/// negacyclic convolution. +/// +/// # Remarks +/// See `` for algorithm. +pub struct TwistedFft +where + T: FftNum, +{ + fwd: Arc>, + rev: Arc>, + + twist: Vec>, + twist_inv: Vec>, +} + +impl TwistedFft +where + T: FftNum + Float, +{ + /// Create a new [TwistedFft] with the given size. + pub fn new(n: usize) -> Self { + assert!(n.is_power_of_two()); + + let n_2 = T::from(n * 2).unwrap(); + let k = n / 2; + + // The true length of the negacyclic sequence is 2N + let mut plan = FftPlanner::new(); + let fwd = plan.plan_fft_forward(k); + let rev = plan.plan_fft_inverse(k); + + let two_pi = T::from(PI).unwrap() * T::from(2.0).unwrap(); + + let w_2n = (Complex::from(two_pi / n_2) * Complex::i()).exp(); + + let twist = (0..k) + .map(|x| w_2n.powf(T::from(x).unwrap())) + .collect::>(); + + let twist_inv = twist + .iter() + .copied() + .map(|t| t.powf(-T::one())) + .collect::>(); + + debug_assert!(twist.iter().zip(twist_inv.iter()).all(|(a, b)| { + let a_a_inv = a * b; + let err = a_a_inv - Complex::one(); + + err.re.abs() < T::from(1e-12).unwrap() && err.im.abs() < T::from(1e-12).unwrap() + })); + + Self { + fwd, + rev, + twist, + twist_inv, + } + } +} + +impl FrequencyTransform for TwistedFft +where + T: FftNum + Float, +{ + type BaseRepr = T; + type FrequencyRepr = Complex; + + fn forward(&self, x: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]) { + assert_eq!(x.len(), self.fwd.len() * 2); + + let n_div_2 = x.len() / 2; + + for i in 0..n_div_2 { + output[i] = Complex::new(x[i], x[i + n_div_2]) * self.twist[i]; + } + + let mut scratch = allocate_scratch(self.fwd.get_inplace_scratch_len()); + let scratch_slice = scratch.as_mut_slice(); + + self.fwd.process_with_scratch(output, scratch_slice); + } + + fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]) { + assert_eq!(data.len(), self.rev.len()); + + let mut ifft = allocate_scratch(data.len()); + let ifft_slice = ifft.as_mut_slice(); + ifft_slice.copy_from_slice(data); + + let mut scratch = allocate_scratch(self.rev.get_inplace_scratch_len()); + let scratch_slice = scratch.as_mut_slice(); + + self.rev.process_with_scratch(ifft_slice, scratch_slice); + + let n_inv = T::one() / T::from(data.len()).unwrap(); + + for (i, x) in ifft_slice.iter().enumerate() { + let tmp = *x * n_inv * self.twist_inv[i]; + + output[i] = tmp.re.round(); + output[i + data.len()] = tmp.im.round(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn can_roundtrip_negacyclic_fft() { + let n = 8; + + let plan = TwistedFft::::new(n); + + let x = (0..n).map(|x| x as f64).collect::>(); + let mut y = vec![Complex::from(0.0); x.len() / 2]; + let mut actual = vec![0.0; x.len()]; + + plan.forward(&x, &mut y); + plan.reverse(&y, &mut actual); + + for (l, r) in actual.iter().zip(x.iter()) { + assert!((l - r).abs() < 1e-12); + } + } + + #[test] + fn can_negacyclic_conv() { + let n = 4; + + let plan = TwistedFft::::new(n); + + let x = (0..n).map(|x| x as f64).collect::>(); + let mut y = vec![Complex::from(0.0); x.len() / 2]; + let mut actual = vec![0.0; x.len()]; + + plan.forward(&x, &mut y); + + let z = y.iter().map(|y| y * y).collect::>(); + + plan.reverse(&z, &mut actual); + + assert_eq!(actual, vec![-10.0, -12.0, -8.0, 4.0]); + } +} diff --git a/sunscreen_tfhe/src/math/goldilocks_field.rs b/sunscreen_tfhe/src/math/goldilocks_field.rs new file mode 100644 index 000000000..666d375a4 --- /dev/null +++ b/sunscreen_tfhe/src/math/goldilocks_field.rs @@ -0,0 +1,359 @@ +use std::ops::{Add, Mul, Neg, Sub}; + +use num::traits::{WrappingAdd, WrappingMul, WrappingNeg, WrappingSub}; +use sunscreen_math::{refify_binary_op, One, Zero}; + +use crate::{Inverse, Pow, RootOfUnity}; + +/// 2^64 - 2^32 + 1 +pub const GOLDILOCKS_PRIME: u64 = 0xFFFFFFFF00000001; + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +#[repr(transparent)] +/// A value in the Goldilocks field (F_p where p = 2^64 - 2^32 + 1). +/// See +/// https://cp4space.hatsya.com/2021/09/01/an-efficient-prime-for-number-theoretic-transforms/ +/// for why this field is so magical. +pub struct Fg(u64); + +impl RootOfUnity for Fg { + fn nth_root_of_unity(n: u64) -> Self { + assert!( + n.is_power_of_two(), + "Goldilocks prime requires power of 2 < 2^32 for n when computing nth root of unity." + ); + + // See https://nufhe.readthedocs.io/en/latest/implementation_details.html for where this constant comes + // from. + const C: Fg = Fg(12037493425763644479); + + let exp = 0x1_0000_0000u64 / n; + + C.pow(exp) + } +} + +impl Fg { + /// Returns `x % GOLDILOCKS_PRIME` + pub fn new(x: u64) -> Self { + if x > GOLDILOCKS_PRIME { + Self(x - GOLDILOCKS_PRIME) + } else { + Self(x) + } + } + + #[inline] + pub fn unreduced_add(self, rhs: Self) -> Fg96 { + let (c, carry) = self.0.overflowing_add(rhs.0); + + Fg96 { + lo: c, + hi: carry.into(), + } + } + + #[inline] + pub fn unreduced_sub(self, rhs: Self) -> Fg96 { + self.unreduced_add(Fg(GOLDILOCKS_PRIME - rhs.0)) + } + + #[inline] + pub fn unreduced_mul(self, rhs: Self) -> Fg159 { + let res = self.0 as u128 * rhs.0 as u128; + + res.into() + } + + #[inline] + /// Compute `self * b + c` and don't reduce the result. + pub fn unreduced_mad(self, b: Self, c: Self) -> Fg159 { + let res = self.0 as u128 * b.0 as u128 + c.0 as u128; + + res.into() + } + + #[inline] + /// Compute `self * b + c`. + pub fn mad(self, b: Self, c: Self) -> Self { + self.unreduced_mad(b, c).reduce() + } +} + +impl From for Fg { + fn from(value: u64) -> Self { + Self::new(value) + } +} + +#[refify_binary_op] +impl Add<&Fg> for &Fg { + type Output = Fg; + + #[inline] + fn add(self, rhs: &Fg) -> Self::Output { + let (res, c) = self.0.overflowing_add(rhs.0); + + // If overflow occurred, then we have 1 carry bit. This gives us + // s = 1 * 2^64 + res = (2^32 - 1) + res. + // + // If overflow didn't occur, but res >= g, then we need to compute + // s - g = s + (2^32 + 1). Note that -g (mod 2^64) = 2^32 - 1 (mod 2^64). + // Thus, we can add (2^32 + 1) in both cases! + if c || res >= GOLDILOCKS_PRIME { + Fg(res.wrapping_add(0xFFFFFFFFu64)) + } else { + Fg(res) + } + } +} + +#[refify_binary_op] +impl Mul<&Fg> for &Fg { + type Output = Fg; + + #[inline] + fn mul(self, rhs: &Fg) -> Self::Output { + self.unreduced_mul(*rhs).reduce() + } +} + +#[refify_binary_op] +impl Sub<&Fg> for &Fg { + type Output = Fg; + + #[inline] + fn sub(self, rhs: &Fg) -> Self::Output { + let (res, c) = self.0.overflowing_sub(rhs.0); + + let offset = 0u32.wrapping_sub(u32::from(c)) as u64; + + // If we underflow, then we need to add g. + // Note that g (mod 2^64) = -(2^32 + 1) (mod 32) + // Thus, this is equivalent to subtracting -(2^32 + 1). + Fg(res.wrapping_sub(offset)) + } +} + +impl Neg for Fg { + type Output = Fg; + + fn neg(self) -> Self::Output { + Self(GOLDILOCKS_PRIME - self.0) + } +} + +impl WrappingAdd for Fg { + #[inline] + fn wrapping_add(&self, v: &Self) -> Self { + self + v + } +} + +impl WrappingMul for Fg { + #[inline] + fn wrapping_mul(&self, v: &Self) -> Self { + self * v + } +} + +impl WrappingSub for Fg { + #[inline] + fn wrapping_sub(&self, v: &Self) -> Self { + self - v + } +} + +impl WrappingNeg for Fg { + #[inline] + fn wrapping_neg(&self) -> Self { + -*self + } +} + +impl Zero for Fg { + #[inline] + fn vartime_is_zero(&self) -> bool { + self.0 == 0 + } + + #[inline] + fn zero() -> Self { + Fg(0) + } +} + +impl One for Fg { + #[inline] + fn one() -> Self { + Fg(1) + } +} + +impl Inverse for Fg { + fn inverse(&self) -> Self { + >::pow(self, GOLDILOCKS_PRIME - 2) + } +} + +#[derive(Copy, Clone, Debug)] +/// A number of the form hi << 64 + C, where B is 32-bit and C = 64-bit. +pub struct Fg96 { + lo: u64, + hi: u64, +} + +impl Fg96 { + #[inline] + /// Reduce the number mod [`GOLDILOCKS_PRIME`]. + /// + /// /// # Remarks + /// Note that 2^64 = 2^32 - 1 (mod g) + /// + /// Thus, we can rewrite as (2^32 - 1) * B + C. + pub fn reduce(self) -> Fg { + let prod = (self.hi << 32) - self.hi; + let mut res = prod.wrapping_add(self.lo); + + if res < prod || res >= GOLDILOCKS_PRIME { + res = res.wrapping_sub(GOLDILOCKS_PRIME); + } + + Fg(res) + } +} + +/// An unreduced value of 159 or fewer bits, where lo is 64-bit, mid is 32-bit and hi is 63-bit. +/// We can represent this value as `lo + mid << 64 + hi << 96`. +pub struct Fg159 { + lo: u64, + mid: u64, + hi: u64, +} + +impl Fg159 { + #[inline] + /// Reduce the number mod [`GOLDILOCKS_PRIME`]. + /// + /// /// # Remarks + /// Note that + /// * 2^64 = 2^32 - 1 (mod g) + /// * 2^96 = -1 + /// + /// Thus, we can rewrite as (2^32 - 1) * B + (C - A). + pub fn reduce(self) -> Fg { + let mut lo2 = self.lo.wrapping_sub(self.hi); + if self.hi > self.lo { + lo2 = lo2.wrapping_add(GOLDILOCKS_PRIME); + } + + let prod = (self.mid << 32) - self.mid; + let mut res = lo2.wrapping_add(prod); + + if res < prod || res >= GOLDILOCKS_PRIME { + res = res.wrapping_sub(GOLDILOCKS_PRIME); + } + + Fg(res) + } +} + +impl From for Fg159 { + #[inline(always)] + fn from(res: u128) -> Self { + Fg159 { + lo: res as u64, + mid: (res >> 64) as u32 as u64, + hi: (res >> 96) as u64, + } + } +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use super::*; + + #[test] + fn can_add_fg() { + for _ in 0..1000 { + let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + + let c = a + b; + let expected = ((a.0 as u128 + b.0 as u128) % GOLDILOCKS_PRIME as u128) as u64; + + assert_eq!(c, Fg(expected)); + } + } + + #[test] + fn can_sub_fg() { + for _ in 0..1000 { + let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + + let c = a - b; + let expected = ((a.0 as u128 + (GOLDILOCKS_PRIME - b.0) as u128) + % GOLDILOCKS_PRIME as u128) as u64; + + assert_eq!(c, Fg(expected)); + } + } + + #[test] + fn can_neg_fg() { + for _ in 0..1000 { + let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + + let c = -a; + let expected = Fg(0) - a; + + assert_eq!(c, expected); + } + } + + #[test] + fn can_mul_fg() { + fn test_case(a: u64, b: u64) { + let c = Fg(a) * Fg(b); + let expected = ((a as u128 * b as u128) % GOLDILOCKS_PRIME as u128) as u64; + + assert_eq!(c, Fg(expected)); + } + + test_case(GOLDILOCKS_PRIME - 1, GOLDILOCKS_PRIME - 1); + + for _ in 0..1000 { + test_case( + thread_rng().next_u64() % GOLDILOCKS_PRIME, + thread_rng().next_u64() % GOLDILOCKS_PRIME, + ); + } + } + + #[test] + fn can_mad_fg() { + for _ in 0..1000 { + let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + let c = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME); + + let acutal = a.mad(b, c); + let expected = + ((a.0 as u128 * b.0 as u128 + c.0 as u128) % GOLDILOCKS_PRIME as u128) as u64; + + assert_eq!(acutal, Fg(expected)); + } + } + + #[test] + fn nth_root_of_unity() { + for i in 1..16u64 { + let root = Fg::nth_root_of_unity(0x1 << i); + + assert_eq!(root.pow(0x1u64 << i), Fg::one()); + } + } +} diff --git a/sunscreen_tfhe/src/math/mod.rs b/sunscreen_tfhe/src/math/mod.rs new file mode 100644 index 000000000..43c255e7a --- /dev/null +++ b/sunscreen_tfhe/src/math/mod.rs @@ -0,0 +1,165 @@ +use std::ops::{Add, BitAnd, Mul, Shr}; + +pub use sunscreen_math::{One, Zero}; + +/// FFT based operations. +pub mod fft; + +mod goldilocks_field; + +/// Math operations on polynomials. +pub mod polynomial; + +/// Operations for performing radix decompositions. +pub mod radix; +mod torus; +pub use torus::*; + +mod basic; +pub use basic::*; + +/// Types where the roots of unity in the given field can be found. +pub trait RootOfUnity +where + Self: Sized, +{ + /// Find the n-th root of unity for the given field. + /// + /// # Remarks + /// `w` is an n-th root of unity if `w^n = 1`. + /// + /// # Panics + /// Implementers may choose to panic if no n-th root of unity + /// exists. You should take care in choosing your field modulus + /// to ensure the root does exist. + fn nth_root_of_unity(n: u64) -> Self; +} + +/// Numbers that have an inverse element. +pub trait Inverse +where + Self: Sized, +{ + /// Find the inverse of the given number. + fn inverse(&self) -> Self; +} + +/// Numbers that can be raised to a power. +pub trait Pow +where + Self: Sized, + T: Shr + BitAnd + One + Eq, +{ + /// Raise the number to the given power. + fn pow(&self, exp: T) -> Self; +} + +impl Pow for U +where + T: Shr + BitAnd + One + Eq + Copy, + U: sunscreen_math::One + Copy + Add + Mul, +{ + fn pow(&self, exp: T) -> U { + let mut result = U::one(); + let mut power = *self; + + for i in 0..64 { + if (exp >> i) & T::one() == T::one() { + result = result * power + } + + power = power * power; + } + + result + } +} + +/// Reinterpret the bits of an unsigned value as signed. +/// +/// # Remarks +/// This is different than a cast. For example, UINT_MAX becomes -1 +/// when interpreted as a 2's complement signed value. +pub trait ReinterpretAsSigned { + /// The output type of the reinterpretation. + type Output: ToF64; + + /// Reinterpret the bits of an unsigned value as signed. + fn reinterpret_as_signed(self) -> Self::Output; +} + +macro_rules! impl_reinterpret_signed { + ($ut:ty, $st:ty) => { + impl ReinterpretAsSigned for $ut { + type Output = $st; + + #[inline(always)] + fn reinterpret_as_signed(self) -> Self::Output { + unsafe { std::mem::transmute(self) } + } + } + }; +} + +impl_reinterpret_signed!(u8, i8); +impl_reinterpret_signed!(u16, i16); +impl_reinterpret_signed!(u32, i32); +impl_reinterpret_signed!(u64, i64); +impl_reinterpret_signed!(u128, i128); + +/// Reinterpret the bits of a signed value as unsigned. +/// +/// # Remarks +/// This is different than a cast. For example, -1 becomes UINT_MAX +/// when interpreted as an unsigned integer. +pub trait ReinterpretAsUnsigned { + /// The output type of the reinterpretation. + type Output; + + /// Reinterpret the bits of a signed value as unsigned. + fn reinterpret_as_unsigned(self) -> Self::Output; +} + +macro_rules! impl_reinterpret_unsigned { + ($st:ty, $ut:ty) => { + impl ReinterpretAsUnsigned for $st { + type Output = $ut; + + #[inline(always)] + fn reinterpret_as_unsigned(self) -> Self::Output { + unsafe { std::mem::transmute(self) } + } + } + }; +} + +impl_reinterpret_unsigned!(i8, u8); +impl_reinterpret_unsigned!(i16, u16); +impl_reinterpret_unsigned!(i32, u32); +impl_reinterpret_unsigned!(i64, u64); +impl_reinterpret_unsigned!(i128, u128); + +/// A type that can be converted from and to the fourier domain. +pub trait FrequencyTransform { + /// Original domain representation. + type BaseRepr; + + /// Fourier domain representation. + type FrequencyRepr; + + /// Perform a fourier transform. + fn forward(&self, data: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]); + + /// Perform an inverse fourier transform. + fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]); +} + +/// A trait that allows types to specify how many leading zeros or ones are in +/// the value. +pub trait LeadingBits { + /// Count the number of leading zeros in the value. + fn leading_zeros(self) -> u32; + + /// Count the number of leading ones in the value. + fn leading_ones(self) -> u32; +} diff --git a/sunscreen_tfhe/src/math/polynomial.rs b/sunscreen_tfhe/src/math/polynomial.rs new file mode 100644 index 000000000..94c998adc --- /dev/null +++ b/sunscreen_tfhe/src/math/polynomial.rs @@ -0,0 +1,353 @@ +use std::{ + num::Wrapping, + ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign}, +}; + +use num::traits::MulAdd; + +use crate::{ + dst::FromMutSlice, entities::PolynomialRef, scratch::allocate_scratch, ToF64, Torus, TorusOps, +}; + +/// Polynomial subtraction in place. This is equivalent to `a -= b` for each +/// coefficient in the polynomial. +pub fn polynomial_sub_assign(lhs: &mut PolynomialRef, rhs: &PolynomialRef) +where + S: SubAssign + Copy, +{ + for (a, b) in lhs + .coeffs_mut() + .iter_mut() + .zip(rhs.coeffs().iter().copied()) + { + *a -= b; + } +} + +/// Negate a polynomial in place. This is equivalent to `a = -a` for each +/// coefficient in the polynomial. +pub fn polynomial_negate(c: &mut PolynomialRef) +where + S: Clone + Copy + Neg, +{ + for c in c.coeffs_mut().iter_mut() { + *c = -*c; + } +} + +/// Compute `c = a * s` where `s` is scalar. +pub fn polynomial_scalar_mul(c: &mut PolynomialRef, a: &PolynomialRef, s: U) +where + S: Clone, + T: Clone + Copy + Mul, + U: Clone + Copy, +{ + for (c, a) in c.coeffs_mut().iter_mut().zip(a.coeffs().iter()) { + *c = *a * s + } +} + +/// Compute `c += a * s`, where `s` is scalar. +pub fn polynomial_scalar_mad(c: &mut PolynomialRef, a: &PolynomialRef, s: U) +where + S: Clone + Copy + Add, + T: Clone + Copy + MulAdd, + U: Clone + Copy, +{ + for (c, a) in c.coeffs_mut().iter_mut().zip(a.coeffs().iter()) { + *c = a.mul_add(s, *c); + } +} + +/// Compute `c = a + b` where a, b, and c are polynomials. +pub fn polynomial_add(c: &mut PolynomialRef, a: &PolynomialRef, b: &PolynomialRef) +where + S: Clone + Copy + Add, +{ + assert_eq!(c.len(), a.len()); + assert_eq!(c.len(), b.len()); + + for (c, (a, b)) in c + .as_mut_slice() + .iter_mut() + .zip(a.as_slice().iter().zip(b.as_slice().iter())) + { + *c = *a + *b; + } +} + +/// Polynomial addition in place. This is equivalent to `a += b` for each +/// coefficient in the polynomial. +pub fn polynomial_add_assign(lhs: &mut PolynomialRef, rhs: &PolynomialRef) +where + S: AddAssign + Copy, +{ + for (a, b) in lhs + .coeffs_mut() + .iter_mut() + .zip(rhs.coeffs().iter().copied()) + { + *a += b; + } +} + +/// Compute `c = a - b` where a, b, and c are polynomials. +pub fn polynomial_sub(c: &mut PolynomialRef, a: &PolynomialRef, b: &PolynomialRef) +where + S: Clone + Copy + Sub, +{ + assert_eq!(c.len(), a.len()); + assert_eq!(c.len(), b.len()); + + for (c, (a, b)) in c + .as_mut_slice() + .iter_mut() + .zip(a.as_slice().iter().zip(b.as_slice().iter())) + { + *c = *a - *b; + } +} + +/// Compute `c += a \[*\] b` where `a` in Z\[X\]/f and `c, b` in T\[X\]/f, and +/// \[*\] is the external product between these rings. +pub fn polynomial_external_mad( + c: &mut PolynomialRef>, + a: &PolynomialRef>, + b: &PolynomialRef, +) where + S: TorusOps, +{ + polynomial_mad_impl(c, a, b); +} + +/// Compute `c += a * b` where `*` the multiplication of the two polynomials of +/// degree (N - 1) modulo (X^N + 1). This is done with the naive algorithm, and +/// hence has O(N^2) time. +pub fn polynomial_mad( + c: &mut PolynomialRef>, + a: &PolynomialRef>, + b: &PolynomialRef>, +) where + S: TorusOps, + Wrapping: Sub, Output = Wrapping> + + Add, Output = Wrapping> + + Mul, Output = Wrapping>, +{ + polynomial_mad_impl(c, a, b); +} + +/// Compute `c += a * b` where `*` the multiplication of the two polynomials of +/// degree (N - 1) modulo (X^N + 1). This is done with the naive algorithm, and +/// hence has O(N^2) time. +fn polynomial_mad_impl( + c: &mut PolynomialRef, + a: &PolynomialRef, + b: &PolynomialRef, +) where + U: Clone + Copy + ToF64, + T: Mul + Clone + Copy + ToF64, + S: Sub + Add + Clone + Copy, +{ + assert!(a.len().is_power_of_two()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.len(), c.len()); + + let len = a.len(); + + // Polynomial's length is a power of 2, so use mask to perform modulus. + let mask = len - 1; + + let coeffs = c.coeffs_mut(); + + let mut poly_f64 = allocate_scratch::(a.len()); + let poly_f64_ref = PolynomialRef::from_mut_slice(poly_f64.as_mut_slice()); + + a.map_into(poly_f64_ref, |x| x.to_f64()); + + for (i, l) in a.coeffs().iter().copied().enumerate() { + for (j, r) in b.coeffs().iter().copied().enumerate() { + if i + j >= len { + // Reduce mod len and subtract due to negacyclic + // polynomials. + let index = (i + j) & mask; + coeffs[index] = coeffs[index] - l * r; + } else { + let index = i + j; + coeffs[index] = coeffs[index] + l * r; + } + } + } +} + +#[cfg(test)] +mod tests { + #[derive(BarrettConfig)] + #[barrett_config(modulus = "18446744073709551616", num_limbs = 2)] + pub struct U64Config; + + pub type Zq64 = Zq<2, BarrettBackend<2, U64Config>>; + + use crate::{ + entities::{Polynomial, PolynomialFft}, + normalized_torus_distance, ReinterpretAsSigned, ReinterpretAsUnsigned, + }; + + use super::*; + use num::{Complex, Zero}; + use rand::{thread_rng, RngCore}; + use sunscreen_math::{ + poly::Polynomial as BasePolynomial, + ring::{BarrettBackend, Zq}, + BarrettConfig, One, + }; + + #[test] + fn can_multiply_polynomials() { + // Compare our negacyclic polynomial implementation against + // sunscreen_math computing (a * b) mod (X^N + 1). + fn case(a: &PolynomialRef>, b: &PolynomialRef) { + let actual = a * b; + + let mut f = vec![::zero(); a.len() + 1]; + + let len = a.len(); + + f[0] = Zq64::one(); + f[a.len()] = Zq64::one(); + + let f = BasePolynomial { coeffs: f }; + + let a = BasePolynomial { + coeffs: a + .coeffs() + .iter() + .map(|x| Zq64::from(x.inner())) + .collect::>(), + }; + + let b = BasePolynomial { + coeffs: b.coeffs().iter().map(|x| Zq64::from(*x)).collect(), + }; + + let expected = a * b; + + let (_, expected) = expected.vartime_div_rem_restricted_rhs(&f); + + let expected = expected + .coeffs + .iter() + .take(len) + .map(|x| Torus::from(x.val.as_words()[0])) + .collect::>(); + + let expected = Polynomial::new(&expected); + + assert_eq!(actual, expected); + } + + for _ in 0..50 { + let len = thread_rng().next_u64() % 8 + 1; + let len = 0x1 << len; + + let a = (0..len) + .map(|_| Torus::from(thread_rng().next_u64())) + .collect::>(); + let a = Polynomial::new(&a); + + let b = (0..len) + .map(|_| thread_rng().next_u64()) + .collect::>(); + let b = Polynomial::new(&b); + + case(&a, &b); + } + } + + #[test] + fn can_roundtrip_polynomial() { + let poly = (0..1024u64).collect::>(); + let poly = Polynomial::new(&poly); + + let mut actual = Polynomial::::zero(poly.len()); + + let mut out = PolynomialFft::new(&vec![Complex::zero(); poly.len() / 2]); + + poly.fft(&mut out); + out.ifft(&mut actual); + + assert_eq!(poly, actual); + } + + #[test] + fn can_multiply_polynomials_fft() { + for _ in 0..100 { + let a = (0..1024u64) + .map(|x| Torus::from(x % 0x8000_0000)) + .collect::>(); + let a = Polynomial::new(&a); + let b = (0..1024u64).map(|x| x % 16).collect::>(); + let b = Polynomial::new(&b); + + let mut expected = Polynomial::>::zero(a.len()); + + polynomial_external_mad(&mut expected, &a, &b); + + let mut a_fft = PolynomialFft::new(&vec![Complex::zero(); a.len() / 2]); + let mut b_fft = a_fft.clone(); + let mut c_fft = a_fft.clone(); + + a.fft(&mut a_fft); + b.fft(&mut b_fft); + + c_fft.multiply_add(&a_fft, &b_fft); + + let mut actual = Polynomial::>::zero(a.len()); + c_fft.ifft(&mut actual); + + assert_eq!(actual, expected); + } + } + + #[test] + fn can_approx_multiply_large_polynomials_fft() { + let n = 1024; + + for _ in 0..100 { + // a is uniform torus elements, b is "small" as we'll encounter + // during radix decomposition. + let a = (0..n) + .map(|_| Torus::from(rand::thread_rng().next_u64())) + .collect::>(); + let a = Polynomial::new(&a); + let b = (0..n) + .map(|_| { + let signed = (rand::thread_rng().next_u64() % 32).reinterpret_as_signed() - 16; + + signed.reinterpret_as_unsigned() + }) + .collect::>(); + let b = Polynomial::new(&b); + + let mut expected = Polynomial::>::zero(a.len()); + + polynomial_external_mad(&mut expected, &a, &b); + + let mut a_fft = PolynomialFft::new(&vec![Complex::zero(); a.len() / 2]); + let mut b_fft = a_fft.clone(); + let mut c_fft = a_fft.clone(); + + a.fft(&mut a_fft); + b.fft(&mut b_fft); + + c_fft.multiply_add(&a_fft, &b_fft); + + let mut actual = Polynomial::>::zero(a.len()); + c_fft.ifft(&mut actual); + + for (a, e) in actual.coeffs().iter().zip(expected.coeffs().iter()) { + let err = normalized_torus_distance(a, e).abs(); + assert!(err < 1e-12); + } + } + } +} diff --git a/sunscreen_tfhe/src/math/radix.rs b/sunscreen_tfhe/src/math/radix.rs new file mode 100644 index 000000000..4e4c738a6 --- /dev/null +++ b/sunscreen_tfhe/src/math/radix.rs @@ -0,0 +1,399 @@ +use crate::{ + entities::{PolynomialIterator, PolynomialRef}, + math::{Torus, TorusOps}, + polynomial::polynomial_scalar_mad, + RadixCount, RadixDecomposition, RadixLog, +}; + +// Needed by allow_scratch_ref +#[allow(unused_imports)] +use crate::dst::FromMutSlice; + +/// An iterator from least to most significant radix decomposition of a value. +pub struct ScalarRadixIterator +where + S: TorusOps, +{ + cur: S, + level: usize, + radix: RadixDecomposition, +} + +impl ScalarRadixIterator { + /// Creates a new [`ScalarRadixIterator`] for the given [Torus] value. + #[inline(always)] + pub fn new(val: Torus, radix: &RadixDecomposition) -> Self { + Self { + cur: round(val, radix), + level: 0, + radix: *radix, + } + } +} + +#[inline(always)] +fn get_next_digit(cur: &mut S, radix_log: usize) -> S { + let mask = S::from_u64((0x1u64 << radix_log) - 1); + + // Interpreting the digits over [-B/2,B/2) reduces noise by half a bit on average. + let mut digit = *cur & mask; + *cur = *cur >> radix_log; + let carry = digit >> (radix_log - 1); + *cur = *cur + carry; + digit = digit.wrapping_sub(&(carry << radix_log)); + + digit +} + +impl Iterator for ScalarRadixIterator { + type Item = S; + + #[inline(always)] + fn next(&mut self) -> Option { + if self.level == self.radix.count.0 { + return None; + } + + let digit = get_next_digit(&mut self.cur, self.radix.radix_log.0); + + self.level += 1; + + Some(digit) + } +} + +/// An iterator from least to most significant radix decomposition of the coefficients +/// of a polynomial. +pub struct PolynomialRadixIterator<'a, S> +where + S: TorusOps, +{ + scratch: &'a mut PolynomialRef, + level: usize, + radix: RadixDecomposition, +} + +impl<'a, S> PolynomialRadixIterator<'a, S> +where + S: TorusOps, +{ + /// Creates a new [`PolynomialRadixIterator`] for the given polynomial. + pub fn new( + poly: &PolynomialRef>, + scratch: &'a mut PolynomialRef, + radix: &RadixDecomposition, + ) -> Self { + assert!(radix.radix_log.0 * radix.count.0 < S::BITS as usize); + assert_ne!(radix.radix_log.0 * radix.count.0, 0); + + poly.map_into(scratch, |x| round(*x, radix)); + + Self { + scratch, + level: 0, + radix: radix.to_owned(), + } + } + + /// Writes the next polynomial decomposition to `dst` and returns `Some(())` if there is a next digit. + pub fn write_next(&mut self, dst: &mut PolynomialRef) -> Option<()> { + if self.level == self.radix.count.0 { + return None; + } + + self.level += 1; + + for (s, r) in self + .scratch + .coeffs_mut() + .iter_mut() + .zip(dst.coeffs_mut().iter_mut()) + { + *r = get_next_digit(s, self.radix.radix_log.0); + } + + Some(()) + } +} + +/// Recomposes a polynomial from its `digits` decomposition and adds it to `dst`. +/// +/// # Remarks +/// The digits should iterate from least to most significant. +pub fn recompose_and_add( + dst: &mut PolynomialRef>, + digits: &mut PolynomialIterator, + radix: RadixLog, + count: RadixCount, +) where + S: TorusOps, +{ + let shift_amount = S::BITS as usize - radix.0 * count.0; + let mut cur_radix = S::from_u64(0x1 << shift_amount); + let mut actual_count = 0; + + for d in digits { + polynomial_scalar_mad(dst, d.as_torus(), cur_radix); + + actual_count += 1; + cur_radix = cur_radix << radix.0; + } + + assert_eq!(count.0, actual_count); +} + +#[inline(always)] +/// Multiply `val` by q / B^(j + 1). +pub fn scale_by_decomposition_factor( + val: S, + j: usize, + radix: &RadixDecomposition, +) -> S { + let shift = S::BITS as usize - radix.radix_log.0 * (j + 1); + let factor = S::one() << shift; + + val.wrapping_mul(&factor) +} + +#[inline(always)] +/// Rounds the given [`Torus`] element and returns the value interpreted as an integer. +fn round(x: Torus, radix: &RadixDecomposition) -> S { + let shift = S::BITS as usize - radix.radix_log.0 * radix.count.0; + let round_bit = (x.inner() >> (shift - 1)) & S::from_u64(0x1); + + (x.inner() >> shift).wrapping_add(&round_bit) +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use crate::{ + entities::{Polynomial, PolynomialList}, + rand::uniform_torus, + scratch::allocate_scratch_ref, + PlaintextBits, PolynomialDegree, + }; + + use super::*; + + #[test] + fn can_round_values() { + assert_eq!( + round( + Torus::from(0x12348FFF_FFFFFFFFu64), + &RadixDecomposition { + radix_log: RadixLog(4), + count: RadixCount(4) + } + ), + 0x1235 + ); + + assert_eq!( + round( + Torus::from(0x12347FFF_FFFFFFFFu64), + &RadixDecomposition { + radix_log: RadixLog(4), + count: RadixCount(4) + } + ), + 0x1234 + ); + } + + #[test] + fn can_decompose() { + let x = Polynomial::new(&[Torus::encode(7u64, PlaintextBits(4))]); + + allocate_scratch_ref!(scratch, PolynomialRef, (PolynomialDegree(x.len()))); + + let mut radix_iter = PolynomialRadixIterator::new( + &x, + scratch, + &RadixDecomposition { + radix_log: RadixLog(2), + count: RadixCount(2), + }, + ); + + let mut dst = Polynomial::zero(1); + + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(1)); + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(2)); + assert_eq!(radix_iter.write_next(&mut dst), None); + + let x = Polynomial::new(&[Torus::encode(1u64, PlaintextBits(1))]); + + let radix_log = 4; + let mut radix_iter = PolynomialRadixIterator::new( + &x, + scratch, + &RadixDecomposition { + radix_log: RadixLog(radix_log), + count: RadixCount(3), + }, + ); + + let mut dst = Polynomial::zero(1); + + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + assert_eq!(dst.coeffs()[0], 0u64); + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + assert_eq!(dst.coeffs()[0], 0u64); + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + + // 2^(beta - 1) - 2^beta + assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(1 << (radix_log - 1))); + assert_eq!(radix_iter.write_next(&mut dst), None); + } + + #[test] + fn can_decompose_polynomial() { + let x = Polynomial::new(&[ + // Decomposes to 1 [for 2^1] + 0 [for 2^2], or 1 + 0 + Torus::encode(1u64, PlaintextBits(4)), + // Decomposes to -2 [for 2^1] + 1 [for 2^2], or -2 + 4 + Torus::encode(2u64, PlaintextBits(4)), + // Decomposes to -1 [for 2^1] + 1 [for 2^2], or -1 + 4 + Torus::encode(3u64, PlaintextBits(4)), + // Decomposes to 0 [for 2^1] + 1 [for 2^2], or 0 + 4 + Torus::encode(4u64, PlaintextBits(4)), + ]); + + allocate_scratch_ref!(scratch, PolynomialRef, (PolynomialDegree(x.len()))); + + let mut radix_iter = PolynomialRadixIterator::new( + &x, + scratch, + &RadixDecomposition { + radix_log: RadixLog(2), + count: RadixCount(2), + }, + ); + + let mut dst = Polynomial::zero(4); + + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + assert_eq!( + dst.coeffs(), + [1u64, 0u64.wrapping_sub(2), 0u64.wrapping_sub(1), 0] + ); + assert_eq!(radix_iter.write_next(&mut dst), Some(())); + assert_eq!(dst.coeffs(), [0, 1, 1, 1]); + assert_eq!(radix_iter.write_next(&mut dst), None); + } + + #[test] + fn can_decompose_recompose() { + let d = PolynomialDegree(8); + + for _ in 0..50 { + let radix = (thread_rng().next_u32() % 7) + 1; + let radix = RadixLog(radix as usize); + + let count = loop { + let count = (thread_rng().next_u32() % 7) + 1; + + if (radix.0 * count as usize) < u64::BITS as usize { + break count; + } + }; + + let count = RadixCount(count as usize); + + let x = (0..d.0).map(|_| uniform_torus::()).collect::>(); + let x = Polynomial::new(&x); + + let expected = x.map(|c| { + let radix_bits = radix.0 * count.0; + let lsb = u64::BITS as usize - radix_bits; + let round_loc = lsb - 1; + + let round = (c.inner() >> round_loc) & 0x1; + let mask = 0xFFFFFFFF_FFFFFFFFu64 << lsb; + Torus::from((c.inner() & mask).wrapping_add(round << lsb)) + }); + + let mut digits = PolynomialList::new(d, count.0); + + allocate_scratch_ref!(scratch, PolynomialRef, (PolynomialDegree(x.len()))); + + let mut radix_iter = PolynomialRadixIterator::new( + &x, + scratch, + &RadixDecomposition { + radix_log: radix, + count, + }, + ); + + for d in digits.iter_mut(d) { + radix_iter.write_next(d); + } + + let mut result = Polynomial::zero(x.len()); + + recompose_and_add(&mut result, &mut digits.iter(d), radix, count); + + assert_eq!(expected, result); + } + } + + fn random_radix() -> RadixDecomposition { + let radix = (thread_rng().next_u32() % 7) + 1; + let radix = RadixLog(radix as usize); + + let count = loop { + let count = (thread_rng().next_u32() % 7) + 1; + + if (radix.0 * count as usize) < u64::BITS as usize { + break count; + } + }; + + let count = RadixCount(count as usize); + + RadixDecomposition { + radix_log: radix, + count, + } + } + + #[test] + fn can_decompose_scalar() { + for _ in 0..50 { + let radix = random_radix(); + let val = Torus::from(thread_rng().next_u64()); + /*let radix = RadixDecomposition { + count: RadixCount(3), + radix_log: RadixLog(4), + }; + let val = Torus::from(0xDEADBEEF_FEEDF00Du64); + */ + + let decomp = ScalarRadixIterator::new(val, &radix); + let mut digits = vec![]; + + for digit in decomp { + digits.push(digit); + } + + let actual = digits.iter().enumerate().fold(0u64, |s, (i, d)| { + let shift_amount = u64::BITS as usize - radix.radix_log.0 * (radix.count.0 - i); + let cur_radix = 0x1u64 << shift_amount; + + (cur_radix.wrapping_mul(*d)).wrapping_add(s) + }); + + // Round shifts our value down to the LSB places, so move it back up to compare + // against a torus element. + let expected = + round(val, &radix) << (u64::BITS as usize - radix.radix_log.0 * (radix.count.0)); + + assert_eq!(actual, expected); + } + } +} diff --git a/sunscreen_tfhe/src/math/torus.rs b/sunscreen_tfhe/src/math/torus.rs new file mode 100644 index 000000000..07e9d381a --- /dev/null +++ b/sunscreen_tfhe/src/math/torus.rs @@ -0,0 +1,619 @@ +use bytemuck::{Pod as BytemuckPod, Zeroable}; +use num::traits::{ + Bounded, MulAdd, Num, WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr, + WrappingSub, +}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Binary, Debug, LowerHex, UpperHex}, + num::Wrapping, + ops::{Add, AddAssign, BitAnd, Deref, Mul, Neg, Shl, Shr, Sub, SubAssign}, +}; +use sunscreen_math::{refify_binary_op, Zero}; + +use crate::{ + math::{ReinterpretAsSigned, ReinterpretAsUnsigned}, + scratch::Pod, + PlaintextBits, +}; + +/// Number of bits used in the representation of a type. +pub trait NumBits { + /// The number of bits used in the representation of a type. + const BITS: u32; +} + +impl NumBits for u32 { + const BITS: u32 = u32::BITS; +} + +impl NumBits for u64 { + const BITS: u32 = u64::BITS; +} + +/// Convert a type into a 64-bit floating point number. +pub trait ToF64 { + /// Approximately convert the value to an f64. + fn to_f64(self) -> f64; +} + +/// Convert a 64-bit floating point number into a type. +pub trait FromF64 +where + Self: Sized, +{ + /// Approximately convert an f64 into a value. + fn from_f64(x: f64) -> Self; +} + +/// A type that supports operations on a Torus. +pub trait TorusOps: + BitAnd + + WrappingAdd + + WrappingSub + + WrappingMul + + WrappingShl + + WrappingShr + + WrappingNeg + + BitAnd + + ReinterpretAsSigned + + Num + + NumBits + + From + + TryFrom + + FromU64 + + Clone + + Copy + + Binary + + LowerHex + + UpperHex + + std::fmt::Debug + + Ord + + Zero + + Pod + + BytemuckPod + + Bounded + + ToF64 + + FromF64 + + ToU64 + + NumBits +{ +} + +// Sound since Torus is a transparent wrapper and `S` impl `Pod` +unsafe impl Zeroable for Torus {} +unsafe impl BytemuckPod for Torus {} + +/// Convert a type into a 64-bit unsigned integer. +pub trait ToU64 { + /// Convert the value to a 64-bit unsigned integer. + fn to_u64(self) -> u64; +} + +impl ToU64 for u32 { + fn to_u64(self) -> u64 { + self as u64 + } +} + +impl ToU64 for u64 { + fn to_u64(self) -> u64 { + self + } +} + +impl ToU64 for Torus +where + T: TorusOps, +{ + fn to_u64(self) -> u64 { + self.0.to_u64() + } +} + +/// Convert a 64-bit unsigned integer into a type. +pub trait FromU64 { + /// For the given 64-bit value, take the N most significant bits, where + /// N is the bitlength of this type. + fn from_u64(val: u64) -> Self; +} + +impl FromU64 for u32 { + fn from_u64(val: u64) -> Self { + (val & 0xFFFFFFFF) as Self + } +} + +impl FromU64 for u64 { + fn from_u64(val: u64) -> Self { + val + } +} + +macro_rules! impl_tof64 { + ($t:ty) => { + impl ToF64 for $t { + fn to_f64(self) -> f64 { + self as f64 + } + } + }; +} + +impl_tof64!(i8); +impl_tof64!(i16); +impl_tof64!(i32); +impl_tof64!(i64); +impl_tof64!(i128); +impl_tof64!(u8); +impl_tof64!(u16); +impl_tof64!(u32); +impl_tof64!(u64); +impl_tof64!(u128); + +impl ToF64 for Wrapping +where + T: ToF64, +{ + fn to_f64(self) -> f64 { + self.0.to_f64() + } +} + +impl ToF64 for Torus +where + T: TorusOps, +{ + fn to_f64(self) -> f64 { + self.0.to_f64() + } +} + +macro_rules! impl_unsigned_fromf64 { + ($t:ty,$st:ty) => { + impl FromF64 for $t { + #[inline(always)] + fn from_f64(x: f64) -> $t { + let x = x as $st; + + x.reinterpret_as_unsigned() + } + } + }; +} + +impl_unsigned_fromf64!(u128, i128); +impl_unsigned_fromf64!(u64, i64); +impl_unsigned_fromf64!(u32, i32); +impl_unsigned_fromf64!(u16, i16); + +impl FromF64 for Torus +where + T: TorusOps, +{ + fn from_f64(x: f64) -> Self { + Self(T::from_f64(x)) + } +} + +impl FromF64 for Wrapping +where + T: FromF64, +{ + fn from_f64(x: f64) -> Self { + Wrapping(T::from_f64(x)) + } +} + +impl TorusOps for u64 {} +impl TorusOps for u32 {} + +/// A wrapper around a type that supports Torus operations. +#[repr(transparent)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Torus(S); + +/// Compute the distance between two Torus values, normalized to the unit torus. +/// The first argument is taken as the reference point. +/// +/// For example, if a is at a normalized position of 0.2, and b is at a +/// normalized position of 0.8, then there are two distances between them: +/// 0.6 and -0.4. This function will return the shorter of the two distances, +/// -0.4. +/// +/// In other words: +/// +/// ```text +/// b.normalized_torus() = a.normalized_torus() + normalized_torus_distance(a, b) (mod 1.0) +/// ``` +pub fn normalized_torus_distance(a: &Torus, b: &Torus) -> f64 { + let a_minus_b = a - b; + let b_minus_a = b - a; + + let modulus = 2_f64.powi(S::BITS as i32); + + let difference = if a_minus_b < b_minus_a { + -a_minus_b.to_f64() + } else { + b_minus_a.to_f64() + }; + + difference / modulus +} + +impl Deref for Torus { + type Target = S; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl NumBits for Torus { + const BITS: u32 = S::BITS; +} + +impl Torus { + /// Compute the normalized position of this [Torus] value on the unit torus [0.0, 1.0). + pub fn normalized_torus(&self) -> f64 { + let modulus: f64 = 2_f64.powi(S::BITS as i32); + let val: f64 = self.0.to_f64(); + + val / modulus + } + + /// Compute the distance between this [Torus] value and another, normalized to + /// the unit torus. + pub fn normalized_torus_distance(&self, other: &Self) -> f64 { + normalized_torus_distance(self, other) + } + + /// Encode a value into a [Torus] that supports up to plain_bits of values. + /// This encodes the value on the equispaced `2^plaintext_bits` positions on + /// a larger torus. + pub fn encode(val: S, plain_bits: PlaintextBits) -> Self { + assert!(plain_bits.0 < S::BITS); + + let encoded = val.wrapping_shl(S::BITS - plain_bits.0); + + Self(encoded) + } + + /// Decode a value from a [Torus] that supports up to plain_bits of values. + pub fn decode(&self, plain_bits: PlaintextBits) -> S { + assert!(plain_bits.0 < S::BITS); + + let round_bit = self.0.wrapping_shr(S::BITS - plain_bits.0 - 1) & S::from(0x1); + let mask = S::from((0x1 << plain_bits.0) - 1); + + (self.0.wrapping_shr(S::BITS - plain_bits.0) + round_bit) & mask + } + + /// Scale a Torus element to a different modulus. Assumes that the two moduli are + /// powers of 2. + pub fn switch_modulus_smaller(&self) -> Torus + where + T: TorusOps + TryFrom, + >::Error: Debug, + { + let shift = (S::BITS - T::BITS) as usize; + + // We don't wrap the bits we don't need + let y = self.0 >> shift; + + Torus(T::try_from(y).expect("impossible error on switch_modulus_smaller")) + } + + /// Return the underlying type of the Torus. + #[inline(always)] + pub fn inner(&self) -> S { + self.0 + } +} + +impl From for Torus { + #[inline(always)] + fn from(value: S) -> Self { + Self(value) + } +} + +impl Zero for Torus { + fn zero() -> Self { + Self(S::from(0)) + } + + fn vartime_is_zero(&self) -> bool { + self.inner() == ::zero() + } +} + +impl Neg for Torus { + type Output = Self; + + fn neg(self) -> Self::Output { + Self::Output::from(self.0.wrapping_neg()) + } +} + +impl WrappingNeg for Torus { + fn wrapping_neg(&self) -> Self { + Self::from(self.0.wrapping_neg()) + } +} + +#[refify_binary_op] +impl Add<&Torus> for &Torus { + type Output = Torus; + + fn add(self, rhs: &Torus) -> Self::Output { + Self::Output::from(self.0.wrapping_add(&rhs.0)) + } +} + +impl WrappingAdd for Torus { + fn wrapping_add(&self, rhs: &Self) -> Self { + self + rhs + } +} + +#[refify_binary_op] +impl Sub<&Torus> for &Torus { + type Output = Torus; + + fn sub(self, rhs: &Torus) -> Self::Output { + Self::Output::from(self.0.wrapping_sub(&rhs.0)) + } +} + +impl WrappingSub for Torus { + fn wrapping_sub(&self, rhs: &Self) -> Self { + self - rhs + } +} + +#[refify_binary_op] +impl Mul<&S> for &Torus { + type Output = Torus; + + fn mul(self, rhs: &S) -> Self::Output { + Self::Output::from(self.wrapping_mul(rhs)) + } +} + +#[refify_binary_op] +impl BitAnd<&Torus> for &Torus { + type Output = Torus; + + fn bitand(self, rhs: &Torus) -> Self::Output { + Torus::from(self.0 & rhs.0) + } +} + +#[refify_binary_op] +impl Shr<&usize> for &Torus { + type Output = Torus; + + fn shr(self, rhs: &usize) -> Self::Output { + Torus::from(self.0 >> *rhs) + } +} + +#[refify_binary_op] +impl Shl<&usize> for &Torus { + type Output = Torus; + + fn shl(self, rhs: &usize) -> Self::Output { + Torus::from(self.0 << *rhs) + } +} + +impl AddAssign for Torus { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs + } +} + +impl AddAssign<&Self> for Torus { + fn add_assign(&mut self, rhs: &Self) { + *self = *self + rhs + } +} + +impl SubAssign for Torus { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } +} + +impl SubAssign<&Self> for Torus { + fn sub_assign(&mut self, rhs: &Self) { + *self = *self - rhs; + } +} + +impl std::iter::Sum for Torus { + fn sum>(iter: I) -> Self { + iter.fold(Self::zero(), |acc, x| acc + x) + } +} + +impl ReinterpretAsSigned for Torus +where + S: TorusOps, +{ + type Output = ::Output; + + #[inline(always)] + fn reinterpret_as_signed(self) -> Self::Output { + self.0.reinterpret_as_signed() + } +} + +impl num::Zero for Torus { + fn zero() -> Self { + Self(::zero()) + } + + fn is_zero(&self) -> bool { + self.0.is_zero() + } +} + +impl MulAdd for Torus { + type Output = Self; + + #[inline(always)] + fn mul_add(self, a: S, b: Self) -> Self::Output { + self * a + b + } +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use super::*; + + #[test] + fn can_negate() { + let x = Torus::::from(0); + assert_eq!(-x, Torus::from(0)); + + let x = Torus::::from(1); + assert_eq!(-x, Torus::from(u64::MAX)); + + let x = Torus::::from(u64::MAX); + assert_eq!(-x, Torus::from(1)); + } + + #[test] + fn can_encode_decode() { + assert_eq!( + Torus::::encode(7, PlaintextBits(4)).0, + 0x70000000_00000000 + ); + + let x = Torus::::from(0x70000000_00000000); + + assert_eq!(x.decode(PlaintextBits(4)), 7); + + let x = Torus::::from(0x7FFFFFFF_FFFFFFFF); + + assert_eq!(x.decode(PlaintextBits(4)), 8); + } + + #[test] + fn can_decode_off_center() { + let t = Torus::::from(((u64::MAX as f64) * 0.6) as u64); + let r = t.decode(PlaintextBits(1)); + assert_eq!(r, 1); + + let t = Torus::::from(((u64::MAX as f64) * 0.3) as u64); + let r = t.decode(PlaintextBits(1)); + assert_eq!(r, 1); + + let t = Torus::::from(((u64::MAX as f64) * 0.8) as u64); + let r = t.decode(PlaintextBits(1)); + assert_eq!(r, 0); + + let t = Torus::::from(((u64::MAX as f64) * 0.2) as u64); + let r = t.decode(PlaintextBits(1)); + assert_eq!(r, 0); + } + + #[test] + fn can_normalize() { + let x = Torus::::from(0); + assert_eq!(x.normalized_torus(), 0.0); + + let x = Torus::::from(u64::MAX / 4); + assert_eq!(x.normalized_torus(), 0.25); + + let x = Torus::::from(u64::MAX / 2); + assert_eq!(x.normalized_torus(), 0.5); + + let x = Torus::::from(u64::MAX / 4 * 3); + assert_eq!(x.normalized_torus(), 0.75); + + let x = Torus::::from(u64::MAX / 8 * 7); + assert_eq!(x.normalized_torus(), 0.875); + } + + #[test] + fn can_compute_distance() { + let a = Torus::::from(0); + let b = Torus::::from(u64::MAX / 4); + + assert_eq!(normalized_torus_distance(&a, &b), 0.25); + assert_eq!(normalized_torus_distance(&b, &a), -0.25); + + let a = Torus::::from(u64::MAX / 4); + let b = Torus::::from(u64::MAX / 2); + + assert_eq!(normalized_torus_distance(&a, &b), 0.25); + assert_eq!(normalized_torus_distance(&b, &a), -0.25); + + let a = Torus::::from(0); + let b = Torus::::from(u64::MAX / 4 * 3); + + assert_eq!(normalized_torus_distance(&a, &b), -0.25); + assert_eq!(normalized_torus_distance(&b, &a), 0.25); + + let a = Torus::::from(u64::MAX / 8); + let b = Torus::::from(u64::MAX / 4 * 3); + + assert_eq!(normalized_torus_distance(&a, &b), -0.375); + assert_eq!(normalized_torus_distance(&b, &a), 0.375); + } + + #[test] + fn test_normalized_relation() { + // Tests the relation: + // b.normalized_torus() = a.normalized_torus() + normalized_torus_distance(a, b) + + for _ in 0..100 { + let a = Torus::::from(thread_rng().next_u64()); + let b = Torus::::from(thread_rng().next_u64()); + + let a_norm = a.normalized_torus(); + let b_norm = b.normalized_torus(); + + let dist = normalized_torus_distance(&a, &b); + + let b_norm_from_dist = a_norm + dist; + + let b_norm_from_dist = if b_norm_from_dist < 0.0 { + b_norm_from_dist + 1.0 + } else if b_norm_from_dist >= 1.0 { + b_norm_from_dist - 1.0 + } else { + b_norm_from_dist + }; + + let diff = (b_norm - b_norm_from_dist).abs(); + + assert!( + diff < 1e-12, + "Normalized torus relation test failed: a = {:?}, b = {:?}, a_norm = {}, b_norm = {}, dist: {}, b_norm_from_dist = {}, diff = {}", + a, + b, + a_norm, + b_norm, + dist, + b_norm_from_dist, + diff + ); + } + } + + #[test] + fn can_modulus_switch() { + let x = Torus::::from(0x12345678_9ABCDEF0); + + let y = x.switch_modulus_smaller::(); + + assert_eq!(y.0, 0x12345678); + } +} diff --git a/sunscreen_tfhe/src/ops/bootstrapping/blind_rotation.rs b/sunscreen_tfhe/src/ops/bootstrapping/blind_rotation.rs new file mode 100644 index 000000000..fa9b6f998 --- /dev/null +++ b/sunscreen_tfhe/src/ops/bootstrapping/blind_rotation.rs @@ -0,0 +1,473 @@ +use num::Complex; + +use crate::{ + dst::FromMutSlice, + entities::{BlindRotationShiftFftRef, GgswCiphertext, GlweCiphertextRef, GlweSecretKeyRef}, + ops::{encryption::encrypt_ggsw_ciphertext_scalar, fft_ops::cmux}, + scratch::allocate_scratch_ref, + GlweDef, PlaintextBits, RadixDecomposition, TorusOps, +}; + +/// Rotate the given ciphertext message polynomial by the given amount as if it +/// had been multiplied by a monomial with either positive or negative degree. +/// Since this is a negacyclic rotation, a rotation to the left negates the last +/// `rotation` coefficients, while a rotation to the right negates the first +/// `rotation` coefficients. +/// +/// Mathematically, this is equivalent to multiplying the underlying polynomial +/// message by X^{rotation} mod (X^N + 1), where N is the polynomial degree. +/// There are some convenient relations to remember over any integer $k$ and +/// $i$: +/// +/// - Rotations are modulus 2N: $X^{k*2N} = 1$ and $X^{k*2N ± i} = X^{± i}$. +/// - Rotations that are equivalent to a shift by -N are equivalent to negating +/// the polynomial: $X^{k*2N + N} = -1$. +/// - A negative rotation has an equivalent positive rotation: $X^{-i} = -X^{N - i}$. +/// +/// # Example +/// +/// ``` +/// use sunscreen_tfhe::{ +/// high_level::{keygen, encryption}, +/// entities::{GlweCiphertext, Polynomial}, +/// ops::bootstrapping::rotate_glwe_monomial_negacyclic, +/// params::{ +/// GlweDef, +/// GlweSize, +/// GlweDimension, +/// PlaintextBits, +/// PolynomialDegree, +/// }, +/// rand::Stddev, +/// }; +/// +/// // Define the GLWE parameters +/// let params = GlweDef { +/// dim: GlweDimension { +/// size: GlweSize(1), +/// polynomial_degree: PolynomialDegree(8), +/// }, +/// std: Stddev(0.0000000444778278004718), +/// }; +/// let plaintext_bits = PlaintextBits(4); +/// +/// // Generate the GLWE secret key +/// let sk = keygen::generate_binary_glwe_sk(¶ms); +/// +/// // Define and encrypt a message +/// let msg = Polynomial::new(&[1, 2, 3, 4, 5, 6, 7, 8]); +/// let ct = encryption::encrypt_glwe(&msg, &sk, ¶ms, plaintext_bits); +/// +/// // Rotate the message polynomial by 1 to the right +/// let mut rotated_ct = GlweCiphertext::new(¶ms); +/// rotate_glwe_monomial_negacyclic(&mut rotated_ct, &ct, 1, ¶ms); +/// +/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, ¶ms, plaintext_bits); +/// +/// assert_eq!(decrypted_msg, Polynomial::new(&[8, 1, 2, 3, 4, 5, 6, 7])); +/// +/// // Rotate the message polynomial by 1 to the left +/// let mut rotated_ct = GlweCiphertext::new(¶ms); +/// rotate_glwe_monomial_negacyclic(&mut rotated_ct, &ct, -1, ¶ms); +/// +/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, ¶ms, plaintext_bits); +/// +/// // Since this is a negacyclic rotation, the element moved to the end is +/// // negated. +/// assert_eq!(decrypted_msg, Polynomial::new(&[2, 3, 4, 5, 6, 7, 8, 15])); +/// ``` +pub fn rotate_glwe_monomial_negacyclic( + output: &mut GlweCiphertextRef, + ct: &GlweCiphertextRef, + rotation: isize, + params: &GlweDef, +) where + S: TorusOps, +{ + let (output_a, output_b) = output.a_b_mut(params); + let (ct_a, ct_b) = ct.a_b(params); + + let output_all_coefficients = output_a.chain(std::iter::once(output_b)); + let ct_all_coefficients = ct_a.chain(std::iter::once(ct_b)); + + for (o, a) in output_all_coefficients.zip(ct_all_coefficients) { + o.clone_from_ref(a); + o.mul_by_monomial_negacyclic(rotation); + } +} + +/// Rotate the given ciphertext message polynomial by the given amount as if it +/// had been multiplied by a monomial. This is equivalent to shifting +/// all the coefficients left by `rotation` and negating the last `rotation` +/// coefficients. Mathematically, this is equivalent to multiplying by +/// X^{-rotation} mod (X^N + 1), or equivalently -X^{N - rotation}. +/// +/// See [`rotate_glwe_monomial_negacyclic`] for the case that handles both +/// positive and negative rotations, plus an example. +pub fn rotate_glwe_negative_monomial_negacyclic( + output: &mut GlweCiphertextRef, + ct: &GlweCiphertextRef, + rotation: usize, + params: &GlweDef, +) where + S: TorusOps, +{ + rotate_glwe_monomial_negacyclic(output, ct, -(rotation as isize), params) +} + +/// Rotate the given ciphertext message polynomial by the given amount as if it +/// had been multiplied by a positive monomial. This is equivalent to shifting +/// all the coefficients right by `rotation` and negating the first `rotation` +/// coefficients. Mathematically, this is equivalent to multiplying by +/// X^{rotation} mod (X^N + 1). +/// +/// See [`rotate_glwe_monomial_negacyclic`] for the case that handles both +/// positive and negative rotations, plus an example. +pub fn rotate_glwe_positive_monomial_negacyclic( + output: &mut GlweCiphertextRef, + ct: &GlweCiphertextRef, + rotation: usize, + params: &GlweDef, +) where + S: TorusOps, +{ + rotate_glwe_monomial_negacyclic(output, ct, rotation as isize, params) +} + +/// Rotate the given ciphertext message polynomial by negative encrypted shift. +/// In practice this is not often used on its own; bootstrapping performs a +/// different procedure. +/// +/// See +/// - [`generate_blind_rotation_shift`] for a way to encrypt a rotation amount. +/// - [`rotate_glwe_monomial_negacyclic`] for how negacyclic rotation works +/// when the rotation amount is public. +/// +/// # Example +/// +/// ``` +/// use sunscreen_tfhe::{ +/// high_level::{keygen, encryption}, +/// entities::{GlweCiphertext, Polynomial}, +/// ops::bootstrapping::{blind_rotation, generate_blind_rotation_shift}, +/// params::{ +/// GlweDef, +/// GlweSize, +/// GlweDimension, +/// PlaintextBits, +/// PolynomialDegree, +/// RadixDecomposition, +/// RadixCount, +/// RadixLog, +/// }, +/// rand::Stddev, +/// }; +/// +/// // Define the GLWE parameters +/// let params = GlweDef { +/// dim: GlweDimension { +/// size: GlweSize(1), +/// polynomial_degree: PolynomialDegree(8), +/// }, +/// std: Stddev(0.0000000444778278004718), +/// }; +/// let radix = RadixDecomposition { +/// count: RadixCount(3), +/// radix_log: RadixLog(4), +/// }; +/// let plaintext_bits = PlaintextBits(4); +/// +/// // Generate the GLWE secret key +/// let sk = keygen::generate_binary_glwe_sk(¶ms); +/// +/// // Define and encrypt a message +/// let msg = Polynomial::new(&[1, 2, 3, 4, 5, 6, 7, 8]); +/// let ct = encryption::encrypt_glwe(&msg, &sk, ¶ms, plaintext_bits); +/// +/// // Generate a blind rotation amount +/// let mut blind_rotation_index = sunscreen_tfhe::entities::BlindRotationShiftFft::new(¶ms, &radix); +/// generate_blind_rotation_shift(&mut blind_rotation_index, 1, &sk, ¶ms, &radix, plaintext_bits); +/// +/// // Rotate the message polynomial by the blind rotation amount +/// let mut rotated_ct = GlweCiphertext::new(¶ms); +/// blind_rotation(&mut rotated_ct, &blind_rotation_index, &ct, ¶ms, &radix); +/// +/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, ¶ms, plaintext_bits); +/// +/// assert_eq!(decrypted_msg, Polynomial::new(&[2, 3, 4, 5, 6, 7, 8, 15])); +/// ``` +pub fn blind_rotation( + output: &mut GlweCiphertextRef, + blind_rotation_index: &BlindRotationShiftFftRef>, + ct: &GlweCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + // Initialize with the unrotated message m + output.clone_from_ref(ct); + allocate_scratch_ref!(rotated_ct, GlweCiphertextRef, (params.dim)); + + for (i, index_select) in blind_rotation_index.rows(params, radix).enumerate() { + let rotation = 1 << i; + + rotate_glwe_negative_monomial_negacyclic(rotated_ct, output, rotation, params); + + let tmp = output.to_owned(); + cmux(output, &tmp, rotated_ct, index_select, params, radix); + } +} + +/// Encrypt an amount to rotate the message polynomial by. +/// +/// This function is mostly provided as a convenience. Bootstrapping will rotate +/// a message without encrypting using a cmux tree, so this function is not +/// strictly necessary. +pub fn generate_blind_rotation_shift( + bootstrap_key: &mut BlindRotationShiftFftRef>, + rotation: usize, + sk: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, +) where + S: TorusOps, +{ + let degree = params.dim.polynomial_degree.0; + assert!(rotation < degree); + + for (i, ggsw_fft) in bootstrap_key.rows_mut(params, radix).enumerate() { + let bit = ((rotation >> i) & 1) as u64; + let mut ct = GgswCiphertext::new(params, radix); + + encrypt_ggsw_ciphertext_scalar( + &mut ct, + S::from_u64(bit), + sk, + params, + radix, + plaintext_bits, + ); + + ct.fft(ggsw_fft, params, radix); + } +} + +#[cfg(test)] +mod tests { + use blind_rotation::generate_blind_rotation_shift; + + use crate::{ + entities::{ + BlindRotationShiftFft, GgswCiphertext, GlweCiphertext, GlweSecretKey, Polynomial, + }, + high_level::{TEST_GLWE_DEF_1, TEST_RADIX}, + ops::{ + bootstrapping::{blind_rotation, rotate_glwe_monomial_negacyclic}, + encryption::decrypt_ggsw_ciphertext, + }, + polynomial::polynomial_external_mad, + GlweDef, GlweDimension, GlweSize, PlaintextBits, PolynomialDegree, Torus, + }; + + #[test] + fn can_rotate() { + let params = GlweDef { + dim: GlweDimension { + polynomial_degree: PolynomialDegree(8), + size: GlweSize(2), + }, + ..TEST_GLWE_DEF_1 + }; + let plaintext_bits = PlaintextBits(4); + + let modulus = 1 << plaintext_bits.0; + let degree = params.dim.polynomial_degree.0; + + let sk = GlweSecretKey::::generate_binary(¶ms); + + let msg_coeffs = (0..degree) + .map(|i| (i % modulus) as u64) + .collect::>(); + let msg = Polynomial::new(&msg_coeffs); + + let ct = sk.encode_encrypt_glwe(&msg, ¶ms, plaintext_bits); + + for rotation in (-2i64 * (degree as i64))..=(2i64 * (degree as i64)) { + println!("Rotation: {}", rotation); + let mut rotation_polynomial = vec![Torus::from(0u64); degree]; + + let direction = if rotation < 0 { -1 } else { 1 }; + let original_rotation = rotation; + let rotation = rotation.unsigned_abs() as usize; + + #[allow(clippy::collapsible_else_if)] + if direction == 1 { + // Positive rotation + if rotation == 0 || rotation == 2 * degree { + rotation_polynomial[0] = Torus::from(1); + } else if rotation < degree { + rotation_polynomial[rotation] = Torus::from(1); + } else if rotation == degree { + rotation_polynomial[0] = -Torus::from(1); + } else { + rotation_polynomial[rotation % degree] = -Torus::from(1); + } + } else { + // Negative rotation + if rotation == 0 || rotation == 2 * degree { + rotation_polynomial[0] = Torus::from(1); + } else if rotation < degree { + rotation_polynomial[(degree - rotation) % degree] = -Torus::from(1); + } else if rotation == degree { + rotation_polynomial[0] = -Torus::from(1); + } else { + rotation_polynomial[(2 * degree - rotation) % degree] = Torus::from(1); + } + } + + let rotation_polynomial = Polynomial::new(&rotation_polynomial); + + let mut expected = Polynomial::>::zero(degree); + polynomial_external_mad(&mut expected, &rotation_polynomial, &msg); + + // Polynomial multiply doesn't reduce modulo the modulus, so we need to do it manually. + let expected = expected.map(|x| x.inner() % (modulus as u64)); + + // Perform encrypted rotation + let mut output_ct = GlweCiphertext::new(¶ms); + + rotate_glwe_monomial_negacyclic( + &mut output_ct, + &ct, + original_rotation as isize, + ¶ms, + ); + + let output_msg = sk.decrypt_decode_glwe(&output_ct, ¶ms, plaintext_bits); + + assert_eq!(output_msg, expected); + } + } + + #[test] + fn rotation_shift_encrypted_properly() { + let params = GlweDef { + dim: GlweDimension { + polynomial_degree: PolynomialDegree(8), + size: GlweSize(2), + }, + ..TEST_GLWE_DEF_1 + }; + let radix = TEST_RADIX; + let degree = params.dim.polynomial_degree.0; + + let sk = GlweSecretKey::::generate_binary(¶ms); + + for rotation in 0..(degree - 1) { + let mut ggsw_index = BlindRotationShiftFft::new(¶ms, &radix); + generate_blind_rotation_shift( + &mut ggsw_index, + rotation, + &sk, + ¶ms, + &radix, + PlaintextBits(4), + ); + + let mut encrypted_rotation = 0u64; + for (i, bit_fft) in ggsw_index.rows(¶ms, &radix).enumerate() { + let mut bit = GgswCiphertext::::new(¶ms, &radix); + bit_fft.ifft(&mut bit, ¶ms, &radix); + + let mut pt = Polynomial::zero(degree); + decrypt_ggsw_ciphertext(&mut pt, &bit, &sk, ¶ms, &radix); + + encrypted_rotation |= pt.coeffs()[0].inner() << i; + } + + assert_eq!(encrypted_rotation, rotation as u64); + } + } + + #[test] + fn can_blind_rotate() { + let params = GlweDef { + dim: GlweDimension { + polynomial_degree: PolynomialDegree(8), + size: GlweSize(2), + }, + ..TEST_GLWE_DEF_1 + }; + let radix = TEST_RADIX; + let plaintext_bits = PlaintextBits(4); + + let modulus = 1 << plaintext_bits.0; + let degree = params.dim.polynomial_degree.0; + let num_bits = (degree as u64).ilog2() as usize; + + let sk = GlweSecretKey::::generate_binary(¶ms); + + let msg_coeffs = (0..degree) + .map(|i| (i % modulus) as u64) + .collect::>(); + let msg = Polynomial::new(&msg_coeffs); + + let ct = sk.encode_encrypt_glwe(&msg, ¶ms, plaintext_bits); + + #[allow(clippy::needless_range_loop)] + for rotation in 0..=(degree - 1) { + let mut expected = Polynomial::>::new(msg.map(|x| Torus::from(*x)).coeffs()); + + for i in 0..num_bits { + let bit = ((rotation >> i) & 1) as u64; + + // We don't perform this rotation + if bit == 0 { + continue; + } + + let mut rotation_polynomial = vec![Torus::from(0u64); degree]; + + rotation_polynomial[degree - (1 << i)] = -Torus::::from(1); + + let rotation_polynomial = Polynomial::new(&rotation_polynomial); + + let tmp = expected.map(|x| x.inner()); + expected = Polynomial::>::zero(degree); + + polynomial_external_mad(&mut expected, &rotation_polynomial, &tmp); + } + + // Polynomial multiply doesn't reduce modulo the modulus, so we need to do it manually. + let expected = expected.map(|x| x.inner() % (modulus as u64)); + + // Perform encrypted rotation + let mut ggsw_index = BlindRotationShiftFft::new(¶ms, &radix); + generate_blind_rotation_shift( + &mut ggsw_index, + rotation, + &sk, + ¶ms, + &radix, + plaintext_bits, + ); + let mut output_ct = GlweCiphertext::new(¶ms); + blind_rotation(&mut output_ct, &ggsw_index, &ct, ¶ms, &radix); + let output_msg = sk.decrypt_decode_glwe(&output_ct, ¶ms, plaintext_bits); + + // Make sure the zero point is rotated the correct amount. + assert_eq!(output_msg.coeffs()[(degree - rotation) % degree], 0); + + // Make sure we have moved the element in the rotation position to the zero position. + assert_eq!(output_msg.coeffs()[0], msg_coeffs[rotation]); + + assert_eq!( + &output_msg, &expected, + "CT encrypted message: {:?}, expected message: {:?}", + &output_msg, &expected + ); + } + } +} diff --git a/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs b/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs new file mode 100644 index 000000000..a4f6e1b69 --- /dev/null +++ b/sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs @@ -0,0 +1,469 @@ +use num::Complex; + +use crate::{ + dst::FromMutSlice, + entities::{ + BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, GgswCiphertextRef, + LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef, + }, + ops::{ + bootstrapping::programmable_bootstrap, homomorphisms::rotate, + keyswitch::private_functional_keyswitch::private_functional_keyswitch, + }, + scratch::allocate_scratch_ref, + GlweDef, LweDef, PlaintextBits, PrivateFunctionalKeyswitchLweCount, RadixDecomposition, Torus, + TorusOps, +}; + +/// Bootstraps a LWE ciphertext to a GGSW ciphertext. +#[allow(clippy::too_many_arguments)] +/// Transform [`LweCiphertextRef`] `input` encrypted under parameters `lwe_0` into +/// the [`GgswCiphertextRef`] `output` encrypted under parameters `glwe_1` with +/// radix decomposition `cbs_radix`. This resets the noise in `output` in the +/// process. +/// +/// [`GgswCiphertext`](crate::entities::GgswCiphertext)s can be used as select +/// inputs for [`cmux`](crate::ops::fft_ops::cmux) operations. +/// +/// # Remarks +/// The following diagram illustrates how circuit bootstrapping works +/// +/// ![Circuit Bootstrapping](LINK TO GITHUB) +/// +/// We perform `cbs_radix.count` programmable bootstrapping (PBS) operations to +/// decompose the original message m under radix `2^cbs_radix.radix_log`. These PBS +/// operations use a bootstrapping key encrypting the level 0 LWE secret key under +/// the level 2 GLWE secret key and internally perform their own radix decomposition +/// parameterized by `pbs_radix`. After performing bootstrapping, we now have +/// `cbs_radix.count` LWE ciphertexts encrypted under the level 2 GLWE secret key +/// (reinterpreted as an LWE key). +/// +/// Next, we take each of these `cbs_radix.count` level 2 LWE ciphertexts and +/// perform `glwe_1.dim.size + 1` private functional keyswitching operations (` +/// (glwe_1.dim.size + 1) * cbs_radix.count` in total). For the first `glwe_1.dim. +/// size` rows of the [`GgswCiphertextRef`] output, this multiplies the radix +/// decomposed message by the negative corresponding secret key. For the last +/// row, we simply multiply our radix decomposed messages by 1. +/// +/// Recall that [`private_functional_keyswitch`] (PFKS) transforms a list of LWE +/// ciphertexts into a [`GlweCiphertext`](crate::entities::GlweCiphertext). In +/// our case, this list contains a single +/// [`LweCiphertext`](crate::entities::LweCiphertext) for each PFKS operation. +/// Each row of the output [`GgswCiphertext`](crate::entities::GgswCiphertext) +/// corresponds to a different PFKS key, encapsulated in `cbsksk`. +/// +/// These PFKS operations switch from a key under parameters `glwe_2` (interpreted +/// as LWE) to `glwe_1` with [`RadixDecomposition`] `pfks_radix`. +/// +/// # Panics +/// * If `bsk` is not valid for bootrapping from parameters `lwe_0` to `glwe_2` +/// (reinterpreted as LWE) with radix decomposition `pbs_radix`. +/// * If `cbsksk` is not a valid keyswitch key set for switching from `glwe_2` +/// (reintrerpreted as LWE) to `glwe_1` with `glwe_1.dim.size` entries and radix +/// decomposition `pfks_radix`. +/// * If `output` is not the correct length for a GGSW ciphertext under `glwe_1` +/// parameters with `cbs_radix` decomposition. +/// * If `input` is not a valid LWE ciphertext under `lwe_0` parameters. +/// * If `lwe_0`, `glwe_1`, `glwe_2`, `cbs_radix`, `pfks_radix`, `pbs_radix` are +/// invalid. +/// +/// # Example +/// ``` +/// use sunscreen_tfhe::{ +/// high_level, +/// high_level::{keygen, encryption, fft}, +/// entities::GgswCiphertext, +/// ops::bootstrapping::circuit_bootstrap, +/// params::{ +/// GLWE_5_256_80, +/// GLWE_1_1024_80, +/// LWE_512_80, +/// PlaintextBits, +/// RadixDecomposition, +/// RadixCount, +/// RadixLog +/// } +/// }; +/// +/// let pbs_radix = RadixDecomposition { +/// count: RadixCount(2), +/// radix_log: RadixLog(16), +/// }; +/// let cbs_radix = RadixDecomposition { +/// count: RadixCount(2), +/// radix_log: RadixLog(5), +/// }; +/// let pfks_radix = RadixDecomposition { +/// count: RadixCount(3), +/// radix_log: RadixLog(11), +/// }; +/// +/// let level_2_params = GLWE_5_256_80; +/// let level_1_params = GLWE_1_1024_80; +/// let level_0_params = LWE_512_80; +/// +/// let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params); +/// let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params); +/// let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params); +/// +/// let bsk = keygen::generate_bootstrapping_key( +/// &sk_0, +/// &sk_2, +/// &level_0_params, +/// &level_2_params, +/// &pbs_radix, +/// ); +/// let bsk = +/// high_level::fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix); +/// +/// let cbsksk = keygen::generate_cbs_ksk( +/// sk_2.to_lwe_secret_key(), +/// &sk_1, +/// &level_2_params.as_lwe_def(), +/// &level_1_params, +/// &pfks_radix, +/// ); +/// +/// let val = 1; +/// let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1)); +/// +/// let mut ggsw = GgswCiphertext::new(&level_1_params, &cbs_radix); +/// +/// // ggsw will contain `val` +/// circuit_bootstrap( +/// &mut ggsw, +/// &ct, +/// &bsk, +/// &cbsksk, +/// &level_0_params, +/// &level_1_params, +/// &level_2_params, +/// &pbs_radix, +/// &cbs_radix, +/// &pfks_radix, +/// ); +/// ``` +pub fn circuit_bootstrap( + output: &mut GgswCiphertextRef, + input: &LweCiphertextRef, + bsk: &BootstrapKeyFftRef>, + cbsksk: &CircuitBootstrappingKeyswitchKeysRef, + lwe_0: &LweDef, + glwe_1: &GlweDef, + glwe_2: &GlweDef, + pbs_radix: &RadixDecomposition, + cbs_radix: &RadixDecomposition, + pfks_radix: &RadixDecomposition, +) { + glwe_1.assert_valid(); + glwe_2.assert_valid(); + lwe_0.assert_valid(); + pbs_radix.assert_valid::(); + cbs_radix.assert_valid::(); + pfks_radix.assert_valid::(); + cbsksk.assert_valid(&glwe_2.as_lwe_def(), glwe_1, pfks_radix); + bsk.assert_valid(lwe_0, glwe_2, pbs_radix); + output.assert_valid(glwe_1, cbs_radix); + input.assert_valid(lwe_0); + + // Step 1, for each l in cbs_radix.count, use bootstrapping to base decompose the + // plaintext in input. We bootstrap from level 0 -> level 2. + allocate_scratch_ref!( + level_2_lwes, + LweCiphertextListRef, + (glwe_2.as_lwe_def().dim, cbs_radix.count.0) + ); + + level_0_to_level_2( + level_2_lwes, + input, + bsk, + lwe_0, + glwe_2, + pbs_radix, + cbs_radix, + ); + + level_2_to_level1( + output, + level_2_lwes, + cbsksk, + glwe_2, + glwe_1, + pfks_radix, + cbs_radix, + ); +} + +#[allow(dead_code)] +#[inline(always)] +fn level_0_to_level_2( + lwes_2: &mut LweCiphertextListRef, + input: &LweCiphertextRef, + bsk: &BootstrapKeyFftRef>, + lwe_0: &LweDef, + glwe_2: &GlweDef, + pbs_radix: &RadixDecomposition, + cbs_radix: &RadixDecomposition, +) { + allocate_scratch_ref!(lut, UnivariateLookupTableRef, (glwe_2.dim)); + allocate_scratch_ref!(lwe_rotated, LweCiphertextRef, (lwe_0.dim)); + allocate_scratch_ref!( + lwe_bootstrapped, + LweCiphertextRef, + (glwe_2.as_lwe_def().dim) + ); + + // Rotate our input by q/4, putting 0 centered on q/4 and 1 centered on + // -q/4. + rotate( + lwe_rotated, + input, + Torus::encode(S::one(), PlaintextBits(2)), + lwe_0, + ); + + for (i, lwe_2) in lwes_2.ciphertexts_mut(&glwe_2.as_lwe_def()).enumerate() { + let cur_level = i + 1; + + // Treat value as a T_{b^l+1} with one extra place for rounding as the last + // step. + let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level + 1) as u32); + + // Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1} + // everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will + // give 1. Since we've shifted our input lwe by q/4, a 1 plaintext + // value will map to 1 and a 0 will map to -1. + let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one(); + + lut.fill_with_constant(minus_one, glwe_2, plaintext_bits); + + programmable_bootstrap( + lwe_bootstrapped, + lwe_rotated, + lut, + bsk, + lwe_0, + glwe_2, + pbs_radix, + ); + + // Now we rotate our message containing -1 or 1 by 1 (wrt plaintext_bits). + // This will overflow -1 to 0 and cause 1 to wrap to 2. + rotate( + lwe_2, + lwe_bootstrapped, + Torus::encode(S::one(), plaintext_bits), + &glwe_2.as_lwe_def(), + ); + } +} + +/// Bootstraps a level 2 GLWE ciphertext to a level 1 GLWE ciphertext. +pub fn level_2_to_level1( + result: &mut GgswCiphertextRef, + lwes_2: &LweCiphertextListRef, + cbsksk: &CircuitBootstrappingKeyswitchKeysRef, + glwe_2: &GlweDef, + glwe_1: &GlweDef, + pfks_radix: &RadixDecomposition, + cbs_radix: &RadixDecomposition, +) { + for (glev, pfksk) in result.rows_mut(glwe_1, cbs_radix).zip(cbsksk.keys( + &glwe_2.as_lwe_def(), + glwe_1, + pfks_radix, + )) { + for (decomp, glwe) in lwes_2 + .ciphertexts(&glwe_2.as_lwe_def()) + .zip(glev.glwe_ciphertexts_mut(glwe_1)) + { + private_functional_keyswitch( + glwe, + &[decomp], + pfksk, + &glwe_2.as_lwe_def(), + glwe_1, + pfks_radix, + &PrivateFunctionalKeyswitchLweCount(1), + ); + } + } +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use crate::{ + entities::{GgswCiphertext, LweCiphertextList}, + high_level::{self, encryption, fft, keygen, TEST_LWE_DEF_1}, + PlaintextBits, RadixCount, RadixDecomposition, RadixLog, GLWE_1_1024_80, GLWE_5_256_80, + LWE_512_80, + }; + + use super::{circuit_bootstrap, level_0_to_level_2}; + + #[test] + fn can_level_0_to_level_2() { + let pbs_radix = RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(16), + }; + let cbs_radix = RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(5), + }; + + let glwe_params = GLWE_5_256_80; + + let mut level_2 = + LweCiphertextList::::new(&glwe_params.as_lwe_def(), cbs_radix.count.0); + + let sk = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1); + let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); + let bsk = keygen::generate_bootstrapping_key( + &sk, + &glwe_sk, + &TEST_LWE_DEF_1, + &glwe_params, + &pbs_radix, + ); + let bsk = fft::fft_bootstrap_key(&bsk, &TEST_LWE_DEF_1, &glwe_params, &pbs_radix); + + let lwe = sk.encrypt(0, &TEST_LWE_DEF_1, PlaintextBits(1)).0; + + level_0_to_level_2( + &mut level_2, + &lwe, + &bsk, + &TEST_LWE_DEF_1, + &glwe_params, + &pbs_radix, + &cbs_radix, + ); + + for (i, lwe_2) in level_2.ciphertexts(&glwe_params.as_lwe_def()).enumerate() { + let cur_level = i + 1; + + let bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level) as u32); + + let actual = + glwe_sk + .to_lwe_secret_key() + .decrypt(lwe_2, &glwe_params.as_lwe_def(), bits); + + assert_eq!(actual, 0); + } + + let lwe = sk.encrypt(1, &TEST_LWE_DEF_1, PlaintextBits(1)).0; + + level_0_to_level_2( + &mut level_2, + &lwe, + &bsk, + &TEST_LWE_DEF_1, + &glwe_params, + &pbs_radix, + &cbs_radix, + ); + + for (i, lwe_2) in level_2.ciphertexts(&glwe_params.as_lwe_def()).enumerate() { + let cur_level = i + 1; + + let bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level) as u32); + + let actual = + glwe_sk + .to_lwe_secret_key() + .decrypt(lwe_2, &glwe_params.as_lwe_def(), bits); + + assert_eq!(actual, 1); + } + } + + #[test] + fn can_circuit_bootstrap() { + let pbs_radix = RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(16), + }; + let cbs_radix = RadixDecomposition { + count: RadixCount(2), + radix_log: RadixLog(5), + }; + let pfks_radix = RadixDecomposition { + count: RadixCount(3), + radix_log: RadixLog(11), + }; + + let level_2_params = GLWE_5_256_80; + let level_1_params = GLWE_1_1024_80; + let level_0_params = LWE_512_80; + + let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params); + let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params); + let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params); + + let bsk = keygen::generate_bootstrapping_key( + &sk_0, + &sk_2, + &level_0_params, + &level_2_params, + &pbs_radix, + ); + let bsk = + high_level::fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix); + + let cbsksk = keygen::generate_cbs_ksk( + sk_2.to_lwe_secret_key(), + &sk_1, + &level_2_params.as_lwe_def(), + &level_1_params, + &pfks_radix, + ); + + for _ in 0..1 { + let val = thread_rng().next_u64() % 2; + + let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1)); + + let mut actual = GgswCiphertext::new(&level_1_params, &cbs_radix); + + circuit_bootstrap( + &mut actual, + &ct, + &bsk, + &cbsksk, + &level_0_params, + &level_1_params, + &level_2_params, + &pbs_radix, + &cbs_radix, + &pfks_radix, + ); + + let expected = + encryption::encrypt_ggsw(val, &sk_1, &level_1_params, &cbs_radix, PlaintextBits(1)); + + for (a, e) in actual + .rows(&level_1_params, &cbs_radix) + .zip(expected.rows(&level_1_params, &cbs_radix)) + { + for (i, (a, e)) in a + .glwe_ciphertexts(&level_1_params) + .zip(e.glwe_ciphertexts(&level_1_params)) + .enumerate() + { + let plaintext_bits = (i + 1) * cbs_radix.radix_log.0; + let plaintext_bits = PlaintextBits(plaintext_bits as u32); + + let a = encryption::decrypt_glwe(a, &sk_1, &level_1_params, plaintext_bits); + let e = encryption::decrypt_glwe(e, &sk_1, &level_1_params, plaintext_bits); + + assert_eq!(a, e); + } + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/bootstrapping/mod.rs b/sunscreen_tfhe/src/ops/bootstrapping/mod.rs new file mode 100644 index 000000000..195df1fb4 --- /dev/null +++ b/sunscreen_tfhe/src/ops/bootstrapping/mod.rs @@ -0,0 +1,8 @@ +mod blind_rotation; +pub use blind_rotation::*; + +mod circuit_bootstrapping; +pub use circuit_bootstrapping::*; + +mod programmable_bootstrapping; +pub use programmable_bootstrapping::*; diff --git a/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs b/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs new file mode 100644 index 000000000..0c3715d84 --- /dev/null +++ b/sunscreen_tfhe/src/ops/bootstrapping/programmable_bootstrapping.rs @@ -0,0 +1,824 @@ +use num::Complex; + +use crate::{ + dst::FromMutSlice, + entities::{ + BivariateLookupTableRef, BootstrapKeyFftRef, BootstrapKeyRef, GlweCiphertextRef, + GlweSecretKeyRef, LweCiphertextRef, LweSecretKeyRef, Polynomial, PolynomialRef, + UnivariateLookupTableRef, + }, + ops::{ + bootstrapping::rotate_glwe_positive_monomial_negacyclic, + ciphertext::{add_lwe_inplace, modulus_switch, sample_extract, scalar_mul_ciphertext_mad}, + encryption::encrypt_ggsw_ciphertext_scalar, + fft_ops::cmux, + }, + scratch::allocate_scratch_ref, + CarryBits, GlweDef, LweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps, +}; + +use super::rotate_glwe_negative_monomial_negacyclic; + +/// Generate a bootstrap key from a LWE secret key to a GLWE secret key. +/// +/// Mathematically, this key is a list of GGSW ciphertexts, one for each bit of +/// the secret key being encrypted. +/// +/// See +/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap) +/// for an example of how to use this key. +pub fn generate_bootstrap_key( + bootstrap_key: &mut BootstrapKeyRef, + sk_to_encrypt: &LweSecretKeyRef, + sk: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + sk.assert_valid(params); + radix.assert_valid::(); + + for (s_i, ggsw) in sk_to_encrypt + .s() + .iter() + .zip(bootstrap_key.rows_mut(params, radix)) + { + encrypt_ggsw_ciphertext_scalar(ggsw, *s_i, sk, params, radix, PlaintextBits(1)); + } +} + +/// Generate a negacyclic LUT for bootstrapping. Another name for this structure +/// is a test polynomial. +/// +/// The map function passed in must have the following negacyclic property, +/// where N is the size of the polynomial: +/// +/// ```text +/// map(N + i) = -map(i) +/// ``` +#[allow(dead_code)] +fn generate_negacyclic_lut( + output: &mut Polynomial>, + map: F, + params: &GlweDef, + plaintext_bits: PlaintextBits, +) where + S: TorusOps, + F: Fn(u64) -> u64, +{ + let p = (1 << plaintext_bits.0) as u64; + let n = params.dim.polynomial_degree.0 as u64; + + let stride = 2 * n / p; + + let delta = S::BITS - plaintext_bits.0; + + let c = output.coeffs_mut(); + + // Written out this way because when we get to programmable boot strapping, + // this will involve replacing p_i with f(p_i) + for (j, p_i_unmapped) in (0..=p / 2).enumerate() { + let j = j as u64; + + let p_i = map(p_i_unmapped); + assert!(p_i < p, "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i); + + let p_i = p_i << delta; + + if j == 0 { + for k in 0..(stride / 2) { + c[k as usize] = Torus::from(S::from_u64(p_i)); + } + } else if j == p / 2 { + for k in (n - (stride / 2))..n { + c[k as usize] = Torus::from(S::from_u64(p_i)); + } + } else { + for k in (stride / 2 + (j - 1) * stride)..(stride / 2 + j * stride) { + c[k as usize] = Torus::from(S::from_u64(p_i)); + } + } + } +} + +/// Generates a lookup table (LUT) to be used with bootstrapping. This LUT is +/// not negacyclic, and hence must be used with LWE inputs that have at least +/// one padding bit. +/// +/// The input `map` is used for generating programmable bootstrapping LUTs. This +/// function takes an element in the plaintext space and must produce another +/// element in the plaintext space. +pub(crate) fn generate_lut( + output: &mut PolynomialRef>, + map: F, + params: &GlweDef, + plaintext_bits: PlaintextBits, +) where + S: TorusOps, + F: Fn(u64) -> u64, +{ + let p = (1 << plaintext_bits.0) as usize; + let n = params.dim.polynomial_degree.0; + + assert!(n >= p); + + let stride = n / p; + + let delta = S::BITS - plaintext_bits.0; + + let c = output.coeffs_mut(); + + for (j, p_i_unmapped) in (0..=p - 1).enumerate() { + let p_i = map(p_i_unmapped as u64); + + assert!(p_i < (p as u64), "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i); + + let p_i = p_i << delta; + + // Insert a stride amount into the LUT + c[j * stride..(j + 1) * stride].iter_mut().for_each(|c| { + *c = Torus::from(S::from_u64(p_i)); + }); + } + + // Negate the first half of p_0 in the LUT in preparation for it to be + // rotated. + c[0..stride / 2].iter_mut().for_each(|c| { + *c = num::traits::WrappingNeg::wrapping_neg(c); + }); + + c.rotate_left(stride / 2); +} + +/// Programmable bootstrapping with a univariate function. +/// +/// The LUT this is a table that maps two inputs into a single output. For +/// example, say we want to encode the negation function `f(x) = (x + 1) % 2` +/// into a lookup table. We would create a +/// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable) that +/// implements this function and then execute it on the input ciphertexts. +/// +/// Important note: This function does not perform key switching. The output +/// ciphertext will be encrypted under the LWE key extracted from the GLWE +/// secret key used for the bootstrapping key. To perform a keyswitch, use +/// [`keyswitch_lwe_to_lwe`](crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe) +/// after the bootstrapping operation. +/// +/// # Example +/// +/// ``` +/// use sunscreen_tfhe::{ +/// high_level::{keygen, encryption, fft}, +/// entities::{UnivariateLookupTable, LweCiphertext}, +/// ops::bootstrapping::programmable_bootstrap, +/// params::{ +/// GLWE_1_1024_80, +/// LWE_512_80, +/// CarryBits, +/// PlaintextBits, +/// RadixDecomposition, +/// RadixCount, +/// RadixLog +/// }, +/// }; +/// +/// // Parameters defining the scheme we are using +/// let lwe_params = LWE_512_80; +/// let glwe_params = GLWE_1_1024_80; +/// let radix = RadixDecomposition { +/// count: RadixCount(3), +/// radix_log: RadixLog(4), +/// }; +/// +/// // We will be showing a binary univariate function. Note that for +/// // programmable bootstrapping to work in general, you will need to include at +/// // least one padding bit to the input. +/// let plaintext_bits = PlaintextBits(1); +/// let carry_bits = CarryBits(1); +/// let plaintext_bits_carry = PlaintextBits(2); +/// +/// // The univariate function we want to evaluate, encoded as a lookup table. +/// let negate = |x| (x + 1) % (1 << plaintext_bits.0); +/// let lut = UnivariateLookupTable::trivial_from_fn( +/// &negate, +/// &glwe_params, +/// plaintext_bits, +/// ); +/// +/// // Generate the secret keys and the bootstrapping key +/// let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params); +/// let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); +/// +/// let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, &lwe_params, &glwe_params, &radix); +/// let bsk = +/// fft::fft_bootstrap_key(&bsk, &lwe_params, &glwe_params, &radix); +/// +/// // Specify the inputs +/// let input_plain = 0; +/// +/// // Encrypt the inputs. Note we are adding carry bits to the inputs. +/// let input = encryption::encrypt_lwe_secret( +/// input_plain, +/// &lwe_sk, +/// &lwe_params, +/// plaintext_bits_carry +/// ); +/// +/// // Perform the programmable bootstrapping +/// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def()); +/// programmable_bootstrap( +/// &mut result, +/// &input, +/// &lut, +/// &bsk, +/// &lwe_params, +/// &glwe_params, +/// &radix, +/// ); +/// +/// // Check the result matches our plaintext function. +/// let decrypted = encryption::decrypt_lwe( +/// &result, +/// &glwe_sk.to_lwe_secret_key(), +/// &glwe_params.as_lwe_def(), +/// plaintext_bits, +/// ); +/// +/// let expected = negate(input_plain); +/// assert_eq!(expected, decrypted); +/// ``` +/// +/// # See also +/// +/// For the bivariate version of programmable bootstrapping, see +/// [`programmable_bootstrap_bivariate`](programmable_bootstrap_bivariate) and +/// its associated LUT +/// [`BivariateLookupTable`](crate::entities::BivariateLookupTable). +pub fn programmable_bootstrap( + output: &mut LweCiphertextRef, + input: &LweCiphertextRef, + lut: &UnivariateLookupTableRef, + bootstrap_key: &BootstrapKeyFftRef>, + lwe_params: &LweDef, + glwe_params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + // Steps: + // 1. Modulus switch the ciphertext to 2N. + // 2. Use a cmux tree to blind rotate V using the elements of the bootstrap key (the input LWE secret key bits). + // 3. Sample extract. + // 4. (Optional, done outside of this method) Key switch to the output LWE + // secret key (should be the one extracted from the GLWE key). + + let degree = glwe_params.dim.polynomial_degree.0; + let two_n = degree.ilog2() + 1; + + // 1. Modulus switch the ciphertext to 2N. + let mut ct = input.to_owned(); + modulus_switch(&mut ct, S::BITS, two_n, lwe_params); + + let (ct_a, ct_b) = ct.a_b(lwe_params); + + // 2. Use a cmux tree to blind rotate V using the elements of the bootstrap + // key (the input LWE secret key bits). + + // Perform V_0 ^ X^{-b} + allocate_scratch_ref!(cmux_output, GlweCiphertextRef, (glwe_params.dim)); + cmux_output.clear(); + + rotate_glwe_negative_monomial_negacyclic( + cmux_output, + lut.glwe(), + ct_b.inner().to_u64() as usize, + glwe_params, + ); + + allocate_scratch_ref!(rotated_ct, GlweCiphertextRef, (glwe_params.dim)); + + // Perform the cmux tree from the bootstrap key with the relation + // V_n = V_{n-1} ^ X^{a_{n-1} s_{n-1}} + for (a_i, index_select) in ct_a.iter().zip(bootstrap_key.rows(glwe_params, radix)) { + let tmp = cmux_output.to_owned(); + + // This operation performs a copy so the rotated_ct doesn't need to be + // cleared. + rotate_glwe_positive_monomial_negacyclic( + rotated_ct, + cmux_output, + a_i.inner().to_u64() as usize, + glwe_params, + ); + + cmux( + cmux_output, + &tmp, + rotated_ct, + index_select, + glwe_params, + radix, + ); + } + + // 3. Sample extract. + sample_extract(output, cmux_output, 0, glwe_params); +} + +/// Evaluate a bivariate function on a packed input. +fn bivariate_function(map: F, input: u64, plaintext_bits: PlaintextBits) -> u64 +where + F: Fn(u64, u64) -> u64, +{ + let modulus = 1 << plaintext_bits.0; + let lhs = (input / modulus) % modulus; + let rhs = input % modulus; + + let result = map(lhs, rhs); + + assert!( + result < modulus, + "The result of the bivariate function must be less than the plaintext modulus" + ); + + result +} + +/// Generate a lookup table that takes two inputs and produces a single output. +pub(crate) fn generate_bivariate_lut( + output: &mut PolynomialRef>, + map: F, + params: &GlweDef, + plaintext_bits: PlaintextBits, + carry_bits: CarryBits, +) where + S: TorusOps, + F: Fn(u64, u64) -> u64, +{ + assert!( + plaintext_bits.0 <= carry_bits.0, + "The number of plaintext bits must be less than or equal to the number of carry bits" + ); + + let wrapped_func = |input: u64| bivariate_function(&map, input, plaintext_bits); + + generate_lut( + output, + wrapped_func, + params, + PlaintextBits(plaintext_bits.0 + carry_bits.0), + ); +} + +/// Programmable bootstrapping with a bivariate function. +/// +/// The LUT this is a table that maps two inputs into a single output. +/// For example, say we want to encode the xor function `f(x, y) = (x + y) % 2` +/// into a lookup table. We would create a +/// [`BivariateLookupTable`](crate::entities::BivariateLookupTable) that +/// implements this function and then execute it on the input ciphertexts. +/// +/// Important note: This function does not perform key switching. The output +/// ciphertext will be encrypted under the LWE key extracted from the GLWE +/// secret key used for the bootstrapping key. To perform a keyswitch, use +/// [`keyswitch_lwe_to_lwe`](crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe) +/// after the bootstrapping operation. +/// +/// # Example +/// +/// ``` +/// use sunscreen_tfhe::{ +/// high_level::{keygen, encryption, fft}, +/// entities::{BivariateLookupTable, LweCiphertext}, +/// ops::bootstrapping::programmable_bootstrap_bivariate, +/// params::{ +/// GLWE_1_1024_80, +/// LWE_512_80, +/// CarryBits, +/// PlaintextBits, +/// RadixDecomposition, +/// RadixCount, +/// RadixLog +/// }, +/// }; +/// +/// // Parameters defining the scheme we are using +/// let lwe_params = LWE_512_80; +/// let glwe_params = GLWE_1_1024_80; +/// let radix = RadixDecomposition { +/// count: RadixCount(3), +/// radix_log: RadixLog(4), +/// }; +/// +/// // We will be showing a binary bivariate function, but bivariate +/// // bootstrapping can be done on more plaintext bits. Note that the effective +/// // number of plaintext bits used is twice the number of plaintext bits +/// // specified because the inputs are packed into one ciphertext inside +/// // `programmable_bootstrap_bivariate`. The number of carry bits must always +/// // be greater than or equal to the number of plaintext bits. +/// let plaintext_bits = PlaintextBits(1); +/// let plaintext_bits_carry = PlaintextBits(2); +/// let carry_bits = CarryBits(1); +/// +/// // The bivariate function we want to evaluate, encoded as a lookup table. +/// let xor = |x, y| (x + y) % (1 << plaintext_bits.0); +/// let lut = BivariateLookupTable::trivial_from_fn( +/// &xor, +/// &glwe_params, +/// plaintext_bits, +/// carry_bits +/// ); +/// +/// // Generate the secret keys and the bootstrapping key +/// let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params); +/// let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); +/// +/// let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, &lwe_params, &glwe_params, &radix); +/// let bsk = +/// fft::fft_bootstrap_key(&bsk, &lwe_params, &glwe_params, &radix); +/// +/// // Specify the inputs +/// let left_input_plain = 0; +/// let right_input_plain = 1; +/// +/// // Encrypt the inputs. Note we are adding carry bits to the inputs. +/// let left_input = encryption::encrypt_lwe_secret( +/// left_input_plain, +/// &lwe_sk, +/// &lwe_params, +/// plaintext_bits_carry +/// ); +/// let right_input = encryption::encrypt_lwe_secret( +/// right_input_plain, +/// &lwe_sk, +/// &lwe_params, +/// plaintext_bits_carry +/// ); +/// +/// // Perform the programmable bootstrapping +/// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def()); +/// programmable_bootstrap_bivariate( +/// &mut result, +/// &left_input, +/// &right_input, +/// &lut, +/// &bsk, +/// &lwe_params, +/// &glwe_params, +/// plaintext_bits, +/// &radix, +/// ); +/// +/// // Check the result matches our plaintext function. +/// let decrypted = encryption::decrypt_lwe_with_carry( +/// &result, +/// &glwe_sk.to_lwe_secret_key(), +/// &glwe_params.as_lwe_def(), +/// plaintext_bits, +/// carry_bits +/// ); +/// +/// let expected = xor(left_input_plain, right_input_plain); +/// assert_eq!(expected, decrypted); +/// ``` +/// +/// # See also +/// +/// For the univariate version of programmable bootstrapping, see +/// [`programmable_bootstrap`](programmable_bootstrap) and its associated LUT +/// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable). +#[allow(clippy::too_many_arguments)] +pub fn programmable_bootstrap_bivariate( + output: &mut LweCiphertextRef, + left_input: &LweCiphertextRef, + right_input: &LweCiphertextRef, + lut: &BivariateLookupTableRef, + bootstrap_key: &BootstrapKeyFftRef>, + lwe_params: &LweDef, + glwe_params: &GlweDef, + plaintext_bits: PlaintextBits, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + // The general operation for a bivariate PBS is + // + // 1. Ensure that the number of carry bits is equal to the size of the + // message or greater. + // 2. Define a LUT where the function takes in one input and decomposes that + // input into the higher n plaintext and lower n plaintext bits. The + // higher n bits are the left input to the bivariate function, while the + // lower n bits are the right input to the bivariate function. + // 3. Encrypt the two input ciphertexts using the number of carry bits and + // the plaintext bits, with padding. + // 4. On the left encrypted input, shift it up by the number of plaintext + // bits by multiplying the ciphertext by the plaintext modulus. + // 5. Add the left and right encrypted inputs together. + // 6. Perform the programmable bootstrapping with this combined input. + + let shift = (1 << plaintext_bits.0) as u64; + + allocate_scratch_ref!(pbs_input, LweCiphertextRef, (lwe_params.dim)); + pbs_input.clear(); + + // (left * modulus) + right to pack the two inputs into a single LWE + scalar_mul_ciphertext_mad(pbs_input, &S::from_u64(shift), left_input, lwe_params); + add_lwe_inplace(pbs_input, right_input, lwe_params); + + programmable_bootstrap( + output, + pbs_input, + lut.as_univariate(), + bootstrap_key, + lwe_params, + glwe_params, + radix, + ) +} + +#[cfg(test)] +mod tests { + + use crate::{ + entities::{ + BivariateLookupTable, BootstrapKey, BootstrapKeyFft, LweCiphertext, LweKeyswitchKey, + UnivariateLookupTable, + }, + high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX}, + ops::{ + encryption::{decrypt_ggsw_ciphertext, encrypt_lwe_ciphertext}, + keyswitch::lwe_keyswitch_key::generate_keyswitch_key_lwe, + }, + RoundedDiv, GLWE_1_1024_80, + }; + + use super::*; + + fn generate_negacyclic_lut_from_formula( + params: &GlweDef, + plaintext_bits: PlaintextBits, + ) -> Polynomial> + where + S: TorusOps, + { + let mut output = Polynomial::>::zero(params.dim.polynomial_degree.0); + + let p = (1 << plaintext_bits.0) as u64; + let n = params.dim.polynomial_degree.0 as u64; + + let divisor = 2 * n; + + for (j, c) in output.coeffs_mut().iter_mut().enumerate() { + let v_i = ((p * (j as u64)).div_rounded(divisor)) % p; + let v_i = v_i << (S::BITS - plaintext_bits.0); + *c = Torus::from(S::from_u64(v_i)); + } + + output + } + + #[test] + fn can_generate_negacyclic_lut() { + let p = PlaintextBits(4); + let params = TEST_GLWE_DEF_1; + + let mut poly = Polynomial::>::zero(params.dim.polynomial_degree.0); + generate_negacyclic_lut(&mut poly, |x| x, ¶ms, p); + + let expected = generate_negacyclic_lut_from_formula(¶ms, p); + + assert_eq!(expected, poly); + } + + #[test] + fn can_generate_bootstrap_key() { + let lwe_params = TEST_LWE_DEF_1; + let glwe_params = TEST_GLWE_DEF_1; + let radix = TEST_RADIX; + + let sk = keygen::generate_binary_lwe_sk(&lwe_params); + let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); + + let mut bootstrap_key = BootstrapKey::new(&lwe_params, &glwe_params, &radix); + generate_bootstrap_key(&mut bootstrap_key, &sk, &glwe_sk, &glwe_params, &radix); + + let mut count = 0; + for (s_i, ct) in sk.s().iter().zip(bootstrap_key.rows(&glwe_params, &radix)) { + let mut msg = Polynomial::>::zero(glwe_params.dim.polynomial_degree.0); + decrypt_ggsw_ciphertext(&mut msg, ct, &glwe_sk, &glwe_params, &radix); + + assert_eq!(msg.coeffs()[0].inner(), *s_i); + + count += 1 + } + + assert_eq!(count, sk.s().len()); + } + + fn bootstrap_helper(map: impl Fn(u64) -> u64) { + let bits = PlaintextBits(3); + let lwe = TEST_LWE_DEF_1; + let glwe = GLWE_1_1024_80; + let radix = TEST_RADIX; + + let original_sk = keygen::generate_binary_lwe_sk(&lwe); + let glwe_sk = keygen::generate_binary_glwe_sk(&glwe); + + // We want to switch from the sample extracted key to the new key. + let mut ksk = LweKeyswitchKey::::new(&glwe.as_lwe_def(), &lwe, &radix); + generate_keyswitch_key_lwe( + &mut ksk, + glwe_sk.to_lwe_secret_key(), + &original_sk, + &lwe, + &radix, + ); + + let mut bsk_nonfft = BootstrapKey::new(&lwe, &glwe, &radix); + generate_bootstrap_key(&mut bsk_nonfft, &original_sk, &glwe_sk, &glwe, &radix); + + let mut bsk = BootstrapKeyFft::new(&lwe, &glwe, &radix); + bsk_nonfft.fft(&mut bsk, &glwe, &radix); + + // Generate the LUT + let lut = UnivariateLookupTable::trivial_from_fn(&map, &glwe, bits); + + let mut failed = Vec::new(); + for msg in 0..(1 << bits.0) { + let mut original_ct = LweCiphertext::new(&lwe); + + // Adding a padding bit + let encoded_msg = msg << (64 - bits.0 - 1); + encrypt_lwe_ciphertext( + &mut original_ct, + &original_sk, + Torus::from(encoded_msg), + &lwe, + ); + + let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def()); + + programmable_bootstrap(&mut new_ct, &original_ct, &lut, &bsk, &lwe, &glwe, &radix); + + let decoded = glwe_sk + .to_lwe_secret_key() + .decrypt(&new_ct, &glwe.as_lwe_def(), bits); + + let result = map(msg); + if result != decoded { + failed.push((result, decoded)); + } + } + + if !failed.is_empty() { + panic!( + "Failed to decrypt the following messages and decrypted values: {:?}", + failed + ); + } + } + + #[test] + fn can_bootstrap() { + bootstrap_helper(|x| x); + } + + #[test] + fn can_bootstrap_with_map() { + bootstrap_helper(|x| (x + 3) % 8); + } + + fn bivariate_bootstrap_helper(map: impl Fn(u64, u64) -> u64) { + let lwe = TEST_LWE_DEF_1; + let glwe = TEST_GLWE_DEF_1; + let _radix = TEST_RADIX; + let bits = PlaintextBits(1); + + let carry_bits = CarryBits(1); + let radix = TEST_RADIX; + + let original_sk = keygen::generate_binary_lwe_sk(&lwe); + let glwe_sk = keygen::generate_binary_glwe_sk(&glwe); + + // We want to switch from the sample extracted key to the new key. + let mut ksk = LweKeyswitchKey::::new(&glwe.as_lwe_def(), &lwe, &radix); + generate_keyswitch_key_lwe( + &mut ksk, + glwe_sk.to_lwe_secret_key(), + &original_sk, + &lwe, + &radix, + ); + + let mut bsk_nonfft = BootstrapKey::new(&lwe, &glwe, &radix); + generate_bootstrap_key(&mut bsk_nonfft, &original_sk, &glwe_sk, &glwe, &radix); + + let mut bsk = BootstrapKeyFft::new(&lwe, &glwe, &radix); + bsk_nonfft.fft(&mut bsk, &glwe, &radix); + + // Generate the LUT + let lut = BivariateLookupTable::trivial_from_fn(&map, &glwe, bits, carry_bits); + + let mut failed = Vec::new(); + let mut succeeded = Vec::new(); + + for left_msg in 0..(1 << bits.0) { + for right_msg in 0..(1 << bits.0) { + let mut left_ct = LweCiphertext::new(&lwe); + let mut right_ct = LweCiphertext::new(&lwe); + + // Adding a padding bit, hence the - 1 + let encoded_left_msg = left_msg << (64 - bits.0 - carry_bits.0 - 1); + let encoded_right_msg = right_msg << (64 - bits.0 - carry_bits.0 - 1); + + encrypt_lwe_ciphertext( + &mut left_ct, + &original_sk, + Torus::from(encoded_left_msg), + &lwe, + ); + + encrypt_lwe_ciphertext( + &mut right_ct, + &original_sk, + Torus::from(encoded_right_msg), + &lwe, + ); + + let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def()); + + programmable_bootstrap_bivariate( + &mut new_ct, + &left_ct, + &right_ct, + &lut, + &bsk, + &lwe, + &glwe, + bits, + &radix, + ); + + let decrypted = glwe_sk + .to_lwe_secret_key() + .decrypt_without_decode(&new_ct, &glwe.as_lwe_def()); + + // We manually decode here because the + let plain_bits = bits; + + let round_bit = decrypted + .inner() + .wrapping_shr(64 - plain_bits.0 - carry_bits.0 - 1) + & 0x1; + let mask = (0x1 << plain_bits.0) - 1; + + let decoded = (decrypted + .inner() + .wrapping_shr(64 - plain_bits.0 - carry_bits.0) + + round_bit) + & mask; + + let result = map(left_msg, right_msg); + if result != decoded { + failed.push(((left_msg, right_msg), result, decoded)); + } else { + succeeded.push(((left_msg, right_msg), result, decoded)); + } + } + } + if !failed.is_empty() { + panic!( + "Failed to decrypt the following messages and decrypted values (as ((left input, right_input), expected, decrypted)): {:?}. However, the following messages and decrypted values succeeded: {:?}", + failed, succeeded + ); + } + } + + fn bivariate_test_function(left: u64, right: u64) -> u64 { + (left + right) % 2 + } + + #[test] + fn can_bootstrap_with_bivariate_map() { + bivariate_bootstrap_helper(bivariate_test_function); + } + + #[test] + fn can_decompose_bivariate_map() { + let plaintext_bits = PlaintextBits(2); + let modulus = 1 << plaintext_bits.0; + + let map = &bivariate_test_function; + + for left in 0u64..(plaintext_bits.0 as u64) { + for right in 0u64..(plaintext_bits.0 as u64) { + let left_shifted = left * modulus; + let input = left_shifted + right; + let result = bivariate_function(map, input, plaintext_bits); + + assert_eq!(result, map(left, right)); + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/ciphertext/glev_ciphertext_ops.rs b/sunscreen_tfhe/src/ops/ciphertext/glev_ciphertext_ops.rs new file mode 100644 index 000000000..b8f93416e --- /dev/null +++ b/sunscreen_tfhe/src/ops/ciphertext/glev_ciphertext_ops.rs @@ -0,0 +1,58 @@ +use crate::{ + entities::{GlevCiphertextRef, GlweCiphertextRef, Polynomial}, + radix::{PolynomialRadixIterator, ScalarRadixIterator}, + GlweDef, TorusOps, +}; + +use super::{glwe_polynomial_mad, glwe_scalar_mad}; + +/// Compute `c += (G^-1 * a) \[*\] b`, where +/// * `G^-1 * a`` is the radix decomposition of `a` +/// * `b` is a GLEV ciphertext. +/// * `c` is a GLWE ciphertext. +/// * \[*\] is the external product between a GLEV ciphertext and `l` polynomials +/// +/// # Remarks +/// This functions takes a PolynomialRadixIterator to perform the decomposition. +pub fn decomposed_polynomial_glev_mad( + c: &mut GlweCiphertextRef, + mut a: PolynomialRadixIterator, + b: &GlevCiphertextRef, + params: &GlweDef, +) where + S: TorusOps, +{ + // a = decomp(a_i) + // b = r + + let b_glwe = b.glwe_ciphertexts(params); + let mut cur_radix: Polynomial = Polynomial::zero(params.dim.polynomial_degree.0); + + // The decomposition of + // + // can be performed using + // sum_{j = 1}^l gamma_j * C_j + // where gamma_j is the polynomial to decompose multiplied by q/B^{j+1} + // Note the reverse of the GLWE ciphertexts here! The decomposition iterator + // returns the decomposed values in the opposite order. + for b in b_glwe.rev() { + a.write_next(&mut cur_radix); + glwe_polynomial_mad(c, b, &cur_radix, params); + } +} + +/// Compute `c += (G^-1 * a) \[*\] b`, where +/// * `G^-1 * a`` is the radix decomposition of `a` +/// * `b` is a GLEV ciphertext. +pub fn decomposed_scalar_glev_mad( + c: &mut GlweCiphertextRef, + a: ScalarRadixIterator, + b: &GlevCiphertextRef, + params: &GlweDef, +) where + S: TorusOps, +{ + for (b, a) in b.glwe_ciphertexts(params).rev().zip(a) { + glwe_scalar_mad(c, b, a, params); + } +} diff --git a/sunscreen_tfhe/src/ops/ciphertext/glwe_ciphertext_ops.rs b/sunscreen_tfhe/src/ops/ciphertext/glwe_ciphertext_ops.rs new file mode 100644 index 000000000..89e460b69 --- /dev/null +++ b/sunscreen_tfhe/src/ops/ciphertext/glwe_ciphertext_ops.rs @@ -0,0 +1,549 @@ +use crate::{ + dst::FromMutSlice, + entities::{ + GgswCiphertextRef, GlweCiphertext, GlweCiphertextRef, LweCiphertextRef, PolynomialRef, + }, + ops::ciphertext::decomposed_polynomial_glev_mad, + polynomial::{ + polynomial_add, polynomial_external_mad, polynomial_negate, polynomial_scalar_mad, + polynomial_sub, + }, + radix::PolynomialRadixIterator, + scratch::allocate_scratch_ref, + GlweDef, RadixDecomposition, TorusOps, +}; + +/** + * Extract a specific coefficient in a message M in a GLWE ciphertext as a LWE + * ciphertext under the LWE extracted secret key (extracted from the GLWE secret + * key). + * + * # Arguments + * + * * `output` - The output LWE ciphertext + * * `glwe` - The input GLWE ciphertext + * * `h` - The index of the coefficient to extract + * + * # Remarks + * For a GLWE ciphertext of size k and dimension N, the output LWE ciphertext + * passed in must have size k*N. + */ +pub fn sample_extract( + output: &mut LweCiphertextRef, + glwe: &GlweCiphertextRef, + h: usize, + params: &GlweDef, +) where + S: TorusOps, +{ + // We are copying parts of the GLWE ciphertext out according to the following rule: + // a_{N*i + j} = a_{i, h - j} for 0 <= i < k, 0 <= j <= h + // a_{N*i + j} = -a_{i, h - j + n} for 0 <= i < k, n + 1 <= j < N + // b = b_n + + #[allow(non_snake_case)] + let N = params.dim.polynomial_degree.0; + let k = params.dim.size.0; + + let lwe_size = k * N; + + let (a_lwe, b_lwe) = output.a_b_mut(¶ms.as_lwe_def()); + + // Make sure that the correctly sized LWE was passed in. + assert_eq!(lwe_size, a_lwe.len()); + + let (a_glwe, b_glwe) = glwe.a_b(params); + + for (i, a_gwe_i) in a_glwe.enumerate() { + #[allow(non_snake_case)] + let Ni = N * i; + let a_glwe_i_coeffs = a_gwe_i.coeffs(); + + for j in 0..=h { + a_lwe[Ni + j] = a_glwe_i_coeffs[h - j]; + } + + for j in (h + 1)..N { + // Note we add N to h first, otherwise h - j might underflow. + a_lwe[Ni + j] = num::traits::WrappingNeg::wrapping_neg(&a_glwe_i_coeffs[h + N - j]); + } + } + + *b_lwe = b_glwe.coeffs()[h]; +} + +/// Add two GLWE ciphertexts together, storing the result in `c`. +pub fn add_glwe_ciphertexts( + c: &mut GlweCiphertextRef, + a: &GlweCiphertextRef, + b: &GlweCiphertextRef, + params: &GlweDef, +) where + S: TorusOps, +{ + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + let (b_a, b_b) = b.a_b(params); + + assert_eq!(c_a.len(), a_a.len()); + assert_eq!(c_a.len(), b_a.len()); + + for (c, (a, b)) in c_a.zip(a_a.zip(b_a)) { + polynomial_add(c, a, b); + } + + polynomial_add(c_b, a_b, b_b); +} + +/// Subtract two GLWE ciphertexts together, storing the result in `c`. +pub fn sub_glwe_ciphertexts( + c: &mut GlweCiphertextRef, + a: &GlweCiphertextRef, + b: &GlweCiphertextRef, + params: &GlweDef, +) where + S: TorusOps, +{ + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + let (b_a, b_b) = b.a_b(params); + + assert_eq!(c_a.len(), a_a.len()); + assert_eq!(c_a.len(), b_a.len()); + + for (c, (a, b)) in c_a.zip(a_a.zip(b_a)) { + polynomial_sub(c, a, b); + } + + polynomial_sub(c_b, a_b, b_b); +} + +/// Homomorphically compute -ct. +/// +/// # Remarks +/// This operation is noiseless. +pub fn glwe_negate_inplace(ct: &mut GlweCiphertextRef, params: &GlweDef) +where + S: TorusOps, +{ + for a in ct.a_mut(params) { + polynomial_negate(a); + } + + polynomial_negate(ct.b_mut(params)); +} + +/// Compute c += a \[*\] b where \[*\] is the external product between a GLWE +/// ciphertext and an polynomial in Z\[X\]/(X^N + 1). +/// +/// # Remarks +/// For this to produce the correct result, degree(b) must be "small" (ideally +/// 0) and the coefficient must be small (i.e. less than the message size). +pub fn glwe_polynomial_mad( + c: &mut GlweCiphertextRef, + a: &GlweCiphertextRef, + b: &PolynomialRef, + params: &GlweDef, +) where + S: TorusOps, +{ + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + + assert_eq!(c_a.len(), params.dim.size.0); + assert_eq!(a_a.len(), params.dim.size.0); + + for (c, a) in c_a.zip(a_a) { + polynomial_external_mad(c, a, b); + } + + polynomial_external_mad(c_b, a_b, b); +} + +/// Compute `c += a \[*\] b`` where +/// * `a` is a GLWE ciphertext +/// * `b` is a scalar +/// * `\[*\]` is the external product operator GLWE \[*\] Z -> GLWE +pub fn glwe_scalar_mad( + c: &mut GlweCiphertextRef, + a: &GlweCiphertextRef, + b: S, + params: &GlweDef, +) where + S: TorusOps, +{ + for (c, a) in c.a_mut(params).zip(a.a(params)) { + polynomial_scalar_mad(c, a, b); + } + + polynomial_scalar_mad(c.b_mut(params), a.b(params), b); +} + +/// Compute `c += a \[*\] b`` where +/// * `a` is a GLWE ciphertext +/// * `b` is a GGSW cipheetext +/// * `\[*\]` is the external product operator GGSW \[*\] GLWE -> GLWE` +pub fn glwe_ggsw_mad( + c: &mut GlweCiphertextRef, + a: &GlweCiphertextRef, + b: &GgswCiphertextRef, + glwe_def: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + let (a_a, a_b) = a.a_b(glwe_def); + let rows = b.rows(glwe_def, radix); + + // Generate an iterator that includes a_a and a_b + let a_then_b_glwe_polynomials = a_a.chain(std::iter::once(a_b)); + + allocate_scratch_ref!(scratch, PolynomialRef, (glwe_def.dim.polynomial_degree)); + + // Performs the external operation + // + // GGSW ⊡ GLWE = sum_i=0^k + // + // where + // * `beta` is the decomposition base + // * `l` is the decomposition level + // * `AB_i` is the i-th polynomial of the GLWE ciphertext with B at the end {A, B} + // * `C_i` is the i-th row of the GGSW ciphertext + for (a_i, r) in a_then_b_glwe_polynomials.zip(rows) { + // For each a polynomial, compute the external product with the GLEV ciphertext + // at index i in the GGSW ciphertext. + let decomp = PolynomialRadixIterator::new(a_i, scratch, radix); + + decomposed_polynomial_glev_mad(c, decomp, r, glwe_def); + } +} + +/// Compute external product of a GLWE ciphertext and a GGSW ciphertext. +/// GGSW ⊡ GLWE -> GLWE +pub fn external_product_ggsw_glwe( + ggsw: &GgswCiphertextRef, + glwe: &GlweCiphertextRef, + params: &GlweDef, + radix: &RadixDecomposition, +) -> GlweCiphertext +where + S: TorusOps, +{ + // Zero GLWE to sum over. + let mut result = GlweCiphertext::new(params); + glwe_ggsw_mad(&mut result, glwe, ggsw, params, radix); + + result +} + +#[cfg(test)] +mod tests { + use crate::{ + entities::{GgswCiphertext, LweCiphertext, Polynomial}, + high_level::*, + high_level::{keygen, TEST_GLWE_DEF_1}, + ops::encryption::{ + decrypt_ggsw_ciphertext, encrypt_ggsw_ciphertext, encrypt_glwe_ciphertext_secret, + trivially_encrypt_glwe_ciphertext, + }, + polynomial::polynomial_mad, + PlaintextBits, Torus, + }; + + use super::*; + use rand::{thread_rng, RngCore}; + + #[test] + fn polynomial_iteration_mut() { + let glwe = TEST_GLWE_DEF_1; + + let mut sk = keygen::generate_binary_glwe_sk(&glwe); + + assert_eq!(sk.s(&glwe).count(), glwe.dim.size.0); + + for s_i in sk.s_mut(&glwe) { + assert_eq!(s_i.len(), glwe.dim.polynomial_degree.0); + + for s in s_i.coeffs() { + assert!(*s == 0 || *s == 1); + } + } + } + + #[test] + fn can_add_glwe_ciphertexts() { + let bits = PlaintextBits(4); + let glwe = TEST_GLWE_DEF_1; + + let sk = keygen::generate_binary_glwe_sk(&glwe); + + let plaintext = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % 4) + .collect::>(), + ); + + let a = sk.encode_encrypt_glwe(&plaintext, &glwe, bits); + let b = sk.encode_encrypt_glwe(&plaintext, &glwe, bits); + + let c = a + b; + + let dec = sk.decrypt_decode_glwe(&c, &glwe, bits); + + for (i, c) in dec.coeffs().iter().enumerate() { + assert_eq!(*c, 2 * (i as u64 % 4)); + } + } + + #[test] + fn can_sub_glwe_ciphertexts() { + let glwe = TEST_GLWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_glwe_sk(&glwe); + + let plaintext = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % 4) + .collect::>(), + ); + + let a = sk.encode_encrypt_glwe(&plaintext, &glwe, bits); + let b = sk.encode_encrypt_glwe(&plaintext, &glwe, bits); + + let c = a - b; + + let dec = sk.decrypt_decode_glwe(&c, &glwe, bits); + + for c in dec.coeffs() { + assert_eq!(*c, 0); + } + } + + #[test] + fn can_internal_product_glwe_polynomial() { + let bits = PlaintextBits(4); + let glwe = TEST_GLWE_DEF_1; + + let sk = keygen::generate_binary_glwe_sk(&glwe); + + let large_poly = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % 4) + .collect::>(), + ); + let small_poly = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| if x < 1 { 3 } else { 0 }) + .collect::>(), + ); + + // Do the external product with an encryption of the large polynomial. + let a = sk.encode_encrypt_glwe(&large_poly, &glwe, bits); + let mut c = GlweCiphertext::new(&glwe); + + glwe_polynomial_mad(&mut c, &a, &small_poly, &glwe); + + let actual = sk.decrypt_decode_glwe(&c, &glwe, bits); + + let mut expected = Polynomial::::zero(glwe.dim.polynomial_degree.0); + + polynomial_mad( + expected.as_wrapping_mut(), + large_poly.as_wrapping(), + small_poly.as_wrapping(), + ); + + assert_eq!(expected, actual); + + // Now do reverse the large and small polynomials' roles. + let a = sk.encode_encrypt_glwe(&small_poly, &glwe, bits); + let mut c = GlweCiphertext::new(&glwe); + + glwe_polynomial_mad(&mut c, &a, &large_poly, &glwe); + + let actual = sk.decrypt_decode_glwe(&c, &glwe, bits); + + assert_eq!(expected, actual); + } + + #[test] + fn can_external_product_ggsw_glwe() { + let glwe = TEST_GLWE_DEF_1; + let bits = PlaintextBits(4); + let radix = TEST_RADIX; + + let sk = keygen::generate_binary_glwe_sk(&glwe); + + // Constant polynomial: [1, 0, 0, ...] + let ggsw_plaintext_polynomial = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| if x < 1 { 1 } else { 0 }) + .collect::>(), + ); + let ggsw_plaintext_polynomial_torus = ggsw_plaintext_polynomial.map(|x| Torus::from(*x)); + + let mut ggsw_ct = GgswCiphertext::new(&glwe, &radix); + encrypt_ggsw_ciphertext( + &mut ggsw_ct, + &ggsw_plaintext_polynomial, + &sk, + &glwe, + &radix, + bits, + ); + + let mut ggsw_decrypt = Polynomial::zero(glwe.dim.polynomial_degree.0); + decrypt_ggsw_ciphertext(&mut ggsw_decrypt, &ggsw_ct, &sk, &glwe, &radix); + assert_eq!( + ggsw_decrypt, ggsw_plaintext_polynomial_torus, + "GGSW decrypt does not match plaintext" + ); + + // [1, 2, ...] + let glwe_plaintext_polynomial = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % (1 << bits.0)) + .collect::>(), + ); + + let glwe_plaintext_polynomial_encoded = + glwe_plaintext_polynomial.map(|x| Torus::encode(*x, bits)); + + // Do the external product with an encryption of the large polynomial. + let mut glwe_ct = GlweCiphertext::new(&glwe); + encrypt_glwe_ciphertext_secret( + &mut glwe_ct, + &glwe_plaintext_polynomial_encoded, + &sk, + &glwe, + ); + + let glwe_decrypt = sk.decrypt_decode_glwe(&glwe_ct, &glwe, bits); + assert_eq!( + glwe_decrypt, glwe_plaintext_polynomial, + "GLWE decrypt does not match plaintext" + ); + + let encrypted_result = external_product_ggsw_glwe(&ggsw_ct, &glwe_ct, &glwe, &radix); + let result = sk.decrypt_decode_glwe(&encrypted_result, &glwe, bits); + + // expected is the polynomial multiplication of + // ggsw_plaintext_polynomial and glwe_plaintext_polynomial + let mut expected = Polynomial::::zero(glwe.dim.polynomial_degree.0); + + polynomial_mad( + expected.as_wrapping_mut(), + ggsw_plaintext_polynomial.as_wrapping(), + glwe_plaintext_polynomial.as_wrapping(), + ); + + // Reduce modulo 2^BITS + let expected = expected.map(|x| x % (1 << bits.0)); + + // Verify that the ciphertext multiplication is correct + assert_eq!( + result, expected, + "External product does not match polynomial multiplication" + ); + } + + #[test] + fn can_add_glwe_with_trivial_glwe() { + let glwe = TEST_GLWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_glwe_sk(&glwe); + + let delta = 1u64 << (64 - bits.0); + + let large_poly = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % 4) + .collect::>(), + ); + let small_poly = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| if x < 1 { 3 } else { 0 }) + .collect::>(), + ); + let small_poly_scaled = small_poly.map(|x| Torus::from(x * delta)); + + let a = sk.encode_encrypt_glwe(&large_poly, &glwe, bits); + + let mut b = GlweCiphertext::new(&glwe); + trivially_encrypt_glwe_ciphertext(&mut b, &small_poly_scaled, &glwe); + + let c = a.as_ref() + b.as_ref(); + + let actual = sk.decrypt_decode_glwe(&c, &glwe, bits); + let expected = small_poly + large_poly; + + assert_eq!(expected, actual); + + let c2 = b.as_ref() + a.as_ref(); + + let actual = sk.decrypt_decode_glwe(&c2, &glwe, bits); + assert_eq!(expected, actual); + } + + #[test] + fn test_sample_extract() { + let bits = PlaintextBits(2); + let glwe_params = TEST_GLWE_DEF_1; + let lwe_params = glwe_params.as_lwe_def(); + + let sk = keygen::generate_binary_glwe_sk(&glwe_params); + let extracted_lwe_sk = sk.to_lwe_secret_key(); + + let large_poly = Polynomial::new( + &(0..glwe_params.dim.polynomial_degree.0 as u64) + .map(|x| x % 4) + .collect::>(), + ); + + let glwe = sk.encode_encrypt_glwe(&large_poly, &glwe_params, bits); + + for h in 0..glwe_params.dim.polynomial_degree.0 { + let mut lwe = LweCiphertext::new(&lwe_params); + sample_extract(&mut lwe, &glwe, h, &glwe_params); + + let lwe_msg = extracted_lwe_sk.decrypt(&lwe, &lwe_params, bits); + + let expected = large_poly.coeffs()[h]; + + assert_eq!(expected, lwe_msg); + } + } + + #[test] + fn can_glwe_scalar_mad() { + for _ in 0..20 { + let sk = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1); + + let plaintext_bits = (thread_rng().next_u64()) % 8 + 1; + let plaintext_bits = PlaintextBits(plaintext_bits as u32); + + let scalar = thread_rng().next_u64() % 64; + + let pt = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0) + .map(|_| thread_rng().next_u64() % plaintext_bits.0 as u64) + .collect::>(); + let pt = Polynomial::new(&pt); + + let ct = sk.encode_encrypt_glwe(&pt, &TEST_GLWE_DEF_1, plaintext_bits); + + let mut result = GlweCiphertext::new(&TEST_GLWE_DEF_1); + + glwe_scalar_mad(&mut result, &ct, scalar, &TEST_GLWE_DEF_1); + + let actual = sk.decrypt_decode_glwe(&result, &TEST_GLWE_DEF_1, plaintext_bits); + + for (pt, actual) in pt.coeffs().iter().zip(actual.coeffs()) { + let expected = pt.wrapping_mul(scalar) % (0x1u64 << plaintext_bits.0); + + assert_eq!(expected, *actual); + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/ciphertext/lev_ciphertext_ops.rs b/sunscreen_tfhe/src/ops/ciphertext/lev_ciphertext_ops.rs new file mode 100644 index 000000000..cd8d7cc80 --- /dev/null +++ b/sunscreen_tfhe/src/ops/ciphertext/lev_ciphertext_ops.rs @@ -0,0 +1,42 @@ +use crate::{ + entities::{LevCiphertextRef, LweCiphertextRef, Polynomial}, + radix::PolynomialRadixIterator, + LweDef, TorusOps, +}; + +use super::scalar_mul_ciphertext_mad; + +/// Compute `c += (G^-1 * a) \[*\] b`, where +/// * `G^-1 * a`` is the radix decomposition of `a` +/// * `b` is a LEV ciphertext. +/// * `c` is a LWE ciphertext. +/// * \[*\] is the external product between a LEV ciphertext and the decomposed +/// LWE ciphertext. +/// +/// # Remarks +/// This functions takes a PolynomialRadixIterator to perform the decomposition. +pub fn decomposed_scalar_lev_mad( + c: &mut LweCiphertextRef, + mut a: PolynomialRadixIterator, + b: &LevCiphertextRef, + params: &LweDef, +) where + S: TorusOps, +{ + let b_lwe = b.lwe_ciphertexts(params); + let mut cur_radix: Polynomial = Polynomial::zero(1); + + // The decomposition of + // + // can be performed using + // sum_{j = 1}^l gamma_j * C_j + // where gamma_j is the polynomial to decompose multiplied by q/B^{j+1} + // Note the reverse of the GLWE ciphertexts here! The decomposition iterator + // returns the decomposed values in the opposite order. + for b in b_lwe.rev() { + a.write_next(&mut cur_radix); + let radix = cur_radix.coeffs()[0]; + + scalar_mul_ciphertext_mad(c, &radix, b, params); + } +} diff --git a/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs b/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs new file mode 100644 index 000000000..dbfa8b019 --- /dev/null +++ b/sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs @@ -0,0 +1,93 @@ +use crate::{entities::LweCiphertextRef, LweDef, RoundedDiv, Torus, TorusOps}; + +/// Add the coefficients of a to the coefficients of c in place. +pub fn add_lwe_inplace(c: &mut LweCiphertextRef, a: &LweCiphertextRef, params: &LweDef) +where + S: TorusOps, +{ + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + + assert_eq!(c_a.len(), a_a.len()); + + for (c, a) in c_a.iter_mut().zip(a_a.iter()) { + *c = num::traits::WrappingAdd::wrapping_add(c, a); + } + + *c_b = num::traits::WrappingAdd::wrapping_add(c_b, a_b); +} + +/// Subtract one LWE ciphertext from another, storing the result in the provided +/// output variable. Mostly meant to be used reduce the number of allocations +/// and with functions like [allocate_scratch_ref]. +pub(crate) fn sub_lwe_ciphertexts( + c: &mut LweCiphertextRef, + a: &LweCiphertextRef, + b: &LweCiphertextRef, + params: &LweDef, +) where + S: TorusOps, +{ + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + let (b_a, b_b) = b.a_b(params); + + assert_eq!(c_a.len(), a_a.len()); + assert_eq!(c_a.len(), b_a.len()); + + for (c, (a, b)) in c_a.iter_mut().zip(a_a.iter().zip(b_a.iter())) { + *c = num::traits::WrappingSub::wrapping_sub(a, b); + } + + *c_b = num::traits::WrappingSub::wrapping_sub(a_b, b_b); +} + +/// Multiplies an LWE ciphertext by a scalar, storing the result in the provided +/// output variable. Mostly meant to be used reduce the number of allocations +/// and with functions like [allocate_scratch_ref]. +pub(crate) fn scalar_mul_ciphertext_mad( + c: &mut LweCiphertextRef, + scalar: &S, + a: &LweCiphertextRef, + params: &LweDef, +) where + S: TorusOps, +{ + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + + assert_eq!(c_a.len(), a_a.len()); + + for (c, a) in c_a.iter_mut().zip(a_a.iter()) { + *c += a * scalar; + } + + *c_b += a_b * scalar; +} + +/// Perform modulus switching on a ciphertext. We are assuming that moduli are +/// both powers of two, and that the original number of bits is greater than the +/// new number of bits. +pub fn modulus_switch( + ct: &mut LweCiphertextRef, + original_bits: u32, + new_bits: u32, + params: &LweDef, +) where + S: TorusOps, +{ + let (c_a, c_b) = ct.a_b_mut(params); + + // We specifically want to zero out the MSBs instead of shifting them back + // around. + for a in c_a { + let c = a.inner().to_u64() as u128; + let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128); + *a = Torus::from(S::from_u64(res as u64)); + } + + let c = c_b.inner().to_u64() as u128; + let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128); + + *c_b = Torus::from(S::from_u64(res as u64)); +} diff --git a/sunscreen_tfhe/src/ops/ciphertext/mod.rs b/sunscreen_tfhe/src/ops/ciphertext/mod.rs new file mode 100644 index 000000000..5eee2fe42 --- /dev/null +++ b/sunscreen_tfhe/src/ops/ciphertext/mod.rs @@ -0,0 +1,11 @@ +mod lwe_ciphertext_ops; +pub use lwe_ciphertext_ops::*; + +mod lev_ciphertext_ops; +pub use lev_ciphertext_ops::*; + +mod glev_ciphertext_ops; +pub use glev_ciphertext_ops::*; + +mod glwe_ciphertext_ops; +pub use glwe_ciphertext_ops::*; diff --git a/sunscreen_tfhe/src/ops/encryption/ggsw_encryption.rs b/sunscreen_tfhe/src/ops/encryption/ggsw_encryption.rs new file mode 100644 index 000000000..99487cd2d --- /dev/null +++ b/sunscreen_tfhe/src/ops/encryption/ggsw_encryption.rs @@ -0,0 +1,403 @@ +use num::Zero; + +use crate::{ + entities::{GgswCiphertextRef, GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef}, + polynomial::{polynomial_external_mad, polynomial_scalar_mad, polynomial_scalar_mul}, + GlweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps, +}; + +use super::{ + decrypt_glwe_ciphertext, encrypt_glwe_ciphertext_secret, + trivially_encrypt_glwe_with_sk_argument, +}; + +/// Perform a ggsw encryption. This is generic in case a trivial GGSW encryption +/// is wanted (for example, for testing purposes). +pub(crate) fn encrypt_ggsw_ciphertext_generic( + ggsw_ciphertext: &mut GgswCiphertextRef, + msg: &PolynomialRef, + glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, + encrypt: impl Fn( + &mut GlweCiphertextRef, + &PolynomialRef>, + &GlweSecretKeyRef, + &GlweDef, + ), +) where + S: TorusOps, +{ + let max_val = S::from_u64(0x1 << plaintext_bits.0); + assert!(msg.coeffs().iter().all(|x| *x < max_val)); + + let decomposition_radix_log = radix.radix_log.0; + let polynomial_degree = params.dim.polynomial_degree.0; + let glwe_size = params.dim.size.0; + + // k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption + // of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of + // M/B^{j+1} + for (i, row) in ggsw_ciphertext.rows_mut(params, radix).enumerate() { + let mut m_times_s = Polynomial::>::zero(polynomial_degree); + + let m_times_s = if i < glwe_size { + // The message is composed of the negated secret key and the message + // for all but the last row. + let s = glwe_secret_key.s(params).nth(i).unwrap(); + polynomial_external_mad(&mut m_times_s, msg.as_torus(), s); + + // Negate the product. + for c in m_times_s.coeffs_mut().iter_mut() { + // Have to call the trait directly because deref is implemented on Torus + *c = num::traits::WrappingNeg::wrapping_neg(c); + } + + &m_times_s + } else { + // Last row isn't multiplied by secret key. + msg.as_torus() + }; + + for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() { + let mut scaled_msg = Polynomial::zero(polynomial_degree); + + // The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to + // multiplying by 2^{log2(q) - log2(B) * (i + 1)} + let decomp_factor = + S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1))); + + polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor); + + encrypt(col, &scaled_msg, glwe_secret_key, params); + } + } +} + +/// Encrypt a GGSW ciphertext with a given message polynomial and secret key. +pub fn encrypt_ggsw_ciphertext( + ggsw_ciphertext: &mut GgswCiphertextRef, + msg: &PolynomialRef, + glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, +) where + S: TorusOps, +{ + encrypt_ggsw_ciphertext_generic( + ggsw_ciphertext, + msg, + glwe_secret_key, + params, + radix, + plaintext_bits, + encrypt_glwe_ciphertext_secret, + ); +} + +/// Encrypt a GGSW ciphertext with a given message polynomial and secret key. +/// This is a trivial encryption that doesn't use the secret key and is not +/// secure. +pub fn trivially_encrypt_ggsw_ciphertext( + ggsw_ciphertext: &mut GgswCiphertextRef, + msg: &PolynomialRef, + glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, +) where + S: TorusOps, +{ + encrypt_ggsw_ciphertext_generic( + ggsw_ciphertext, + msg, + glwe_secret_key, + params, + radix, + plaintext_bits, + trivially_encrypt_glwe_with_sk_argument, + ); +} + +/// Encrypt scalar (i.e. degree 0 polynomial) msg as a GGSW ciphertext. +pub fn encrypt_ggsw_ciphertext_scalar( + ggsw_ciphertext: &mut GgswCiphertextRef, + msg: S, + glwe_secret_key: &GlweSecretKeyRef, + glwe_def: &GlweDef, + radix: &RadixDecomposition, + plaintext_bits: PlaintextBits, +) where + S: TorusOps, +{ + let max_val = S::from_u64(0x1 << plaintext_bits.0); + assert!(msg < max_val); + + let decomposition_radix_log = radix.radix_log.0; + let polynomial_degree = glwe_def.dim.polynomial_degree.0; + let glwe_size = glwe_def.dim.size.0; + + // k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption + // of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of + // M/B^{j+1} + for (i, row) in ggsw_ciphertext.rows_mut(glwe_def, radix).enumerate() { + let mut m_times_s = Polynomial::>::zero(polynomial_degree); + let m_times_s = if i < glwe_size { + let s = glwe_secret_key.s(glwe_def).nth(i).unwrap(); + polynomial_scalar_mad(&mut m_times_s, s.as_torus(), msg); + &m_times_s + } else { + // Last row isn't multiplied by secret key. + m_times_s.clear(); + m_times_s.coeffs_mut()[0] = Torus::from(msg); + &m_times_s + }; + + for (j, col) in row.glwe_ciphertexts_mut(glwe_def).enumerate() { + let mut scaled_msg = Polynomial::zero(polynomial_degree); + // The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to + // multiplying by 2^{log2(q) - log2(B) * (i + 1)} + let decomp_factor = + S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1))); + + if i < glwe_size { + let decomp_factor = decomp_factor.wrapping_neg(); + + polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor); + } else { + scaled_msg.coeffs_mut()[0] = Torus::from(msg.wrapping_mul(&decomp_factor)); + + for c in scaled_msg.coeffs_mut().iter_mut().skip(1) { + *c = Torus::zero(); + } + } + + encrypt_glwe_ciphertext_secret(col, &scaled_msg, glwe_secret_key, glwe_def); + } + } +} + +fn decrypt_glwe_in_ggsw( + msg: &mut PolynomialRef>, + ggsw_ciphertext: &GgswCiphertextRef, + glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + row: usize, + column: usize, +) -> Option<()> +where + S: TorusOps, +{ + let decomposition_radix_log = radix.radix_log.0; + + // To decrypt a GGSW ciphertext, it suffices to decrypt the first GLWE ciphertext in + // the last row and divide by its decomposition factor. + let glev = ggsw_ciphertext.rows(params, radix).nth(row)?; + let glwe = glev.glwe_ciphertexts(params).nth(column)?; + + // Decrypt that specific GLWE ciphertext, which should have a message of + // q / beta ^ {column + 1} * SM, where SM is the message times the secret + // every row but the last (-SM) and M for the last row. + decrypt_glwe_ciphertext(msg, glwe, glwe_secret_key, params); + + let mask = (0x1 << decomposition_radix_log) - 1; + + for c in msg.coeffs_mut() { + let val = c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1)); + let r = (c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1) - 1)) + & S::from_u64(0x1); + + *c = Torus::from((val + r) & S::from_u64(mask)); + } + + Some(()) +} + +/// Decrypt a GGSW ciphertext with a given secret key. +pub fn decrypt_ggsw_ciphertext( + msg: &mut PolynomialRef>, + ggsw_ciphertext: &GgswCiphertextRef, + glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + let row = params.dim.size.0; + + decrypt_glwe_in_ggsw(msg, ggsw_ciphertext, glwe_secret_key, params, radix, row, 0).unwrap(); +} + +#[cfg(test)] +mod tests { + use crate::{entities::GgswCiphertext, high_level::TEST_GLWE_DEF_1, high_level::*}; + + use super::*; + + #[test] + fn can_encrypt_decrypt_gsw_const_coeff() { + let params = TEST_GLWE_DEF_1; + let radix = &TEST_RADIX; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let msg = 1; + + let ct = encryption::encrypt_ggsw(msg, &sk, ¶ms, radix, bits); + let pt = encryption::decrypt_ggsw(&ct, &sk, ¶ms, radix, bits); + + assert_eq!(pt.coeffs()[0], msg); + + for c in pt.coeffs().iter().skip(1) { + assert_eq!(*c, 0); + } + } + + /// Test that each of the rows in the GGSW ciphertext is a GLWE ciphertext that encodes the + /// appropriate message (usually the decomposed message times the secret key) + #[test] + fn can_decrypt_all_elements_ggsw() { + let params = TEST_GLWE_DEF_1; + let radix = TEST_RADIX; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let coeffs = (0..params.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(); + let msg = Polynomial::new(&coeffs); + + let mut ct = GgswCiphertext::new(¶ms, &radix); + encrypt_ggsw_ciphertext(&mut ct, &msg, &sk, ¶ms, &radix, bits); + + let mut pt = Polynomial::zero(params.dim.polynomial_degree.0); + decrypt_ggsw_ciphertext(&mut pt, &ct, &sk, ¶ms, &radix); + let pt = pt.map(|x| x.inner()); + + // Ensure that the basic decryption works. + assert_eq!(pt, msg); + + let n_rows = ct.rows(¶ms, &radix).len(); + let n_cols = ct + .rows(¶ms, &radix) + .next() + .unwrap() + .glwe_ciphertexts(¶ms) + .len(); + + // Beta + let decomposition_radix_log = radix.radix_log.0; + + for i in 0..n_rows { + let mut m_times_s = Polynomial::zero(params.dim.polynomial_degree.0); + let m_times_s = if i < params.dim.size.0 { + // The message is composed of the negated secret key and the message + // for all but the last row. + let s = sk.s(¶ms).nth(i).unwrap(); + polynomial_external_mad(&mut m_times_s, msg.as_torus(), s); + + // Negate the product. + for c in m_times_s.coeffs_mut().iter_mut() { + // Have to call the trait directly because deref is implemented on Torus + *c = num::traits::WrappingNeg::wrapping_neg(c); + } + + &m_times_s + } else { + // Last row isn't multiplied by secret key. + msg.as_torus() + }; + + for j in 0..n_cols { + let mut pt = Polynomial::zero(params.dim.polynomial_degree.0); + let mut msg = m_times_s.to_owned(); + + let mask = (0x1 << decomposition_radix_log) - 1; + + for c in msg.coeffs_mut() { + *c = Torus::from(c.inner() & mask); + } + + decrypt_glwe_in_ggsw(&mut pt, &ct, &sk, ¶ms, &radix, i, j).unwrap(); + + assert_eq!(pt, msg); + } + } + } + + #[test] + fn can_trivially_decrypy_ggsw() { + let params = TEST_GLWE_DEF_1; + let radix = TEST_RADIX; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let coeffs = (0..params.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(); + let msg = Polynomial::new(&coeffs); + + let mut ct = GgswCiphertext::new(¶ms, &radix); + trivially_encrypt_ggsw_ciphertext(&mut ct, &msg, &sk, ¶ms, &radix, bits); + + let mut pt = Polynomial::zero(params.dim.polynomial_degree.0); + decrypt_ggsw_ciphertext(&mut pt, &ct, &sk, ¶ms, &radix); + let pt = pt.map(|x| x.inner()); + + // Ensure that the basic decryption works. + assert_eq!(pt, msg); + + let n_rows = ct.rows(¶ms, &radix).len(); + let n_cols = ct + .rows(¶ms, &radix) + .next() + .unwrap() + .glwe_ciphertexts(¶ms) + .len(); + + // Beta + let decomposition_radix_log = radix.radix_log.0; + + for i in 0..n_rows { + let mut m_times_s = Polynomial::zero(params.dim.polynomial_degree.0); + let m_times_s = if i < params.dim.size.0 { + // The message is composed of the negated secret key and the message + // for all but the last row. + let s = sk.s(¶ms).nth(i).unwrap(); + polynomial_external_mad(&mut m_times_s, msg.as_torus(), s); + + // Negate the product. + for c in m_times_s.coeffs_mut().iter_mut() { + // Have to call the trait directly because deref is implemented on Torus + *c = num::traits::WrappingNeg::wrapping_neg(c); + } + + &m_times_s + } else { + // Last row isn't multiplied by secret key. + msg.as_torus() + }; + + for j in 0..n_cols { + let mut pt = Polynomial::zero(params.dim.polynomial_degree.0); + let mut msg = m_times_s.to_owned(); + + let mask = (0x1 << decomposition_radix_log) - 1; + + for c in msg.coeffs_mut() { + *c = Torus::from(c.inner() & mask); + } + + decrypt_glwe_in_ggsw(&mut pt, &ct, &sk, ¶ms, &radix, i, j).unwrap(); + + assert_eq!(pt, msg); + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs b/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs new file mode 100644 index 000000000..d0994f60b --- /dev/null +++ b/sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs @@ -0,0 +1,184 @@ +use num::Zero; + +use crate::{ + entities::{GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef}, + polynomial::{polynomial_add_assign, polynomial_external_mad, polynomial_sub_assign}, + rand::{normal_torus, uniform_torus}, + GlweDef, Torus, TorusOps, +}; + +pub(crate) fn trivially_encrypt_glwe_with_sk_argument( + glwe_ciphertext: &mut GlweCiphertextRef, + msg: &PolynomialRef>, + _glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, +) where + S: TorusOps, +{ + trivially_encrypt_glwe_ciphertext(glwe_ciphertext, msg, params); +} + +/// Encrypt `msg` into a into the given GLWE ciphertext `c` using the secret key `sk.` +pub fn encrypt_glwe_ciphertext_secret_generic( + c: &mut GlweCiphertextRef, + msg: &PolynomialRef>, + sk: &GlweSecretKeyRef, + params: &GlweDef, +) where + S: TorusOps, +{ + let mut tmp = Polynomial::zero(params.dim.polynomial_degree.0); + + let (a, b) = c.a_b_mut(params); + + // tmp = A_i * S_i + for (a_i, s_i) in a.zip(sk.s(params)) { + // Fill a_i with uniform data + for c in a_i.coeffs_mut() { + *c = uniform_torus(); + } + + polynomial_external_mad(&mut tmp, a_i, s_i); + } + + // b = A * S + polynomial_add_assign(b, &tmp); + + // b = A * S + m + polynomial_add_assign(b, msg); + + let e = Polynomial::new( + &(0..msg.len()) + .map(|_| normal_torus::(params.std)) + .collect::>(), + ); + + // b = A * S + m + e + polynomial_add_assign(b, &e); +} + +/// Encrypt `msg` into a into the given GLWE ciphertext `c` using the secret key `sk.` +pub fn encrypt_glwe_ciphertext_secret( + c: &mut GlweCiphertextRef, + msg: &PolynomialRef>, + sk: &GlweSecretKeyRef, + params: &GlweDef, +) where + S: TorusOps, +{ + encrypt_glwe_ciphertext_secret_generic(c, msg, sk, params) +} + +/// Generate a trivial GLWE encryption. Note that the caller will need to scale +/// the message appropriately; a factor like delta is not automatically applied. +pub fn trivially_encrypt_glwe_ciphertext( + c: &mut GlweCiphertextRef, + msg: &PolynomialRef>, + params: &GlweDef, +) where + S: TorusOps, +{ + let (a, b) = c.a_b_mut(params); + + // tmp = A_i * S_i + for a_i in a { + // Fill a_i with zero data + for c in a_i.coeffs_mut() { + *c = Torus::zero(); + } + } + + // b = m + b.clone_from_ref(msg); +} + +/// Decrypt GLWE ciphertext `ct` into `msg` using secret key `sk`. +pub fn decrypt_glwe_ciphertext( + msg: &mut PolynomialRef>, + ct: &GlweCiphertextRef, + sk: &GlweSecretKeyRef, + params: &GlweDef, +) where + S: TorusOps, +{ + let (a, b) = ct.a_b(params); + + let mut tmp = Polynomial::zero(b.len()); + + // msg = b + msg.clone_from_ref(b); + + // tmp = A_i * S_i + for (a_i, s_i) in a.zip(sk.s(params)) { + polynomial_external_mad(&mut tmp, a_i, s_i); + } + + // msg = b - A * S = m + e + polynomial_sub_assign(msg, &tmp); +} + +#[cfg(test)] +mod tests { + use crate::{high_level::*, PlaintextBits}; + + use super::*; + + // Encryption + + #[test] + fn can_encrypt_decrypt() { + let params = TEST_GLWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let plaintext = Polynomial::new( + &(0..params.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(), + ); + + let ct = encryption::encrypt_glwe(&plaintext, &sk, ¶ms, bits); + let dec = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits); + + assert_eq!(dec, plaintext); + } + + #[test] + fn can_encrypt_decrypt_uniform() { + let params = TEST_GLWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_uniform_glwe_sk(¶ms); + + let plaintext = Polynomial::new( + &(0..params.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(), + ); + + let ct = encryption::encrypt_glwe(&plaintext, &sk, ¶ms, bits); + let dec = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits); + + assert_eq!(dec, plaintext); + } + + #[test] + fn trivial_glwe_decrypts() { + let params = TEST_GLWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_glwe_sk(¶ms); + + let plaintext = Polynomial::new( + &(0..params.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(), + ); + + let ct = encryption::trivial_glwe(&plaintext, ¶ms, bits); + let dec = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits); + + assert_eq!(dec, plaintext); + } +} diff --git a/sunscreen_tfhe/src/ops/encryption/lwe_encryption.rs b/sunscreen_tfhe/src/ops/encryption/lwe_encryption.rs new file mode 100644 index 000000000..4060ffd3b --- /dev/null +++ b/sunscreen_tfhe/src/ops/encryption/lwe_encryption.rs @@ -0,0 +1,114 @@ +use sunscreen_math::Zero; + +use crate::{ + entities::{LweCiphertextRef, LweSecretKeyRef}, + math::{Torus, TorusOps}, + rand::{normal_torus, uniform_torus}, + LweDef, PlaintextBits, +}; + +/// Generate a trivial GLWE encryption. Note that the caller will need to scale +/// the message appropriately; a factor like delta is not automatically applied. +pub fn trivially_encrypt_lwe_ciphertext( + c: &mut LweCiphertextRef, + msg: &Torus, + params: &LweDef, +) where + S: TorusOps, +{ + let (a, b) = c.a_b_mut(params); + + // tmp = A_i * S_i + for a_i in a { + *a_i = Torus::zero(); + } + + // b = m + *b = *msg; +} + +/// Encrypts the given message under sk, writing the ciphertext to ct. Returns the +/// randomness used to generate the ciphertext. +pub fn encrypt_lwe_ciphertext( + ct: &mut LweCiphertextRef, + sk: &LweSecretKeyRef, + msg: Torus, + params: &LweDef, +) -> Torus +where + S: TorusOps, +{ + let (a, b) = ct.a_b_mut(params); + + for (a_i, d_i) in a.iter_mut().zip(sk.as_slice().iter()) { + *a_i = uniform_torus::(); + *b += *a_i * d_i; + } + + let e = normal_torus(params.std); + *b += msg + e; + + e +} + +/// Encrypts the given message under sk, writing the ciphertext to ct. Returns the +/// randomness used to generate the ciphertext. +pub fn encode_and_encrypt_lwe_ciphertext( + ct: &mut LweCiphertextRef, + sk: &LweSecretKeyRef, + msg: S, + params: &LweDef, + plaintext_bits: PlaintextBits, +) -> Torus +where + S: TorusOps, +{ + let msg = Torus::::encode(msg, plaintext_bits); + + encrypt_lwe_ciphertext(ct, sk, msg, params) +} + +#[cfg(test)] +mod tests { + + use crate::{high_level::*, PlaintextBits}; + + #[test] + fn can_encrypt_decrypt() { + let params = TEST_LWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let ct = encryption::encrypt_lwe_secret(4, &sk, ¶ms, bits); + let pt = encryption::decrypt_lwe(&ct, &sk, ¶ms, bits); + + assert_eq!(pt, 4); + } + + #[test] + fn can_encrypt_decrypt_uniform() { + let params = TEST_LWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_uniform_lwe_sk(¶ms); + + let ct = encryption::encrypt_lwe_secret(4, &sk, ¶ms, bits); + let pt = encryption::decrypt_lwe(&ct, &sk, ¶ms, bits); + + assert_eq!(pt, 4); + } + + #[test] + fn can_trivially_decrypt() { + let params = TEST_LWE_DEF_1; + let bits = PlaintextBits(4); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let ct = encryption::trivial_lwe(4, ¶ms, bits); + let pt = encryption::decrypt_lwe(&ct, &sk, ¶ms, bits); + + assert_eq!(pt, 4); + } +} diff --git a/sunscreen_tfhe/src/ops/encryption/mod.rs b/sunscreen_tfhe/src/ops/encryption/mod.rs new file mode 100644 index 000000000..7df920f97 --- /dev/null +++ b/sunscreen_tfhe/src/ops/encryption/mod.rs @@ -0,0 +1,8 @@ +mod lwe_encryption; +pub use lwe_encryption::*; + +mod glwe_encryption; +pub use glwe_encryption::*; + +mod ggsw_encryption; +pub use ggsw_encryption::*; diff --git a/sunscreen_tfhe/src/ops/fft_ops.rs b/sunscreen_tfhe/src/ops/fft_ops.rs new file mode 100644 index 000000000..6a0c56f4c --- /dev/null +++ b/sunscreen_tfhe/src/ops/fft_ops.rs @@ -0,0 +1,285 @@ +use num::Complex; + +use crate::{ + dst::{FromMutSlice, OverlaySize}, + entities::{ + GgswCiphertextFftRef, GlevCiphertextFftRef, GlweCiphertextFftRef, GlweCiphertextRef, + PolynomialFftRef, PolynomialRef, + }, + ops::ciphertext::{add_glwe_ciphertexts, sub_glwe_ciphertexts}, + radix::PolynomialRadixIterator, + scratch::{allocate_scratch, allocate_scratch_ref}, + GlweDef, RadixDecomposition, TorusOps, +}; + +/// Compute `c += a \[*] b`` where +/// * `a` is a GLWE ciphertext +/// * `b` is a GGSW cipheetext +/// * `\[*\]` is the external product operator GGSW \[*\] GLWE -> GLWE` +pub fn glwe_ggsw_mad( + c_fft: &mut GlweCiphertextFftRef>, + a: &GlweCiphertextRef, + b_fft: &GgswCiphertextFftRef>, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + let (a_a, a_b) = a.a_b(params); + let rows = b_fft.rows(params, radix); + + // Generate an iterator that includes a_a and a_b + let a_then_b_glwe_polynomials = a_a.chain(std::iter::once(a_b)); + + allocate_scratch_ref!(scratch, PolynomialRef, (params.dim.polynomial_degree)); + + // Performs the external operation + // + // GGSW ⊡ GLWE = sum_i=0^k + // + // where + // * `beta` is the decomposition base + // * `l` is the decomposition level + // * `AB_i` is the i-th polynomial of the GLWE ciphertext with B at the end {A, B} + // * `C_i` is the i-th row of the GGSW ciphertext + for (a_i, r) in a_then_b_glwe_polynomials.zip(rows) { + // For each a polynomial, compute the external product with the GLEV ciphertext + // at index i in the GGSW ciphertext. + let decomp = PolynomialRadixIterator::new(a_i, scratch, radix); + + decomposed_polynomial_glev_mad(c_fft, decomp, r, params); + } +} + +/// Compute `c += (G^-1 * a) \[*\] b`, where +/// * `G^-1 * a`` is the radix decomposition of `a` +/// * `b` is a GLEV ciphertext. +/// * `c` is a GLWE ciphertext. +/// * \[*\] is the external product between a GLEV ciphertext and `l` polynomials +/// +/// # Remarks +/// This functions takes a PolynomialRadixIterator to perform the decomposition. +pub fn decomposed_polynomial_glev_mad( + c: &mut GlweCiphertextFftRef>, + mut a: PolynomialRadixIterator, + b: &GlevCiphertextFftRef>, + params: &GlweDef, +) where + S: TorusOps, +{ + let b_glwe = b.glwe_ciphertexts(params); + + let mut cur_radix = allocate_scratch::(params.dim.polynomial_degree.0); + let cur_radix = PolynomialRef::from_mut_slice(cur_radix.as_mut_slice()); + + let mut decomp_fft = allocate_scratch(PolynomialFftRef::>::size( + params.dim.polynomial_degree, + )); + let decomp_fft = PolynomialFftRef::from_mut_slice(decomp_fft.as_mut_slice()); + + // The decomposition of + // + // can be performed using + // sum_{j = 1}^l gamma_j * C_j + // where gamma_j is the polynomial to decompose multiplied by q/B^{j+1} + // Note the reverse of the GLWE ciphertexts here! The decomposition iterator + // returns the decomposed values in the opposite order. + for b in b_glwe.rev() { + a.write_next(cur_radix); + cur_radix.fft(decomp_fft); + + glwe_polynomial_mad(c, b, decomp_fft, params); + } +} + +/// Compute c += a \[*\] b where \[*\] is the external product +/// between a GLWE ciphertext and an polynomial in Z\[X\]/(X^N + 1). +/// +/// # Remarks +/// For this to produce the correct result, degree(b) must be "small" +/// (ideally 0) and the coefficient must be small (i.e. less than the +/// message size). +pub fn glwe_polynomial_mad( + c: &mut GlweCiphertextFftRef>, + a: &GlweCiphertextFftRef>, + b: &PolynomialFftRef>, + params: &GlweDef, +) { + let (c_a, c_b) = c.a_b_mut(params); + let (a_a, a_b) = a.a_b(params); + + assert_eq!(c_a.len(), params.dim.size.0); + assert_eq!(a_a.len(), params.dim.size.0); + + for (c, a) in c_a.zip(a_a) { + c.multiply_add(a, b); + } + + c_b.multiply_add(a_b, b); +} + +/// Performs a CMUX operation, which enables one of two GLWE ciphertexts +/// to be selected from an encrypted boolean GGSW ciphertext. The result +/// is stored in `c`. +/// +/// Conceptually, this can be seen as the following operation in Rust: +/// +/// ```text +/// let c = if b_fft { d_1 } else { d_0 } +/// ``` +/// +/// where the output `c` is a different encryption than either of the initial +/// inputs. Note that this will result in higher noise than in the original +/// ciphertexts. +pub fn cmux( + c: &mut GlweCiphertextRef, + d_0: &GlweCiphertextRef, + d_1: &GlweCiphertextRef, + b_fft: &GgswCiphertextFftRef>, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + allocate_scratch_ref!(diff, GlweCiphertextRef, (params.dim)); + + sub_glwe_ciphertexts(diff, d_1, d_0, params); + + allocate_scratch_ref!(prod_fft, GlweCiphertextFftRef>, (params.dim)); + + prod_fft.clear(); + + glwe_ggsw_mad(prod_fft, diff, b_fft, params, radix); + + allocate_scratch_ref!(prod, GlweCiphertextRef, (params.dim)); + + prod_fft.ifft(prod, params); + + add_glwe_ciphertexts(c, prod, d_0, params); +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use crate::{ + entities::{GgswCiphertextFft, GlweCiphertext, GlweCiphertextFft, Polynomial}, + high_level::*, + PlaintextBits, Torus, + }; + + use super::*; + + #[test] + fn can_fft_external_product_glwe_ggsw() { + let glwe_params = TEST_GLWE_DEF_1; + let radix = TEST_RADIX; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_glwe_sk(&glwe_params); + + for _ in 0..100 { + let sel = thread_rng().next_u64() % 2; + + let ggsw = encryption::encrypt_ggsw(sel, &sk, &glwe_params, &radix, bits); + + let glwe_pt = (0..glwe_params.dim.polynomial_degree.0) + .map(|_| thread_rng().next_u64() % 2) + .collect::>(); + let glwe_pt = Polynomial::new(&glwe_pt); + + let glwe = encryption::encrypt_glwe(&glwe_pt, &sk, &glwe_params, bits); + + let mut ggsw_fft = GgswCiphertextFft::new(&glwe_params, &radix); + let mut res_fft = GlweCiphertextFft::new(&glwe_params); + let mut res = GlweCiphertext::::new(&glwe_params); + + ggsw.fft(&mut ggsw_fft, &glwe_params, &radix); + + glwe_ggsw_mad(&mut res_fft, &glwe, &ggsw_fft, &glwe_params, &radix); + + res_fft.ifft(&mut res, &glwe_params); + + let actual = encryption::decrypt_glwe(&res, &sk, &glwe_params, bits); + + if sel == 1 { + assert_eq!(actual, glwe_pt); + } else { + assert_eq!( + actual, + Polynomial::zero(glwe_params.dim.polynomial_degree.0) + ); + } + } + } + + #[test] + fn can_cmux_fft() { + let glwe = TEST_GLWE_DEF_1; + let sk = keygen::generate_binary_glwe_sk(&glwe); + let radix = TEST_RADIX; + let bits = PlaintextBits(1); + + for _ in 0..100 { + let sel = thread_rng().next_u64() % 2; + + let sel_ct = encryption::encrypt_ggsw(sel, &sk, &glwe, &radix, bits); + + let a = (0..glwe.dim.polynomial_degree.0) + .map(|_| thread_rng().next_u64() % 2) + .collect::>(); + let a = Polynomial::new(&a); + + let a_ct = encryption::encrypt_glwe(&a, &sk, &glwe, bits); + + let b = (0..glwe.dim.polynomial_degree.0) + .map(|_| thread_rng().next_u64() % 2) + .collect::>(); + let b = Polynomial::new(&b); + + let b_ct = encryption::encrypt_glwe(&b, &sk, &glwe, bits); + + let sel_fft = fft::fft_ggsw(&sel_ct, &glwe, &radix); + + let mut res_ct = GlweCiphertext::new(&glwe); + + cmux(&mut res_ct, &a_ct, &b_ct, &sel_fft, &glwe, &radix); + + let actual = encryption::decrypt_glwe(&res_ct, &sk, &glwe, bits); + + if sel == 1 { + assert_eq!(actual, b); + } else { + assert_eq!(actual, a); + } + } + } + + #[test] + fn cmux_trivial_ciphertexts_yields_nontrivial() { + let sk = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1); + + let plaintext_bits = crate::PlaintextBits(1); + + let a = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(); + let a = encryption::trivial_glwe(&a, &TEST_GLWE_DEF_1, plaintext_bits); + let b = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0 as u64) + .map(|x| (x + 1) % 2) + .collect::>(); + let b = encryption::trivial_glwe(&b, &TEST_GLWE_DEF_1, plaintext_bits); + + let sel = encryption::encrypt_ggsw(1, &sk, &TEST_GLWE_DEF_1, &TEST_RADIX, plaintext_bits); + + let sel = fft::fft_ggsw(&sel, &TEST_GLWE_DEF_1, &TEST_RADIX); + + let res = evaluation::cmux(&sel, &a, &b, &TEST_GLWE_DEF_1, &TEST_RADIX); + + for a in res.a(&TEST_GLWE_DEF_1) { + let zero = Polynomial::>::zero(TEST_GLWE_DEF_1.dim.polynomial_degree.0); + + assert_ne!(a.to_owned(), zero); + } + } +} diff --git a/sunscreen_tfhe/src/ops/homomorphisms/lwe.rs b/sunscreen_tfhe/src/ops/homomorphisms/lwe.rs new file mode 100644 index 000000000..955f90b89 --- /dev/null +++ b/sunscreen_tfhe/src/ops/homomorphisms/lwe.rs @@ -0,0 +1,67 @@ +use crate::{entities::LweCiphertextRef, LweDef, Torus, TorusOps}; + +/// Add `amount` to each torus element (mod q) in the ciphertext. +/// This shifts where messages lie on the torus and adds no noise. +/// +/// # Remark +/// Suppose we have plaintexts 0 and 1 that lie centered at 0 and q/2 respectively. +/// If we rotate by q/4, then the 0 lies centered at q/4 and 1 lies at 3q/4 == -q/4. +pub fn rotate( + output: &mut LweCiphertextRef, + input: &LweCiphertextRef, + amount: Torus, + lwe: &LweDef, +) { + output.assert_valid(lwe); + input.assert_valid(lwe); + + output.a_mut(lwe).clone_from_slice(input.a(lwe)); + *output.b_mut(lwe) = input.b(lwe) + amount; +} + +#[cfg(test)] +mod tests { + use crate::{ + entities::LweCiphertext, + high_level::{keygen, TEST_LWE_DEF_1}, + PlaintextBits, Torus, + }; + + use super::rotate; + + #[test] + fn can_rotate() { + let lwe_params = TEST_LWE_DEF_1; + + for _ in 0..100 { + let sk = keygen::generate_binary_lwe_sk(&lwe_params); + let val = sk.encrypt(0, &lwe_params, PlaintextBits(1)).0; + + let mut res = LweCiphertext::new(&lwe_params); + + rotate( + &mut res, + &val, + Torus::encode(1, PlaintextBits(2)), + &lwe_params, + ); + + let t = sk.decrypt_without_decode(&res, &lwe_params); + + assert!(t.inner() < 0x1u64 << 63); + + let val = sk.encrypt(1, &lwe_params, PlaintextBits(1)).0; + + rotate( + &mut res, + &val, + Torus::encode(1, PlaintextBits(2)), + &lwe_params, + ); + + let t = sk.decrypt_without_decode(&res, &lwe_params); + + assert!(t.inner() > 0x1u64 << 63); + } + } +} diff --git a/sunscreen_tfhe/src/ops/homomorphisms/mod.rs b/sunscreen_tfhe/src/ops/homomorphisms/mod.rs new file mode 100644 index 000000000..b5baa4414 --- /dev/null +++ b/sunscreen_tfhe/src/ops/homomorphisms/mod.rs @@ -0,0 +1,2 @@ +mod lwe; +pub use lwe::*; diff --git a/sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch.rs b/sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch.rs new file mode 100644 index 000000000..629304d0a --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch.rs @@ -0,0 +1,107 @@ +use crate::{ + dst::FromMutSlice, + entities::{GlweCiphertext, GlweCiphertextRef, GlweKeyswitchKeyRef, PolynomialRef}, + ops::{ + ciphertext::{decomposed_polynomial_glev_mad, sub_glwe_ciphertexts}, + encryption::trivially_encrypt_glwe_ciphertext, + }, + radix::PolynomialRadixIterator, + scratch::allocate_scratch_ref, + GlweDef, RadixDecomposition, TorusOps, +}; + +/// Switches a ciphertext under the original key to a ciphertext under the new +/// key using a keyswitch key. +/// +/// # Remark +/// +/// This performs the following operation: +/// +/// ```text +/// switched_ciphertext = trivial_encrypt(ciphertext_b) - sum_i() +/// ``` +/// +/// where `trivial_encrypt` is the encryption of the body of the original +/// ciphertext. +pub fn keyswitch_glwe_to_glwe( + output: &mut GlweCiphertextRef, + ciphertext_under_original_key: &GlweCiphertextRef, + keyswitch_key: &GlweKeyswitchKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(params); + + let keyswitch_glevs = keyswitch_key.rows(params, radix); + + let mut a_i_decomp_sum = GlweCiphertext::new(params); + allocate_scratch_ref!(scratch, PolynomialRef, (params.dim.polynomial_degree)); + + // sum_i() + for (a_i, glev_i) in ciphertext_a.zip(keyswitch_glevs) { + let decomp = PolynomialRadixIterator::new(a_i, scratch, radix); + + decomposed_polynomial_glev_mad(&mut a_i_decomp_sum, decomp, glev_i, params); + } + + // trivial_encrypt(ciphertext_b) + let mut trivial_b = GlweCiphertext::new(params); + trivially_encrypt_glwe_ciphertext(&mut trivial_b, ciphertext_b, params); + + // output = trivial_encrypt(ciphertext_b) - sum_i() + sub_glwe_ciphertexts(output, &trivial_b, &a_i_decomp_sum, params); +} + +#[cfg(test)] +mod tests { + + use crate::{ + entities::{GlweCiphertext, GlweKeyswitchKey, Polynomial}, + high_level::*, + ops::keyswitch::{ + glwe_keyswitch::keyswitch_glwe_to_glwe, glwe_keyswitch_key::generate_keyswitch_key_glwe, + }, + PlaintextBits, + }; + + #[test] + fn keyswitch_glwe() { + let glwe = TEST_GLWE_DEF_1; + let bits = PlaintextBits(1); + + let original_sk = keygen::generate_binary_glwe_sk(&glwe); + let new_sk = keygen::generate_binary_glwe_sk(&glwe); + + let mut ksk = GlweKeyswitchKey::::new(&TEST_GLWE_DEF_1, &TEST_RADIX); + generate_keyswitch_key_glwe( + &mut ksk, + &original_sk, + &new_sk, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + ); + + let msg = Polynomial::new( + &(0..glwe.dim.polynomial_degree.0 as u64) + .map(|x| x % 2) + .collect::>(), + ); + + let original_ct = original_sk.encode_encrypt_glwe(&msg, &glwe, bits); + + let mut new_ct = GlweCiphertext::new(&glwe); + keyswitch_glwe_to_glwe( + &mut new_ct, + &original_ct, + &ksk, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + ); + + let new_decrypted = new_sk.decrypt_decode_glwe(&new_ct, &glwe, bits); + + assert_eq!(new_decrypted.coeffs(), msg.coeffs()); + } +} diff --git a/sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch_key.rs b/sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch_key.rs new file mode 100644 index 000000000..4e705efd8 --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch_key.rs @@ -0,0 +1,150 @@ +use crate::{ + entities::{ + GlweCiphertextRef, GlweKeyswitchKeyRef, GlweSecretKeyRef, Polynomial, PolynomialRef, + }, + ops::encryption::encrypt_glwe_ciphertext_secret_generic, + polynomial::polynomial_scalar_mul, + GlweDef, RadixDecomposition, Torus, TorusOps, +}; + +fn encrypt_glwe_ciphertext_secret_with_keyswitch_noise( + c: &mut GlweCiphertextRef, + msg: &PolynomialRef>, + sk: &GlweSecretKeyRef, + params: &GlweDef, +) where + S: TorusOps, +{ + encrypt_glwe_ciphertext_secret_generic(c, msg, sk, params); +} + +/** + * Generates a keyswitch key from the original key to the new key. The resulting + * keyswitch key is encrypted under the new key. This function is generic over + * the encrypt function should there be a need. + * + * The specific operation on each GLev ciphertext row inside a keyswitch key is + * + * ```text + * KSK_i = (GLWE_{s', }) + * ``` + */ +fn encrypt_keyswitch_key_generic( + keyswitch_key: &mut GlweKeyswitchKeyRef, + original_glwe_secret_key: &GlweSecretKeyRef, + new_glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, + encrypt: impl Fn( + &mut GlweCiphertextRef, + &PolynomialRef>, + &GlweSecretKeyRef, + &GlweDef, + ), +) where + S: TorusOps, +{ + let decomposition_radix_log = radix.radix_log.0; + let polynomial_degree = params.dim.polynomial_degree.0; + + for (i, row) in keyswitch_key.rows_mut(params, radix).enumerate() { + let s = original_glwe_secret_key + .s(params) + .nth(i) + .unwrap() + .map(|x| Torus::from(*x)); + + for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() { + let mut scaled_original_key = Polynomial::zero(polynomial_degree); + // The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to + // multiplying by 2^{log2(q) - log2(B) * (i + 1)} + let decomp_factor = + S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1))); + + polynomial_scalar_mul(&mut scaled_original_key, &s, decomp_factor); + + encrypt(col, &scaled_original_key, new_glwe_secret_key, params); + } + } +} + +/// Generate a keyswitch key from the original key to the new key. +/// For use with +/// [`keyswitch_glwe_to_glwe`](crate::ops::keyswitch::glwe_keyswitch::keyswitch_glwe_to_glwe). +pub fn generate_keyswitch_key_glwe( + keyswitch_key: &mut GlweKeyswitchKeyRef, + original_glwe_secret_key: &GlweSecretKeyRef, + new_glwe_secret_key: &GlweSecretKeyRef, + params: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + encrypt_keyswitch_key_generic( + keyswitch_key, + original_glwe_secret_key, + new_glwe_secret_key, + params, + radix, + encrypt_glwe_ciphertext_secret_with_keyswitch_noise, + ) +} + +#[cfg(test)] +mod tests { + + use crate::{ + dst::FromSlice, + entities::{GlweKeyswitchKey, GlweKeyswitchKeyRef}, + high_level::{TEST_GLWE_DEF_1, TEST_RADIX}, + Torus, + }; + + #[test] + fn test_generate_keyswitch_key_glwe() { + // Size of the arrary should be (k + 1) * l * k * poly_degree as we are + // GLev encrypting all S_i with a given radix count and base. + let ksk = GlweKeyswitchKey::::new(&TEST_GLWE_DEF_1, &TEST_RADIX); + + let k = TEST_GLWE_DEF_1.dim.size.0; + let l = TEST_RADIX.count.0; + let poly_degree = TEST_GLWE_DEF_1.dim.polynomial_degree.0; + assert_eq!(ksk.as_slice().len(), (k + 1) * l * k * poly_degree); + + // Generate fake data to iterate through. + let ksk_data = (0..ksk.as_slice().len()) + .map(|x| Torus::from(x as u64)) + .collect::>>(); + + let ksk = GlweKeyswitchKeyRef::::from_slice(&ksk_data); + + // Check that the data is correct. + let glwe_size = TEST_GLWE_DEF_1.dim.polynomial_degree.0 * (TEST_GLWE_DEF_1.dim.size.0 + 1); + let mut count = 0; + for row in ksk.rows(&TEST_GLWE_DEF_1, &TEST_RADIX) { + for glwe in row.glwe_ciphertexts(&TEST_GLWE_DEF_1) { + let (a, b) = glwe.a_b(&TEST_GLWE_DEF_1); + + let expected_vector = (count..(glwe_size + count)) + .map(|x| Torus::from(x as u64)) + .collect::>(); + + let (a_expected, b_expected) = expected_vector + .split_at(TEST_GLWE_DEF_1.dim.size.0 * TEST_GLWE_DEF_1.dim.polynomial_degree.0); + + let a_i_expected = a_expected + .chunks(TEST_GLWE_DEF_1.dim.polynomial_degree.0) + .map(|x| x.to_vec()) + .collect::>>>(); + + for (a_j, a_j_expected) in a.zip(a_i_expected) { + assert_eq!(a_j.coeffs(), a_j_expected); + } + + assert_eq!(b.coeffs(), b_expected); + + count += glwe_size; + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch.rs b/sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch.rs new file mode 100644 index 000000000..0a41188b9 --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch.rs @@ -0,0 +1,91 @@ +use crate::{ + dst::{FromMutSlice, FromSlice}, + entities::{LweCiphertext, LweCiphertextRef, LweKeyswitchKeyRef, PolynomialRef}, + ops::{ + ciphertext::{decomposed_scalar_lev_mad, sub_lwe_ciphertexts}, + encryption::trivially_encrypt_lwe_ciphertext, + }, + radix::PolynomialRadixIterator, + scratch::allocate_scratch_ref, + LweDef, PolynomialDegree, RadixDecomposition, TorusOps, +}; + +/// Switches a ciphertext under the original key to a ciphertext under the new +/// key using a keyswitch key. +/// +/// Arguments: +/// +/// * output: the output ciphertext +/// * ciphertext_under_original_key: the input ciphertext +/// * keyswitch_key: the keyswitch key +/// * old_params: the parameters of the original ciphertext +/// * new_params: the parameters of the output ciphertext +pub fn keyswitch_lwe_to_lwe( + output: &mut LweCiphertextRef, + ciphertext_under_original_key: &LweCiphertextRef, + keyswitch_key: &LweKeyswitchKeyRef, + old_params: &LweDef, + new_params: &LweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + keyswitch_key.assert_valid(old_params, new_params, radix); + + let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(old_params); + + let keyswitch_levs = keyswitch_key.rows(new_params, radix); + + let mut a_i_decomp_sum = LweCiphertext::new(new_params); + + allocate_scratch_ref!(scratch, PolynomialRef, (PolynomialDegree(1))); + + // sum_i() + for (a_i, lev_i) in ciphertext_a.iter().zip(keyswitch_levs) { + let decomp = + PolynomialRadixIterator::new(PolynomialRef::from_slice(&[*a_i]), scratch, radix); + + decomposed_scalar_lev_mad(&mut a_i_decomp_sum, decomp, lev_i, new_params); + } + + // trivial_encrypt(ciphertext_b) + let mut trivial_b = LweCiphertext::new(new_params); + trivially_encrypt_lwe_ciphertext(&mut trivial_b, ciphertext_b, new_params); + + // output = trivial_encrypt(ciphertext_b) - sum_i() + sub_lwe_ciphertexts(output, &trivial_b, &a_i_decomp_sum, new_params); +} + +#[cfg(test)] +mod tests { + + use rand::{thread_rng, RngCore}; + + use crate::{high_level::*, PlaintextBits}; + + #[test] + fn keyswitch_lwe() { + let bits = PlaintextBits(4); + let from_lwe = TEST_LWE_DEF_1; + let to_lwe = TEST_LWE_DEF_2; + let radix = TEST_RADIX; + + for _ in 0..50 { + let original_sk = keygen::generate_binary_lwe_sk(&from_lwe); + let new_sk = keygen::generate_binary_lwe_sk(&to_lwe); + + let ksk = keygen::generate_ksk(&original_sk, &new_sk, &from_lwe, &to_lwe, &radix); + + let msg = thread_rng().next_u64() % (1 << bits.0); + + let original_ct = original_sk.encrypt(msg, &from_lwe, bits).0; + + let new_ct = + evaluation::keyswitch_lwe_to_lwe(&original_ct, &ksk, &from_lwe, &to_lwe, &radix); + + let new_decrypted = new_sk.decrypt(&new_ct, &to_lwe, bits); + + assert_eq!(new_decrypted, msg); + } + } +} diff --git a/sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch_key.rs b/sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch_key.rs new file mode 100644 index 000000000..4ae22547a --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch_key.rs @@ -0,0 +1,41 @@ +use crate::{ + entities::{LweKeyswitchKeyRef, LweSecretKeyRef}, + ops::encryption::encrypt_lwe_ciphertext, + LweDef, RadixDecomposition, Torus, TorusOps, +}; + +/// Generates a keyswitch key from an original LWE key to a new LWE key. The +/// resulting keyswitch key is encrypted under the new key. +/// +/// Arguments: +/// +/// * keyswitch_key: the resulting keyswitch key +/// * original_lwe_secret_key: the original LWE secret key +/// * new_lwe_secret_key: the new LWE secret key +/// * new_params: the parameters of the new LWE secret key +pub fn generate_keyswitch_key_lwe( + keyswitch_key: &mut LweKeyswitchKeyRef, + original_lwe_secret_key: &LweSecretKeyRef, + new_lwe_secret_key: &LweSecretKeyRef, + new_params: &LweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, +{ + let decomposition_radix_log = radix.radix_log.0; + + for (i, row) in keyswitch_key.rows_mut(new_params, radix).enumerate() { + let s_i = original_lwe_secret_key.s()[i]; + + for (j, col) in row.lwe_ciphertexts_mut(new_params).enumerate() { + // The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to + // multiplying by 2^{log2(q) - log2(B) * (i + 1)} + let decomp_factor = + S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1))); + + let msg = decomp_factor * s_i; + + encrypt_lwe_ciphertext(col, new_lwe_secret_key, Torus::from(msg), new_params); + } + } +} diff --git a/sunscreen_tfhe/src/ops/keyswitch/mod.rs b/sunscreen_tfhe/src/ops/keyswitch/mod.rs new file mode 100644 index 000000000..1715cfcee --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/mod.rs @@ -0,0 +1,17 @@ +/// Methods for performing a private functional keyswitch (PFKS) +pub mod private_functional_keyswitch; + +/// Methods for performing a public functional keyswitch (PuFKS) +pub mod public_functional_keyswitch; + +/// Generate LWE keyswitch keys. +pub mod lwe_keyswitch_key; + +/// Generate GLWE keyswitch keys. +pub mod glwe_keyswitch_key; + +/// Methods for performing a LWE keyswitch. +pub mod lwe_keyswitch; + +/// Methods for performing a GLWE keyswitch. +pub mod glwe_keyswitch; diff --git a/sunscreen_tfhe/src/ops/keyswitch/private_functional_keyswitch.rs b/sunscreen_tfhe/src/ops/keyswitch/private_functional_keyswitch.rs new file mode 100644 index 000000000..38df1d0c8 --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/private_functional_keyswitch.rs @@ -0,0 +1,350 @@ +use sunscreen_math::Zero; + +use crate::{ + dst::FromMutSlice, + entities::{ + CircuitBootstrappingKeyswitchKeysRef, GlweCiphertextRef, GlweSecretKeyRef, + LweCiphertextRef, LweSecretKeyRef, PolynomialRef, PrivateFunctionalKeyswitchKeyRef, + }, + ops::{ + ciphertext::{decomposed_scalar_glev_mad, glwe_negate_inplace}, + encryption::encrypt_glwe_ciphertext_secret, + }, + radix::{scale_by_decomposition_factor, ScalarRadixIterator}, + scratch::allocate_scratch_ref, + GlweDef, LweDef, PrivateFunctionalKeyswitchLweCount, RadixDecomposition, Torus, TorusOps, +}; + +/// Initialize `output`, a +/// [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey), +/// under the given scheme parameters for the given secret mapping `map`. +/// Conceptually, this map transforms a list of torus plaintexts into a +/// polynomial plaintext. +/// +/// # Remarks +/// `map` must be an R-Lipschitzian morphism `T_q^p -> T_q[X]` where `p = lwe_count`. +/// +/// The first parameter in `map` is the output +/// [`Polynomial`](crate::entities::Polynomial) of the morphism. This parameter +/// is initialized to 0. +/// +/// The second argument is a [`&[Torus]`](crate::math::Torus) of length `lwe_count`. +/// +/// # Security +/// To prevent side channels, `map` must run in constant time. +/// +/// # Panics +/// * If `output` is not valid for the given `from_lwe`, `to_glwe`, `radix`, `lwe_count`. +/// * If any of `from_lwe`, `to_glwe`, `radix`, `lwe_count` are invalid. +#[allow(clippy::too_many_arguments)] +pub fn generate_private_functional_keyswitch_key( + output: &mut PrivateFunctionalKeyswitchKeyRef, + from_key: &LweSecretKeyRef, + to_key: &GlweSecretKeyRef, + map: F, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + lwe_count: &PrivateFunctionalKeyswitchLweCount, +) where + S: TorusOps, + F: Fn(&mut PolynomialRef>, &[Torus]), +{ + output.assert_valid(from_lwe, to_glwe, radix, lwe_count); + radix.assert_valid::(); + from_key.assert_valid(from_lwe); + to_key.assert_valid(to_glwe); + to_glwe.assert_valid(); + from_lwe.assert_valid(); + lwe_count.assert_valid(); + + allocate_scratch_ref!( + pt_poly, + PolynomialRef>, + (to_glwe.dim.polynomial_degree) + ); + allocate_scratch_ref!(pt_touri, [Torus], lwe_count.0); + + let mut glevs = output.glevs_mut(to_glwe, radix); + let minus_one = ::zero().wrapping_sub(&::one()); + + for z in 0..lwe_count.0 { + for s_i in from_key.s().iter().chain([minus_one].iter()) { + let glev = glevs.next().unwrap(); + + for (j, glwe) in glev.glwe_ciphertexts_mut(to_glwe).enumerate() { + let scaled_s_i = scale_by_decomposition_factor(*s_i, j, radix); + + pt_poly.clear(); + pt_touri.iter_mut().for_each(|x| *x = Torus::zero()); + + pt_touri[z] = Torus::from(scaled_s_i); + + map(pt_poly, pt_touri); + + encrypt_glwe_ciphertext_secret(glwe, pt_poly, to_key, to_glwe); + } + } + } +} + +/// Perform a private functional keyswitch. See +/// [`module`](crate::ops::keyswitch::private_functional_keyswitch) documentation for more +/// details. +pub fn private_functional_keyswitch( + output: &mut GlweCiphertextRef, + inputs: &[&LweCiphertextRef], + pfksk: &PrivateFunctionalKeyswitchKeyRef, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, + lwe_count: &PrivateFunctionalKeyswitchLweCount, +) { + output.assert_valid(to_glwe); + pfksk.assert_valid(from_lwe, to_glwe, radix, lwe_count); + from_lwe.assert_valid(); + to_glwe.assert_valid(); + radix.assert_valid::(); + lwe_count.assert_valid(); + + assert_eq!(lwe_count.0, inputs.len()); + + let mut ksk_glevs = pfksk.glevs(to_glwe, radix); + + for input in inputs.iter() { + for i in 0..from_lwe.dim.0 + 1 { + // Treating the z'th ciphertext as slice of length n + 1 allows us to iterate + // over a || b. + let ab = input.as_slice(); + + let glev = ksk_glevs.next().unwrap(); + let decomp = ScalarRadixIterator::new(ab[i], radix); + + decomposed_scalar_glev_mad(output, decomp, glev, to_glwe); + } + } + + // Return minus output. + glwe_negate_inplace(output, to_glwe); +} + +/// Generate the keys for a private functional keyswitch. +pub fn generate_circuit_bootstrapping_pfks_keys( + output: &mut CircuitBootstrappingKeyswitchKeysRef, + from_key: &LweSecretKeyRef, + to_key: &GlweSecretKeyRef, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, +) { + output.assert_valid(from_lwe, to_glwe, radix); + from_key.assert_valid(from_lwe); + to_glwe.assert_valid(); + to_key.assert_valid(to_glwe); + radix.assert_valid::(); + from_lwe.assert_valid(); + + // Fill in k pfks keys that multiply each of the "a" GLEVs by the corresponding + // polynomial in the GLWE secret key. + for (pfksk, s) in output + .keys_mut(from_lwe, to_glwe, radix) + .zip(to_key.s(to_glwe)) + .take(to_glwe.dim.size.0) + { + let map = |poly: &mut PolynomialRef>, x: &[Torus]| { + for (c, a) in poly.coeffs_mut().iter_mut().zip(s.coeffs().iter()) { + *c = -x[0] * a; + } + }; + + generate_private_functional_keyswitch_key( + pfksk, + from_key, + to_key, + map, + from_lwe, + to_glwe, + radix, + &PrivateFunctionalKeyswitchLweCount(1), + ); + } + + // Now fill in the "b" GLEV. + // TODO: We could compute this row with public key switching. Is it worth it? + let b = output + .keys_mut(from_lwe, to_glwe, radix) + .nth(to_glwe.dim.size.0) + .unwrap(); + + let map = |poly: &mut PolynomialRef>, x: &[Torus]| { + poly.clear(); + poly.coeffs_mut()[0] = x[0]; + }; + + generate_private_functional_keyswitch_key( + b, + from_key, + to_key, + map, + from_lwe, + to_glwe, + radix, + &PrivateFunctionalKeyswitchLweCount(1), + ) +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, RngCore}; + + use crate::{ + entities::{GlweCiphertext, PrivateFunctionalKeyswitchKey}, + high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX}, + PlaintextBits, PrivateFunctionalKeyswitchLweCount, + }; + + use super::*; + + #[test] + fn can_create_private_functional_keyswitch_key() { + for _ in 0..5 { + let lwe_count = + PrivateFunctionalKeyswitchLweCount((thread_rng().next_u64() as usize % 8) + 1); + + let lwe_key = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1); + let glwe_key = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1); + + let mut pfks_key = PrivateFunctionalKeyswitchKey::::new( + &TEST_LWE_DEF_1, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + &lwe_count, + ); + + fn map(poly: &mut PolynomialRef>, inputs: &[Torus]) { + for (i, input) in inputs.iter().enumerate() { + poly.coeffs_mut()[i] = *input; + } + } + + generate_private_functional_keyswitch_key( + &mut pfks_key, + &lwe_key, + &glwe_key, + map, + &TEST_LWE_DEF_1, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + &lwe_count, + ); + + let mut glevs = pfks_key.glevs(&TEST_GLWE_DEF_1, &TEST_RADIX); + + let minus_one = u64::MAX; + + for z in 0..lwe_count.0 { + for s_i in lwe_key.s().iter().chain([minus_one].iter()) { + let glev = glevs.next().unwrap(); + + for (j, glwe) in glev.glwe_ciphertexts(&TEST_GLWE_DEF_1).enumerate() { + let plaintext_bits = (j + 1) * TEST_RADIX.radix_log.0; + let plaintext_bits = PlaintextBits(plaintext_bits as u32); + + let pt = + glwe_key.decrypt_decode_glwe(glwe, &TEST_GLWE_DEF_1, plaintext_bits); + + for i in 0..pt.coeffs().len() { + if *s_i == minus_one && i == z { + let expected = (0x1u64 << ((j + 1) * TEST_RADIX.radix_log.0)) - 1; + let actual = pt.coeffs()[i]; + + assert_eq!(actual, expected); + } else if i == z { + assert_eq!(pt.coeffs()[i], *s_i); + } else { + assert_eq!(pt.coeffs()[i], 0); + } + } + } + } + } + } + } + + #[test] + fn can_private_functional_keyswitch() { + for _ in 0..5 { + let lwe_count = + PrivateFunctionalKeyswitchLweCount((thread_rng().next_u64() as usize % 8) + 1); + + let lwe_key = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1); + let glwe_key = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1); + + let mut pfks_key = PrivateFunctionalKeyswitchKey::::new( + &TEST_LWE_DEF_1, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + &lwe_count, + ); + + fn map(poly: &mut PolynomialRef>, inputs: &[Torus]) { + for (i, input) in inputs.iter().enumerate() { + poly.coeffs_mut()[i] = *input; + } + } + + generate_private_functional_keyswitch_key( + &mut pfks_key, + &lwe_key, + &glwe_key, + map, + &TEST_LWE_DEF_1, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + &lwe_count, + ); + + let plaintext_bits = PlaintextBits(4); + + let pts = (0..lwe_count.0) + .map(|_x| thread_rng().next_u64() % (0x1u64 << plaintext_bits.0)) + .collect::>(); + let lwe_cts = pts + .iter() + .map(|x| lwe_key.encrypt(*x, &TEST_LWE_DEF_1, plaintext_bits)) + .collect::>(); + let mut lwe_ct_refs: Vec<&LweCiphertextRef<_>> = vec![]; + + for ct in lwe_cts.iter() { + lwe_ct_refs.push(&ct.0); + } + + let mut result = GlweCiphertext::new(&TEST_GLWE_DEF_1); + + private_functional_keyswitch( + &mut result, + &lwe_ct_refs, + &pfks_key, + &TEST_LWE_DEF_1, + &TEST_GLWE_DEF_1, + &TEST_RADIX, + &lwe_count, + ); + + let actual = glwe_key.decrypt_decode_glwe(&result, &TEST_GLWE_DEF_1, plaintext_bits); + + for (i, (c, pt)) in actual + .coeffs() + .iter() + .zip(pts.iter().cycle().take(actual.coeffs().len())) + .enumerate() + { + if i < lwe_count.0 { + assert_eq!(*c, *pt); + } else { + assert_eq!(*c, 0); + } + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/keyswitch/public_functional_keyswitch.rs b/sunscreen_tfhe/src/ops/keyswitch/public_functional_keyswitch.rs new file mode 100644 index 000000000..55f8ca449 --- /dev/null +++ b/sunscreen_tfhe/src/ops/keyswitch/public_functional_keyswitch.rs @@ -0,0 +1,264 @@ +use num::Complex; + +use crate::dst::FromMutSlice; +use crate::entities::{ + GlevCiphertextFftRef, GlweCiphertextRef, LweCiphertextRef, LweSecretKeyRef, PolynomialRef, +}; +use crate::ops::ciphertext::glwe_negate_inplace; +use crate::ops::encryption::encrypt_glwe_ciphertext_secret; +use crate::ops::fft_ops::decomposed_polynomial_glev_mad; +use crate::polynomial::polynomial_add_assign; +use crate::radix::PolynomialRadixIterator; +use crate::scratch::allocate_scratch; +use crate::Torus; +use crate::{ + entities::{GlweCiphertextFftRef, GlweSecretKeyRef, PublicFunctionalKeyswitchKeyRef}, + radix::scale_by_decomposition_factor, + scratch::allocate_scratch_ref, + GlweDef, LweDef, RadixDecomposition, TorusOps, +}; + +/// Generate a public functional keyswitch key, which is used to transform a +/// list of LWE ciphertexts into a GLWE ciphertext while applying a provided +/// function that converts the scalars in the LWE ciphertexts to the polynomial +/// message space in the GLWE ciphertext. +/// +/// See +/// [`public_functional_keyswitch`](crate::ops::keyswitch::public_functional_keyswitch) +/// for more details. +pub fn generate_public_functional_keyswitch_key( + output: &mut PublicFunctionalKeyswitchKeyRef, + from_sk: &LweSecretKeyRef, + to_sk: &GlweSecretKeyRef, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, +) { + from_sk.assert_valid(from_lwe); + to_sk.assert_valid(to_glwe); + output.assert_valid(from_lwe, to_glwe, radix); + + allocate_scratch_ref!(pt, PolynomialRef>, (to_glwe.dim.polynomial_degree)); + pt.clear(); + + for (s_i, glev) in from_sk.s().iter().zip(output.glevs_mut(to_glwe, radix)) { + for (j, glwe_ct) in (0..radix.count.0).zip(glev.glwe_ciphertexts_mut(to_glwe)) { + let x = scale_by_decomposition_factor(*s_i, j, radix); + + pt.coeffs_mut()[0] = Torus::from(x); + + encrypt_glwe_ciphertext_secret(glwe_ct, pt, to_sk, to_glwe); + } + } +} + +/// Perform a public functional keyswitch, where a list of LWE ciphertexts are +/// transformed into a GLWE ciphertext while applying a provided function. +/// Conceptually, this map transforms a list of torus plaintexts into a +/// polynomial plaintext. +/// +/// This operation is called "public" because the function F is public; a +/// variant of this operation called +/// [`private_functional_keyswitch`](crate::ops::keyswitch::private_functional_keyswitch) +/// is also available, where the function F is secret encoded in the keyswitch +/// key. +/// +/// # Remarks +/// `map` must be an R-Lipschitzian morphism `T_q^p -> T_q[X]` where `p = lwe_count`. +/// +/// The first parameter in `map` is the output +/// [`Polynomial`](crate::entities::Polynomial) of the morphism. This parameter +/// is initialized to 0. +/// +/// The second argument is a [`&[Torus]`](crate::math::Torus) of length `lwe_count`. +pub fn public_functional_keyswitch( + output: &mut GlweCiphertextRef, + inputs: &[&LweCiphertextRef], + pufksk: &PublicFunctionalKeyswitchKeyRef, + f: F, + from_lwe: &LweDef, + to_glwe: &GlweDef, + radix: &RadixDecomposition, +) where + S: TorusOps, + F: Fn(&mut PolynomialRef>, &[Torus]), +{ + pufksk.assert_valid(from_lwe, to_glwe, radix); + output.assert_valid(to_glwe); + + for i in inputs { + i.assert_valid(from_lwe); + } + + assert!(inputs.len() <= to_glwe.dim.polynomial_degree.0); + + allocate_scratch_ref!( + poly, + PolynomialRef>, + (to_glwe.dim.polynomial_degree) + ); + output.clear(); + allocate_scratch_ref!( + decomp_scratch, + PolynomialRef, + (to_glwe.dim.polynomial_degree) + ); + let mut a_buf = allocate_scratch::>(inputs.len()); + let lwe_vals = a_buf.as_mut_slice(); + allocate_scratch_ref!( + glev_fft, + GlevCiphertextFftRef>, + (to_glwe.dim, radix.count) + ); + allocate_scratch_ref!( + output_fft, + GlweCiphertextFftRef>, + (to_glwe.dim) + ); + + output_fft.clear(); + + // Compute all the a terms + for (i, row) in pufksk.glevs(to_glwe, radix).enumerate() { + for (j, a_i) in inputs.iter().map(|x| x.a(from_lwe)[i]).enumerate() { + lwe_vals[j] = a_i; + } + + row.fft(glev_fft, to_glwe); + + f(poly, lwe_vals); + + let decomp = PolynomialRadixIterator::new(poly, decomp_scratch, radix); + + decomposed_polynomial_glev_mad(output_fft, decomp, glev_fft, to_glwe); + } + + // Compute the b term + for (j, b) in inputs.iter().map(|x| x.b(from_lwe)).enumerate() { + lwe_vals[j] = *b; + } + + f(poly, lwe_vals); + + output_fft.ifft(output, to_glwe); + + // Compute (0, b) - output + glwe_negate_inplace(output, to_glwe); + polynomial_add_assign(output.b_mut(to_glwe), poly); +} + +#[cfg(test)] +mod tests { + + use rand::{thread_rng, RngCore}; + + use crate::{ + entities::{GlweCiphertext, PublicFunctionalKeyswitchKey}, + high_level::{encryption, keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX}, + PlaintextBits, + }; + + use super::*; + + #[test] + fn can_generate_public_functional_keyswitch_key() { + let lwe_params = TEST_LWE_DEF_1; + let glwe_params = TEST_GLWE_DEF_1; + let radix = TEST_RADIX; + + let mut pufksk = PublicFunctionalKeyswitchKey::new(&lwe_params, &glwe_params, &TEST_RADIX); + + let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params); + let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); + + generate_public_functional_keyswitch_key( + &mut pufksk, + &lwe_sk, + &glwe_sk, + &lwe_params, + &glwe_params, + &radix, + ); + + for (s_i, glev) in lwe_sk.s().iter().zip(pufksk.glevs(&glwe_params, &radix)) { + for (j, glwe_ct) in (0..radix.count.0).zip(glev.glwe_ciphertexts(&glwe_params)) { + let plaintext_bits = ((j + 1) * radix.radix_log.0) as u32; + let plaintext_bits = PlaintextBits(plaintext_bits); + + let pt = encryption::decrypt_glwe(glwe_ct, &glwe_sk, &glwe_params, plaintext_bits); + + assert_eq!(pt.coeffs()[0], *s_i); + + for x in pt.coeffs().iter().skip(1) { + assert_eq!(*x, 0); + } + } + } + } + + #[test] + fn can_public_functional_keyswitch() { + let lwe_params = TEST_LWE_DEF_1; + let glwe_params = TEST_GLWE_DEF_1; + let radix = TEST_RADIX; + + let mut pufksk = PublicFunctionalKeyswitchKey::new(&lwe_params, &glwe_params, &TEST_RADIX); + + let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params); + let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params); + + let plaintext_bits = PlaintextBits(4); + + generate_public_functional_keyswitch_key( + &mut pufksk, + &lwe_sk, + &glwe_sk, + &lwe_params, + &glwe_params, + &radix, + ); + + for _ in 0..10 { + let lwe_count = thread_rng().next_u64() as usize % glwe_params.dim.polynomial_degree.0; + + let pts = (0..lwe_count) + .map(|_| thread_rng().next_u64() % (0x1 << plaintext_bits.0)) + .collect::>(); + + let lwes = pts + .iter() + .map(|x| encryption::encrypt_lwe_secret(*x, &lwe_sk, &lwe_params, plaintext_bits)) + .collect::>(); + + let mut lwe_refs: Vec<&LweCiphertextRef> = vec![]; + + for x in lwes.iter() { + lwe_refs.push(x); + } + + let mut output = GlweCiphertext::new(&glwe_params); + + fn map(poly: &mut PolynomialRef>, tori: &[Torus]) { + for (c, t) in poly.coeffs_mut().iter_mut().zip(tori.iter()) { + *c = *t; + } + } + + public_functional_keyswitch( + &mut output, + &lwe_refs, + &pufksk, + map, + &lwe_params, + &glwe_params, + &radix, + ); + + let actual = encryption::decrypt_glwe(&output, &glwe_sk, &glwe_params, plaintext_bits); + + for (a, e) in actual.coeffs().iter().zip(pts.iter()) { + assert_eq!(a, e); + } + } + } +} diff --git a/sunscreen_tfhe/src/ops/mod.rs b/sunscreen_tfhe/src/ops/mod.rs new file mode 100644 index 000000000..f3b8ef43e --- /dev/null +++ b/sunscreen_tfhe/src/ops/mod.rs @@ -0,0 +1,19 @@ +/// Ciphertext operations where one of the operands is in FFT form. +pub mod fft_ops; + +/// Methods for key switching a ciphertext from one key to another, potentially +/// switching the parameters at the same time. +pub mod keyswitch; + +/// Methods for bootstrapping an LWE ciphertext from one key to another, while +/// refreshing the noise in the ciphertext. +pub mod bootstrapping; + +/// Methods for homomorphic operations on ciphertexts. +pub mod homomorphisms; + +/// Methods for operating on different ciphertext types. +pub mod ciphertext; + +/// Methods for encrypting and decrypting to various ciphertext types. +pub mod encryption; diff --git a/sunscreen_tfhe/src/params.rs b/sunscreen_tfhe/src/params.rs new file mode 100644 index 000000000..b3d4a5ae1 --- /dev/null +++ b/sunscreen_tfhe/src/params.rs @@ -0,0 +1,310 @@ +use serde::{Deserialize, Serialize}; + +use crate::rand::Stddev; +use crate::TorusOps; + +use sunscreen_math::security::{lwe_std_to_security_level, SecurityLevelResults}; + +trait SecurityLevel { + fn security_level(&self) -> f64; + + fn assert_security_level(&self, specified_security_level: usize) { + // Note that the underlying approximation is accurate up to 4 bits, so + // that is what we use here. + let tolerance = 4.0; + + let security_level = self.security_level(); + let security_difference = (security_level - specified_security_level as f64).abs(); + + assert!( + security_difference <= tolerance, + "Security level mismatch: expected {}, got {}", + specified_security_level, + security_level + ) + } +} + +fn generic_security_level(dimension: usize, std: f64) -> f64 { + match lwe_std_to_security_level(dimension, std) { + SecurityLevelResults::Level(level) => level, + SecurityLevelResults::BelowStandardDeviationBound => { + panic!("Standard deviation is too small for the given dimension") + } + SecurityLevelResults::AboveStandardDeviationBound => { + panic!("Standard deviation is too large for the given dimension") + } + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of torus elements in the LWE lattice. +pub struct LweDimension(pub usize); + +impl LweDimension { + /// Asserts this LWE problem is well-formed. + pub fn assert_valid(&self) { + assert_ne!(self.0, 0); + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The degree of the modulus polynomial `(x^N+1)` in a GLWE instance. +/// +/// # Remarks +/// GLWE encryption uses polynomials in `Z_q\[X\]/(X^N + 1)` where N is a power +/// of two. I.e. negacyclic polynomials modulo (X^N + 1) where the coefficients +/// are integers mod `q`. +pub struct PolynomialDegree(pub usize); + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of polynomials in a GLWE instance. +pub struct GlweSize(pub usize); + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of plaintext bits to encode into a message. +/// +/// # Remarks +/// Packing too much data into messages can result in incorrect decryptions due +/// to noise. +/// +/// For binary, set this to one. +pub struct PlaintextBits(pub u32); + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of padding bits to include in an LWE ciphertext. +pub struct CarryBits(pub u32); + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of digits to decompose a value into. +pub struct RadixCount(pub usize); + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of bits in a digit output during base decomposition. +pub struct RadixLog(pub usize); + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The number of [`LweCiphertext`](crate::entities::LweCiphertext)s that get +/// mapped into a [`GlweCiphertext`](crate::entities::GlweCiphertext) during +/// private functional keyswitching. +pub struct PrivateFunctionalKeyswitchLweCount(pub usize); + +impl PrivateFunctionalKeyswitchLweCount { + #[inline(always)] + /// Assert this [`PrivateFunctionalKeyswitchLweCount`] is valid. + pub fn assert_valid(&self) { + assert_ne!(self.0, 0); + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +/// The parameters defining how to do approximately perform base decomposition. I.e. +/// decompose values into digits. +/// +/// # Validity +/// For [`RadixDecomposition`] parameters to be valid: +/// * `count` must be greater than zero. +/// * `radix_log` must be greater than zero. +/// * `count * radix_log` must be less than equal to the number of bits in the +/// [`Torus`](crate::Torus) element used in the operation(s) performing radix +/// decomposition. +/// +/// Calling [`assert_valid`](Self::assert_valid) will panic if the parameters are invalid. +pub struct RadixDecomposition { + /// The number of digits to decompose a value into. + pub count: RadixCount, + + /// The number of bits in a digit output during base decomposition. + pub radix_log: RadixLog, +} + +impl RadixDecomposition { + #[inline(always)] + /// Panics if these [`RadixDecomposition`] parameters are invalid. + pub fn assert_valid(&self) { + assert!(self.count.0 > 0); + assert!(self.radix_log.0 > 0); + assert!(self.count.0 * self.radix_log.0 <= S::BITS as usize); + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +/// A [`PolynomialDegree`] and [`GlweSize`] in a GLWE instance. +pub struct GlweDimension { + /// The degree of the polynomial in a GLWE instance. + pub polynomial_degree: PolynomialDegree, + + /// The number of polynomials in a GLWE instance. + pub size: GlweSize, +} + +impl GlweDimension { + /// Reinterpret this GLWE problem instance as an LWE problem instance, returning + /// the [`LweDimension`]. + pub fn as_lwe_dimension(&self) -> LweDimension { + LweDimension(self.polynomial_degree.0 * self.size.0) + } + + #[inline(always)] + /// Assert these GLWE parameters are valid. + pub fn assert_valid(&self) { + assert!(self.polynomial_degree.0.is_power_of_two()); + assert!(self.polynomial_degree.0 > 0); + assert!(self.size.0 > 0); + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +/// Parameters that define an LWE problem instance. +/// +/// # Security +/// `dim` and `std` must be properly set to attain the desired security level. Improper +/// values will result in an insecure scheme. +pub struct LweDef { + /// The dimension of the LWE lattice. + pub dim: LweDimension, + + /// The standard deviation of the noise in the LWE lattice. + pub std: Stddev, +} + +impl LweDef { + #[inline(always)] + /// Asserts this LWE problem is well-formed. + pub fn assert_valid(&self) { + self.dim.assert_valid(); + } +} + +impl SecurityLevel for LweDef { + fn security_level(&self) -> f64 { + generic_security_level(self.dim.0, self.std.0) + } +} + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +/// Parameters that define a GLWE problem instance. +/// +/// # Security +/// `dim` and `std` must be properly set to attain the desired security level. Improper +/// values will result in an insecure scheme. +pub struct GlweDef { + /// The dimension of the GLWE instance. + pub dim: GlweDimension, + + /// The standard deviation of the noise in the GLWE instance. + pub std: Stddev, +} + +impl GlweDef { + /// Reinterpret this GLWE instance as an LWE instance of the same lattice dimension. + pub fn as_lwe_def(&self) -> LweDef { + LweDef { + dim: self.dim.as_lwe_dimension(), + std: self.std, + } + } + + #[inline(always)] + /// Assert the GLWE instance is valid. + pub fn assert_valid(&self) { + self.dim.assert_valid(); + } +} + +impl SecurityLevel for GlweDef { + fn security_level(&self) -> f64 { + self.as_lwe_def().security_level() + } +} + +/// 128-bit secure parameters for an LWE instance with a dimension of 512. +pub const LWE_512_128: LweDef = LweDef { + dim: LweDimension(512), + std: Stddev(0.0004899836456140595), +}; + +/// 128-bit secure parameters for a GLWE instance with 1 polynomial of degree 1024. +pub const GLWE_1_1024_128: GlweDef = GlweDef { + dim: GlweDimension { + size: GlweSize(1), + polynomial_degree: PolynomialDegree(1024), + }, + std: Stddev(0.0000000444778278004718), +}; + +/// 128-bit secure parameters for a GLWE instance with 1 polynomial of degree 2048. +pub const GLWE_1_2048_128: GlweDef = GlweDef { + dim: GlweDimension { + size: GlweSize(1), + polynomial_degree: PolynomialDegree(2048), + }, + std: Stddev(0.00000000000000034667670193445625), +}; + +/// 80-bit secure parameters for an LWE instance with a dimension of 512. +pub const LWE_512_80: LweDef = LweDef { + dim: LweDimension(512), + std: Stddev(0.000001842343446823844), +}; + +/// 80-bit secure parameters for a GLWE instance with 5 polynomials of degree 256. +pub const GLWE_5_256_80: GlweDef = GlweDef { + dim: GlweDimension { + size: GlweSize(5), + polynomial_degree: PolynomialDegree(256), + }, + std: Stddev(0.000000000000002106764669572764), +}; + +/// 80-bit secure parameters for a GLWE instance with 1 polynomial of degree 1024. +pub const GLWE_1_1024_80: GlweDef = GlweDef { + dim: GlweDimension { + size: GlweSize(1), + polynomial_degree: PolynomialDegree(1024), + }, + std: Stddev(0.0000000000010900242107812643), +}; + +#[cfg(test)] +mod tests { + + use sunscreen_math::security::lwe_security_level_to_std; + + use super::*; + + #[test] + fn check_security_levels() { + let actual_lwe_std = lwe_security_level_to_std(512, 128.0); + println!("LWE 512 128: {}", actual_lwe_std); + LWE_512_128.assert_security_level(128); + + let actual_glwe_std = lwe_security_level_to_std(1024, 128.0); + println!("GLWE 1 1024 128: {}", actual_glwe_std); + GLWE_1_1024_128.assert_security_level(128); + + let actual_glwe_std = lwe_security_level_to_std(2048, 128.0); + println!("GLWE 1 2048 128: {}", actual_glwe_std); + GLWE_1_2048_128.assert_security_level(128); + + let actual_lwe_std = lwe_security_level_to_std(512, 80.0); + println!("LWE 512 80: {}", actual_lwe_std); + LWE_512_80.assert_security_level(80); + + let actual_glwe_std = lwe_security_level_to_std(256 * 5, 80.0); + println!("GLWE 5 256 80: {}", actual_glwe_std); + GLWE_5_256_80.assert_security_level(80); + + let actual_glwe_std = lwe_security_level_to_std(1024, 80.0); + println!("GLWE 1 1024 80: {}", actual_glwe_std); + GLWE_1_1024_80.assert_security_level(80); + } +} diff --git a/sunscreen_tfhe/src/rand.rs b/sunscreen_tfhe/src/rand.rs new file mode 100644 index 000000000..6d522ff2a --- /dev/null +++ b/sunscreen_tfhe/src/rand.rs @@ -0,0 +1,97 @@ +use std::fmt::Debug; + +use rand::{thread_rng, Rng, RngCore}; +use rand_distr::Normal; +use serde::{Deserialize, Serialize}; + +use crate::math::{Torus, TorusOps}; + +#[derive(Debug, Copy, Clone, Serialize, Deserialize)] +#[serde(transparent)] +/// The standard deviation of a Gaussian distribution normalized over the torus +/// `T_q`. +pub struct Stddev(pub f64); + +/// Sample a random torus element from the a normal distribution +/// with a mean of 0 and the given stddev +pub fn normal_torus(std: Stddev) -> Torus { + let dist = Normal::new(0., std.0).unwrap(); + + let e_0 = thread_rng().sample(dist); + let q = (S::BITS as f64).exp2(); + + let e = f64::round(e_0 * q) as i64; + let e: u64 = unsafe { std::mem::transmute(e) }; + + Torus::from(S::from_u64(e)) +} + +/// Generate a random torus element uniformly +pub fn uniform_torus() -> Torus { + Torus::from(S::from_u64(thread_rng().next_u64())) +} + +/// Generate a random binary torus element +pub fn binary() -> S { + S::from_u64(thread_rng().next_u64() % 2) +} + +#[cfg(test)] +mod tests { + use std::mem::transmute_copy; + + use crate::math::ToF64; + + use super::*; + + #[test] + fn can_produce_random_torus() { + pub fn case() + where + S: TorusOps, + I: ToF64 + Copy + Debug, + >::Error: Debug, + { + let q = (S::BITS as f64).exp2(); + let n: i32 = 100_000; + + let dev = Stddev(0.000_448_516_698_238_696_5); + + let data = (0..n) + .map(|_| { + let t = normal_torus::(dev).inner(); + unsafe { transmute_copy::(&t) } + }) + .collect::>(); + + // Reinterpreting the torus points as i64 values should give a mean of approximately zero. + let mean = data + .iter() + .copied() + .map(|x| x.to_f64()) + .fold(0., |s, x| s + x) + / (q * n as f64); + + assert!(mean < 1e-5); + + // Scale the integer values back to the range [0, 1) and compute the stddev + let measured_std = data + .iter() + .copied() + .map(|x| { + let val = (x.to_f64() / q) - mean; + + val * val + }) + .fold(0f64, |s, x| s + x) + / n as f64; + + let measured_std = measured_std.sqrt(); + + assert!((measured_std - dev.0).abs() < 0.00001); + } + + case::(); + case::(); + } +} diff --git a/sunscreen_tfhe/src/scratch.rs b/sunscreen_tfhe/src/scratch.rs new file mode 100644 index 000000000..61f4f6c38 --- /dev/null +++ b/sunscreen_tfhe/src/scratch.rs @@ -0,0 +1,427 @@ +use linked_list::{Cursor, LinkedList}; +use num::{Complex, Float}; +use rustfft::FftNum; +use std::{ + alloc::Layout, + cell::RefCell, + marker::PhantomData, + mem::{size_of, transmute}, +}; + +use crate::{Torus, TorusOps}; + +thread_local! { + static SCRATCH: RefCell> = RefCell::new(None); +} + +macro_rules! allocate_scratch_ref { + ($out_ident:ident,$ref_t:ident<$t_arg:ty>, ($($args:expr),*)) => { + let mut tmp = crate::scratch::allocate_scratch(<$ref_t<$t_arg> as crate::dst::OverlaySize>::size(($($args),*))); + let $out_ident = <$ref_t<$t_arg>>::from_mut_slice(tmp.as_mut_slice()); + }; + ($out_ident:ident,[$ref_t:ident<$t_arg:ty>], $len:expr) => { + let mut tmp = crate::scratch::allocate_scratch::<$ref_t<$t_arg>>(<[$ref_t<$t_arg>] as crate::dst::OverlaySize>::size($len)); + let $out_ident = tmp.as_mut_slice(); + } +} + +pub(crate) use allocate_scratch_ref; + +/// Indicates this is a "Plain Old Data" type. For `T` qualify as such, +/// all bit patterns must be considered a properly initialized instance of +/// `T`. +/// +/// # Safety +/// Implementing this trait on types that don't meet the above requirements +/// may result in undefined behavior. +pub unsafe trait Pod {} + +unsafe impl Pod for u8 {} +unsafe impl Pod for u16 {} +unsafe impl Pod for u32 {} +unsafe impl Pod for u64 {} +unsafe impl Pod for u128 {} +unsafe impl Pod for i8 {} +unsafe impl Pod for i16 {} +unsafe impl Pod for i32 {} +unsafe impl Pod for i64 {} +unsafe impl Pod for i128 {} +unsafe impl Pod for f32 {} +unsafe impl Pod for f64 {} +unsafe impl Pod for Complex where T: Float + FftNum {} +unsafe impl Pod for Torus where S: TorusOps {} + +/// Allocate a scratch buffer in a cache efficient manner. Freed scratch +/// buffers are reused in subsequent allocations. +/// +/// # Remarks +/// The returned [`ScratchBuffer`] will be aligned to `align_of::()` and +/// have a length equal to count. +/// +/// # Panics +/// If `T` is a zero-sized type (e.g. `()`). +pub fn allocate_scratch(count: usize) -> ScratchBuffer<'static, T> +where + T: Pod, +{ + SCRATCH.with(|s| { + if s.borrow().is_none() { + let new_scratch = Scratch::new(); + *s.borrow_mut() = Some(new_scratch); + } + + (*s.borrow_mut()).as_mut().unwrap().allocate::(count) + }) +} + +/// An "allocator" designed for allocating scratch memory. +/// +/// # Remarks +/// Internally, this data structure is a [`LinkedList`] of +/// [`Vec`]s treated as a stack. +/// +/// [`Scratch`] is designed to provide a cache locality for +/// temporary buffers by reusing allocations. +/// +/// Upon allocation, we update the "top" of the stack to be the +/// furthest consecutive free buffer from the actual top. +/// +/// The only way to use this type is through the `thread_local` +/// singleton, as this is the only way to guarantee soundness. +/// The references doled out by allocate have `'static`` +/// lifetimes. This is needed so you can have mutable references +/// to different allocations at the same time. +struct Scratch { + // Only accessed through the cursor, so compiler thinks it's unused. + #[allow(unused)] + stack: Box>, + top: *mut Cursor<'static, Allocation>, +} + +impl Drop for Scratch { + fn drop(&mut self) { + let top = unsafe { Box::from_raw(self.top) }; + + std::mem::drop(top); + } +} + +impl Scratch { + /// We require this to be private. [`allocate_scratch`] should be + /// the only way to use scratch memory, which will allocate memory + /// using a thread_local allocator. + fn new() -> Self { + let mut list = Box::new(LinkedList::new()); + + let cursor = Box::new(list.cursor()); + let top = unsafe { transmute(Box::into_raw(cursor)) }; + + Self { stack: list, top } + } + + /// Allocate a buffer matching the given specification. + fn allocate(&mut self, count: usize) -> ScratchBuffer<'static, T> + where + T: Pod, + { + assert_ne!(size_of::(), 0); + + let top = unsafe { &mut *self.top }; + + // Push the top as far down until we hit the bottom or an allocation + // currently in use. + loop { + let prev = top.peek_prev(); + + if let Some(x) = prev { + if x.is_free { + top.prev().unwrap(); + continue; + } + } + + break; + } + + let layout = Layout::array::(count).unwrap(); + let req_len = layout.size() + layout.align(); + + let allocation = match top.peek_next() { + Some(d) => { + assert!(d.is_free); + + // Resize the allocation if needed. + if d.data.len() < req_len { + d.data.resize(req_len, 0u8); + } + + d.requested_len = count; + top.next().unwrap() + } + None => { + let data = vec![0u8; req_len]; + + let allocation = Allocation { + requested_len: count, + is_free: false, + data, + }; + + top.insert(allocation); + top.next().unwrap() + } + }; + + allocation.is_free = false; + + ScratchBuffer { + allocation: allocation as *mut Allocation, + _phantom: PhantomData, + } + } +} + +struct Allocation { + requested_len: usize, + data: Vec, + is_free: bool, +} + +pub struct ScratchBuffer<'a, T> { + allocation: *mut Allocation, + _phantom: PhantomData<&'a T>, +} + +impl<'a, T> ScratchBuffer<'a, T> { + #[allow(unused)] + /// Get a slice to the underlying data. + /// + /// # Remarks + /// While not extremely expensive, this operation does require capturing + /// an aligned slice of data in an underlying allocation. As such, + /// you should avoid repeated calls. + pub fn as_slice(&self) -> &[T] { + let count = unsafe { (*self.allocation).requested_len }; + let (_, slice, _) = unsafe { (*self.allocation).data.align_to::() }; + unsafe { transmute(&slice[0..count]) } + } + + /// Get a mutable slice to the underlying data. + /// + /// # Remarks + /// While not extremely expensive, this operation does require capturing + /// an aligned slice of data in an underlying allocation. As such, + /// you should avoid repeated calls. + pub fn as_mut_slice(&mut self) -> &mut [T] { + let count = unsafe { (*self.allocation).requested_len }; + let (_pre, slice, _post) = unsafe { (*self.allocation).data.align_to_mut::() }; + unsafe { transmute(&mut slice[0..count]) } + } +} + +impl<'a, T> Drop for ScratchBuffer<'a, T> { + fn drop(&mut self) { + unsafe { (*self.allocation).is_free = true }; + } +} + +#[cfg(test)] +mod tests { + use std::mem::align_of; + + use super::*; + + #[test] + fn can_allocate() { + let mut scratch = Scratch::new(); + + let mut d = scratch.allocate::(64); + + let d = d.as_mut_slice(); + assert_eq!(d.len(), 64); + + for (i, d_i) in d.iter_mut().enumerate() { + *d_i = i as u64; + } + + assert_eq!(scratch.stack.len(), 1); + } + + #[test] + fn buffers_get_reused() { + let mut scratch = Scratch::new(); + + let b = scratch.allocate::(64); + + assert_eq!(scratch.stack.len(), 1); + let b_slice = b.as_slice(); + assert_eq!(b_slice.len(), 64); + assert_eq!(b_slice.as_ptr().align_offset(align_of::()), 0); + let first_ptr = b_slice.as_ptr(); + + std::mem::drop(b); + + let mut b = scratch.allocate::(64); + let b_slice = b.as_mut_slice(); + assert_eq!(first_ptr, b_slice.as_ptr()); + assert_eq!(scratch.stack.len(), 1); + assert_eq!(b_slice.len(), 64); + + for (i, b_i) in b_slice.iter_mut().enumerate() { + *b_i = i as u64; + } + } + + #[test] + #[ignore] + fn reallocate_on_bigger_request() { + let mut scratch = Scratch::new(); + + let mut b = scratch.allocate::(64); + + assert_eq!(scratch.stack.len(), 1); + let b_slice = b.as_mut_slice(); + assert_eq!(b_slice.len(), 64); + assert_eq!(b_slice.as_ptr().align_offset(align_of::()), 0); + let first_ptr = b_slice.as_ptr(); + + for (i, b_i) in b_slice.iter_mut().enumerate() { + *b_i = i as u64; + } + + std::mem::drop(b); + + let mut b = scratch.allocate::(16384); + let b = b.as_mut_slice(); + assert_ne!(first_ptr, b.as_ptr()); + assert_eq!(scratch.stack.len(), 1); + assert_eq!(b.len(), 16384); + + for (i, b_i) in b.iter_mut().enumerate() { + *b_i = i as u64; + } + } + + #[test] + fn allocate_two_buffers() { + let mut scratch = Scratch::new(); + + let mut a = scratch.allocate::(12); + let mut b = scratch.allocate::(12); + + let a = a.as_mut_slice(); + let b = b.as_mut_slice(); + + assert_eq!(a.len(), 12); + assert_eq!(b.len(), 12); + assert_eq!(scratch.stack.len(), 2); + assert_ne!(a.as_mut_ptr(), b.as_mut_ptr()); + + for i in 0..a.len() { + a[i] = i as u64; + b[i] = i as u64; + } + } + + #[test] + fn align_16() { + let mut scratch = Scratch::new(); + let mut buffers = (0..10) + .map(|_| scratch.allocate::(10)) + .collect::>(); + + assert_eq!(scratch.stack.len(), 10); + + for b in buffers.iter_mut() { + let b = b.as_mut_slice(); + assert_eq!(b.len(), 10); + + for (i, b_i) in b.iter_mut().enumerate() { + *b_i = i as u128; + } + } + } + + #[test] + fn align_65536() { + // Chose an alignment larger than any reasonable OS's page size + // to try to force the alignment algorithm into play. + #[repr(C, align(65536))] + #[derive(Copy, Clone)] + struct Foo { + x: u32, + } + + unsafe impl Pod for Foo {} + + let mut scratch = Scratch::new(); + let mut b = scratch.allocate::(15); + + let b_slice = b.as_mut_slice(); + assert_eq!(b_slice.len(), 15); + + for b in b_slice { + b.x = 22; + let ptr = b as *mut Foo; + + // Check that each item is memory aligned to the proper + // location. + assert_eq!(ptr.align_offset(align_of::()), 0); + } + } + + #[test] + fn stack_coalesces_correctly() { + let mut scratch = Scratch::new(); + let a = scratch.allocate::(16); + let mut b: ScratchBuffer<'_, u64> = scratch.allocate::(16); + let b_ptr = b.as_mut_slice().as_mut_ptr(); + + let c: ScratchBuffer<'_, u64> = scratch.allocate::(16); + let d: ScratchBuffer<'_, u64> = scratch.allocate::(16); + + std::mem::drop(b); + assert_eq!(scratch.stack.len(), 4); + + // We can't reuse b's buffer until c, d, e get dropped. + let mut e: ScratchBuffer<'_, u64> = scratch.allocate::(16); + assert_ne!(b_ptr, e.as_mut_slice().as_mut_ptr()); + + assert_eq!(scratch.stack.len(), 5); + + std::mem::drop(c); + std::mem::drop(d); + std::mem::drop(e); + + // Now we can reuse b's buffer. + let mut f = scratch.allocate::(16); + assert_eq!(f.as_mut_slice().as_mut_ptr(), b_ptr); + assert_eq!(scratch.stack.len(), 5); + + std::mem::drop(a); + } + + #[test] + fn zero_size_allocations() { + let mut scratch = Scratch::new(); + let a = scratch.allocate::(2); + let b = scratch.allocate::(0); + + let a_slice = a.as_slice(); + let b_slice = b.as_slice(); + + assert_eq!(scratch.stack.len(), 2); + assert_eq!(a_slice.len(), 2); + assert_eq!(b_slice.len(), 0); + } + + #[test] + #[should_panic] + fn zst_allocations_should_panic() { + struct Foo {} + unsafe impl Pod for Foo {} + + let mut scratch = Scratch::new(); + let _ = scratch.allocate::(0x1 << 48); + } +} diff --git a/sunscreen_tfhe/src/zkp.rs b/sunscreen_tfhe/src/zkp.rs new file mode 100644 index 000000000..3fdd49b53 --- /dev/null +++ b/sunscreen_tfhe/src/zkp.rs @@ -0,0 +1,986 @@ +use logproof::{ + crypto::CryptoHash, + linear_algebra::{Matrix, PolynomialMatrix}, + math::ModSwitch, + rings::ZqRistretto, + Bounds, LogProofProverKnowledge, LogProofVerifierKnowledge as VerifierKnowledge, +}; +use sunscreen_math::{ + poly::Polynomial, + ring::{BarrettBackend, Ring, RingModulus, Zq}, + BarrettConfig, One, Zero, +}; + +use crate::{ + entities::{LweCiphertext, LwePublicKey, LweSecretKey, TlwePublicEncRandomness}, + math::{Torus, TorusOps}, + LweDef, PlaintextBits, +}; + +/// Proof statements for the SDLP proof system when applied to TFHE. +#[derive(Debug)] +pub enum ProofStatement<'a, 'b, S: TorusOps + TorusZq> { + /// A private key encryption statement. + PrivateKeyEncryption { + /// The message ID being encrypted. + message_id: usize, + + /// The ciphertext being encrypted. + ciphertext: &'a LweCiphertext, + }, + + /// A public key encryption statement. + PublicKeyEncryption { + /// The message ID being encrypted. + message_id: usize, + + /// The encrypted data under the provided public key. + ciphertext: &'a LweCiphertext, + + /// The public key being used to encrypt the message. + public_key: &'b LwePublicKey, + }, +} + +/// Witness information for the SDLP proof system when applied to TFHE. +/// This is the private information used when generating a proof. +#[derive(Debug)] +pub enum Witness<'a, 'b, S: TorusOps + TorusZq> { + /// A private key encryption witness. + PrivateKeyEncryption { + /// The randomness used in the encryption. + randomness: Torus, + + /// The private key used in the encryption. + private_key: &'a LweSecretKey, + }, + + /// A public key encryption witness. + PublicKeyEncryption { + /// The randomness used in the encryption. + randomness: &'b TlwePublicEncRandomness, + }, +} + +/// Generate LogProofProverKnowledge for the SDLP proof system. +pub fn generate_tfhe_sdlp_prover_knowledge( + statements: &[ProofStatement], + messages: &[Torus], + witness: &[Witness], + lwe: &LweDef, + plaintext_bits: PlaintextBits, +) -> LogProofProverKnowledge { + let vk = generate_tfhe_sdlp_verifier_knowledge(statements, lwe, plaintext_bits); + + let s = compute_s(statements, witness, messages, lwe); + + LogProofProverKnowledge { vk, s } +} + +/// Modulus for u32. +#[derive(BarrettConfig)] +#[barrett_config(modulus = "4294967296", num_limbs = 1)] +pub struct U32Config; + +/// Field for u32. +pub type Zq32 = Zq<1, BarrettBackend<1, U32Config>>; + +/// Modulus for u64. +#[derive(BarrettConfig)] +#[barrett_config(modulus = "18446744073709551616", num_limbs = 2)] +pub struct U64Config; + +/// Field for u64. +pub type Zq64 = Zq<2, BarrettBackend<2, U64Config>>; + +/// Torus properties on a discrete ring `Z_q`. +pub trait TorusZq +where + Self: Sized, +{ + /// The discrete ring `Z_q`. + type Zq: Ring + + From + + CryptoHash + + ModSwitch + + RingModulus<4> + + Ord + + From; +} + +impl TorusZq for u32 { + type Zq = Zq32; +} + +impl TorusZq for u64 { + type Zq = Zq64; +} + +/// Computes the public information needed to prove and verify public and private key +/// TLWE encryptions. +/// +/// # Remarks +/// Using only private key encryption results in significantly faster runtime since we +/// can characterize `Z_q[X]/f` with `f = X + 1`. +/// +/// # Details +/// Let `N` be the TLWE lattice dimension. +/// let `M` be the number of messages encrypted in in the ciphertexts +/// Note that multiple ciphertexts can encrypt the same message and SDLP can prove they contain +/// the same message. +/// +/// ## A matrix's structure +/// SDLP proves a linear relation `AS=T` where +/// * `A in Z_q[X]^{m x k}/f` +/// * `S in Z_q[X]^{k x n}/f` +/// * `T in Z_q[X]^{m x n}/f` +/// * `f = X^D + 1` +/// +/// For proving encryptions of TLWE ciphertexts, we have: +/// * `m` is the number of proof statements. +/// * `n = 1`. +/// * `k = num_messages + num_public_keys * (N + 1) + num_private_keys * N + num_private_encs + num_public_encs * (N + 1)`. +/// * See [quotient ring modulus `f`](#quotient-ring-modulus-f). +/// +/// The matrix `A` is arranged into rows, one per proof statement whose column structure depends +/// on whether the statement is for public or private key encryption. +/// +/// ### Private key statements +/// Each private key encryption statement row has the following column arrangement: +/// ```ignore +/// 1 at msg_idx pub_key (N/A) a at pvt_key_id * N 1 at pvt_stmt_idx e_idx (N/A) +/// V V V V V +/// [ 0 .. 1 .. 0 0 .. 0 0 .. a_0, a_1, .. a_N .. 0 0 .. 1 .. 0 0 .. 0] +/// ``` +/// ### Public key statements +/// Recall that a TLWE ciphertext consists of (a_1, ... a_N, b), where `a_i, b` in the discrete +/// torus `T`. +/// +/// Furthermore, the public key `P = (p_1, ... p_N)` where `p_i` is a secret key encryption of +/// zero. +/// +/// Each public key encryption statement row has the following column arrangement: +/// ```ignore +/// 1 at msg_idx p at pub_key_id * N a (N/A) b (N/A) e_idx +/// V V V V V +/// [ 0 .. 1 .. 0 0 .. p_0, p_1, .. p_N .. 0 0 .. 0 0 .. 0 0 .. 1 .. 0 ] +/// ``` +/// For these entries, we reinterpret each `p_i` as a polynomial `mod (X^N + 1)`. +/// +/// ## T's structure +/// While SDLP support T as a matrix, we only need a vector of `m` rows. +/// +/// The structure of each row in T depends on whether said the corresponding statement describes +/// a public or private key encryption. +/// +/// ### Private key statements +/// The row's polynomial is simply `b * X^0`. +/// +/// ### Public key statements +/// The rows's polynomial is `a_0 * X^0, a_1 * X^1, ... a_N * X^{N - 1} b * X^N` +/// +/// ## Quotient ring modulus `f` +/// When `statements` contains only secret key encryption statements, `f = X + 1`. Otherwise, +/// `f = X^{N + 1} + 1`. +pub fn generate_tfhe_sdlp_verifier_knowledge( + statements: &[ProofStatement], + lwe: &LweDef, + plaintext_bits: PlaintextBits, +) -> VerifierKnowledge { + // If we need to prove any public encryption statements, then `f = X^{lwe_dimension + 1} + 1`. + // If not, we can use the more efficient X + 1. + let mut f_coeffs = vec![S::Zq::from(::zero()); f_degree(statements, lwe) + 1]; + f_coeffs[0] = S::Zq::from(S::one()); + + let last_coeff = f_coeffs.len() - 1; + f_coeffs[last_coeff] = S::Zq::from(S::one()); + + let f = Polynomial::new(&f_coeffs); + + let a = compute_a(statements, lwe, plaintext_bits); + let t = compute_t(statements, lwe); + let bounds = compute_bounds(statements, lwe, plaintext_bits); + + VerifierKnowledge::new(a, t, f, bounds) +} + +fn compute_bounds( + statements: &[ProofStatement], + lwe: &LweDef, + plaintext_bits: PlaintextBits, +) -> Matrix { + let (_, cols) = proof_matrix_dim(statements, lwe.dim.0); + let offsets = compute_a_column_offsets(statements, lwe); + + let mut bounds = Matrix::::new(cols, 1); + + let num_messages = num_messages(statements); + let num_coeffs = f_degree(statements, lwe); + let lwe_dimension = lwe.dim.0; + + // Bounds for messages + for i in 0..num_messages { + let mut b = Bounds(vec![0; num_coeffs]); + b.0[0] = 0x1u64 << plaintext_bits.0; + debug_assert_eq!(bounds[(i, 0)].0, &[]); + bounds[(i, 0)] = b; + } + + // Public r and e + for i in 0..num_public(statements) { + for j in 0..lwe_dimension { + let mut b = Bounds(vec![0; num_coeffs]); + + // Values of r are binary + b.0[0] = 0x1u64 << plaintext_bits.0; + debug_assert_eq!( + bounds[(offsets.public_keys + i * lwe_dimension + j, 0)].0, + &[] + ); + bounds[(offsets.public_keys + i * lwe_dimension + j, 0)] = b; + } + + // e is normal distributed over the torus. + // TODO: This bound is too high. Get a tighter bound. + let b = Bounds(vec![0x1u64 << (60 - plaintext_bits.0); num_coeffs]); + + debug_assert_eq!(bounds[(offsets.public_e + i, 0)].0, &[]); + bounds[(offsets.public_e + i, 0)] = b; + } + + // Private s and e + for i in 0..num_private(statements) { + for j in 0..lwe_dimension { + let mut b = Bounds(vec![0; num_coeffs]); + + // Values of s are binary + b.0[0] = 0x1u64 << plaintext_bits.0; + debug_assert_eq!( + bounds[(offsets.private_a + j + i * lwe_dimension, 0)].0, + &[] + ); + bounds[(offsets.private_a + j + i * lwe_dimension, 0)] = b; + } + + let mut b = Bounds(vec![0; num_coeffs]); + + // e is normal distributed over the torus. + // TODO: This bound is too high. Get a tighter bound. + b.0[0] = 0x1u64 << (62 - plaintext_bits.0); + debug_assert_eq!(bounds[(offsets.private_e + i, 0)].0, &[]); + bounds[(offsets.private_e + i, 0)] = b; + } + + bounds +} + +fn f_degree(statements: &[ProofStatement], lwe: &LweDef) -> usize { + let lwe_dimension = lwe.dim.0; + + if num_public(statements) > 0 { + lwe_dimension + 1 + } else { + 1 + } +} + +fn encoding_factor(plaintext_bits: PlaintextBits) -> S::Zq { + let x = S::from_u64(0x1u64 << (S::BITS - plaintext_bits.0)); + S::Zq::from(x) +} + +struct IdxOffsets { + public_keys: usize, + public_e: usize, + private_a: usize, + private_e: usize, +} + +#[inline] +fn compute_a_column_offsets( + statements: &[ProofStatement], + lwe: &LweDef, +) -> IdxOffsets { + let lwe_dimension = lwe.dim.0; + let public_keys = num_messages(statements); + let public_e = public_keys + num_public(statements) * lwe_dimension; + let private_a = public_e + num_public(statements); + let private_e = private_a + num_private(statements) * lwe_dimension; + + IdxOffsets { + public_keys, + public_e, + private_a, + private_e, + } +} + +fn compute_a( + statements: &[ProofStatement], + lwe: &LweDef, + plaintext_bits: PlaintextBits, +) -> Matrix> { + let lwe_dimension = lwe.dim.0; + let offsets = compute_a_column_offsets(statements, lwe); + + let (rows, cols) = proof_matrix_dim(statements, lwe_dimension); + + let mut a = Matrix::>::new(rows, cols); + + assert!(plaintext_bits.0 > 0); + + let msg_encode = encoding_factor::(plaintext_bits); + + let mut cur_public = 0; + let mut cur_private = 0; + + for (i, s) in statements.iter().enumerate() { + let (is_public, public_key, message_id, ciphertext) = match s { + ProofStatement::PrivateKeyEncryption { + ciphertext, + message_id, + } => (false, None, *message_id, *ciphertext), + ProofStatement::PublicKeyEncryption { + public_key, + ciphertext, + message_id, + } => (true, Some(public_key), *message_id, *ciphertext), + }; + + // Insert the message coefficient. For private encryptions, this goes in + // the constant coefficient. For public, it goes in the d-1 coefficient. + let coeffs = if !is_public { + vec![msg_encode.clone()] + } else { + let mut coeffs = vec![S::Zq::from(::zero()); lwe_dimension + 1]; + coeffs[lwe_dimension] = msg_encode.clone(); + coeffs + }; + + debug_assert_eq!(a[(i, message_id)], Polynomial::zero()); + a[(i, message_id)] = Polynomial::new(&coeffs); + + if is_public { + // Push the public key + let pk = public_key.unwrap(); + + for (j, z) in pk.enc_zeros(lwe).enumerate() { + let (z_a, z_b) = z.a_b(lwe); + let mut coeffs = z_a + .iter() + .map(|x| S::Zq::from(x.inner())) + .collect::>(); + coeffs.push(S::Zq::from(z_b.inner())); + + let public_key_idx = offsets.public_keys + cur_public * lwe_dimension + j; + + debug_assert_eq!(a[(i, public_key_idx)], Polynomial::zero()); + a[(i, public_key_idx)] = Polynomial::new(&coeffs); + } + + // Push the randomness + debug_assert_eq!(a[(i, offsets.public_e + cur_public)], Polynomial::zero()); + a[(i, offsets.public_e + cur_public)] = Polynomial::one(); + cur_public += 1; + } else { + let (c_a, _c_b) = ciphertext.a_b(lwe); + + // Place the a values of the cipertext in the matrix. + for (j, a_j) in c_a.iter().enumerate() { + let private_key_idx = offsets.private_a + j + cur_private * lwe_dimension; + + debug_assert_eq!(a[(i, private_key_idx)], Polynomial::zero()); + a[(i, private_key_idx)] = Polynomial::new(&[S::Zq::from(a_j.inner())]); + } + + debug_assert_eq!(a[(i, offsets.private_e)], Polynomial::zero()); + + a[(i, offsets.private_e + i)] = Polynomial::one(); + cur_private += 1; + } + } + + a +} + +fn compute_t( + statements: &[ProofStatement], + lwe: &LweDef, +) -> PolynomialMatrix { + let mut t = PolynomialMatrix::new(statements.len(), 1); + + for (i, s) in statements.iter().enumerate() { + match s { + ProofStatement::PrivateKeyEncryption { + ciphertext: c, + message_id: _, + } => { + let (_, c_b) = c.a_b(lwe); + + debug_assert_eq!(t[(i, 0)], Polynomial::zero()); + t[(i, 0)] = Polynomial::new(&[S::Zq::from(c_b.inner())]); + } + ProofStatement::PublicKeyEncryption { + public_key: _, + message_id: _, + ciphertext: c, + } => { + let (c_a, c_b) = c.a_b(lwe); + + let mut coeffs = c_a + .iter() + .map(|x| S::Zq::from(x.inner())) + .collect::>(); + + coeffs.push(S::Zq::from(c_b.inner())); + + debug_assert_eq!(t[(i, 0)], Polynomial::zero()); + t[(i, 0)] = Polynomial::new(&coeffs); + } + } + } + + t +} + +#[inline] +fn compute_s( + statements: &[ProofStatement], + witness: &[Witness], + messages: &[Torus], + lwe: &LweDef, +) -> PolynomialMatrix { + assert_eq!(statements.len(), witness.len()); + + let lwe_dimension = lwe.dim.0; + + // If the a matrix is rows x cols, the S witness must have 'cols' rows. + let (_, cols) = proof_matrix_dim(statements, lwe_dimension); + + // Column offsets in the A matrix become row offsets in the S witness. + let offsets = compute_a_column_offsets(statements, lwe); + + let mut s = PolynomialMatrix::new(cols, 1); + + // Put the messages into the witness + for (i, m) in messages.iter().take(num_messages(statements)).enumerate() { + debug_assert_eq!(s[(i, 0)], Polynomial::zero()); + s[(i, 0)] = Polynomial::new(&[S::Zq::from(m.inner())]); + } + + let public_entries = witness.iter().filter_map(|x| match x { + Witness::PublicKeyEncryption { randomness } => Some(randomness), + _ => None, + }); + + // Put public 'r' and 'e' randomness into the witness + for (i, rnd) in public_entries.enumerate() { + for (j, r) in rnd.r.iter().enumerate() { + let public_key_index = offsets.public_keys + j + i * lwe_dimension; + + debug_assert_eq!(s[(public_key_index, 0)], Polynomial::zero()); + s[(public_key_index, 0)] = Polynomial::new(&[S::Zq::from(*r)]); + } + + let mut coeffs = rnd + .e + .a_b(lwe) + .0 + .iter() + .map(|x| S::Zq::from(x.inner())) + .collect::>(); + coeffs.push(S::Zq::from(rnd.e.a_b(lwe).1.inner())); + + let public_randomness_index = offsets.public_e + i; + + debug_assert_eq!(s[(public_randomness_index, 0)], Polynomial::zero()); + s[(public_randomness_index, 0)] = Polynomial::new(&coeffs); + } + + let private_keys = witness.iter().filter_map(|x| match x { + Witness::PrivateKeyEncryption { + randomness, + private_key, + } => Some((private_key, randomness)), + _ => None, + }); + + // Put the secret keys into the witness. + for (i, (sk, e)) in private_keys.enumerate() { + for (j, sk) in sk.s().iter().enumerate() { + let private_key_index = offsets.private_a + j + i * lwe_dimension; + + debug_assert_eq!(s[(private_key_index, 0)], Polynomial::zero()); + s[(private_key_index, 0)] = Polynomial::new(&[S::Zq::from(*sk)]); + } + + let private_randomness_index = offsets.private_e + i; + + debug_assert_eq!(s[(private_randomness_index, 0)], Polynomial::zero()); + s[(private_randomness_index, 0)] = Polynomial::new(&[S::Zq::from(e.inner())]); + } + + s +} + +#[inline(always)] +fn num_private(statements: &[ProofStatement]) -> usize { + // Public key statements require 2 rows while private key statements require only one. + statements + .iter() + .filter(|x| { + matches!( + x, + ProofStatement::PrivateKeyEncryption { + ciphertext: _, + message_id: _ + } + ) + }) + .count() +} + +#[inline(always)] +fn num_public(statements: &[ProofStatement]) -> usize { + statements.len() - num_private(statements) +} + +/// Returns a tuple of the `(rows, cols)` in SDLP's `A` matrix. +#[inline] +fn proof_matrix_dim( + statements: &[ProofStatement], + lwe_dimension: usize, +) -> (usize, usize) { + let num_rows = statements.len(); + + let num_cols = num_messages(statements) // message terms + + num_public(statements) * lwe_dimension // public key terms + + num_private(statements) * lwe_dimension // a terms + + num_rows; // Public + private e term '1' coeffs + + (num_rows, num_cols) +} + +#[inline] +fn num_messages(statements: &[ProofStatement]) -> usize { + statements.iter().fold(0usize, |max, x| { + let message_id = match x { + ProofStatement::PublicKeyEncryption { + public_key: _, + ciphertext: _, + message_id, + } => message_id, + ProofStatement::PrivateKeyEncryption { + ciphertext: _, + message_id, + } => message_id, + }; + + usize::max(max, *message_id) + }) + 1 +} + +#[cfg(test)] +mod tests { + use logproof::{InnerProductVerifierKnowledge, LogProof, LogProofGenerators}; + use merlin::Transcript; + use rand::{thread_rng, RngCore}; + + use crate::{ + high_level::*, + zkp::{num_private, num_public}, + LweDef, LweDimension, LWE_512_80, + }; + + use super::*; + + #[test] + fn can_compute_a_column_offsets() { + let lwe = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + let lwe_dimension = lwe.dim.0; + + let ct = LweCiphertext::::zero(&lwe); + + let sk = keygen::generate_binary_lwe_sk(&lwe); + let pk = keygen::generate_lwe_pk(&sk, &lwe); + + let statements = [ + ProofStatement::PrivateKeyEncryption { + ciphertext: &ct, + message_id: 1, + }, + ProofStatement::PrivateKeyEncryption { + ciphertext: &ct, + message_id: 0, + }, + ProofStatement::PublicKeyEncryption { + ciphertext: &ct, + public_key: &pk, + message_id: 0, + }, + ProofStatement::PrivateKeyEncryption { + ciphertext: &ct, + message_id: 1, + }, + ProofStatement::PrivateKeyEncryption { + ciphertext: &ct, + message_id: 0, + }, + ProofStatement::PublicKeyEncryption { + ciphertext: &ct, + public_key: &pk, + message_id: 1, + }, + ProofStatement::PublicKeyEncryption { + ciphertext: &ct, + public_key: &pk, + message_id: 1, + }, + ]; + + let idx = compute_a_column_offsets(&statements, &lwe); + + let num_messages = num_messages(&statements); + let num_private = num_private(&statements); + let num_public = num_public(&statements); + + assert_eq!(num_messages, 2); + assert_eq!(num_private, 4); + assert_eq!(num_public, 3); + + // We have 2 messages that precede this. + assert_eq!(idx.public_keys, num_messages); + // We have 3 different public encryptions, so there are 3 public key + // column entries. That 2 use the same key is immaterial, as they + // must match different randomness in the witness. + assert_eq!(idx.public_e, idx.public_keys + num_public * lwe_dimension); + // We have 3 public key encryptions that contribute to this offset. + // Each 'e' is a polynomial of degree `lwe_dimension`. + assert_eq!(idx.private_a, idx.public_e + num_public); + // Finally, each secret_a entry takes `lwe_dimension` columns. + assert_eq!(idx.private_e, idx.private_a + num_private * lwe_dimension); + } + + #[test] + fn num_messages_works() { + let lwe = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + + let sk = keygen::generate_binary_lwe_sk(&lwe); + let pk = keygen::generate_lwe_pk(&sk, &lwe); + + let ct = encryption::trivial_lwe(0, &lwe, PlaintextBits(1)); + + let num_messages = num_messages(&[ + ProofStatement::PrivateKeyEncryption { + ciphertext: &ct, + message_id: 0, + }, + ProofStatement::PrivateKeyEncryption { + ciphertext: &ct, + message_id: 1, + }, + ProofStatement::PublicKeyEncryption { + public_key: &pk, + ciphertext: &ct, + message_id: 0, + }, + ]); + + assert_eq!(num_messages, 2); + } + + fn prove_and_verify(pk: &LogProofProverKnowledge) { + let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize); + let u = InnerProductVerifierKnowledge::get_u(); + let mut p_t = Transcript::new(b"test"); + + let proof = LogProof::create(&mut p_t, pk, &gen.g, &gen.h, &u); + + let mut v_t = Transcript::new(b"test"); + + proof.verify(&mut v_t, &pk.vk, &gen.g, &gen.h, &u).unwrap(); + } + + #[test] + fn one_secret_key() { + let params = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let (ct, rng) = + encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, PlaintextBits(1)); + + let pk = generate_tfhe_sdlp_prover_knowledge( + &[ProofStatement::PrivateKeyEncryption { + message_id: 0, + ciphertext: &ct, + }], + &[Torus::from(1)], + &[Witness::PrivateKeyEncryption { + randomness: rng, + private_key: &sk, + }], + ¶ms, + PlaintextBits(1), + ); + + prove_and_verify::(&pk); + } + + #[test] + fn two_secret_key() { + let params = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + + let (ct0, rng0) = + encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits); + let (ct1, rng1) = + encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits); + + let pk = generate_tfhe_sdlp_prover_knowledge( + &[ + ProofStatement::PrivateKeyEncryption { + message_id: 0, + ciphertext: &ct0, + }, + ProofStatement::PrivateKeyEncryption { + message_id: 1, + ciphertext: &ct1, + }, + ], + &[Torus::from(1), Torus::from(1)], + &[ + Witness::PrivateKeyEncryption { + randomness: rng0, + private_key: &sk, + }, + Witness::PrivateKeyEncryption { + randomness: rng1, + private_key: &sk, + }, + ], + ¶ms, + bits, + ); + + prove_and_verify::(&pk); + } + + #[test] + fn one_public_key() { + let params = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let pk = keygen::generate_lwe_pk(&sk, ¶ms); + + let (ct, rng) = encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits); + + let pk = generate_tfhe_sdlp_prover_knowledge( + &[ProofStatement::PublicKeyEncryption { + message_id: 0, + public_key: &pk, + ciphertext: &ct, + }], + &[Torus::from(1)], + &[Witness::PublicKeyEncryption { randomness: &rng }], + ¶ms, + bits, + ); + + prove_and_verify::(&pk); + } + + #[test] + fn two_public_key() { + let params = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let pk = keygen::generate_lwe_pk(&sk, ¶ms); + + let (ct0, rng0) = encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits); + let (ct1, rng1) = encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits); + + let pk = generate_tfhe_sdlp_prover_knowledge( + &[ + ProofStatement::PublicKeyEncryption { + message_id: 0, + public_key: &pk, + ciphertext: &ct0, + }, + ProofStatement::PublicKeyEncryption { + message_id: 1, + public_key: &pk, + ciphertext: &ct1, + }, + ], + &[Torus::from(1), Torus::from(1)], + &[ + Witness::PublicKeyEncryption { randomness: &rng0 }, + Witness::PublicKeyEncryption { randomness: &rng1 }, + ], + ¶ms, + bits, + ); + + prove_and_verify::(&pk); + } + + #[test] + fn one_public_one_private() { + let params = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + let bits = PlaintextBits(1); + + let sk = keygen::generate_binary_lwe_sk(¶ms); + let pk = keygen::generate_lwe_pk(&sk, ¶ms); + + let (ct_priv, rng_priv) = + encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits); + let (ct_pub, rng_pub) = + encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits); + + let pk = generate_tfhe_sdlp_prover_knowledge( + &[ + ProofStatement::PrivateKeyEncryption { + message_id: 0, + ciphertext: &ct_priv, + }, + ProofStatement::PublicKeyEncryption { + message_id: 0, + public_key: &pk, + ciphertext: &ct_pub, + }, + ], + &[Torus::from(1)], + &[ + Witness::PrivateKeyEncryption { + randomness: rng_priv, + private_key: &sk, + }, + Witness::PublicKeyEncryption { + randomness: &rng_pub, + }, + ], + ¶ms, + bits, + ); + + prove_and_verify::(&pk); + } + + #[ignore] + #[test] + fn complex_examples() { + let params = LweDef { + dim: LweDimension(4), + ..LWE_512_80 + }; + let bits = PlaintextBits(1); + + let case = || { + let sk = keygen::generate_binary_lwe_sk(¶ms); + let pk = keygen::generate_lwe_pk(&sk, ¶ms); + + let num_messages = thread_rng().next_u64() as usize % 7 + 1; + let num_secret_encryptions = thread_rng().next_u64() as usize % 8; + let num_public_encryptions = thread_rng().next_u64() as usize % 8; + + let messages = (0..num_messages) + .map(|_| thread_rng().next_u64() % 2) + .collect::>(); + + // Skip trivial cases. I don't think SDLP allows it and it's boring.. + if num_public_encryptions == 0 && num_secret_encryptions == 0 { + return; + } + + let mut statements = vec![]; + let mut witnesses = vec![]; + let mut private_info = vec![]; + let mut public_info = vec![]; + + for _ in 0..num_secret_encryptions { + let msg_id = thread_rng().next_u64() as usize % num_messages; + + let (ct, noise) = encryption::encrypt_lwe_secret_and_return_randomness( + messages[msg_id], + &sk, + ¶ms, + bits, + ); + private_info.push((ct, noise, msg_id)); + } + + for (ct, noise, msg_id) in private_info.iter() { + statements.push(ProofStatement::PrivateKeyEncryption { + message_id: *msg_id, + ciphertext: ct, + }); + witnesses.push(Witness::PrivateKeyEncryption { + randomness: *noise, + private_key: &sk, + }); + } + + for _ in 0..num_public_encryptions { + let msg_id = thread_rng().next_u64() as usize % num_messages; + + let (ct, noise) = encryption::encrypt_lwe_and_return_randomness( + messages[msg_id], + &pk, + ¶ms, + bits, + ); + public_info.push((ct, noise, msg_id)); + } + + for (ct, noise, msg_id) in public_info.iter() { + statements.push(ProofStatement::PublicKeyEncryption { + message_id: *msg_id, + ciphertext: ct, + public_key: &pk, + }); + witnesses.push(Witness::PublicKeyEncryption { randomness: noise }); + } + + let messages = messages.iter().map(|x| Torus::from(*x)).collect::>(); + + let pk = generate_tfhe_sdlp_prover_knowledge( + &statements, + &messages, + &witnesses, + ¶ms, + bits, + ); + + prove_and_verify::(&pk); + }; + + for _ in 0..5 { + case(); + } + } +}