mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-01-09 13:48:06 -05:00
Sunscreen's TFHE implementation (#349)
Co-authored-by: Sam Tay <samctay@pm.me>
This commit is contained in:
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -60,14 +60,14 @@ jobs:
|
||||
|
||||
# Build package in prep for user docs
|
||||
- name: Build sunscreen and bincode
|
||||
run: cargo build --release --features bulletproofs,linkedproofs --package sunscreen --package bincode
|
||||
run: cargo build --profile mdbook --features bulletproofs,linkedproofs --package sunscreen --package bincode
|
||||
# Build mdbook
|
||||
- name: Build mdBook
|
||||
run: cargo build --release
|
||||
working-directory: ./mdBook
|
||||
# Build user documentation
|
||||
- name: Test docs
|
||||
run: ../mdBook/target/release/mdbook test -L dependency=../target/release/deps --extern sunscreen=../target/release/libsunscreen.rlib --extern bincode=../target/release/libbincode.rlib
|
||||
run: ../mdBook/target/release/mdbook test -L dependency=../target/mdbook/deps --extern sunscreen=../target/mdbook/libsunscreen.rlib --extern bincode=../target/mdbook/libbincode.rlib
|
||||
working-directory: ./sunscreen_docs
|
||||
|
||||
lint:
|
||||
|
||||
188
.gitignore
vendored
188
.gitignore
vendored
@@ -1,5 +1,189 @@
|
||||
sunscreen_docs/book
|
||||
/target
|
||||
.DS_Store
|
||||
__bindgen.ii
|
||||
__bindgen.cpp
|
||||
|
||||
# Ignore output from the proptest crate
|
||||
proptest-regressions/
|
||||
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=emacs,vim,visualstudiocode,macos,windows,linux,direnv,rust
|
||||
|
||||
### direnv ###
|
||||
.direnv
|
||||
.envrc
|
||||
|
||||
### Emacs ###
|
||||
# -*- mode: gitignore; -*-
|
||||
*~
|
||||
\#*\#
|
||||
/.emacs.desktop
|
||||
/.emacs.desktop.lock
|
||||
*.elc
|
||||
auto-save-list
|
||||
tramp
|
||||
.\#*
|
||||
|
||||
# Org-mode
|
||||
.org-id-locations
|
||||
*_archive
|
||||
|
||||
# flymake-mode
|
||||
*_flymake.*
|
||||
|
||||
# eshell files
|
||||
/eshell/history
|
||||
/eshell/lastdir
|
||||
|
||||
# elpa packages
|
||||
/elpa/
|
||||
|
||||
# reftex files
|
||||
*.rel
|
||||
|
||||
# AUCTeX auto folder
|
||||
/auto/
|
||||
|
||||
# cask packages
|
||||
.cask/
|
||||
dist/
|
||||
|
||||
# Flycheck
|
||||
flycheck_*.el
|
||||
|
||||
# server auth directory
|
||||
/server/
|
||||
|
||||
# projectiles files
|
||||
.projectile
|
||||
|
||||
# directory configuration
|
||||
.dir-locals.el
|
||||
|
||||
# network security
|
||||
/network-security.data
|
||||
|
||||
|
||||
### Linux ###
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### macOS Patch ###
|
||||
# iCloud generated files
|
||||
*.icloud
|
||||
|
||||
### Rust ###
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
### Vim ###
|
||||
# Swap
|
||||
[._]*.s[a-v][a-z]
|
||||
!*.svg # comment out if you don't need vector files
|
||||
[._]*.sw[a-p]
|
||||
[._]s[a-rt-v][a-z]
|
||||
[._]ss[a-gi-z]
|
||||
[._]sw[a-p]
|
||||
|
||||
# Session
|
||||
Session.vim
|
||||
Sessionx.vim
|
||||
|
||||
# Temporary
|
||||
.netrwhist
|
||||
# Auto-generated tag files
|
||||
tags
|
||||
# Persistent undo
|
||||
[._]*.un~
|
||||
|
||||
### VisualStudioCode ###
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
!.vscode/launch.json
|
||||
!.vscode/extensions.json
|
||||
!.vscode/*.code-snippets
|
||||
|
||||
# Local History for Visual Studio Code
|
||||
.history/
|
||||
|
||||
# Built Visual Studio Code Extensions
|
||||
*.vsix
|
||||
|
||||
### VisualStudioCode Patch ###
|
||||
# Ignore all local history of files
|
||||
.history
|
||||
.ionide
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
|
||||
88
Cargo.lock
generated
88
Cargo.lock
generated
@@ -643,6 +643,8 @@ dependencies = [
|
||||
"num-traits 0.2.17",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@@ -1593,6 +1595,12 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-list"
|
||||
version = "0.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4dacf969043dc69f1f731b5042eb05e030d264bcf34f2242889fcbdc7a65f06"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.10"
|
||||
@@ -2196,6 +2204,15 @@ dependencies = [
|
||||
"syn 2.0.38",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "primal-check"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9df7f93fd637f083201473dab4fee2db4c429d32e55e3299980ab3957ab916a0"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "private_tx_linkedproof"
|
||||
version = "0.1.0"
|
||||
@@ -2220,9 +2237,9 @@ checksum = "f89dff0959d98c9758c88826cc002e2c3d0b9dfac4139711d1f30de442f1139b"
|
||||
|
||||
[[package]]
|
||||
name = "proptest"
|
||||
version = "1.3.1"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7c003ac8c77cb07bb74f5f198bce836a689bcd5a42574612bf14d17bfd08c20e"
|
||||
checksum = "31b476131c3c86cb68032fdc5cb6d5a1045e3e42d96b69fa599fd77701e1f5bf"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"bit-vec",
|
||||
@@ -2232,7 +2249,7 @@ dependencies = [
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"rand_xorshift",
|
||||
"regex-syntax 0.7.5",
|
||||
"regex-syntax 0.8.2",
|
||||
"rusty-fork",
|
||||
"tempfile",
|
||||
"unarray",
|
||||
@@ -2356,6 +2373,15 @@ dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "realfft"
|
||||
version = "3.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "953d9f7e5cdd80963547b456251296efc2626ed4e3cbf36c869d9564e0220571"
|
||||
dependencies = [
|
||||
"rustfft",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.2.11"
|
||||
@@ -2393,9 +2419,9 @@ checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.5"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da"
|
||||
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
||||
|
||||
[[package]]
|
||||
name = "remove_dir_all"
|
||||
@@ -2485,6 +2511,21 @@ dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustfft"
|
||||
version = "6.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43806561bc506d0c5d160643ad742e3161049ac01027b5e6d7524091fd401d86"
|
||||
dependencies = [
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits 0.2.17",
|
||||
"primal-check",
|
||||
"strength_reduce",
|
||||
"transpose",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.20"
|
||||
@@ -2773,6 +2814,12 @@ dependencies = [
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strength_reduce"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
@@ -2994,6 +3041,27 @@ dependencies = [
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sunscreen_tfhe"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"criterion 0.5.1",
|
||||
"linked-list",
|
||||
"logproof",
|
||||
"merlin",
|
||||
"num",
|
||||
"paste",
|
||||
"proptest",
|
||||
"rand",
|
||||
"rand_distr",
|
||||
"realfft",
|
||||
"rustfft",
|
||||
"serde",
|
||||
"sunscreen_math",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sunscreen_zkp_backend"
|
||||
version = "0.8.1"
|
||||
@@ -3195,6 +3263,16 @@ dependencies = [
|
||||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "transpose"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6522d49d03727ffb138ae4cbc1283d3774f0d10aa7f9bf52e6784c45daf9b23"
|
||||
dependencies = [
|
||||
"num-integer",
|
||||
"strength_reduce",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "try-lock"
|
||||
version = "0.2.3"
|
||||
|
||||
13
Cargo.toml
13
Cargo.toml
@@ -17,7 +17,7 @@ members = [
|
||||
"sunscreen_math",
|
||||
"sunscreen_math_macros",
|
||||
"sunscreen_runtime",
|
||||
"sunscreen_compiler_common",
|
||||
"sunscreen_tfhe",
|
||||
"sunscreen_zkp_backend",
|
||||
]
|
||||
exclude = ["mdBook", "rust-playground"]
|
||||
@@ -25,6 +25,17 @@ exclude = ["mdBook", "rust-playground"]
|
||||
[profile.release]
|
||||
split-debuginfo = "packed"
|
||||
debug = true
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
|
||||
[profile.bench]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
|
||||
[profile.mdbook]
|
||||
inherits = "release"
|
||||
lto = false
|
||||
codegen-units = 16
|
||||
|
||||
[workspace.dependencies]
|
||||
bytemuck = "1.13.0"
|
||||
|
||||
81
sunscreen_tfhe/.vscode/launch.json
vendored
Normal file
81
sunscreen_tfhe/.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug unit tests in library 'sunscreen_tfhe'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--lib",
|
||||
"--package=sunscreen_tfhe"
|
||||
],
|
||||
"filter": {
|
||||
"name": "sunscreen_tfhe",
|
||||
"kind": "lib"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug benchmark 'tfhe_proof'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--bench=tfhe_proof",
|
||||
"--package=sunscreen_tfhe"
|
||||
],
|
||||
"filter": {
|
||||
"name": "tfhe_proof",
|
||||
"kind": "bench"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug benchmark 'fft'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--bench=fft",
|
||||
"--package=sunscreen_tfhe"
|
||||
],
|
||||
"filter": {
|
||||
"name": "fft",
|
||||
"kind": "bench"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug benchmark 'ops'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--bench=ops",
|
||||
"--package=sunscreen_tfhe"
|
||||
],
|
||||
"filter": {
|
||||
"name": "ops",
|
||||
"kind": "bench"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
||||
3
sunscreen_tfhe/.vscode/settings.json
vendored
Normal file
3
sunscreen_tfhe/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"rust-analyzer.showUnlinkedFileNotification": false
|
||||
}
|
||||
52
sunscreen_tfhe/Cargo.toml
Normal file
52
sunscreen_tfhe/Cargo.toml
Normal file
@@ -0,0 +1,52 @@
|
||||
[package]
|
||||
name = "sunscreen_tfhe"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
authors = ["Sunscreen"]
|
||||
rust-version = "1.56.0"
|
||||
license = "AGPL-3.0-only"
|
||||
description = "This crate contains the Sunscreen Torus FHE (TFHE) implementation"
|
||||
homepage = "https://sunscreen.tech"
|
||||
repository = "https://github.com/Sunscreen-tech/Sunscreen"
|
||||
documentation = "https://docs.sunscreen.tech"
|
||||
keywords = ["FHE", "TFHE", "lattice", "cryptography"]
|
||||
categories = ["cryptography"]
|
||||
readme = "crates-io.md"
|
||||
|
||||
|
||||
[dependencies]
|
||||
bytemuck = { workspace = true }
|
||||
# TODO: Remove when Rust stabilizes Cursor API
|
||||
linked-list = "0.0.3"
|
||||
logproof = { workspace = true, optional = true }
|
||||
num = { workspace = true }
|
||||
paste = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true }
|
||||
realfft = "3.3.0"
|
||||
rustfft = "6.1.0"
|
||||
serde = { workspace = true }
|
||||
sunscreen_math = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5.1"
|
||||
merlin = "3.0.0"
|
||||
proptest = "1.4.0"
|
||||
|
||||
[features]
|
||||
logproof = ["dep:logproof"]
|
||||
|
||||
[[bench]]
|
||||
name = "tfhe_proof"
|
||||
harness = false
|
||||
required-features= ["logproof"]
|
||||
|
||||
[[bench]]
|
||||
name = "fft"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "ops"
|
||||
harness = false
|
||||
35
sunscreen_tfhe/barrett.py
Normal file
35
sunscreen_tfhe/barrett.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import math
|
||||
import sys
|
||||
|
||||
radix = 10
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: barrett <number of 64-bit limbs> <modulus> [<modulus radix>]")
|
||||
|
||||
n = int(sys.argv[1])
|
||||
|
||||
if len(sys.argv) == 4:
|
||||
radix = int(sys.argv[3])
|
||||
|
||||
p = int(sys.argv[2], radix)
|
||||
|
||||
def compute_vals(n, p):
|
||||
r = math.floor(2**(64 * n) // p)
|
||||
s = math.floor(2**(128 * n) // p) - 2**(64 * n) * r
|
||||
t = 2**(64*n) - r * p
|
||||
|
||||
return (n, r, s, t)
|
||||
|
||||
(_, r, s, t) = compute_vals(n, p)
|
||||
|
||||
def print_value(n, name, x):
|
||||
print(name + " = [")
|
||||
|
||||
for i in range(n):
|
||||
print(" " + str((x >> (64 * i)) & 0xFFFFFFFFFFFFFFFF) + ",")
|
||||
|
||||
print("]")
|
||||
|
||||
print_value(n, "r", r)
|
||||
print_value(n, "s", s)
|
||||
print_value(n, "t", t)
|
||||
47
sunscreen_tfhe/benches/fft.rs
Normal file
47
sunscreen_tfhe/benches/fft.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use num::Complex;
|
||||
use sunscreen_tfhe::{math::fft::negacyclic::TwistedFft, FrequencyTransform};
|
||||
|
||||
fn negacyclic_fft(c: &mut Criterion) {
|
||||
let n = 2048;
|
||||
|
||||
let plan = TwistedFft::<f64>::new(n);
|
||||
|
||||
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut y = vec![Complex::from(0.0); x.len() / 2];
|
||||
|
||||
c.bench_function("FFT 2048", |s| {
|
||||
s.iter(|| {
|
||||
plan.forward(&x, &mut y);
|
||||
});
|
||||
});
|
||||
|
||||
let n = 1024;
|
||||
|
||||
let plan = TwistedFft::<f64>::new(n);
|
||||
|
||||
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut y = vec![Complex::from(0.0); x.len() / 2];
|
||||
|
||||
c.bench_function("FFT 1024", |s| {
|
||||
s.iter(|| {
|
||||
plan.forward(&x, &mut y);
|
||||
});
|
||||
});
|
||||
|
||||
let n: usize = 256;
|
||||
|
||||
let plan = TwistedFft::<f64>::new(n);
|
||||
|
||||
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut y = vec![Complex::from(0.0); x.len() / 2];
|
||||
|
||||
c.bench_function("FFT 256", |s| {
|
||||
s.iter(|| {
|
||||
plan.forward(&x, &mut y);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, negacyclic_fft);
|
||||
criterion_main!(benches);
|
||||
235
sunscreen_tfhe/benches/ops.rs
Normal file
235
sunscreen_tfhe/benches/ops.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
use criterion::{
|
||||
criterion_group, criterion_main, measurement::WallTime, BenchmarkGroup, Criterion,
|
||||
};
|
||||
|
||||
use sunscreen_tfhe::{
|
||||
entities::{
|
||||
GgswCiphertext, GgswCiphertextFft, GlweCiphertext, Polynomial, UnivariateLookupTable,
|
||||
},
|
||||
high_level::*,
|
||||
ops::bootstrapping::circuit_bootstrap,
|
||||
rand::Stddev,
|
||||
GlweDef, GlweDimension, GlweSize, LweDef, LweDimension, PlaintextBits, PolynomialDegree,
|
||||
RadixCount, RadixDecomposition, RadixLog, GLWE_1_1024_80, GLWE_5_256_80, LWE_512_80,
|
||||
};
|
||||
|
||||
fn cmux(c: &mut Criterion) {
|
||||
struct CmuxParams {
|
||||
gsw_radix: RadixDecomposition,
|
||||
glwe: GlweDef,
|
||||
}
|
||||
|
||||
fn cmux_params(params: &CmuxParams, c: &mut Criterion) {
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms.glwe);
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let msg = (0..params.glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let msg = Polynomial::new(&msg);
|
||||
|
||||
let a = encryption::encrypt_glwe(&msg, &sk, ¶ms.glwe, bits);
|
||||
let b = a.clone();
|
||||
let sel = encryption::encrypt_ggsw(1, &sk, ¶ms.glwe, ¶ms.gsw_radix, bits);
|
||||
let mut sel_fft = GgswCiphertextFft::new(¶ms.glwe, ¶ms.gsw_radix);
|
||||
|
||||
sel.fft(&mut sel_fft, ¶ms.glwe, ¶ms.gsw_radix);
|
||||
|
||||
let name = format!(
|
||||
"cmux N={} k={} l={}",
|
||||
params.glwe.dim.polynomial_degree.0, params.glwe.dim.size.0, params.gsw_radix.count.0
|
||||
);
|
||||
|
||||
let mut result = GlweCiphertext::new(¶ms.glwe);
|
||||
|
||||
c.bench_function(&name, |bench| {
|
||||
bench.iter(|| {
|
||||
sunscreen_tfhe::ops::fft_ops::cmux(
|
||||
&mut result,
|
||||
&a,
|
||||
&b,
|
||||
&sel_fft,
|
||||
¶ms.glwe,
|
||||
¶ms.gsw_radix,
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let params = CmuxParams {
|
||||
gsw_radix: RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(10),
|
||||
},
|
||||
glwe: GLWE_5_256_80,
|
||||
};
|
||||
|
||||
cmux_params(¶ms, c);
|
||||
|
||||
let params = CmuxParams {
|
||||
gsw_radix: RadixDecomposition {
|
||||
count: RadixCount(1),
|
||||
radix_log: RadixLog(11),
|
||||
},
|
||||
glwe: GLWE_1_1024_80,
|
||||
};
|
||||
|
||||
cmux_params(¶ms, c);
|
||||
}
|
||||
|
||||
fn programmable_bootstrapping(c: &mut Criterion) {
|
||||
fn run_bench(
|
||||
name: &str,
|
||||
g: &mut BenchmarkGroup<WallTime>,
|
||||
lwe: &LweDef,
|
||||
glwe: &GlweDef,
|
||||
bs_radix: &RadixDecomposition,
|
||||
) {
|
||||
let lwe_sk = keygen::generate_binary_lwe_sk(lwe);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(glwe);
|
||||
let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, lwe, glwe, bs_radix);
|
||||
let bsk = fft::fft_bootstrap_key(&bsk, lwe, glwe, bs_radix);
|
||||
|
||||
let ct = lwe_sk.encrypt(1, &glwe.as_lwe_def(), PlaintextBits(1)).0;
|
||||
let lut = UnivariateLookupTable::trivial_from_fn(|x| x, glwe, PlaintextBits(1));
|
||||
|
||||
g.bench_function(name, |b| {
|
||||
b.iter(|| {
|
||||
evaluation::univariate_programmable_bootstrap(&ct, &lut, &bsk, lwe, glwe, bs_radix);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
let mut g = c.benchmark_group("Bootstrapping");
|
||||
|
||||
// CBS parameters
|
||||
let radix = RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(16),
|
||||
};
|
||||
|
||||
run_bench(
|
||||
"CBS parameters",
|
||||
&mut g,
|
||||
&LWE_512_80,
|
||||
&GLWE_5_256_80,
|
||||
&radix,
|
||||
);
|
||||
|
||||
// Binary PBS parameters
|
||||
let bs_radix = RadixDecomposition {
|
||||
count: RadixCount(3),
|
||||
radix_log: RadixLog(6),
|
||||
};
|
||||
|
||||
run_bench(
|
||||
"boolean PBS parameters",
|
||||
&mut g,
|
||||
&LweDef {
|
||||
dim: LweDimension(722),
|
||||
std: Stddev(0.000013071021089943935),
|
||||
},
|
||||
&GlweDef {
|
||||
dim: GlweDimension {
|
||||
size: GlweSize(2),
|
||||
polynomial_degree: PolynomialDegree(512),
|
||||
},
|
||||
std: Stddev(0.00000004990272175010415),
|
||||
},
|
||||
&bs_radix,
|
||||
);
|
||||
|
||||
// 3-bit message 1-bit carry PBS parameters
|
||||
let bs_radix = RadixDecomposition {
|
||||
count: RadixCount(1),
|
||||
radix_log: RadixLog(23),
|
||||
};
|
||||
|
||||
run_bench(
|
||||
"3+1 message PBS parameters",
|
||||
&mut g,
|
||||
&LweDef {
|
||||
dim: LweDimension(742),
|
||||
std: Stddev(0.000007069849454709433),
|
||||
},
|
||||
&GlweDef {
|
||||
dim: GlweDimension {
|
||||
size: GlweSize(1),
|
||||
polynomial_degree: PolynomialDegree(2048),
|
||||
},
|
||||
std: Stddev(0.00000000000000029403601535432533),
|
||||
},
|
||||
&bs_radix,
|
||||
);
|
||||
}
|
||||
|
||||
fn circuit_bootstrapping(c: &mut Criterion) {
|
||||
let pbs_radix = RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(16),
|
||||
};
|
||||
let cbs_radix = RadixDecomposition {
|
||||
count: RadixCount(1),
|
||||
radix_log: RadixLog(11),
|
||||
};
|
||||
let pfks_radix = RadixDecomposition {
|
||||
count: RadixCount(3),
|
||||
radix_log: RadixLog(11),
|
||||
};
|
||||
|
||||
let level_2_params = GLWE_5_256_80;
|
||||
let level_1_params = GLWE_1_1024_80;
|
||||
let level_0_params = LWE_512_80;
|
||||
|
||||
let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params);
|
||||
let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params);
|
||||
let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params);
|
||||
|
||||
let bsk = keygen::generate_bootstrapping_key(
|
||||
&sk_0,
|
||||
&sk_2,
|
||||
&level_0_params,
|
||||
&level_2_params,
|
||||
&pbs_radix,
|
||||
);
|
||||
let bsk = fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix);
|
||||
|
||||
let cbsksk = keygen::generate_cbs_ksk(
|
||||
sk_2.to_lwe_secret_key(),
|
||||
&sk_1,
|
||||
&level_2_params.as_lwe_def(),
|
||||
&level_1_params,
|
||||
&pfks_radix,
|
||||
);
|
||||
|
||||
let val = 0;
|
||||
|
||||
let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1));
|
||||
|
||||
let mut actual = GgswCiphertext::new(&level_1_params, &cbs_radix);
|
||||
|
||||
c.bench_function("Circuit bootstrap", |b| {
|
||||
b.iter(|| {
|
||||
circuit_bootstrap(
|
||||
&mut actual,
|
||||
&ct,
|
||||
&bsk,
|
||||
&cbsksk,
|
||||
&level_0_params,
|
||||
&level_1_params,
|
||||
&level_2_params,
|
||||
&pbs_radix,
|
||||
&cbs_radix,
|
||||
&pfks_radix,
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
cmux,
|
||||
programmable_bootstrapping,
|
||||
circuit_bootstrapping
|
||||
);
|
||||
criterion_main!(benches);
|
||||
128
sunscreen_tfhe/benches/tfhe_proof.rs
Normal file
128
sunscreen_tfhe/benches/tfhe_proof.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
use logproof::{
|
||||
InnerProductVerifierKnowledge, LogProof, LogProofGenerators, LogProofProverKnowledge,
|
||||
};
|
||||
use merlin::Transcript;
|
||||
use sunscreen_tfhe::{
|
||||
high_level::*,
|
||||
zkp::{generate_tfhe_sdlp_prover_knowledge, ProofStatement, TorusZq, Witness},
|
||||
PlaintextBits, Torus, TorusOps, LWE_512_80,
|
||||
};
|
||||
|
||||
fn make_proof<S: TorusOps + TorusZq>(pk: &LogProofProverKnowledge<S::Zq>) -> LogProof {
|
||||
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
|
||||
let u = InnerProductVerifierKnowledge::get_u();
|
||||
let mut p_t = Transcript::new(b"test");
|
||||
|
||||
LogProof::create(&mut p_t, pk, &gen.g, &gen.h, &u)
|
||||
}
|
||||
|
||||
fn tfhe_secret_proof(c: &mut Criterion) {
|
||||
let params = LWE_512_80;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let enc_data = (0..32)
|
||||
.map(|_| encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits))
|
||||
.collect::<Vec<_>>();
|
||||
let msg = vec![Torus::from(1u64); 32];
|
||||
|
||||
let statement = enc_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, d)| ProofStatement::PrivateKeyEncryption {
|
||||
message_id: i,
|
||||
ciphertext: &d.0,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let witness = enc_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(_i, d)| Witness::PrivateKeyEncryption {
|
||||
randomness: d.1,
|
||||
private_key: &sk,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(&statement, &msg, &witness, ¶ms, bits);
|
||||
|
||||
let p = make_proof::<u64>(&pk);
|
||||
|
||||
let mut g = c.benchmark_group("Secret key encryption");
|
||||
g.sample_size(10);
|
||||
|
||||
g.bench_function("Prove 32-bit secret encryption", |b| {
|
||||
b.iter(|| {
|
||||
let _ = make_proof::<u64>(&pk);
|
||||
});
|
||||
});
|
||||
|
||||
g.bench_function("Verify 32-bit secret encryption", |b| {
|
||||
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
|
||||
let u = InnerProductVerifierKnowledge::get_u();
|
||||
|
||||
b.iter(|| {
|
||||
let mut t = Transcript::new(b"test");
|
||||
|
||||
p.verify(&mut t, &pk.vk, &gen.g, &gen.h, &u).unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn tfhe_public_proof(c: &mut Criterion) {
|
||||
let params = LWE_512_80;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let public = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
let enc_data = (0..32)
|
||||
.map(|_| encryption::encrypt_lwe_and_return_randomness(1, &public, ¶ms, bits))
|
||||
.collect::<Vec<_>>();
|
||||
let msg = vec![Torus::from(1u64); 32];
|
||||
|
||||
let statement = enc_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, d)| ProofStatement::PublicKeyEncryption {
|
||||
public_key: &public,
|
||||
message_id: i,
|
||||
ciphertext: &d.0,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let witness = enc_data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(_i, d)| Witness::PublicKeyEncryption { randomness: &d.1 })
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(&statement, &msg, &witness, ¶ms, bits);
|
||||
|
||||
let p = make_proof::<u64>(&pk);
|
||||
|
||||
let mut g = c.benchmark_group("Public key encryption");
|
||||
g.sample_size(10);
|
||||
|
||||
g.bench_function("Prove 32-bit public encryption", |b| {
|
||||
b.iter(|| {
|
||||
let _ = make_proof::<u64>(&pk);
|
||||
});
|
||||
});
|
||||
|
||||
g.bench_function("Verify 32-bit public encryption", |b| {
|
||||
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
|
||||
let u = InnerProductVerifierKnowledge::get_u();
|
||||
|
||||
b.iter(|| {
|
||||
let mut t = Transcript::new(b"test");
|
||||
|
||||
p.verify(&mut t, &pk.vk, &gen.g, &gen.h, &u).unwrap();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, tfhe_secret_proof, tfhe_public_proof,);
|
||||
criterion_main!(benches);
|
||||
BIN
sunscreen_tfhe/images/circuit_bootstrapping.graffle
Normal file
BIN
sunscreen_tfhe/images/circuit_bootstrapping.graffle
Normal file
Binary file not shown.
BIN
sunscreen_tfhe/images/circuit_bootstrapping.png
Normal file
BIN
sunscreen_tfhe/images/circuit_bootstrapping.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 103 KiB |
34
sunscreen_tfhe/mont.py
Normal file
34
sunscreen_tfhe/mont.py
Normal file
@@ -0,0 +1,34 @@
|
||||
p = 13
|
||||
r = 16
|
||||
rinv = (r**(p-2)) % p
|
||||
|
||||
p_inv = (p**(p-2)) % r
|
||||
p_prime = (p_inv * (r - 1)) % r
|
||||
print((p_inv * p) % r)
|
||||
|
||||
print((r * rinv) % p)
|
||||
|
||||
def to_mont(x):
|
||||
return (r * x) % p
|
||||
|
||||
def from_mont(x):
|
||||
return (x * rinv) % p
|
||||
|
||||
def mont_add(x, y):
|
||||
return (x + y) % p
|
||||
|
||||
def mont_mul(x, y):
|
||||
return (x * y * rinv) % p
|
||||
|
||||
def redc():
|
||||
pass
|
||||
|
||||
a = to_mont(5)
|
||||
b = to_mont(6)
|
||||
|
||||
|
||||
c = mont_mul(a, b)
|
||||
d = mont_add(a, b)
|
||||
|
||||
print(from_mont(c), from_mont(d))
|
||||
|
||||
293
sunscreen_tfhe/src/dst.rs
Normal file
293
sunscreen_tfhe/src/dst.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
use crate::scratch::Pod;
|
||||
|
||||
macro_rules! dst {
|
||||
($(#[$meta:meta])* $t:ty, $ref_t:ty, $wrapper:ty, ($($derive:ident),* $(,)? ), ($($t_bounds:ty),* $(,)? )) => {
|
||||
paste::paste! {
|
||||
|
||||
$(#[$meta])*
|
||||
#[derive($($derive,)*)]
|
||||
pub struct $t<T> where T: Clone $(+ $t_bounds)* {
|
||||
data: Vec<$wrapper<T>>
|
||||
}
|
||||
|
||||
/// A reference to the data structure.
|
||||
#[repr(transparent)]
|
||||
pub struct $ref_t<T> where T: Clone $(+ $t_bounds)* {
|
||||
data: [$wrapper<T>],
|
||||
}
|
||||
|
||||
impl<T> $ref_t<T> where T: Clone $(+ $t_bounds)* {
|
||||
/// Clones the contents of rhs into self
|
||||
pub fn clone_from_ref(&mut self, rhs: &$ref_t<T>) {
|
||||
for (l, r) in self.data.iter_mut().zip(rhs.data.iter()) {
|
||||
*l = r.clone();
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a slice view of the data representing a $t.
|
||||
pub fn as_slice(&self) -> &[$wrapper<T>] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Returns a mutable slice view of the data representing a $t.
|
||||
pub fn as_mut_slice(&mut self) -> &mut [$wrapper<T>] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
/// Move the contents of rhs into self.
|
||||
pub fn move_from(&mut self, rhs: $t<T>) {
|
||||
for (l, r) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
||||
*l = r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> crate::dst::FromSlice<$wrapper<T>> for $ref_t<T> where T: Clone $(+ $t_bounds)* {
|
||||
fn from_slice(s: &[$wrapper<T>]) -> &$ref_t<T> {
|
||||
unsafe { &*(s as *const [$wrapper<T>] as *const $ref_t<T>) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> crate::dst::FromMutSlice<$wrapper<T>> for $ref_t<T> where T: Clone $(+ $t_bounds)* {
|
||||
fn from_mut_slice(s: &mut [$wrapper<T>]) -> &mut $ref_t<T> {
|
||||
unsafe { &mut *(s as *mut [$wrapper<T>] as *mut $ref_t<T>) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> $ref_t<T> where T: Clone $(+ $t_bounds)*, $wrapper<T>: num::Zero {
|
||||
/// Clears the contents of self to contain zero
|
||||
pub fn clear(&mut self) {
|
||||
|
||||
for x in self.as_mut_slice() {
|
||||
*x = <$wrapper<T> as num::Zero>::zero();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::borrow::Borrow< $ref_t <T>> for $t<T> where T: Clone $(+ $t_bounds)* {
|
||||
fn borrow(&self) -> &$ref_t<T> {
|
||||
let ptr = self.data.as_slice() as *const [$wrapper<T>] as *const $ref_t<T>;
|
||||
|
||||
unsafe { &*ptr }
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::convert::AsRef< $ref_t <T>> for $t<T> where T: Clone $(+ $t_bounds)*
|
||||
{
|
||||
fn as_ref(&self) -> &$ref_t<T> {
|
||||
<Self as std::borrow::Borrow<$ref_t <T>>>::borrow(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::borrow::BorrowMut< $ref_t<T>> for $t<T> where T: Clone $(+ $t_bounds)* {
|
||||
fn borrow_mut(&mut self) -> &mut $ref_t<T> {
|
||||
let ptr = self.data.as_mut_slice() as *mut [$wrapper<T>] as *mut $ref_t<T>;
|
||||
|
||||
unsafe { &mut *ptr }
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::borrow::ToOwned for $ref_t<T> where T: Clone $(+ $t_bounds)* {
|
||||
type Owned = $t<T>;
|
||||
|
||||
fn to_owned(&self) -> Self::Owned {
|
||||
$t { data: self.data.to_owned() }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for $t<T> where T: Clone $(+ $t_bounds)* {
|
||||
type Target = $ref_t<T>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
<Self as std::borrow::Borrow::<$ref_t<T>>>::borrow(&self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for $t<T> where T: Clone $(+ $t_bounds)* {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
<Self as std::borrow::BorrowMut::<$ref_t<T>>>::borrow_mut(self)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! dst_iter {
|
||||
($t:ty, $t_mut:ty, $wrapper_type: ty, $item_ref:ty, ($($t_bounds:ty,)*)) => {
|
||||
paste::paste!{
|
||||
/// An iterator to access references to an underlying type.
|
||||
pub struct $t<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
data: &'a [$wrapper_type<T>],
|
||||
stride: usize,
|
||||
front_idx: usize,
|
||||
back_idx: i64
|
||||
}
|
||||
|
||||
impl<'a, T> $t<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
/// Create a new iterator that emits references to the contained type
|
||||
/// by striding over the underlying data.
|
||||
pub fn new(data: &'a [$wrapper_type<T>], stride: usize) -> Self {
|
||||
assert_eq!(data.len() % stride, 0);
|
||||
|
||||
Self {
|
||||
data,
|
||||
stride,
|
||||
front_idx: 0,
|
||||
back_idx: (data.len() as i64) - (stride as i64)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// The total number of items this iterator will emit and has emitted.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This method returns the same value regardless of how many times
|
||||
/// `next` has been called.
|
||||
///
|
||||
/// This operation does not consume the iterator.
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len() / self.stride
|
||||
}
|
||||
|
||||
/// Returns true if the iterator is empty.
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::iter::Iterator for $t<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
type Item = &'a $item_ref<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.front_idx >= self.data.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if (self.front_idx as i64) == self.back_idx + self.stride as i64 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let data = <$item_ref<T> as crate::dst::FromSlice<$wrapper_type<T>>>::from_slice(
|
||||
&self.data[self.front_idx..self.front_idx + self.stride]
|
||||
);
|
||||
|
||||
self.front_idx += self.stride;
|
||||
|
||||
Some(data)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::iter::DoubleEndedIterator for $t<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
fn next_back(&mut self) -> Option<Self::Item> {
|
||||
if self.back_idx < 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
if (self.front_idx as i64) == self.back_idx + self.stride as i64 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let start = self.back_idx as usize;
|
||||
|
||||
let data = <$item_ref<T> as crate::dst::FromSlice<$wrapper_type<T>>>::from_slice(
|
||||
&self.data[start..start + self.stride]
|
||||
);
|
||||
|
||||
self.back_idx -= (self.stride as i64);
|
||||
|
||||
Some(data)
|
||||
}
|
||||
}
|
||||
|
||||
/// A mutable iterator to access references to an underlying type.
|
||||
pub struct $t_mut<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
data: *mut $wrapper_type<T>,
|
||||
len: usize,
|
||||
stride: usize,
|
||||
idx: usize,
|
||||
_phantom: std::marker::PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl<'a, T> $t_mut<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
/// Create a new iterator that emits references to the contained type
|
||||
/// by striding over the underlying data, mutably.
|
||||
pub fn new(data: &'a mut [$wrapper_type<T>], stride: usize) -> Self {
|
||||
assert_eq!(data.len() % stride, 0);
|
||||
|
||||
Self {
|
||||
idx: 0,
|
||||
stride,
|
||||
data: data.as_mut_ptr(),
|
||||
len: data.len(),
|
||||
_phantom: std::marker::PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// The total number of items this iterator will emit and has emitted.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This method returns the same value regardless of how many times
|
||||
/// `next` has been called.
|
||||
///
|
||||
/// This operation does not consume the iterator.
|
||||
pub fn len(&self) -> usize {
|
||||
self.len / self.stride
|
||||
}
|
||||
|
||||
/// Returns true if the iterator is empty.
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> std::iter::Iterator for $t_mut<'a, T> where T: Clone $(+ $t_bounds)* {
|
||||
type Item = &'a mut $item_ref<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.idx == self.len {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Since the slices emitted by this iterator don't overlap, this is sound.
|
||||
let data = unsafe {
|
||||
let slice = self.data.add(self.idx);
|
||||
std::slice::from_raw_parts_mut(slice, self.stride)
|
||||
};
|
||||
|
||||
self.idx += self.stride;
|
||||
|
||||
Some(<$item_ref<T> as crate::dst::FromMutSlice<$wrapper_type<T>>>::from_mut_slice(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub type NoWrapper<T> = T;
|
||||
|
||||
pub trait OverlaySize {
|
||||
type Inputs: Copy + Clone;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize;
|
||||
}
|
||||
|
||||
impl<S: Pod> OverlaySize for [S] {
|
||||
type Inputs = usize;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
t
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FromSlice<T> {
|
||||
fn from_slice(data: &[T]) -> &Self;
|
||||
}
|
||||
|
||||
pub trait FromMutSlice<T> {
|
||||
fn from_mut_slice(data: &mut [T]) -> &mut Self;
|
||||
}
|
||||
98
sunscreen_tfhe/src/entities/bivariate_lookup_table.rs
Normal file
98
sunscreen_tfhe/src/entities/bivariate_lookup_table.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::{FromMutSlice, FromSlice, OverlaySize},
|
||||
entities::PolynomialRef,
|
||||
ops::{bootstrapping::generate_bivariate_lut, encryption::trivially_encrypt_glwe_ciphertext},
|
||||
scratch::allocate_scratch_ref,
|
||||
CarryBits, GlweDef, GlweDimension, PlaintextBits, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{GlweCiphertextRef, UnivariateLookupTableRef};
|
||||
|
||||
dst! {
|
||||
/// Lookup table for a bivariate function used during
|
||||
/// [`programmable_bootstrap_bivariate`](crate::ops::bootstrapping::programmable_bootstrap_bivariate)
|
||||
/// See that function for more details.
|
||||
BivariateLookupTable,
|
||||
BivariateLookupTableRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for BivariateLookupTableRef<S> {
|
||||
type Inputs = GlweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlweCiphertextRef::<S>::size(t)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> BivariateLookupTable<S> {
|
||||
/// Creates a [BivariateLookupTable] filled with the result of
|
||||
/// a function applied to every possible pair of plaintext inputs.
|
||||
pub fn trivial_from_fn<F>(
|
||||
map: F,
|
||||
glwe: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
carry_bits: CarryBits,
|
||||
) -> Self
|
||||
where
|
||||
F: Fn(u64, u64) -> u64,
|
||||
{
|
||||
let mut lut = BivariateLookupTable {
|
||||
data: vec![Torus::zero(); BivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
};
|
||||
|
||||
lut.fill_trivial_from_fn(map, glwe, plaintext_bits, carry_bits);
|
||||
|
||||
lut
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> BivariateLookupTableRef<S> {
|
||||
/// Convert a [BivariateLookupTableRef] to a [UnivariateLookupTableRef].
|
||||
pub fn as_univariate(&self) -> &UnivariateLookupTableRef<S> {
|
||||
// This works because a bivariate lookup table is just a univariate
|
||||
// lookup table.
|
||||
UnivariateLookupTableRef::from_slice(&self.data)
|
||||
}
|
||||
|
||||
/// Gets a copy of the underlying [GlweCiphertextRef] from the
|
||||
/// [BivariateLookupTableRef].
|
||||
pub fn glwe(&self) -> &GlweCiphertextRef<S> {
|
||||
GlweCiphertextRef::from_slice(&self.data)
|
||||
}
|
||||
|
||||
/// Gets a mutable copy of the underlying [GlweCiphertextRef] from the
|
||||
/// [BivariateLookupTableRef].
|
||||
pub fn glwe_mut(&mut self) -> &mut GlweCiphertextRef<S> {
|
||||
GlweCiphertextRef::from_mut_slice(&mut self.data)
|
||||
}
|
||||
|
||||
/// Fills the [BivariateLookupTableRef] with the result of a bivariate
|
||||
/// function.
|
||||
pub fn fill_trivial_from_fn<F: Fn(u64, u64) -> u64>(
|
||||
&mut self,
|
||||
map: F,
|
||||
glwe: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
carry_bits: CarryBits,
|
||||
) {
|
||||
allocate_scratch_ref!(poly, PolynomialRef<Torus<S>>, (glwe.dim.polynomial_degree));
|
||||
|
||||
generate_bivariate_lut(poly, map, glwe, plaintext_bits, carry_bits);
|
||||
|
||||
trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe);
|
||||
}
|
||||
|
||||
/// Creates a lookup table filled with the same value at every entry.
|
||||
pub fn fill_with_constant(&mut self, val: S, glwe: &GlweDef, plaintext_bits: PlaintextBits) {
|
||||
self.clear();
|
||||
for o in self.glwe_mut().b_mut(glwe).coeffs_mut() {
|
||||
*o = Torus::encode(val, plaintext_bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
137
sunscreen_tfhe/src/entities/blind_rotation_shift.rs
Normal file
137
sunscreen_tfhe/src/entities/blind_rotation_shift.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use num::{Complex, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
entities::{
|
||||
GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, GgswCiphertextFftRef,
|
||||
GgswCiphertextIterator, GgswCiphertextIteratorMut, GgswCiphertextRef,
|
||||
},
|
||||
GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// An encrypted amount to rotate the polynomials in a GLWE ciphertext by.
|
||||
/// The [BlindRotationShiftFft] variant of this type is used by the
|
||||
/// [`blind_rotate`](crate::ops::bootstrapping::blind_rotation) function.
|
||||
BlindRotationShift,
|
||||
BlindRotationShiftRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for BlindRotationShiftRef<S> {
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
let n_bits = (t.0.polynomial_degree.0 as u64).ilog2() as usize;
|
||||
|
||||
GgswCiphertextRef::<S>::size(t) * n_bits
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> BlindRotationShift<S> {
|
||||
/// Create a new zero [BlindRotationShift] with the given parameters.
|
||||
///
|
||||
/// A blind rotation shift is a collection of GGSW ciphertexts, each of
|
||||
/// which encrypts a single bit in position `i` representing how much to
|
||||
/// shift the input by as `2^i`.
|
||||
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let len = BlindRotationShiftRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> BlindRotationShiftRef<S> {
|
||||
/// Iterate over the rows of the [BlindRotationShift].
|
||||
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextIterator<S> {
|
||||
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextIterator::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Iterate over the rows of the [BlindRotationShift] mutably.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextIteratorMut<S> {
|
||||
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextIteratorMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
}
|
||||
|
||||
dst! {
|
||||
/// An encrypted amount to rotate the polynomials in a GLWE ciphertext by.
|
||||
/// Used by the
|
||||
/// [`blind_rotate`](crate::ops::bootstrapping::blind_rotation) function.
|
||||
/// The non-FFT version of this type is [BlindRotationShift].
|
||||
BlindRotationShiftFft,
|
||||
BlindRotationShiftFftRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
()
|
||||
}
|
||||
|
||||
impl OverlaySize for BlindRotationShiftFftRef<Complex<f64>> {
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
let n_bits = (t.0.polynomial_degree.0 as u64).ilog2() as usize;
|
||||
|
||||
GgswCiphertextFftRef::<Complex<f64>>::size(t) * n_bits
|
||||
}
|
||||
}
|
||||
|
||||
impl BlindRotationShiftFft<Complex<f64>> {
|
||||
/// Create a new zero [BlindRotationShiftFft] with the given parameters.
|
||||
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let len = BlindRotationShiftFftRef::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BlindRotationShiftFftRef<Complex<f64>> {
|
||||
/// Iterate over the rows of the [BlindRotationShiftFft].
|
||||
pub fn rows(
|
||||
&self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextFftIterator<Complex<f64>> {
|
||||
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextFftIterator::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Iterate over the rows of the [BlindRotationShiftFft] mutably.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextFftIteratorMut<Complex<f64>> {
|
||||
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextFftIteratorMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
|
||||
/// Compute the inverse FFT of the [BlindRotationShiftFft] and store the
|
||||
/// result in the result [BlindRotationShift].
|
||||
pub fn ifft<S: TorusOps>(
|
||||
&self,
|
||||
result: &mut BlindRotationShiftRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
|
||||
s.ifft(r, params, radix);
|
||||
}
|
||||
}
|
||||
}
|
||||
160
sunscreen_tfhe/src/entities/bootstrap_key.rs
Normal file
160
sunscreen_tfhe/src/entities/bootstrap_key.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use num::{Complex, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
entities::{
|
||||
GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, GgswCiphertextFftRef,
|
||||
GgswCiphertextIterator, GgswCiphertextIteratorMut, GgswCiphertextRef,
|
||||
},
|
||||
GlweDef, GlweDimension, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// Keys used for bootstrapping. The [BootstrapKeyFft] variant of this type
|
||||
/// is used by the bootstrapping functions such as
|
||||
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
|
||||
BootstrapKey,
|
||||
BootstrapKeyRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for BootstrapKeyRef<S> {
|
||||
type Inputs = (LweDimension, GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GgswCiphertextRef::<S>::size((t.1, t.2)) * t.0 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> BootstrapKey<S> {
|
||||
/// Create a new zero [BootstrapKey] with the given parameters.
|
||||
///
|
||||
/// A bootstrapping key is a collection of GGSW ciphertexts, each of which
|
||||
/// encrypts a single bit of an LWE secret key. This representation cannot
|
||||
/// be directly with the bootstrapping functions, but the FFT version of the
|
||||
/// bootstrapping key that can be used with the bootstrapping functions can
|
||||
/// be created by calling the [BootstrapKeyRef::fft] method
|
||||
pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let len = BootstrapKeyRef::<S>::size((lwe_params.dim, glwe_params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> BootstrapKeyRef<S> {
|
||||
/// Iterate over the rows of the [BootstrapKey].
|
||||
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextIterator<S> {
|
||||
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextIterator::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Iterate over the rows of the [BootstrapKey] mutably.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextIteratorMut<S> {
|
||||
let stride = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextIteratorMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
|
||||
/// Perform an FFT on the [BootstrapKey] to obtain a [BootstrapKeyFft].
|
||||
pub fn fft(
|
||||
&self,
|
||||
result: &mut BootstrapKeyFftRef<Complex<f64>>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
|
||||
s.fft(r, params, radix);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dst! {
|
||||
/// Keys used for bootstrapping. Used by the bootstrapping functions such as
|
||||
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
|
||||
/// The non-FFT variant of this type is [BootstrapKey].
|
||||
BootstrapKeyFft,
|
||||
BootstrapKeyFftRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
()
|
||||
}
|
||||
|
||||
impl OverlaySize for BootstrapKeyFftRef<Complex<f64>> {
|
||||
type Inputs = (LweDimension, GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GgswCiphertextFftRef::<Complex<f64>>::size((t.1, t.2)) * t.0 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl BootstrapKeyFft<Complex<f64>> {
|
||||
/// Create a new zero [BootstrapKeyFft] with the given parameters.
|
||||
///
|
||||
/// A bootstrapping key is a collection of GGSW ciphertexts, each of which
|
||||
/// encrypts a single bit of an LWE secret key. In this representation, the
|
||||
/// GGSW ciphertexts are in the frequency domain and can be used directly by
|
||||
/// the bootstrapping functions such as
|
||||
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
|
||||
pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BootstrapKeyFftRef<Complex<f64>> {
|
||||
/// Iterate over the rows of the [BootstrapKeyFft].
|
||||
pub fn rows(
|
||||
&self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextFftIterator<Complex<f64>> {
|
||||
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextFftIterator::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Iterate over the rows of the [BootstrapKeyFft] mutably.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextFftIteratorMut<Complex<f64>> {
|
||||
let stride = GgswCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextFftIteratorMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
|
||||
/// Perform an IFFT on the [BootstrapKeyFft] to obtain a [BootstrapKey].
|
||||
pub fn ifft<S: TorusOps>(
|
||||
&self,
|
||||
result: &mut BootstrapKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
|
||||
s.ifft(r, params, radix);
|
||||
}
|
||||
}
|
||||
|
||||
/// Asserts that the [BootstrapKeyFft] is valid for the given parameters.
|
||||
#[inline(always)]
|
||||
pub(crate) fn assert_valid(&self, lwe: &LweDef, glwe: &GlweDef, radix: &RadixDecomposition) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
BootstrapKeyFftRef::size((lwe.dim, glwe.dim, radix.count))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize, GlweDef, GlweDimension, LweDef, LweDimension,
|
||||
PrivateFunctionalKeyswitchLweCount, RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
PrivateFunctionalKeyswitchKeyIter, PrivateFunctionalKeyswitchKeyIterMut,
|
||||
PrivateFunctionalKeyswitchKeyRef,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// Key for Circuit Bootstrapping Key Switching.
|
||||
CircuitBootstrappingKeyswitchKeys,
|
||||
CircuitBootstrappingKeyswitchKeysRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for CircuitBootstrappingKeyswitchKeysRef<S> {
|
||||
type Inputs = (LweDimension, GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
PrivateFunctionalKeyswitchKeyRef::<S>::size((
|
||||
t.0,
|
||||
t.1,
|
||||
t.2,
|
||||
PrivateFunctionalKeyswitchLweCount(1),
|
||||
)) * (t.1.size.0 + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> CircuitBootstrappingKeyswitchKeys<S> {
|
||||
/// Allocate a new [`CircuitBootstrappingKeyswitchKeys`] for the given parameters.
|
||||
pub fn new(from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let len = CircuitBootstrappingKeyswitchKeysRef::<S>::size((
|
||||
from_lwe.dim,
|
||||
to_glwe.dim,
|
||||
radix.count,
|
||||
));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> CircuitBootstrappingKeyswitchKeysRef<S> {
|
||||
/// Get an iterator over the contained [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s.
|
||||
pub fn keys(
|
||||
&self,
|
||||
lwe: &LweDef,
|
||||
glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> PrivateFunctionalKeyswitchKeyIter<S> {
|
||||
let stride = PrivateFunctionalKeyswitchKeyRef::<S>::size((
|
||||
lwe.dim,
|
||||
glwe.dim,
|
||||
radix.count,
|
||||
PrivateFunctionalKeyswitchLweCount(1),
|
||||
));
|
||||
|
||||
PrivateFunctionalKeyswitchKeyIter::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Get a mutable iterator over the contained [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s.
|
||||
pub fn keys_mut(
|
||||
&mut self,
|
||||
lwe: &LweDef,
|
||||
glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> PrivateFunctionalKeyswitchKeyIterMut<S> {
|
||||
let stride = PrivateFunctionalKeyswitchKeyRef::<S>::size((
|
||||
lwe.dim,
|
||||
glwe.dim,
|
||||
radix.count,
|
||||
PrivateFunctionalKeyswitchLweCount(1),
|
||||
));
|
||||
|
||||
PrivateFunctionalKeyswitchKeyIterMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Assert these keys are valid under the given parameters.
|
||||
pub fn assert_valid(&self, from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
CircuitBootstrappingKeyswitchKeysRef::<S>::size((
|
||||
from_lwe.dim,
|
||||
to_glwe.dim,
|
||||
radix.count,
|
||||
))
|
||||
);
|
||||
}
|
||||
}
|
||||
114
sunscreen_tfhe/src/entities/ggsw_ciphertext.rs
Normal file
114
sunscreen_tfhe/src/entities/ggsw_ciphertext.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use num::{Complex, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize, ops::ciphertext::external_product_ggsw_glwe, GlweDef, GlweDimension,
|
||||
RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
GgswCiphertextFftRef, GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef,
|
||||
GlweCiphertext, GlweCiphertextRef,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// A GGSW ciphertext. For the FFT variant, see
|
||||
/// [`GgswCiphertextFft`](crate::entities::GgswCiphertextFft).
|
||||
GgswCiphertext,
|
||||
GgswCiphertextRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
dst_iter! { GgswCiphertextIterator, GgswCiphertextIteratorMut, Torus, GgswCiphertextRef, (TorusOps,)}
|
||||
|
||||
impl<S> OverlaySize for GgswCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlevCiphertextRef::<S>::size(t) * (t.0.size.0 + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GgswCiphertext<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Create a new zero GGSW ciphertext with the given parameters.
|
||||
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let elems = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); elems],
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new GGSW ciphertext from a slice of Torus elements.
|
||||
pub fn from_slice(data: &[Torus<S>], params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let elems = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
assert_eq!(data.len(), elems);
|
||||
|
||||
Self {
|
||||
data: data.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the external product between a GGSW ciphertext and a GLWE ciphertext.
|
||||
/// GGSW ⊡ GLWE -> GLWE
|
||||
pub fn external_product(
|
||||
&self,
|
||||
glwe: &GlweCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlweCiphertext<S> {
|
||||
external_product_ggsw_glwe(self, glwe, params, radix)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GgswCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the rows of the GGSW ciphertext, which are
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
|
||||
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GlevCiphertextIterator<S> {
|
||||
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GlevCiphertextIterator::new(&self.data, stride)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the GGSW ciphertext, which are
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextIteratorMut<S> {
|
||||
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GlevCiphertextIteratorMut::new(&mut self.data, stride)
|
||||
}
|
||||
|
||||
/// Compute the FFT of each of the GLWE ciphertexts in the GGSW ciphertext.
|
||||
/// The result is stored in `result`.
|
||||
pub fn fft(
|
||||
&self,
|
||||
result: &mut GgswCiphertextFftRef<Complex<f64>>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
|
||||
s.fft(r, params);
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert that the GGSW ciphertext is valid for the given parameters.
|
||||
#[inline(always)]
|
||||
pub(crate) fn assert_valid(&self, glwe: &GlweDef, radix: &RadixDecomposition) {
|
||||
assert_eq!(self.as_slice().len(), Self::size((glwe.dim, radix.count)));
|
||||
}
|
||||
}
|
||||
79
sunscreen_tfhe/src/entities/ggsw_ciphertext_fft.rs
Normal file
79
sunscreen_tfhe/src/entities/ggsw_ciphertext_fft.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use num::{Complex, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
entities::GgswCiphertextRef,
|
||||
GlweDef, GlweDimension, RadixCount, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
use super::{GlevCiphertextFftIterator, GlevCiphertextFftIteratorMut, GlevCiphertextFftRef};
|
||||
|
||||
dst! {
|
||||
/// The FFT variant of a GGSW ciphertext. See
|
||||
/// [`GgswCiphertext`](crate::entities::GgswCiphertext) for more details.
|
||||
GgswCiphertextFft,
|
||||
GgswCiphertextFftRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
()
|
||||
}
|
||||
dst_iter! { GgswCiphertextFftIterator, GgswCiphertextFftIteratorMut, NoWrapper, GgswCiphertextFftRef, ()}
|
||||
|
||||
impl OverlaySize for GgswCiphertextFftRef<Complex<f64>> {
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlevCiphertextFftRef::<Complex<f64>>::size(t) * (t.0.size.0 + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl GgswCiphertextFft<Complex<f64>> {
|
||||
/// Creates a new GGSW ciphertext with FFT representation.
|
||||
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> GgswCiphertextFft<Complex<f64>> {
|
||||
let len = GgswCiphertextFftRef::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextFft {
|
||||
data: vec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GgswCiphertextFftRef<Complex<f64>> {
|
||||
/// Returns an iterator over the rows of the GGSW ciphertext, which are
|
||||
/// [GlevCiphertextFft](crate::entities::GlevCiphertextFft)s.
|
||||
pub fn rows(
|
||||
&self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextFftIterator<Complex<f64>> {
|
||||
let stride = GlevCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
|
||||
|
||||
GlevCiphertextFftIterator::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the GGSW ciphertext, which are
|
||||
/// [GlevCiphertextFft](crate::entities::GlevCiphertextFft)s.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextFftIteratorMut<Complex<f64>> {
|
||||
let stride = GlevCiphertextFftRef::<Complex<f64>>::size((params.dim, radix.count));
|
||||
|
||||
GlevCiphertextFftIteratorMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
|
||||
/// Computes the inverse FFT of the GGSW ciphertexts and stores computation
|
||||
/// in `result`.
|
||||
pub fn ifft<S: TorusOps>(
|
||||
&self,
|
||||
result: &mut GgswCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
for (s, r) in self.rows(params, radix).zip(result.rows_mut(params, radix)) {
|
||||
s.ifft(r, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
58
sunscreen_tfhe/src/entities/glev_ciphertext.rs
Normal file
58
sunscreen_tfhe/src/entities/glev_ciphertext.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use num::Complex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{dst::OverlaySize, GlweDef, GlweDimension, RadixCount, Torus, TorusOps};
|
||||
|
||||
use super::{
|
||||
GlevCiphertextFftRef, GlweCiphertextIterator, GlweCiphertextIteratorMut, GlweCiphertextRef,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// A GLEV ciphertext. For the FFT variant, see
|
||||
/// [`GlevCiphertextFft`](crate::entities::GlevCiphertextFft).
|
||||
GlevCiphertext,
|
||||
GlevCiphertextRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps,)
|
||||
}
|
||||
dst_iter! { GlevCiphertextIterator, GlevCiphertextIteratorMut, Torus, GlevCiphertextRef, (TorusOps,)}
|
||||
|
||||
impl<S> OverlaySize for GlevCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlweCiphertextRef::<S>::size(t.0) * t.1 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlevCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the rows of the GLEV ciphertext, which are
|
||||
/// [`GlweCiphertext`](crate::entities::GlweCiphertext)s.
|
||||
pub fn glwe_ciphertexts(&self, params: &GlweDef) -> GlweCiphertextIterator<S> {
|
||||
GlweCiphertextIterator::new(&self.data, GlweCiphertextRef::<S>::size(params.dim))
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the GLEV ciphertext, which are
|
||||
/// [`GlweCiphertext](crate::entities::GlweCiphertext)s.
|
||||
pub fn glwe_ciphertexts_mut(&mut self, params: &GlweDef) -> GlweCiphertextIteratorMut<S> {
|
||||
GlweCiphertextIteratorMut::new(&mut self.data, GlweCiphertextRef::<S>::size(params.dim))
|
||||
}
|
||||
|
||||
/// Compute the FFT of each of the GLWE ciphertexts in the GLEV ciphertext.
|
||||
/// The result is stored in `result`.
|
||||
pub fn fft(&self, result: &mut GlevCiphertextFftRef<Complex<f64>>, params: &GlweDef) {
|
||||
for (i, fft) in self
|
||||
.glwe_ciphertexts(params)
|
||||
.zip(result.glwe_ciphertexts_mut(params))
|
||||
{
|
||||
i.fft(fft, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
65
sunscreen_tfhe/src/entities/glev_ciphertext_fft.rs
Normal file
65
sunscreen_tfhe/src/entities/glev_ciphertext_fft.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use num::Complex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
GlweDef, GlweDimension, RadixCount, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
GlevCiphertextRef, GlweCiphertextFftIterator, GlweCiphertextFftIteratorMut,
|
||||
GlweCiphertextFftRef,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// The FFT variant of a GLEV ciphertext. See
|
||||
/// [GlevCiphertext](crate::entities::GlevCiphertext) for more details.
|
||||
GlevCiphertextFft,
|
||||
GlevCiphertextFftRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
()
|
||||
}
|
||||
dst_iter! { GlevCiphertextFftIterator, GlevCiphertextFftIteratorMut, NoWrapper, GlevCiphertextFftRef, ()}
|
||||
|
||||
impl OverlaySize for GlevCiphertextFftRef<Complex<f64>> {
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlweCiphertextFftRef::<Complex<f64>>::size(t.0) * t.1 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl GlevCiphertextFftRef<Complex<f64>> {
|
||||
/// Returns an iterator over the rows of the GLEV ciphertext, which are
|
||||
/// [`GlweCiphertextFft`](crate::entities::GlweCiphertextFft)s.
|
||||
pub fn glwe_ciphertexts(&self, params: &GlweDef) -> GlweCiphertextFftIterator<Complex<f64>> {
|
||||
GlweCiphertextFftIterator::new(
|
||||
&self.data,
|
||||
GlweCiphertextFftRef::<Complex<f64>>::size(params.dim),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the GLEV ciphertext, which are
|
||||
/// [`GlweCiphertextFft`](crate::entities::GlweCiphertextFft)s.
|
||||
pub fn glwe_ciphertexts_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
) -> GlweCiphertextFftIteratorMut<Complex<f64>> {
|
||||
GlweCiphertextFftIteratorMut::new(
|
||||
&mut self.data,
|
||||
GlweCiphertextFftRef::<Complex<f64>>::size(params.dim),
|
||||
)
|
||||
}
|
||||
|
||||
/// Computes the inverse FFT of the GLEV ciphertexts and stores computation
|
||||
/// in `result`.
|
||||
pub fn ifft<S: TorusOps>(&self, result: &mut GlevCiphertextRef<S>, params: &GlweDef) {
|
||||
for (i, ifft) in self
|
||||
.glwe_ciphertexts(params)
|
||||
.zip(result.glwe_ciphertexts_mut(params))
|
||||
{
|
||||
i.ifft(ifft, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
159
sunscreen_tfhe/src/entities/glwe_ciphertext.rs
Normal file
159
sunscreen_tfhe/src/entities/glwe_ciphertext.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use num::{Complex, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{FromMutSlice, FromSlice, OverlaySize},
|
||||
entities::GgswCiphertextRef,
|
||||
macros::{impl_binary_op, impl_unary_op},
|
||||
ops::ciphertext::external_product_ggsw_glwe,
|
||||
GlweDef, GlweDimension, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
GlweCiphertextFftRef, GlweSecretKeyRef, PolynomialIterator, PolynomialIteratorMut,
|
||||
PolynomialRef,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// A GLWE ciphertext.
|
||||
GlweCiphertext,
|
||||
GlweCiphertextRef,
|
||||
Torus,
|
||||
(Debug, Clone, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
dst_iter! { GlweCiphertextIterator, GlweCiphertextIteratorMut, Torus, GlweCiphertextRef, (TorusOps,) }
|
||||
|
||||
// Also implements the assign operators.
|
||||
impl_binary_op!(Add, GlweCiphertext, (TorusOps,));
|
||||
impl_binary_op!(Sub, GlweCiphertext, (TorusOps,));
|
||||
impl_unary_op!(Neg, GlweCiphertext);
|
||||
|
||||
impl<S> OverlaySize for GlweCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = GlweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
// We have `n` a polynomials plus 1 b polynomial each of degree d.
|
||||
GlweSecretKeyRef::<S>::size(t) + t.polynomial_degree.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweCiphertext<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Initialize an empty (zero) GLWE ciphertext
|
||||
pub fn new(params: &GlweDef) -> GlweCiphertext<S> {
|
||||
params.dim.assert_valid();
|
||||
|
||||
let len = GlweCiphertextRef::<S>::size(params.dim);
|
||||
|
||||
let data = (0..len).map(|_| Torus::<S>::zero()).collect::<Vec<_>>();
|
||||
|
||||
GlweCiphertext { data }
|
||||
}
|
||||
|
||||
/// Computes the external product of a GLWE ciphertext and a GGSW ciphertext.
|
||||
/// GGSW ⊡ GLWE -> GLWE
|
||||
pub fn external_product(
|
||||
&self,
|
||||
ggsw: &GgswCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlweCiphertext<S> {
|
||||
external_product_ggsw_glwe(ggsw, self, params, radix)
|
||||
}
|
||||
|
||||
/// Generate a GLWE ciphertext from a slice of [crate::TorusOps] elements.
|
||||
pub fn from_slice(data: &[S], params: &GlweDef) -> GlweCiphertext<S> {
|
||||
assert_eq!(data.len(), GlweCiphertextRef::<S>::size(params.dim));
|
||||
|
||||
GlweCiphertext {
|
||||
data: data
|
||||
.iter()
|
||||
.map(|x| Torus::from(*x))
|
||||
.collect::<Vec<Torus<S>>>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
|
||||
pub fn a_b(
|
||||
&self,
|
||||
params: &GlweDef,
|
||||
) -> (PolynomialIterator<Torus<S>>, &PolynomialRef<Torus<S>>) {
|
||||
let (a, b) = self.data.as_ref().split_at(self.split_idx(params));
|
||||
|
||||
(
|
||||
PolynomialIterator::new(a, params.dim.polynomial_degree.0),
|
||||
PolynomialRef::from_slice(b),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns an interator over the a polynomials in a GLWE ciphertext.
|
||||
pub fn a(&self, params: &GlweDef) -> PolynomialIterator<Torus<S>> {
|
||||
self.a_b(params).0
|
||||
}
|
||||
|
||||
/// Returns a reference to the b polynomial in a GLWE ciphertext.
|
||||
pub fn b(&self, params: &GlweDef) -> &PolynomialRef<Torus<S>> {
|
||||
self.a_b(params).1
|
||||
}
|
||||
|
||||
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
|
||||
pub fn a_b_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
) -> (
|
||||
PolynomialIteratorMut<Torus<S>>,
|
||||
&mut PolynomialRef<Torus<S>>,
|
||||
) {
|
||||
let polynomial_degree = params.dim.polynomial_degree;
|
||||
let split_idx = self.split_idx(params);
|
||||
|
||||
let (a, b) = self.data.as_mut().split_at_mut(split_idx);
|
||||
|
||||
(
|
||||
PolynomialIteratorMut::new(a, polynomial_degree.0),
|
||||
PolynomialRef::from_mut_slice(b),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the a polynomials in a GLWE ciphertext.
|
||||
pub fn a_mut(&mut self, params: &GlweDef) -> PolynomialIteratorMut<Torus<S>> {
|
||||
self.a_b_mut(params).0
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the b polynomial in a GLWE ciphertext.
|
||||
pub fn b_mut(&mut self, params: &GlweDef) -> &mut PolynomialRef<Torus<S>> {
|
||||
self.a_b_mut(params).1
|
||||
}
|
||||
|
||||
fn split_idx(&self, params: &GlweDef) -> usize {
|
||||
params.dim.size.0 * params.dim.polynomial_degree.0
|
||||
}
|
||||
|
||||
/// Create an FFT transformed version of `self` stored to result.
|
||||
pub fn fft(&self, result: &mut GlweCiphertextFftRef<Complex<f64>>, params: &GlweDef) {
|
||||
for (a, fft) in self.a(params).zip(result.a_mut(params)) {
|
||||
a.fft(fft);
|
||||
}
|
||||
|
||||
self.b(params).fft(result.b_mut(params));
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn assert_valid(&self, params: &GlweDef) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
GlweCiphertextRef::<S>::size(params.dim)
|
||||
)
|
||||
}
|
||||
}
|
||||
139
sunscreen_tfhe/src/entities/glwe_ciphertext_fft.rs
Normal file
139
sunscreen_tfhe/src/entities/glwe_ciphertext_fft.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use num::{complex::Complex64, Complex, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{FromMutSlice, FromSlice, NoWrapper, OverlaySize},
|
||||
GlweDef, GlweDimension, TorusOps,
|
||||
};
|
||||
|
||||
use super::{GlweCiphertextRef, PolynomialFftIterator, PolynomialFftIteratorMut, PolynomialFftRef};
|
||||
|
||||
dst! {
|
||||
/// The FFT variant of a GLWE ciphertext. See
|
||||
/// [`GlweCiphertext`](crate::entities::GlweCiphertext) for more details.
|
||||
GlweCiphertextFft,
|
||||
GlweCiphertextFftRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
()
|
||||
}
|
||||
dst_iter! { GlweCiphertextFftIterator, GlweCiphertextFftIteratorMut, NoWrapper, GlweCiphertextFftRef, ()}
|
||||
|
||||
impl OverlaySize for GlweCiphertextFftRef<Complex<f64>> {
|
||||
type Inputs = GlweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
// FFT polynomials are half the length of their standard counterparts.
|
||||
PolynomialFftRef::<Complex<f64>>::size(t.polynomial_degree) * (t.size.0 + 1)
|
||||
}
|
||||
}
|
||||
|
||||
impl GlweCiphertextFft<Complex<f64>> {
|
||||
/// Creates a new zero GLWE ciphertext in the frequency domain.
|
||||
pub fn new(params: &GlweDef) -> Self {
|
||||
let len = GlweCiphertextFftRef::size(params.dim);
|
||||
|
||||
Self {
|
||||
data: vec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GlweCiphertextFftRef<Complex<f64>> {
|
||||
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
|
||||
pub fn a_b(
|
||||
&self,
|
||||
params: &GlweDef,
|
||||
) -> (
|
||||
PolynomialFftIterator<Complex<f64>>,
|
||||
&PolynomialFftRef<Complex<f64>>,
|
||||
) {
|
||||
let (a, b) = self.as_slice().split_at(self.split_idx(params));
|
||||
|
||||
(
|
||||
PolynomialFftIterator::new(a, params.dim.polynomial_degree.0 / 2),
|
||||
PolynomialFftRef::from_slice(b),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns an interator over the a polynomials in a GLWE ciphertext.
|
||||
pub fn a(&self, params: &GlweDef) -> PolynomialFftIterator<Complex<f64>> {
|
||||
self.a_b(params).0
|
||||
}
|
||||
|
||||
/// Returns a reference to the b polynomial in a GLWE ciphertext.
|
||||
pub fn b(&self, params: &GlweDef) -> &PolynomialFftRef<Complex64> {
|
||||
self.a_b(params).1
|
||||
}
|
||||
|
||||
/// Returns an iterator over the `a` polynomials and the `b` polynomial.
|
||||
pub fn a_b_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
) -> (
|
||||
PolynomialFftIteratorMut<Complex<f64>>,
|
||||
&mut PolynomialFftRef<Complex<f64>>,
|
||||
) {
|
||||
let polynomial_degree = params.dim.polynomial_degree;
|
||||
let split_idx = self.split_idx(params);
|
||||
|
||||
let (a, b) = self.as_mut_slice().split_at_mut(split_idx);
|
||||
|
||||
(
|
||||
PolynomialFftIteratorMut::new(a, polynomial_degree.0 / 2),
|
||||
PolynomialFftRef::from_mut_slice(b),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the a polynomials in a GLWE ciphertext.
|
||||
pub fn a_mut(&mut self, params: &GlweDef) -> PolynomialFftIteratorMut<Complex<f64>> {
|
||||
self.a_b_mut(params).0
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the b polynomial in a GLWE ciphertext.
|
||||
pub fn b_mut(&mut self, params: &GlweDef) -> &mut PolynomialFftRef<Complex<f64>> {
|
||||
self.a_b_mut(params).1
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn split_idx(&self, params: &GlweDef) -> usize {
|
||||
params.dim.size.0 * params.dim.polynomial_degree.0 / 2
|
||||
}
|
||||
|
||||
/// Computes the inverse FFT of the GLWE ciphertext and stores the
|
||||
/// computation in `result`.
|
||||
pub fn ifft<T: TorusOps>(&self, result: &mut GlweCiphertextRef<T>, params: &GlweDef) {
|
||||
for (a, fft) in self.a(params).zip(result.a_mut(params)) {
|
||||
a.ifft(fft);
|
||||
}
|
||||
|
||||
self.b(params).ifft(result.b_mut(params));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{entities::Polynomial, high_level::*, PlaintextBits, GLWE_1_1024_80};
|
||||
|
||||
#[test]
|
||||
fn can_decrypt_glwe_after_fft_roundtrip() {
|
||||
let params = GLWE_1_1024_80;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let pt = (0..params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let pt = Polynomial::new(&pt);
|
||||
|
||||
let mut ct = encryption::encrypt_glwe(&pt, &sk, ¶ms, bits);
|
||||
let fft = fft::fft_glwe(&ct, ¶ms);
|
||||
|
||||
fft.ifft(&mut ct, ¶ms);
|
||||
|
||||
let actual = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(actual, pt);
|
||||
}
|
||||
}
|
||||
75
sunscreen_tfhe/src/entities/glwe_keyswitch_key.rs
Normal file
75
sunscreen_tfhe/src/entities/glwe_keyswitch_key.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use num::Zero;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize, GlweDef, GlweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef};
|
||||
|
||||
// TODO: This GLWE keyswitch only works for switching to a new key with the same
|
||||
// parameter. Copy what is above but changed for polynomials to enable
|
||||
// converting to a different key parameter set.
|
||||
dst! {
|
||||
/// A GLWE keyswitch key used to switch a ciphertext from one key to another.
|
||||
/// See [`module`](crate::ops::keyswitch) documentation for more details.
|
||||
GlweKeyswitchKey,
|
||||
GlweKeyswitchKeyRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps,)
|
||||
}
|
||||
|
||||
impl<S> OverlaySize for GlweKeyswitchKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = (GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlevCiphertextRef::<S>::size(t) * (t.0.size.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweKeyswitchKey<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Creates a new GLWE keyswitch key. This enables switching to a new key as
|
||||
/// well as switching from the `original_params` that define the first key
|
||||
/// to the `new_params` that define the second key.
|
||||
pub fn new(params: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
// TODO: Shouldn't this function take 2 GlweDefs?
|
||||
// Ryan: to whoever wrote this, yes, see the above todo next to the dst.
|
||||
let elems = GlweKeyswitchKeyRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); elems],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweKeyswitchKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the rows of the GLWE keyswitch key, which are
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
|
||||
pub fn rows(&self, params: &GlweDef, radix: &RadixDecomposition) -> GlevCiphertextIterator<S> {
|
||||
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GlevCiphertextIterator::new(&self.data, stride)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the GLWE keyswitch key, which are
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextIteratorMut<S> {
|
||||
let stride = GlevCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
GlevCiphertextIteratorMut::new(&mut self.data, stride)
|
||||
}
|
||||
}
|
||||
385
sunscreen_tfhe/src/entities/glwe_secret_key.rs
Normal file
385
sunscreen_tfhe/src/entities/glwe_secret_key.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{FromSlice, NoWrapper, OverlaySize},
|
||||
entities::GgswCiphertext,
|
||||
macros::{impl_binary_op, impl_unary_op},
|
||||
ops::encryption::{
|
||||
decrypt_glwe_ciphertext, encrypt_ggsw_ciphertext, encrypt_glwe_ciphertext_secret,
|
||||
},
|
||||
rand::{binary, uniform_torus},
|
||||
GlweDef, GlweDimension, PlaintextBits, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
GlweCiphertext, GlweCiphertextRef, LweSecretKeyRef, Polynomial, PolynomialIterator,
|
||||
PolynomialIteratorMut, PolynomialRef,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// A GLWE secret key. This is a list of `s` polynomials of degree `n` where
|
||||
/// `n` is the polynomial degree. The length of the list is `k` where `k` is
|
||||
/// the size of the GLWE secret key.
|
||||
GlweSecretKey,
|
||||
GlweSecretKeyRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps,)
|
||||
}
|
||||
|
||||
// We can use these macros (which does operations on the underlying data vector)
|
||||
// because addition, subtraction, and negation of polynomials is just element
|
||||
// wise addition, subtraction, and negation of the underlying data vector.
|
||||
impl_binary_op!(Add, GlweSecretKey, (TorusOps,));
|
||||
impl_binary_op!(Sub, GlweSecretKey, (TorusOps,));
|
||||
impl_unary_op!(Neg, GlweSecretKey);
|
||||
|
||||
impl<S> OverlaySize for GlweSecretKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = GlweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
PolynomialRef::<S>::size(t.polynomial_degree) * t.size.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweSecretKey<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
fn generate(params: &GlweDef, torus_element_generator: impl Fn() -> S) -> GlweSecretKey<S> {
|
||||
params.dim.assert_valid();
|
||||
|
||||
let len = GlweSecretKeyRef::<S>::size(params.dim);
|
||||
|
||||
GlweSecretKey {
|
||||
data: (0..len)
|
||||
.map(|_| torus_element_generator())
|
||||
.collect::<Vec<_>>(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random binary GLWE secret key.
|
||||
pub fn generate_binary(params: &GlweDef) -> GlweSecretKey<S> {
|
||||
Self::generate(params, binary)
|
||||
}
|
||||
|
||||
/// Generate a secret key with uniformly random coefficients. This can be
|
||||
/// used when performing threshold decryption, which needs random secret
|
||||
/// keys that are uniform over the entire ciphertext modulus. Uniform
|
||||
/// secret keys are also valid keys for encryption/decryption but are not
|
||||
/// widely used.
|
||||
pub fn generate_uniform(params: &GlweDef) -> GlweSecretKey<S> {
|
||||
Self::generate(params, || uniform_torus::<S>().inner())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweSecretKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the `s` polynomials in a GLWE secret key.
|
||||
pub fn s(&self, params: &GlweDef) -> PolynomialIterator<S> {
|
||||
PolynomialIterator::new(&self.data, params.dim.polynomial_degree.0)
|
||||
}
|
||||
|
||||
/// Decrypts and decodes a GLWE ciphertext into a polynomial.
|
||||
pub fn decrypt_decode_glwe(
|
||||
&self,
|
||||
ct: &GlweCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> Polynomial<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let mut result = Polynomial::zero(ct.a_b(params).1.len());
|
||||
|
||||
decrypt_glwe_ciphertext(&mut result, ct, self, params);
|
||||
|
||||
result.map(|x| x.decode(plaintext_bits))
|
||||
}
|
||||
|
||||
/// Encodes and encrypts a message as a GLWE ciphertext using a secret key.
|
||||
pub fn encode_encrypt_glwe(
|
||||
&self,
|
||||
plaintext: &PolynomialRef<S>,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> GlweCiphertext<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let plaintext = plaintext.map(|x| Torus::encode(*x, plaintext_bits));
|
||||
|
||||
let mut ct = GlweCiphertext::new(params);
|
||||
|
||||
encrypt_glwe_ciphertext_secret(&mut ct, &plaintext, self, params);
|
||||
|
||||
ct
|
||||
}
|
||||
|
||||
/// Encodes and encrypts a message as a GGSW ciphertext using a secret key.
|
||||
pub fn encode_encrypt_ggsw(
|
||||
&self,
|
||||
msg: &PolynomialRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> GgswCiphertext<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let mut ggsw = GgswCiphertext::new(params, radix);
|
||||
|
||||
encrypt_ggsw_ciphertext(&mut ggsw, msg, self, params, radix, plaintext_bits);
|
||||
|
||||
ggsw
|
||||
}
|
||||
|
||||
/// Returns a representation of a GLWE secret key as an LWE secret key.
|
||||
/// This is a LWE secret key with dimension `N * k` where `N` is the
|
||||
/// polynomial degree and `k` is the size of the GLWE secret key.
|
||||
pub fn to_lwe_secret_key(&self) -> &LweSecretKeyRef<S> {
|
||||
LweSecretKeyRef::from_slice(&self.data)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn assert_valid(&self, params: &GlweDef) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
GlweSecretKeyRef::<S>::size(params.dim)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> GlweSecretKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an mutable iterator over the `s` polynomials in a GLWE secret
|
||||
/// key.
|
||||
pub fn s_mut(&mut self, params: &GlweDef) -> PolynomialIteratorMut<S> {
|
||||
PolynomialIteratorMut::new(&mut self.data, params.dim.polynomial_degree.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{high_level::*, GLWE_1_1024_80};
|
||||
|
||||
use num::traits::{WrappingAdd, WrappingNeg, WrappingSub};
|
||||
|
||||
#[test]
|
||||
fn secret_key_dimensions() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
assert_eq!(sk.s(¶ms).count(), params.dim.size.0);
|
||||
|
||||
for s_i in sk.s(¶ms) {
|
||||
assert_eq!(s_i.len(), params.dim.polynomial_degree.0);
|
||||
|
||||
for s in s_i.coeffs() {
|
||||
assert!(*s == 0 || *s == 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Addition
|
||||
|
||||
#[test]
|
||||
fn add_secret_keys() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk + sk2;
|
||||
|
||||
assert_eq!(sk3_expected, sk3.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_assign_secret_keys() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let mut sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
sk2 += sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_secret_key_refs() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.as_ref() + sk2.as_ref();
|
||||
|
||||
assert_eq!(sk3_expected, sk3.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_add_secret_keys() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.wrapping_add(&sk2);
|
||||
|
||||
assert_eq!(sk3_expected, sk3.data)
|
||||
}
|
||||
|
||||
// Subtraction
|
||||
|
||||
#[test]
|
||||
fn sub_secret_keys() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk - sk2;
|
||||
|
||||
assert_eq!(sk3_expected, sk3.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_assign_secret_keys() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let mut sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk2
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
sk2 -= sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_secret_key_refs() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.as_ref() - sk2.as_ref();
|
||||
|
||||
assert_eq!(sk3_expected, sk3.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_sub_secret_keys() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.wrapping_sub(&sk2);
|
||||
|
||||
assert_eq!(sk3_expected, sk3.data)
|
||||
}
|
||||
|
||||
// Negation
|
||||
|
||||
#[test]
|
||||
fn neg_secret_key() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2 = -sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_secret_key_ref() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2 = -sk.as_ref();
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_neg_secret_key() {
|
||||
let params = GLWE_1_1024_80;
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2 = sk.wrapping_neg();
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
}
|
||||
}
|
||||
44
sunscreen_tfhe/src/entities/lev_ciphertext.rs
Normal file
44
sunscreen_tfhe/src/entities/lev_ciphertext.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{dst::OverlaySize, LweDef, LweDimension, RadixCount, Torus, TorusOps};
|
||||
|
||||
use super::{LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef};
|
||||
|
||||
// Iteration over LWE ciphertexts
|
||||
dst! {
|
||||
/// A Lev Ciphertext is a collection of LWE ciphertexts.
|
||||
LevCiphertext,
|
||||
LevCiphertextRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps,)
|
||||
}
|
||||
dst_iter! { LevCiphertextIterator, LevCiphertextIteratorMut, Torus, LevCiphertextRef, (TorusOps,)}
|
||||
|
||||
impl<S> OverlaySize for LevCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = (LweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
LweCiphertextRef::<S>::size(t.0) * t.1 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LevCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the rows of the Lev ciphertext, which are
|
||||
/// [`LweCiphertext`](crate::entities::LweCiphertext)s.
|
||||
pub fn lwe_ciphertexts(&self, params: &LweDef) -> LweCiphertextIterator<S> {
|
||||
LweCiphertextIterator::new(&self.data, LweCiphertextRef::<S>::size(params.dim))
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the Lev ciphertext, which are
|
||||
/// [`LweCiphertext`](crate::entities::LweCiphertext)s.
|
||||
pub fn lwe_ciphertexts_mut(&mut self, params: &LweDef) -> LweCiphertextIteratorMut<S> {
|
||||
LweCiphertextIteratorMut::new(&mut self.data, LweCiphertextRef::<S>::size(params.dim))
|
||||
}
|
||||
}
|
||||
192
sunscreen_tfhe/src/entities/lwe_ciphertext.rs
Normal file
192
sunscreen_tfhe/src/entities/lwe_ciphertext.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
use num::Zero;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize,
|
||||
macros::{impl_binary_op, impl_unary_op},
|
||||
LweDef, LweDimension, Torus, TorusOps,
|
||||
};
|
||||
|
||||
dst! {
|
||||
/// An LWE ciphertext.
|
||||
LweCiphertext,
|
||||
LweCiphertextRef,
|
||||
Torus,
|
||||
(Clone, Debug,Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
dst_iter! { LweCiphertextIterator, LweCiphertextIteratorMut, Torus, LweCiphertextRef, (TorusOps,) }
|
||||
|
||||
impl_binary_op!(Add, LweCiphertext, (TorusOps,));
|
||||
impl_binary_op!(Sub, LweCiphertext, (TorusOps,));
|
||||
impl_unary_op!(Neg, LweCiphertext);
|
||||
|
||||
impl<S> OverlaySize for LweCiphertextRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = LweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
t.0 + 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> LweCiphertext<S> {
|
||||
/// Create a new LWE ciphertext with all coefficients set to zero.
|
||||
pub fn new(params: &LweDef) -> Self {
|
||||
Self::zero(params)
|
||||
}
|
||||
|
||||
/// Create a new LWE ciphertext with all coefficients set to zero.
|
||||
pub fn zero(params: &LweDef) -> Self {
|
||||
let data = vec![Torus::zero(); LweCiphertextRef::<S>::size(params.dim)];
|
||||
|
||||
Self { data }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> LweCiphertextRef<S> {
|
||||
fn split_at_idx(&self, params: &LweDef) -> usize {
|
||||
params.dim.0
|
||||
}
|
||||
|
||||
/// Returns a reference to the mask A and body B in an LWE ciphertext.
|
||||
pub fn a_b(&self, params: &LweDef) -> (&[Torus<S>], &Torus<S>) {
|
||||
let (a, b) = self.data.split_at(self.split_at_idx(params));
|
||||
|
||||
(a, &b[0])
|
||||
}
|
||||
|
||||
/// Returns a reference to the mask A in an LWE ciphertext.
|
||||
pub fn a(&self, params: &LweDef) -> &[Torus<S>] {
|
||||
let (a, _) = self.a_b(params);
|
||||
|
||||
a
|
||||
}
|
||||
|
||||
/// Returns a reference to the body B in an LWE ciphertext.
|
||||
pub fn b(&self, params: &LweDef) -> &Torus<S> {
|
||||
let (_, b) = self.a_b(params);
|
||||
|
||||
b
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the mask A and body B in an LWE ciphertext.
|
||||
pub fn a_b_mut(&mut self, params: &LweDef) -> (&mut [Torus<S>], &mut Torus<S>) {
|
||||
let (a, b) = self.data.split_at_mut(self.split_at_idx(params));
|
||||
|
||||
(a, &mut b[0])
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the mask A in an LWE ciphertext.
|
||||
pub fn a_mut(&mut self, params: &LweDef) -> &mut [Torus<S>] {
|
||||
let (a, _) = self.a_b_mut(params);
|
||||
|
||||
a
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the body B in an LWE ciphertext.
|
||||
pub fn b_mut(&mut self, params: &LweDef) -> &mut Torus<S> {
|
||||
let (_, b) = self.a_b_mut(params);
|
||||
|
||||
b
|
||||
}
|
||||
|
||||
/// Asserts that the LWE ciphertext is valid for a given LWE dimension.
|
||||
#[inline(always)]
|
||||
pub(crate) fn assert_valid(&self, params: &LweDef) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
LweCiphertextRef::<S>::size(params.dim)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{high_level::*, PlaintextBits, LWE_512_80};
|
||||
use proptest::prelude::*;
|
||||
|
||||
// Test that the negation of a ciphertext is the same as the negation of the
|
||||
// plaintext.
|
||||
proptest! {
|
||||
#[test]
|
||||
fn negation_homomorphism(a in any::<u64>()) {
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let params = LWE_512_80;
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits);
|
||||
|
||||
let a_enc_neg = -a_enc;
|
||||
|
||||
prop_assert_eq!(encryption::decrypt_lwe(&a_enc_neg, &sk, ¶ms, bits), a.wrapping_neg() % (0x1 << bits.0 as u64));
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the addition of ciphertexts is the same as the addition of the
|
||||
// plaintexts.
|
||||
proptest! {
|
||||
#[test]
|
||||
fn additive_homomorphism(a in any::<u64>(), b in any::<u64>()) {
|
||||
let params = LWE_512_80;
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits);
|
||||
let b_enc = encryption::encrypt_lwe_secret(b, &sk, ¶ms, bits);
|
||||
|
||||
let c_enc = a_enc + b_enc;
|
||||
|
||||
prop_assert_eq!(encryption::decrypt_lwe(&c_enc, &sk, ¶ms, bits), a.wrapping_add(b) % (0x1 << bits.0 as u64));
|
||||
}
|
||||
}
|
||||
|
||||
// Test that the subtraction of ciphertexts is the same as the subtraction
|
||||
// of the plaintexts.
|
||||
proptest! {
|
||||
#[test]
|
||||
fn subtraction_homomorphism(a in any::<u64>(), b in any::<u64>()) {
|
||||
let params = LWE_512_80;
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits);
|
||||
let b_enc = encryption::encrypt_lwe_secret(b, &sk, ¶ms, bits);
|
||||
|
||||
let c_enc = a_enc - b_enc;
|
||||
|
||||
prop_assert_eq!(encryption::decrypt_lwe(&c_enc, &sk, ¶ms, bits), a.wrapping_sub(b) % (0x1 << bits.0 as u64));
|
||||
}
|
||||
}
|
||||
|
||||
// Testing that the addition of a ciphertext and a negated ciphertext is the
|
||||
// same as the subtraction of the ciphertexts.
|
||||
proptest! {
|
||||
#[test]
|
||||
fn add_negative_is_subtraction(a in any::<u64>(), b in any::<u64>()) {
|
||||
let params = LWE_512_80;
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let a_enc = encryption::encrypt_lwe_secret(a, &sk, ¶ms, bits);
|
||||
let b_enc = encryption::encrypt_lwe_secret(b, &sk, ¶ms, bits);
|
||||
|
||||
let c_enc_by_add_neg = a_enc.as_ref() + (-(b_enc.as_ref())).as_ref();
|
||||
let c_enc_by_sub = a_enc.as_ref() - b_enc.as_ref();
|
||||
|
||||
// Test that the a values are the same
|
||||
for (a_enc_by_add_neg_i, a_enc_by_sub_i) in c_enc_by_add_neg.a(¶ms).iter().zip(c_enc_by_sub.a(¶ms).iter()) {
|
||||
assert_eq!(a_enc_by_add_neg_i, a_enc_by_sub_i);
|
||||
}
|
||||
|
||||
// Test that the b values are the same
|
||||
assert_eq!(c_enc_by_add_neg.b(¶ms), c_enc_by_sub.b(¶ms));
|
||||
}
|
||||
}
|
||||
}
|
||||
49
sunscreen_tfhe/src/entities/lwe_ciphertext_list.rs
Normal file
49
sunscreen_tfhe/src/entities/lwe_ciphertext_list.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{dst::OverlaySize, LweDef, LweDimension, Torus, TorusOps};
|
||||
|
||||
use super::{LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef};
|
||||
|
||||
dst! {
|
||||
/// A list of LWE ciphertexts. Used during
|
||||
/// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
|
||||
LweCiphertextList,
|
||||
LweCiphertextListRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for LweCiphertextListRef<S> {
|
||||
type Inputs = (LweDimension, usize);
|
||||
|
||||
#[inline(always)]
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
LweCiphertextRef::<S>::size(t.0) * t.1
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> LweCiphertextList<S> {
|
||||
/// Create a new zero [LweCiphertextList] with the given parameters.
|
||||
///
|
||||
/// This data structure represents is a list of LWE ciphertexts, used for
|
||||
/// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
|
||||
pub fn new(lwe: &LweDef, count: usize) -> Self {
|
||||
Self {
|
||||
data: vec![Torus::zero(); LweCiphertextListRef::<S>::size((lwe.dim, count))],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> LweCiphertextListRef<S> {
|
||||
/// Iterate over the LWE ciphertexts in the list.
|
||||
pub fn ciphertexts(&self, lwe: &LweDef) -> LweCiphertextIterator<S> {
|
||||
LweCiphertextIterator::new(self.as_slice(), LweCiphertextRef::<S>::size(lwe.dim))
|
||||
}
|
||||
|
||||
/// Iterate over the LWE ciphertexts in the list mutably.
|
||||
pub fn ciphertexts_mut(&mut self, lwe: &LweDef) -> LweCiphertextIteratorMut<S> {
|
||||
LweCiphertextIteratorMut::new(self.as_mut_slice(), LweCiphertextRef::<S>::size(lwe.dim))
|
||||
}
|
||||
}
|
||||
170
sunscreen_tfhe/src/entities/lwe_keyswitch_key.rs
Normal file
170
sunscreen_tfhe/src/entities/lwe_keyswitch_key.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
use num::Zero;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{LevCiphertextIterator, LevCiphertextIteratorMut, LevCiphertextRef};
|
||||
|
||||
dst! {
|
||||
/// A LWE keyswitch key used to switch a ciphertext from one key to another.
|
||||
/// See [`module`](crate::ops::keyswitch) documentation for more details.
|
||||
LweKeyswitchKey,
|
||||
LweKeyswitchKeyRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps,)
|
||||
}
|
||||
|
||||
impl<S> OverlaySize for LweKeyswitchKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
// Old LWE dimension, new LWE dimension, radix count
|
||||
type Inputs = (LweDimension, LweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
// Number of rows should be equal to the number of elements in the original key
|
||||
let num_rows = t.0 .0;
|
||||
|
||||
// Each row is made up of encryptions under the new key
|
||||
let len_row = LevCiphertextRef::<S>::size((t.1, t.2));
|
||||
|
||||
// Encrypt the secret key s_i in each row
|
||||
len_row * (num_rows)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LweKeyswitchKey<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Creates a new LWE keyswitch key. This enables switching to a new key as
|
||||
/// well as switching from the `original_params` that define the first key
|
||||
/// to the `new_params` that define the second key.
|
||||
pub fn new(original_params: &LweDef, new_params: &LweDef, radix: &RadixDecomposition) -> Self {
|
||||
let elems =
|
||||
LweKeyswitchKeyRef::<S>::size((original_params.dim, new_params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); elems],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LweKeyswitchKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns an iterator over the rows of the LWE keyswitch key, which are
|
||||
/// [`LevCiphertext`](crate::entities::LevCiphertext)s.
|
||||
pub fn rows(
|
||||
&self,
|
||||
new_params: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> LevCiphertextIterator<S> {
|
||||
let stride = LevCiphertextRef::<S>::size((new_params.dim, radix.count));
|
||||
|
||||
LevCiphertextIterator::new(&self.data, stride)
|
||||
}
|
||||
|
||||
/// Returns a mutable iterator over the rows of the LWE keyswitch key, which are
|
||||
/// [`LevCiphertext`](crate::entities::LevCiphertext)s.
|
||||
pub fn rows_mut(
|
||||
&mut self,
|
||||
new_params: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> LevCiphertextIteratorMut<S> {
|
||||
let stride = LevCiphertextRef::<S>::size((new_params.dim, radix.count));
|
||||
|
||||
LevCiphertextIteratorMut::new(&mut self.data, stride)
|
||||
}
|
||||
|
||||
/// Asserts that the keyswitch key is valid for the given parameters.
|
||||
#[inline(always)]
|
||||
pub(crate) fn assert_valid(
|
||||
&self,
|
||||
original_params: &LweDef,
|
||||
new_params: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
LweKeyswitchKeyRef::<S>::size((original_params.dim, new_params.dim, radix.count))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
entities::{LweCiphertext, LweKeyswitchKey},
|
||||
high_level::*,
|
||||
high_level::{TEST_LWE_DEF_1, TEST_LWE_DEF_2, TEST_RADIX},
|
||||
ops::keyswitch::{
|
||||
lwe_keyswitch::keyswitch_lwe_to_lwe, lwe_keyswitch_key::generate_keyswitch_key_lwe,
|
||||
},
|
||||
PlaintextBits,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn keyswitch_lwe() {
|
||||
let bits = PlaintextBits(4);
|
||||
let from_lwe = TEST_LWE_DEF_1;
|
||||
let to_lwe = TEST_LWE_DEF_2;
|
||||
|
||||
for _ in 0..50 {
|
||||
let original_sk = keygen::generate_binary_lwe_sk(&from_lwe);
|
||||
let new_sk = keygen::generate_binary_lwe_sk(&to_lwe);
|
||||
|
||||
let mut ksk = LweKeyswitchKey::<u64>::new(&from_lwe, &to_lwe, &TEST_RADIX);
|
||||
generate_keyswitch_key_lwe(&mut ksk, &original_sk, &new_sk, &to_lwe, &TEST_RADIX);
|
||||
|
||||
let msg = thread_rng().next_u64() % (1 << bits.0);
|
||||
|
||||
let original_ct = original_sk.encrypt(msg, &from_lwe, bits).0;
|
||||
|
||||
let mut new_ct = LweCiphertext::new(&to_lwe);
|
||||
keyswitch_lwe_to_lwe(
|
||||
&mut new_ct,
|
||||
&original_ct,
|
||||
&ksk,
|
||||
&from_lwe,
|
||||
&to_lwe,
|
||||
&TEST_RADIX,
|
||||
);
|
||||
|
||||
let new_decrypted = new_sk.decrypt(&new_ct, &to_lwe, bits);
|
||||
|
||||
assert_eq!(new_decrypted, msg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lwe_keyswitch_keygen() {
|
||||
let from_lwe = TEST_LWE_DEF_1;
|
||||
let to_lwe = TEST_LWE_DEF_2;
|
||||
|
||||
for _ in 0..10 {
|
||||
let sk_1 = keygen::generate_binary_lwe_sk(&from_lwe);
|
||||
let sk_2 = keygen::generate_binary_lwe_sk(&to_lwe);
|
||||
|
||||
let mut ksk = LweKeyswitchKey::<u64>::new(&from_lwe, &to_lwe, &TEST_RADIX);
|
||||
generate_keyswitch_key_lwe(&mut ksk, &sk_1, &sk_2, &to_lwe, &TEST_RADIX);
|
||||
|
||||
for (i, r) in ksk.rows(&to_lwe, &TEST_RADIX).enumerate() {
|
||||
for (j, l) in r.lwe_ciphertexts(&to_lwe).enumerate() {
|
||||
let decomp = (j + 1) * TEST_RADIX.radix_log.0;
|
||||
|
||||
let res = sk_2.decrypt(l, &to_lwe, PlaintextBits(decomp as u32));
|
||||
|
||||
assert_eq!(res, sk_1.s()[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
158
sunscreen_tfhe/src/entities/lwe_public_key.rs
Normal file
158
sunscreen_tfhe/src/entities/lwe_public_key.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use num::Zero;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize,
|
||||
ops::encryption::encode_and_encrypt_lwe_ciphertext,
|
||||
rand::{binary, normal_torus},
|
||||
LweDef, LweDimension, PlaintextBits, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
LweCiphertext, LweCiphertextIterator, LweCiphertextIteratorMut, LweCiphertextRef,
|
||||
LweSecretKeyRef,
|
||||
};
|
||||
|
||||
/// Randomness used to encrypt a message with a public key.
|
||||
#[derive(Debug)]
|
||||
pub struct TlwePublicEncRandomness<S: TorusOps> {
|
||||
/// The binary selectors of the encryptions of zero in the public key.
|
||||
pub r: Vec<S>,
|
||||
|
||||
/// The gaussian noise added to make the LWE problem.
|
||||
pub e: LweCiphertext<S>,
|
||||
}
|
||||
|
||||
dst! {
|
||||
/// An LWE public key.
|
||||
LwePublicKey,
|
||||
LwePublicKeyRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S> OverlaySize for LwePublicKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = LweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
LweCiphertextRef::<S>::size(t) * t.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LwePublicKey<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Generate an LWE public key from a given secret key. This is done by
|
||||
/// encrypting the LWE dimension number of zeros under the secret key, and
|
||||
/// then using the resulting ciphertext as the public key.
|
||||
pub fn generate(sk: &LweSecretKeyRef<S>, params: &LweDef) -> Self {
|
||||
let mut pk = LwePublicKey {
|
||||
data: vec![Torus::zero(); LwePublicKeyRef::<S>::size(params.dim)],
|
||||
};
|
||||
let enc_zeros = pk.enc_zeros_mut(params);
|
||||
|
||||
for z in enc_zeros {
|
||||
encode_and_encrypt_lwe_ciphertext(z, sk, <S as Zero>::zero(), params, PlaintextBits(1));
|
||||
}
|
||||
|
||||
pk
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LwePublicKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Get the public key data as an iterator.
|
||||
pub fn enc_zeros(&self, params: &LweDef) -> LweCiphertextIterator<S> {
|
||||
LweCiphertextIterator::new(&self.data, LweCiphertextRef::<S>::size(params.dim))
|
||||
}
|
||||
|
||||
/// Get the public key data as a mutable iterator.
|
||||
pub fn enc_zeros_mut(&mut self, params: &LweDef) -> LweCiphertextIteratorMut<S> {
|
||||
LweCiphertextIteratorMut::new(&mut self.data, LweCiphertextRef::<S>::size(params.dim))
|
||||
}
|
||||
|
||||
/// Encrypt a message as an LWE ciphertext using a public key, returning the
|
||||
/// encrypted message and the randomness used.
|
||||
pub fn encrypt(
|
||||
&self,
|
||||
msg: S,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> (LweCiphertext<S>, TlwePublicEncRandomness<S>) {
|
||||
let msg = Torus::<S>::encode(msg, plaintext_bits);
|
||||
let lwe_dimension = params.dim.0;
|
||||
|
||||
let mut acc = LweCiphertext::zero(params);
|
||||
let (acc_a, acc_b) = acc.a_b_mut(params);
|
||||
|
||||
let mut r_noise = vec![];
|
||||
let mut e = LweCiphertext::zero(params);
|
||||
let (e_a, e_b) = e.a_b_mut(params);
|
||||
|
||||
for z in self.enc_zeros(params) {
|
||||
let (a, b) = z.a_b(params);
|
||||
let r = binary::<S>();
|
||||
r_noise.push(r);
|
||||
|
||||
for i in 0..lwe_dimension {
|
||||
acc_a[i] += a[i] * r;
|
||||
}
|
||||
|
||||
*acc_b += *b * r;
|
||||
}
|
||||
|
||||
for i in 0..lwe_dimension {
|
||||
let a_noise = normal_torus(params.std);
|
||||
e_a[i] = a_noise;
|
||||
acc_a[i] += a_noise;
|
||||
}
|
||||
|
||||
*acc_b += msg;
|
||||
*e_b = normal_torus(params.std);
|
||||
*acc_b += *e_b;
|
||||
|
||||
let noise = TlwePublicEncRandomness { r: r_noise, e };
|
||||
|
||||
(acc, noise)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
high_level::{encryption, keygen, TEST_LWE_DEF_1},
|
||||
PlaintextBits,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn public_key_is_zeros() {
|
||||
let params = TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let pk = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
for ct in pk.enc_zeros(¶ms) {
|
||||
let pt = encryption::decrypt_lwe(ct, &sk, ¶ms, PlaintextBits(1));
|
||||
assert_eq!(pt, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_public_key_encrypt() {
|
||||
let params = TEST_LWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let pk = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
let ct = encryption::encrypt_lwe(5, &pk, ¶ms, bits);
|
||||
assert_eq!(encryption::decrypt_lwe(&ct, &sk, ¶ms, bits), 5);
|
||||
}
|
||||
}
|
||||
337
sunscreen_tfhe/src/entities/lwe_secret_key.rs
Normal file
337
sunscreen_tfhe/src/entities/lwe_secret_key.rs
Normal file
@@ -0,0 +1,337 @@
|
||||
use num::Zero;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
macros::{impl_binary_op, impl_unary_op},
|
||||
ops::encryption::encode_and_encrypt_lwe_ciphertext,
|
||||
rand::{binary, uniform_torus},
|
||||
LweDef, LweDimension, PlaintextBits, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{LweCiphertext, LweCiphertextRef};
|
||||
|
||||
dst! {
|
||||
/// An LWE secret key.
|
||||
LweSecretKey,
|
||||
LweSecretKeyRef,
|
||||
NoWrapper,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
()
|
||||
}
|
||||
|
||||
impl_binary_op!(Add, LweSecretKey, (TorusOps,));
|
||||
impl_binary_op!(Sub, LweSecretKey, (TorusOps,));
|
||||
impl_unary_op!(Neg, LweSecretKey);
|
||||
|
||||
impl<S> OverlaySize for LweSecretKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Inputs = LweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
t.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LweSecretKey<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
fn generate(params: &LweDef, torus_element_generator: fn() -> S) -> Self {
|
||||
let len = LweSecretKeyRef::<S>::size(params.dim);
|
||||
|
||||
LweSecretKey {
|
||||
data: (0..len)
|
||||
.map(|_| torus_element_generator())
|
||||
.collect::<Vec<_>>(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random binary LWE secret key
|
||||
pub fn generate_binary(params: &LweDef) -> Self {
|
||||
Self::generate(params, binary)
|
||||
}
|
||||
|
||||
/// Generate a secret key with uniformly random coefficients. This can be
|
||||
/// used when performing threshold decryption, which needs random secret
|
||||
/// keys that are uniform over the entire ciphertext modulus. Uniform
|
||||
/// secret keys are also valid keys for encryption/decryption but are not
|
||||
/// widely used.
|
||||
pub fn generate_uniform(params: &LweDef) -> Self {
|
||||
Self::generate(params, || uniform_torus::<S>().inner())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LweSecretKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Create an LWE ciphertext from a given message with a private key. The
|
||||
/// message should be in the plaintext space, and will be encoded onto the
|
||||
/// Torus automatically.
|
||||
pub fn encrypt(
|
||||
&self,
|
||||
msg: S,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> (LweCiphertext<S>, Torus<S>) {
|
||||
let mut ct = LweCiphertext::<S>::zero(params);
|
||||
|
||||
let e = encode_and_encrypt_lwe_ciphertext(&mut ct, self, msg, params, plaintext_bits);
|
||||
|
||||
(ct, e)
|
||||
}
|
||||
|
||||
/// Decrypts the given ciphertext, returning the message. The message will
|
||||
/// not be decoded into the plaintext space; the caller is responsible for
|
||||
/// performing operations like shifting by delta and rounding. See
|
||||
/// [Self::decrypt] for a function that performs the decoding automatically.
|
||||
pub fn decrypt_without_decode(&self, ct: &LweCiphertextRef<S>, params: &LweDef) -> Torus<S> {
|
||||
ct.assert_valid(params);
|
||||
|
||||
let (a, b) = ct.a_b(params);
|
||||
|
||||
let mut dot = Torus::<S>::zero();
|
||||
|
||||
for (a_i, d_i) in a.iter().zip(self.data.iter()) {
|
||||
dot += a_i * d_i
|
||||
}
|
||||
|
||||
b - dot
|
||||
}
|
||||
|
||||
/// Decrypts and decodes a ciphertext, returning the message. The message
|
||||
/// will be decoded into the plaintext space. See
|
||||
/// [Self::decrypt_without_decode] for a function that does not perform the
|
||||
/// decoding.
|
||||
pub fn decrypt(
|
||||
&self,
|
||||
ct: &LweCiphertextRef<S>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> S {
|
||||
let msg = self.decrypt_without_decode(ct, params);
|
||||
|
||||
msg.decode(plaintext_bits)
|
||||
}
|
||||
|
||||
/// Asserts that a given secret key is valid for a given LWE dimension.
|
||||
pub fn assert_valid(&self, params: &LweDef) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
LweSecretKeyRef::<S>::size(params.dim)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> LweSecretKeyRef<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Returns the secret key data as a slice.
|
||||
pub fn s(&self) -> &[S] {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::high_level::*;
|
||||
use num::traits::{WrappingAdd, WrappingNeg, WrappingSub};
|
||||
|
||||
// Addition
|
||||
|
||||
#[test]
|
||||
fn add_secret_keys() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk3_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk + sk2;
|
||||
|
||||
assert_eq!(sk3_expected, sk3.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_assign_secret_keys() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let mut sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk2_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
sk2 += sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_secret_key_refs() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk3_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.as_ref() + sk2.as_ref();
|
||||
|
||||
assert_eq!(sk3_expected, sk3.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_add_secret_keys() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk3_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.wrapping_add(&sk2);
|
||||
|
||||
assert_eq!(sk3_expected, sk3.s())
|
||||
}
|
||||
|
||||
// Subtraction
|
||||
|
||||
#[test]
|
||||
fn sub_secret_keys() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk3_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk - sk2;
|
||||
|
||||
assert_eq!(sk3_expected, sk3.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_assign_secret_keys() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let mut sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk2_expected = sk2
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk.s().iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
sk2 -= sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sub_secret_key_refs() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk3_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.as_ref() - sk2.as_ref();
|
||||
|
||||
assert_eq!(sk3_expected, sk3.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_sub_secret_keys() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
let sk2 = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk3_expected = sk
|
||||
.s()
|
||||
.iter()
|
||||
.zip(sk2.s().iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let sk3 = sk.wrapping_sub(&sk2);
|
||||
|
||||
assert_eq!(sk3_expected, sk3.s())
|
||||
}
|
||||
|
||||
// Negation
|
||||
|
||||
#[test]
|
||||
fn neg_secret_key() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2 = -sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn neg_secret_key_ref() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2 = -sk.as_ref();
|
||||
|
||||
assert_eq!(sk2_expected, sk2.s())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wrapping_neg_secret_key() {
|
||||
let params = &TEST_LWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(params);
|
||||
|
||||
let sk2_expected = sk.s().iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2 = sk.wrapping_neg();
|
||||
|
||||
assert_eq!(sk2_expected, sk2.s())
|
||||
}
|
||||
}
|
||||
71
sunscreen_tfhe/src/entities/mod.rs
Normal file
71
sunscreen_tfhe/src/entities/mod.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
mod public_functional_keyswitch_key;
|
||||
pub use public_functional_keyswitch_key::*;
|
||||
|
||||
mod blind_rotation_shift;
|
||||
pub use blind_rotation_shift::*;
|
||||
|
||||
mod lwe_ciphertext_list;
|
||||
pub use lwe_ciphertext_list::*;
|
||||
|
||||
mod private_functional_keyswitch_key;
|
||||
pub use private_functional_keyswitch_key::*;
|
||||
|
||||
mod circuit_bootstrapping_private_keyswitch_keys;
|
||||
pub use circuit_bootstrapping_private_keyswitch_keys::*;
|
||||
|
||||
mod bootstrap_key;
|
||||
pub use bootstrap_key::*;
|
||||
|
||||
mod univariate_lookup_table;
|
||||
pub use univariate_lookup_table::*;
|
||||
|
||||
mod bivariate_lookup_table;
|
||||
pub use bivariate_lookup_table::*;
|
||||
|
||||
mod glwe_secret_key;
|
||||
pub use glwe_secret_key::*;
|
||||
|
||||
mod glwe_ciphertext;
|
||||
pub use glwe_ciphertext::*;
|
||||
|
||||
mod glwe_ciphertext_fft;
|
||||
pub use glwe_ciphertext_fft::*;
|
||||
|
||||
mod lwe_ciphertext;
|
||||
pub use lwe_ciphertext::*;
|
||||
|
||||
mod lwe_secret_key;
|
||||
pub use lwe_secret_key::*;
|
||||
|
||||
mod lwe_public_key;
|
||||
pub use lwe_public_key::*;
|
||||
|
||||
mod glev_ciphertext;
|
||||
pub use glev_ciphertext::*;
|
||||
|
||||
mod glev_ciphertext_fft;
|
||||
pub use glev_ciphertext_fft::*;
|
||||
|
||||
mod ggsw_ciphertext;
|
||||
pub use ggsw_ciphertext::*;
|
||||
|
||||
mod ggsw_ciphertext_fft;
|
||||
pub use ggsw_ciphertext_fft::*;
|
||||
|
||||
mod lev_ciphertext;
|
||||
pub use lev_ciphertext::*;
|
||||
|
||||
mod lwe_keyswitch_key;
|
||||
pub use lwe_keyswitch_key::*;
|
||||
|
||||
mod glwe_keyswitch_key;
|
||||
pub use glwe_keyswitch_key::*;
|
||||
|
||||
mod polynomial;
|
||||
pub use polynomial::*;
|
||||
|
||||
mod polynomial_fft;
|
||||
pub use polynomial_fft::*;
|
||||
|
||||
mod polynomial_list;
|
||||
pub use polynomial_list::*;
|
||||
588
sunscreen_tfhe/src/entities/polynomial.rs
Normal file
588
sunscreen_tfhe/src/entities/polynomial.rs
Normal file
@@ -0,0 +1,588 @@
|
||||
use std::{
|
||||
num::Wrapping,
|
||||
ops::{Add, AddAssign, Mul, Sub, SubAssign},
|
||||
};
|
||||
|
||||
use num::{Complex, Zero};
|
||||
|
||||
use crate::{
|
||||
dst::{FromMutSlice, FromSlice, NoWrapper, OverlaySize},
|
||||
fft::negacyclic::get_fft,
|
||||
polynomial::{polynomial_add_assign, polynomial_external_mad, polynomial_sub_assign},
|
||||
scratch::allocate_scratch,
|
||||
FrequencyTransform, PolynomialDegree, ReinterpretAsSigned, ToF64, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::PolynomialFftRef;
|
||||
|
||||
dst! {
|
||||
/// A type representing a polynomial.
|
||||
Polynomial,
|
||||
PolynomialRef,
|
||||
NoWrapper,
|
||||
(Debug, Clone, PartialEq, Eq,),
|
||||
()
|
||||
}
|
||||
dst_iter! { PolynomialIterator, PolynomialIteratorMut, NoWrapper, PolynomialRef, () }
|
||||
|
||||
impl<T> OverlaySize for PolynomialRef<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
type Inputs = PolynomialDegree;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
t.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Polynomial<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// Create a new polynomial from a slice of coefficients.
|
||||
pub fn new(data: &[T]) -> Polynomial<T> {
|
||||
Polynomial {
|
||||
data: data.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new polynomial filled with zeros of a specified length.
|
||||
pub fn zero(len: usize) -> Polynomial<T>
|
||||
where
|
||||
T: Zero,
|
||||
{
|
||||
Polynomial {
|
||||
data: vec![T::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FromIterator<T> for Polynomial<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
|
||||
Self {
|
||||
data: iter.into_iter().collect::<Vec<_>>(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PolynomialRef<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// Returns the coefficients of the polynomial in ascending order.
|
||||
pub fn coeffs(&self) -> &[T] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Returns the mutable coefficients of the polynomial in ascending order.
|
||||
pub fn coeffs_mut(&mut self) -> &mut [T] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
/// Apply a function to each coefficient of the polynomial and return a new
|
||||
/// polynomial.
|
||||
pub fn map<F, U>(&self, f: F) -> Polynomial<U>
|
||||
where
|
||||
F: Fn(&T) -> U,
|
||||
U: Clone,
|
||||
{
|
||||
Polynomial {
|
||||
data: self.data.iter().map(f).collect::<Vec<_>>(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps this polynomial using f into the dst [`PolynomialRef`].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `dst.len() != self.len()`
|
||||
pub fn map_into<F, U>(&self, dst: &mut PolynomialRef<U>, f: F)
|
||||
where
|
||||
F: Fn(&T) -> U,
|
||||
U: Clone,
|
||||
{
|
||||
assert_eq!(dst.len(), self.len());
|
||||
|
||||
dst.coeffs_mut()
|
||||
.iter_mut()
|
||||
.zip(self.coeffs().iter())
|
||||
.for_each(|(d, s)| *d = f(s));
|
||||
}
|
||||
|
||||
/// Returns the number of coefficients in the polynomial.
|
||||
#[inline]
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Returns true if the polynomial has no coefficients.
|
||||
#[inline]
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PolynomialRef<T>
|
||||
where
|
||||
T: TorusOps,
|
||||
{
|
||||
/// Reinterpret the this polynomial as a polynomial of torus elements.
|
||||
pub fn as_torus(&self) -> &PolynomialRef<Torus<T>> {
|
||||
let as_torus = bytemuck::cast_slice(&self.data);
|
||||
|
||||
PolynomialRef::from_slice(as_torus)
|
||||
}
|
||||
|
||||
/// Reinterpret this polynomial of integers as having wrapping semantics.
|
||||
pub fn as_wrapping(&self) -> &PolynomialRef<Wrapping<T>> {
|
||||
let as_wrapping = bytemuck::cast_slice(&self.data);
|
||||
|
||||
PolynomialRef::from_slice(as_wrapping)
|
||||
}
|
||||
|
||||
/// Reinterpret this polynomial of integers as having wrapping semantics.
|
||||
pub fn as_wrapping_mut(&mut self) -> &mut PolynomialRef<Wrapping<T>> {
|
||||
let as_wrapping = bytemuck::cast_slice_mut(&mut self.data);
|
||||
|
||||
PolynomialRef::from_mut_slice(as_wrapping)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PolynomialRef<Torus<T>>
|
||||
where
|
||||
T: TorusOps,
|
||||
{
|
||||
/**
|
||||
* Multiply by a monomial X^-degree, returning a new Polynomial.
|
||||
*/
|
||||
pub fn mul_by_negative_monomial_negacyclic(&mut self, degree: usize) {
|
||||
let len = self.len();
|
||||
|
||||
// The behavior of the rotation is the same for every degree + k*2N for
|
||||
// k >= 0.
|
||||
let degree = degree % (2 * len);
|
||||
|
||||
// If the degree is 0 (or a multiple of 2N), the polynomial is unchanged.
|
||||
if degree == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the degree is N (or of the form N + k*2N), the polynomial is negated.
|
||||
if degree == len {
|
||||
self.data
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = num::traits::WrappingNeg::wrapping_neg(x));
|
||||
return;
|
||||
}
|
||||
|
||||
let shift = degree % len;
|
||||
self.data.rotate_left(shift);
|
||||
|
||||
let negate_segment = if degree < len {
|
||||
(len - shift)..len
|
||||
} else {
|
||||
0..(len - shift)
|
||||
};
|
||||
|
||||
for i in negate_segment {
|
||||
self.data[i] = num::traits::WrappingNeg::wrapping_neg(&self.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiply by a monomial X^degree, returning a new Polynomial.
|
||||
*/
|
||||
pub fn mul_by_positive_monomial_negacyclic(&mut self, degree: usize) {
|
||||
let len = self.len();
|
||||
|
||||
// The behavior of the rotation is the same for every degree + k*2N for
|
||||
// k >= 0.
|
||||
let degree = degree % (2 * len);
|
||||
|
||||
// If the degree is 0 (or a multiple of 2N), the polynomial is unchanged.
|
||||
if degree == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the degree is N (or of the form N + k*2N), the polynomial is negated.
|
||||
if degree == len {
|
||||
self.data
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = num::traits::WrappingNeg::wrapping_neg(x));
|
||||
return;
|
||||
}
|
||||
|
||||
let shift = degree % len;
|
||||
self.data.rotate_right(shift);
|
||||
|
||||
let negate_segment = if degree < len { 0..degree } else { shift..len };
|
||||
|
||||
for i in negate_segment {
|
||||
self.data[i] = num::traits::WrappingNeg::wrapping_neg(&self.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiply by a monomial X^degree, returning a new Polynomial. The degree
|
||||
* can be either positive or negative.
|
||||
*/
|
||||
pub fn mul_by_monomial_negacyclic(&mut self, degree: isize) {
|
||||
if degree < 0 {
|
||||
self.mul_by_negative_monomial_negacyclic(degree.unsigned_abs());
|
||||
} else {
|
||||
self.mul_by_positive_monomial_negacyclic(degree as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, U> PolynomialRef<T>
|
||||
where
|
||||
U: ToF64,
|
||||
T: Clone + Copy + ReinterpretAsSigned<Output = U>,
|
||||
{
|
||||
/// Compute the FFT of the polynomial.
|
||||
pub fn fft(&self, out: &mut PolynomialFftRef<Complex<f64>>) {
|
||||
assert!(self.len().is_power_of_two());
|
||||
assert_eq!(self.len(), out.len() * 2);
|
||||
|
||||
let mut self_f64 = allocate_scratch::<f64>(self.len());
|
||||
let self_f64 = self_f64.as_mut_slice();
|
||||
|
||||
for (o, i) in self_f64.iter_mut().zip(self.coeffs().iter()) {
|
||||
// Reinterperet [0, 1) to [-q/2, q/2) to slightly increase
|
||||
// precision.
|
||||
*o = (*i).reinterpret_as_signed().to_f64();
|
||||
}
|
||||
|
||||
let log_n = self.len().ilog2() as usize;
|
||||
|
||||
let fft = get_fft(log_n);
|
||||
fft.forward(self_f64, out.as_mut_slice());
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Add<Polynomial<S>> for Polynomial<S>
|
||||
where
|
||||
S: Add<S, Output = S> + Copy,
|
||||
{
|
||||
type Output = Polynomial<S>;
|
||||
|
||||
fn add(self, rhs: Polynomial<S>) -> Self::Output {
|
||||
self.as_ref().add(rhs.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Add<&PolynomialRef<S>> for &PolynomialRef<S>
|
||||
where
|
||||
S: Add<S, Output = S> + Copy,
|
||||
{
|
||||
type Output = Polynomial<S>;
|
||||
|
||||
fn add(self, rhs: &PolynomialRef<S>) -> Self::Output {
|
||||
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
|
||||
|
||||
let coeffs = self
|
||||
.coeffs()
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(rhs.coeffs().as_ref().iter())
|
||||
.map(|(a, b)| *a + *b)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Polynomial { data: coeffs }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AddAssign<&PolynomialRef<S>> for PolynomialRef<S>
|
||||
where
|
||||
S: AddAssign<S> + Copy,
|
||||
{
|
||||
fn add_assign(&mut self, rhs: &PolynomialRef<S>) {
|
||||
polynomial_add_assign(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Sub<Polynomial<S>> for Polynomial<S>
|
||||
where
|
||||
S: Sub<S, Output = S> + Copy,
|
||||
{
|
||||
type Output = Polynomial<S>;
|
||||
|
||||
fn sub(self, rhs: Polynomial<S>) -> Self::Output {
|
||||
self.as_ref().sub(rhs.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Sub<&PolynomialRef<S>> for &PolynomialRef<S>
|
||||
where
|
||||
S: Sub<S, Output = S> + Copy,
|
||||
{
|
||||
type Output = Polynomial<S>;
|
||||
|
||||
fn sub(self, rhs: &PolynomialRef<S>) -> Self::Output {
|
||||
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
|
||||
|
||||
let coeffs = self
|
||||
.coeffs()
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(rhs.coeffs().as_ref().iter())
|
||||
.map(|(a, b)| *a - *b)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Polynomial { data: coeffs }
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> SubAssign<&PolynomialRef<S>> for PolynomialRef<S>
|
||||
where
|
||||
S: SubAssign<S> + Copy,
|
||||
{
|
||||
fn sub_assign(&mut self, rhs: &PolynomialRef<S>) {
|
||||
polynomial_sub_assign(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Mul<&PolynomialRef<S>> for &PolynomialRef<Torus<S>>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Output = Polynomial<Torus<S>>;
|
||||
|
||||
/// External product of T\[X\]/f * Z\[X\]/f
|
||||
/// TODO: use NTT to do in nlog(n) time.
|
||||
fn mul(self, rhs: &PolynomialRef<S>) -> Self::Output {
|
||||
assert_eq!(rhs.len(), self.len());
|
||||
|
||||
let mut c = Polynomial {
|
||||
data: vec![Torus::zero(); rhs.len()],
|
||||
};
|
||||
|
||||
polynomial_external_mad(&mut c, self, rhs);
|
||||
|
||||
c
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::{entities::Polynomial, Torus};
|
||||
|
||||
#[test]
|
||||
fn can_add_polynomials() {
|
||||
let a = Polynomial::new(&[1, 2, 3]);
|
||||
|
||||
let b = Polynomial::new(&[4, 5, 6]);
|
||||
|
||||
let expected = Polynomial::new(&[5, 7, 9]);
|
||||
|
||||
let c = a.deref() + b.deref();
|
||||
assert_eq!(c, expected);
|
||||
}
|
||||
|
||||
// A golden test but easier than recoding the logic again.
|
||||
#[test]
|
||||
fn can_multiply_by_positive_monomial_negacyclic() {
|
||||
let original = Polynomial::new(&[1, 2, 3, 4].map(Torus::<u64>::from));
|
||||
|
||||
let mut shift_0 = original.clone();
|
||||
shift_0.mul_by_positive_monomial_negacyclic(0);
|
||||
let expected_0 = original.clone();
|
||||
|
||||
assert_eq!(shift_0, expected_0);
|
||||
|
||||
let mut shift_1 = original.clone();
|
||||
shift_1.mul_by_positive_monomial_negacyclic(1);
|
||||
let expected_1 = Polynomial::new(&[
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
Torus::from(1),
|
||||
Torus::from(2),
|
||||
Torus::from(3),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_1, expected_1);
|
||||
|
||||
let mut shift_2 = original.clone();
|
||||
shift_2.mul_by_positive_monomial_negacyclic(2);
|
||||
let expected_2 = Polynomial::new(&[
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
Torus::from(1),
|
||||
Torus::from(2),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_2, expected_2);
|
||||
|
||||
let mut shift_3 = original.clone();
|
||||
shift_3.mul_by_positive_monomial_negacyclic(3);
|
||||
let expected_3 = Polynomial::new(&[
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
Torus::from(1),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_3, expected_3);
|
||||
|
||||
let mut shift_4 = original.clone();
|
||||
shift_4.mul_by_positive_monomial_negacyclic(4);
|
||||
let expected_4 = Polynomial::new(&[
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_4, expected_4);
|
||||
|
||||
let mut shift_5 = original.clone();
|
||||
shift_5.mul_by_positive_monomial_negacyclic(5);
|
||||
let expected_5 = Polynomial::new(&[
|
||||
Torus::from(4u64),
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_5, expected_5);
|
||||
|
||||
let mut shift_6 = original.clone();
|
||||
shift_6.mul_by_positive_monomial_negacyclic(6);
|
||||
let expected_6 = Polynomial::new(&[
|
||||
Torus::from(3u64),
|
||||
Torus::from(4u64),
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_6, expected_6);
|
||||
|
||||
let mut shift_7 = original.clone();
|
||||
shift_7.mul_by_positive_monomial_negacyclic(7);
|
||||
let expected_7 = Polynomial::new(&[
|
||||
Torus::from(2u64),
|
||||
Torus::from(3u64),
|
||||
Torus::from(4u64),
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_7, expected_7);
|
||||
|
||||
let mut shift_8 = original.clone();
|
||||
shift_8.mul_by_positive_monomial_negacyclic(8);
|
||||
let expected_8 = original.clone();
|
||||
|
||||
assert_eq!(shift_8, expected_8);
|
||||
|
||||
let mut shift_9 = original.clone();
|
||||
shift_9.mul_by_positive_monomial_negacyclic(9);
|
||||
let expected_9 = shift_1.clone();
|
||||
|
||||
assert_eq!(shift_9, expected_9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_by_negative_monomial_negacyclic() {
|
||||
let original = Polynomial::new(&[1, 2, 3, 4].map(Torus::<u64>::from));
|
||||
|
||||
let mut shift_0 = original.clone();
|
||||
shift_0.mul_by_negative_monomial_negacyclic(0);
|
||||
let expected_0 = original.clone();
|
||||
|
||||
assert_eq!(shift_0, expected_0);
|
||||
|
||||
let mut shift_1 = original.clone();
|
||||
shift_1.mul_by_negative_monomial_negacyclic(1);
|
||||
let expected_1 = Polynomial::new(&[
|
||||
Torus::from(2),
|
||||
Torus::from(3),
|
||||
Torus::from(4),
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_1, expected_1);
|
||||
|
||||
let mut shift_2 = original.clone();
|
||||
shift_2.mul_by_negative_monomial_negacyclic(2);
|
||||
let expected_2 = Polynomial::new(&[
|
||||
Torus::from(3),
|
||||
Torus::from(4),
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_2, expected_2);
|
||||
|
||||
let mut shift_3 = original.clone();
|
||||
shift_3.mul_by_negative_monomial_negacyclic(3);
|
||||
let expected_3 = Polynomial::new(&[
|
||||
Torus::from(4),
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_3, expected_3);
|
||||
|
||||
let mut shift_4 = original.clone();
|
||||
shift_4.mul_by_negative_monomial_negacyclic(4);
|
||||
let expected_4 = Polynomial::new(&[
|
||||
Torus::from(1u64.wrapping_neg()),
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_4, expected_4);
|
||||
|
||||
let mut shift_5 = original.clone();
|
||||
shift_5.mul_by_negative_monomial_negacyclic(5);
|
||||
let expected_5 = Polynomial::new(&[
|
||||
Torus::from(2u64.wrapping_neg()),
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
Torus::from(1),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_5, expected_5);
|
||||
|
||||
let mut shift_6 = original.clone();
|
||||
shift_6.mul_by_negative_monomial_negacyclic(6);
|
||||
let expected_6 = Polynomial::new(&[
|
||||
Torus::from(3u64.wrapping_neg()),
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
Torus::from(1),
|
||||
Torus::from(2),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_6, expected_6);
|
||||
|
||||
let mut shift_7 = original.clone();
|
||||
shift_7.mul_by_negative_monomial_negacyclic(7);
|
||||
let expected_7 = Polynomial::new(&[
|
||||
Torus::from(4u64.wrapping_neg()),
|
||||
Torus::from(1),
|
||||
Torus::from(2),
|
||||
Torus::from(3),
|
||||
]);
|
||||
|
||||
assert_eq!(shift_7, expected_7);
|
||||
|
||||
let mut shift_8 = original.clone();
|
||||
shift_8.mul_by_negative_monomial_negacyclic(8);
|
||||
let expected_8 = original.clone();
|
||||
|
||||
assert_eq!(shift_8, expected_8);
|
||||
|
||||
let mut shift_9 = original.clone();
|
||||
shift_9.mul_by_negative_monomial_negacyclic(9);
|
||||
let expected_9 = shift_1.clone();
|
||||
|
||||
assert_eq!(shift_9, expected_9);
|
||||
}
|
||||
}
|
||||
158
sunscreen_tfhe/src/entities/polynomial_fft.rs
Normal file
158
sunscreen_tfhe/src/entities/polynomial_fft.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use num::{traits::MulAdd, Complex};
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
fft::negacyclic::get_fft,
|
||||
scratch::allocate_scratch,
|
||||
FrequencyTransform, FromF64, NumBits, PolynomialDegree,
|
||||
};
|
||||
|
||||
use super::PolynomialRef;
|
||||
|
||||
dst! {
|
||||
/// The FFT of a polynomial. See [`Polynomial`](crate::entities::Polynomial)
|
||||
/// for the non-FFT variant.
|
||||
PolynomialFft,
|
||||
PolynomialFftRef,
|
||||
NoWrapper,
|
||||
(Debug, Clone, PartialEq, Eq,),
|
||||
()
|
||||
}
|
||||
dst_iter!(
|
||||
PolynomialFftIterator,
|
||||
PolynomialFftIteratorMut,
|
||||
NoWrapper,
|
||||
PolynomialFftRef,
|
||||
()
|
||||
);
|
||||
|
||||
impl<T> OverlaySize for PolynomialFftRef<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
type Inputs = PolynomialDegree;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
t.0 / 2
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PolynomialFft<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// Create a new polynomial with the given length in the fourier domain.
|
||||
pub fn new(data: &[T]) -> Self {
|
||||
Self {
|
||||
data: data.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> PolynomialFftRef<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
/// Returns the coefficients of the polynomial in the fourier domain.
|
||||
pub fn coeffs(&self) -> &[T] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Returns the mutable coefficients of the polynomial in the fourier domain.
|
||||
pub fn coeffs_mut(&mut self) -> &mut [T] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
/// Returns the number of coefficients in the polynomial.
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Returns true if the polynomial has no coefficients.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl PolynomialFftRef<Complex<f64>> {
|
||||
/// Compute the inverse FFT of the polynomial.
|
||||
pub fn ifft<T>(&self, poly: &mut PolynomialRef<T>)
|
||||
where
|
||||
T: Clone + FromF64 + NumBits,
|
||||
{
|
||||
assert!(self.len().is_power_of_two());
|
||||
assert_eq!(self.len() * 2, poly.len());
|
||||
|
||||
let log_n = poly.len().ilog2() as usize;
|
||||
|
||||
let fft = get_fft(log_n);
|
||||
|
||||
let mut ifft = allocate_scratch::<f64>(poly.len());
|
||||
let ifft = ifft.as_mut_slice();
|
||||
|
||||
fft.reverse(&self.data, ifft);
|
||||
|
||||
// When the exponent != 0 && exponent != 1024,
|
||||
// IEEE-754 doubles are represented as -1**s * 1.m * 2**(e - 1023).
|
||||
//
|
||||
// m is 52 bits, e is 11 bits, and s is 1 bit.
|
||||
//
|
||||
// Thus, to compute 2**x, we set e = 1023 + x, m=0, and s = 0. So, we just
|
||||
// need to fill in EXP and shift it up 52 places.
|
||||
//
|
||||
// We first reduce modulo q
|
||||
let exp: u64 = 1023 + T::BITS as u64;
|
||||
let q: f64 = f64::from_bits(exp << 52);
|
||||
|
||||
let exp_div_2 = exp - 1;
|
||||
let q_div_2 = f64::from_bits(exp_div_2 << 52);
|
||||
|
||||
// Exploit the fact that q is a power of 2 when performing the modulo
|
||||
// reduction. Could possibly be even faster by masking and shifting
|
||||
// the mantissa and tweaking the exponent. However, profiling on ARM
|
||||
// indicates this is no longer a bottleneck with the code below.
|
||||
//
|
||||
// See https://stackoverflow.com/questions/49139283/are-there-any-numbers-that-enable-fast-modulo-calculation-on-floats
|
||||
//
|
||||
// Don't know why Rust decides not to inline this. Inlining allows
|
||||
// the below loop to get unrolled, vectorized, and division gets
|
||||
// replaced with multiplication since q is a known constant.
|
||||
#[inline(always)]
|
||||
fn mod_q(val: f64, q: f64) -> f64 {
|
||||
f64::mul_add(-(val / q).trunc(), q, val)
|
||||
}
|
||||
|
||||
for (o, ifft) in poly.coeffs_mut().iter_mut().zip(ifft.iter()) {
|
||||
let mut ifft = mod_q(*ifft, q);
|
||||
|
||||
// Next, we need to adjust x outside [-q/2, q/2) to wrap to the correct torus
|
||||
// point.
|
||||
if ifft >= q_div_2 {
|
||||
ifft -= q;
|
||||
} else if ifft <= -q_div_2 {
|
||||
ifft += q;
|
||||
}
|
||||
|
||||
*o = T::from_f64(ifft);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the multiplication of two polynomials as `c += a * b`. This is
|
||||
/// more efficient than the naive method, and has a runtime of O(N). Note
|
||||
/// that performing the FFT and IFFT to get in and out of the fourier domain
|
||||
/// costs O(N log N).
|
||||
pub fn multiply_add(
|
||||
&mut self,
|
||||
a: &PolynomialFftRef<Complex<f64>>,
|
||||
b: &PolynomialFftRef<Complex<f64>>,
|
||||
) {
|
||||
for ((c, a), b) in self
|
||||
.coeffs_mut()
|
||||
.iter_mut()
|
||||
.zip(a.coeffs().iter())
|
||||
.zip(b.coeffs().iter())
|
||||
{
|
||||
*c = a.mul_add(b, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
52
sunscreen_tfhe/src/entities/polynomial_list.rs
Normal file
52
sunscreen_tfhe/src/entities/polynomial_list.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use num::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::{NoWrapper, OverlaySize},
|
||||
PolynomialDegree,
|
||||
};
|
||||
|
||||
use super::{PolynomialIterator, PolynomialIteratorMut, PolynomialRef};
|
||||
|
||||
dst! {
|
||||
/// A list of polynomials.
|
||||
PolynomialList,
|
||||
PolynomialListRef,
|
||||
NoWrapper,
|
||||
(Debug, Clone, PartialEq, Eq,),
|
||||
()
|
||||
}
|
||||
|
||||
impl<S: Clone> OverlaySize for PolynomialListRef<S> {
|
||||
type Inputs = (PolynomialDegree, usize);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
PolynomialRef::<S>::size(t.0) * t.1
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> PolynomialList<S>
|
||||
where
|
||||
S: Clone + Zero,
|
||||
{
|
||||
/// Create a new polynomial list, where each polynomial has the same degree.
|
||||
pub fn new(degree: PolynomialDegree, count: usize) -> Self {
|
||||
Self {
|
||||
data: vec![S::zero(); degree.0 * count],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> PolynomialListRef<S>
|
||||
where
|
||||
S: Clone + Zero,
|
||||
{
|
||||
/// Iterate over the polynomials in the list.
|
||||
pub fn iter(&self, degree: PolynomialDegree) -> PolynomialIterator<S> {
|
||||
PolynomialIterator::new(&self.data, PolynomialRef::<S>::size(degree))
|
||||
}
|
||||
|
||||
/// Iterate over the polynomials in the list mutably.
|
||||
pub fn iter_mut(&mut self, degree: PolynomialDegree) -> PolynomialIteratorMut<S> {
|
||||
PolynomialIteratorMut::new(&mut self.data, PolynomialRef::<S>::size(degree))
|
||||
}
|
||||
}
|
||||
134
sunscreen_tfhe/src/entities/private_functional_keyswitch_key.rs
Normal file
134
sunscreen_tfhe/src/entities/private_functional_keyswitch_key.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize,
|
||||
entities::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef},
|
||||
GlweDef, GlweDimension, LweDef, LweDimension, PrivateFunctionalKeyswitchLweCount, RadixCount,
|
||||
RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::LweSecretKeyRef;
|
||||
|
||||
dst! {
|
||||
/// Key for Private Functional Key Switching. See
|
||||
/// [`module`](crate::ops::keyswitch::private_functional_keyswitch)
|
||||
/// documentation for more details.
|
||||
PrivateFunctionalKeyswitchKey,
|
||||
PrivateFunctionalKeyswitchKeyRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
dst_iter!(
|
||||
PrivateFunctionalKeyswitchKeyIter,
|
||||
PrivateFunctionalKeyswitchKeyIterMut,
|
||||
Torus,
|
||||
PrivateFunctionalKeyswitchKeyRef,
|
||||
(TorusOps,)
|
||||
);
|
||||
|
||||
impl<S: TorusOps> OverlaySize for PrivateFunctionalKeyswitchKeyRef<S> {
|
||||
type Inputs = (
|
||||
LweDimension,
|
||||
GlweDimension,
|
||||
RadixCount,
|
||||
PrivateFunctionalKeyswitchLweCount,
|
||||
);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlevCiphertextRef::<S>::size((t.1, t.2)) * (LweSecretKeyRef::<S>::size(t.0) + 1) * t.3 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> PrivateFunctionalKeyswitchKey<S> {
|
||||
/// Construct a new uninitialized [`PrivateFunctionalKeyswitchKey`]. This key is used
|
||||
/// compute a secret function mapping `lwe_count`
|
||||
/// [`LweCiphertext`](crate::entities::LweCiphertext)s to a
|
||||
/// [`GlweCiphertext`](crate::entities::GlweCiphertext).
|
||||
///
|
||||
/// # Remarks
|
||||
/// The key is composed of a `(from_lwe.dim + 1) * lwe_count.0` matrix of
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s. The leading
|
||||
/// dimension contains radix-scaled encryptions of the bits in a
|
||||
/// [`LweSecretKey`](crate::entities::LweSecretKey) followed by a scaled
|
||||
/// encryption of -1.
|
||||
///
|
||||
/// Plaintexts in the GLEV are scaled by the usual `q/beta^(j+1)`.
|
||||
///
|
||||
/// The trailing dimension iterates over `0..lwe_count.0`.
|
||||
pub fn new(
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
lwe_count: &PrivateFunctionalKeyswitchLweCount,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
Torus::zero();
|
||||
PrivateFunctionalKeyswitchKeyRef::<S>::size((
|
||||
from_lwe.dim,
|
||||
to_glwe.dim,
|
||||
radix.count,
|
||||
*lwe_count
|
||||
))
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> PrivateFunctionalKeyswitchKeyRef<S> {
|
||||
/// Returns an iterator over the
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s that
|
||||
/// compose this key.
|
||||
///
|
||||
/// # See also
|
||||
/// To make sense of the layout, see also [`PrivateFunctionalKeyswitchKey::new()`](./struct.PrivateFunctionalKeyswitchKey.html#remarks).
|
||||
pub fn glevs(
|
||||
&self,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextIterator<S> {
|
||||
GlevCiphertextIterator::new(
|
||||
self.as_slice(),
|
||||
GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count)),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a muitable iterator over the
|
||||
/// [`GlevCiphertext`](crate::entities::GlevCiphertext)s that compose this
|
||||
/// key.
|
||||
///
|
||||
/// # See also
|
||||
/// To make sense of the layout, see also [`PrivateFunctionalKeyswitchKey::new()`](./struct.PrivateFunctionalKeyswitchKey.html#remarks).
|
||||
pub fn glevs_mut(
|
||||
&mut self,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextIteratorMut<S> {
|
||||
GlevCiphertextIteratorMut::new(
|
||||
self.as_mut_slice(),
|
||||
GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count)),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Assert this value is correct for the given parameters.
|
||||
pub fn assert_valid(
|
||||
&self,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
lwe_count: &PrivateFunctionalKeyswitchLweCount,
|
||||
) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
PrivateFunctionalKeyswitchKeyRef::<S>::size((
|
||||
from_lwe.dim,
|
||||
to_glwe.dim,
|
||||
radix.count,
|
||||
*lwe_count
|
||||
))
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::OverlaySize,
|
||||
entities::{GlevCiphertextIterator, GlevCiphertextIteratorMut, GlevCiphertextRef},
|
||||
GlweDef, GlweDimension, LweDef, LweDimension, RadixCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::GlweCiphertextRef;
|
||||
|
||||
dst! {
|
||||
/// Public Functional Key Switching Key. See
|
||||
/// [`module`](crate::ops::keyswitch::public_functional_keyswitch)
|
||||
/// documentation for more details.
|
||||
PublicFunctionalKeyswitchKey,
|
||||
PublicFunctionalKeyswitchKeyRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for PublicFunctionalKeyswitchKeyRef<S> {
|
||||
type Inputs = (LweDimension, GlweDimension, RadixCount);
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlweCiphertextRef::<S>::size(t.1) * t.0 .0 * t.2 .0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> PublicFunctionalKeyswitchKey<S> {
|
||||
/// Construct a new uninitialized [`PublicFunctionalKeyswitchKey`]. This key is used
|
||||
/// when performing a [`public_functional_keyswitch`](crate::ops::keyswitch::public_functional_keyswitch).
|
||||
pub fn new(from_lwe: &LweDef, to_glwe: &GlweDef, radix: &RadixDecomposition) -> Self {
|
||||
let len =
|
||||
PublicFunctionalKeyswitchKeyRef::<S>::size((from_lwe.dim, to_glwe.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> PublicFunctionalKeyswitchKeyRef<S> {
|
||||
/// Iterate over the rows of the [`PublicFunctionalKeyswitchKey`].
|
||||
pub fn glevs(
|
||||
&self,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextIterator<S> {
|
||||
let stride = GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count));
|
||||
|
||||
GlevCiphertextIterator::new(self.as_slice(), stride)
|
||||
}
|
||||
|
||||
/// Iterate over the rows of the [`PublicFunctionalKeyswitchKey`] mutably.
|
||||
pub fn glevs_mut(
|
||||
&mut self,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlevCiphertextIteratorMut<S> {
|
||||
let stride = GlevCiphertextRef::<S>::size((to_glwe.dim, radix.count));
|
||||
|
||||
GlevCiphertextIteratorMut::new(self.as_mut_slice(), stride)
|
||||
}
|
||||
|
||||
/// Asserts that the key is valid for the given parameters.
|
||||
pub(crate) fn assert_valid(
|
||||
&self,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
assert_eq!(
|
||||
self.as_slice().len(),
|
||||
PublicFunctionalKeyswitchKeyRef::<S>::size((from_lwe.dim, to_glwe.dim, radix.count))
|
||||
);
|
||||
}
|
||||
}
|
||||
83
sunscreen_tfhe/src/entities/univariate_lookup_table.rs
Normal file
83
sunscreen_tfhe/src/entities/univariate_lookup_table.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::{FromMutSlice, FromSlice, OverlaySize},
|
||||
entities::PolynomialRef,
|
||||
ops::{bootstrapping::generate_lut, encryption::trivially_encrypt_glwe_ciphertext},
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, GlweDimension, PlaintextBits, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::GlweCiphertextRef;
|
||||
|
||||
dst! {
|
||||
/// Lookup table for a univariate function used during
|
||||
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap)
|
||||
/// and [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
|
||||
UnivariateLookupTable,
|
||||
UnivariateLookupTableRef,
|
||||
Torus,
|
||||
(Clone, Debug, Serialize, Deserialize),
|
||||
(TorusOps)
|
||||
}
|
||||
|
||||
impl<S: TorusOps> OverlaySize for UnivariateLookupTableRef<S> {
|
||||
type Inputs = GlweDimension;
|
||||
|
||||
fn size(t: Self::Inputs) -> usize {
|
||||
GlweCiphertextRef::<S>::size(t)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> UnivariateLookupTable<S> {
|
||||
/// Creates a lookup table that is trivially encrypted.
|
||||
pub fn trivial_from_fn<F>(map: F, glwe: &GlweDef, plaintext_bits: PlaintextBits) -> Self
|
||||
where
|
||||
F: Fn(u64) -> u64,
|
||||
{
|
||||
let mut lut = UnivariateLookupTable {
|
||||
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
};
|
||||
|
||||
lut.fill_trivial_from_fn(map, glwe, plaintext_bits);
|
||||
|
||||
lut
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> UnivariateLookupTableRef<S> {
|
||||
/// Return the underlying GLWE representation of a lookup table.
|
||||
pub fn glwe(&self) -> &GlweCiphertextRef<S> {
|
||||
GlweCiphertextRef::from_slice(&self.data)
|
||||
}
|
||||
|
||||
/// Return a mutable representation of the underlying GLWE representation of
|
||||
/// a lookup table.
|
||||
pub fn glwe_mut(&mut self) -> &mut GlweCiphertextRef<S> {
|
||||
GlweCiphertextRef::from_mut_slice(&mut self.data)
|
||||
}
|
||||
|
||||
/// Generates a look up table filled with the values from the provided map,
|
||||
/// and trivially encrypts the lookup table.
|
||||
pub fn fill_trivial_from_fn<F: Fn(u64) -> u64>(
|
||||
&mut self,
|
||||
map: F,
|
||||
glwe: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) {
|
||||
allocate_scratch_ref!(poly, PolynomialRef<Torus<S>>, (glwe.dim.polynomial_degree));
|
||||
|
||||
generate_lut(poly, map, glwe, plaintext_bits);
|
||||
|
||||
trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe);
|
||||
}
|
||||
|
||||
/// Creates a lookup table filled with the same value at every entry.
|
||||
pub fn fill_with_constant(&mut self, val: S, glwe: &GlweDef, plaintext_bits: PlaintextBits) {
|
||||
self.clear();
|
||||
for o in self.glwe_mut().b_mut(glwe).coeffs_mut() {
|
||||
*o = Torus::encode(val, plaintext_bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
4
sunscreen_tfhe/src/error.rs
Normal file
4
sunscreen_tfhe/src/error.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
#[derive(thiserror::Error)]
|
||||
pub enum Error {
|
||||
OutOfRange
|
||||
}
|
||||
939
sunscreen_tfhe/src/high_level.rs
Normal file
939
sunscreen_tfhe/src/high_level.rs
Normal file
@@ -0,0 +1,939 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use crate::{
|
||||
rand::Stddev, GlweDef, GlweDimension, GlweSize, LweDef, LweDimension, PolynomialDegree,
|
||||
RadixCount, RadixDecomposition, RadixLog,
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const TEST_RADIX: RadixDecomposition = RadixDecomposition {
|
||||
count: RadixCount(3),
|
||||
radix_log: RadixLog(4),
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const TEST_GLWE_DEF_1: GlweDef = GlweDef {
|
||||
dim: GlweDimension {
|
||||
polynomial_degree: PolynomialDegree(128),
|
||||
size: GlweSize(2),
|
||||
},
|
||||
std: Stddev(1e-16),
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const TEST_GLWE_DEF_2: GlweDef = GlweDef {
|
||||
dim: GlweDimension {
|
||||
polynomial_degree: PolynomialDegree(256),
|
||||
size: GlweSize(3),
|
||||
},
|
||||
std: Stddev(1e-16),
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const TEST_LWE_DEF_1: LweDef = LweDef {
|
||||
dim: LweDimension(128),
|
||||
std: Stddev(1e-16),
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const TEST_LWE_DEF_2: LweDef = LweDef {
|
||||
dim: LweDimension(256),
|
||||
std: Stddev(1e-16),
|
||||
};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const TEST_LWE_DEF_3: LweDef = LweDef {
|
||||
dim: LweDimension(128),
|
||||
std: Stddev(0.0),
|
||||
};
|
||||
|
||||
/// TFHE functionality related to key generation.
|
||||
pub mod keygen {
|
||||
use crate::{
|
||||
entities::{
|
||||
BootstrapKey, CircuitBootstrappingKeyswitchKeys, GlweSecretKey, GlweSecretKeyRef,
|
||||
LweKeyswitchKey, LwePublicKey, LweSecretKey, LweSecretKeyRef,
|
||||
},
|
||||
ops::{
|
||||
bootstrapping::generate_bootstrap_key,
|
||||
keyswitch::{
|
||||
lwe_keyswitch_key::generate_keyswitch_key_lwe,
|
||||
private_functional_keyswitch::generate_circuit_bootstrapping_pfks_keys,
|
||||
},
|
||||
},
|
||||
GlweDef, LweDef, RadixDecomposition,
|
||||
};
|
||||
|
||||
/// Generate a new binary [`LweSecretKey`] under the given LWE parameters.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Any functions that use this key will need the same [`LweDef`].
|
||||
///
|
||||
/// These keys may be used to create bootstrapping keys.
|
||||
///
|
||||
/// # Panics
|
||||
/// If [`LweDef`] is invalid.
|
||||
///
|
||||
/// # Security
|
||||
/// These keys are *not* secure under some threshold cryptography settings.
|
||||
/// Under those settings, you should use [`generate_uniform_lwe_sk`].
|
||||
///
|
||||
/// This key is secret and care should be taken as to which parties
|
||||
/// possess it. Anyone who possesses the returned [`LweSecretKey`]
|
||||
/// can decrypt any messages encrypted under it.
|
||||
pub fn generate_binary_lwe_sk(params: &LweDef) -> LweSecretKey<u64> {
|
||||
LweSecretKey::generate_binary(params)
|
||||
}
|
||||
|
||||
/// Generate a new binary [`LweSecretKey`] under the given LWE parameters.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Any functions that use this key will need the same [`LweDef`].
|
||||
///
|
||||
/// These keys may *not* directly be used to create bootstrapping keys.
|
||||
/// However, in threshold schemes that use them, usually you derive
|
||||
/// a binary key from uniform key shares.
|
||||
///
|
||||
/// # Panics
|
||||
/// If [`LweDef`] is invalid.
|
||||
///
|
||||
/// # Security
|
||||
/// This key is secret and care should be taken as to which parties
|
||||
/// possess it. Anyone who possesses the returned [`LweSecretKey`]
|
||||
/// can decrypt any messages encrypted under it.
|
||||
pub fn generate_uniform_lwe_sk(params: &LweDef) -> LweSecretKey<u64> {
|
||||
LweSecretKey::generate_uniform(params)
|
||||
}
|
||||
|
||||
/// Generate a new [`LwePublicKey`] under the given parameters. This
|
||||
/// public key is paired with `sk` - that is messages encrypted under
|
||||
/// this public key can be decrypted with `sk`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Any functions that use this key will need to use the same `params`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` is invalid
|
||||
/// If `params` doesn't correspond with `sk`.
|
||||
///
|
||||
/// # Security
|
||||
/// This key is public and sharing it does not compromise semantic
|
||||
/// security.
|
||||
pub fn generate_lwe_pk(sk: &LweSecretKeyRef<u64>, params: &LweDef) -> LwePublicKey<u64> {
|
||||
LwePublicKey::generate(sk, params)
|
||||
}
|
||||
|
||||
/// Generate a new GLWE secret key under the given GLWE parameters.
|
||||
/// The key will consist of {0,1}-valued coefficients.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Any functions that use this key will need the same [`GlweDef`].
|
||||
///
|
||||
/// # Panics
|
||||
/// If [`GlweDef`] is invalid.
|
||||
///
|
||||
/// # Security
|
||||
/// Binary [GlweSecretKey]s are insecure in some threshold cryptography
|
||||
/// settings. Under those settings, you should use
|
||||
/// [`generate_uniform_glwe_sk`].
|
||||
///
|
||||
/// This key is secret and care should be taken as to which parties
|
||||
/// possess it. Anyone who possesses the returned [`GlweSecretKey`]
|
||||
/// can decrypt any messages encrypted under it.
|
||||
pub fn generate_binary_glwe_sk(params: &GlweDef) -> GlweSecretKey<u64> {
|
||||
GlweSecretKey::generate_binary(params)
|
||||
}
|
||||
|
||||
/// Generate a new GLWE secret key under the given GLWE parameters.
|
||||
/// The key will consist of uniform coefficients over the Torus's
|
||||
/// isomorphic ring.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Any functions that use this key will need the same [`GlweDef`].
|
||||
///
|
||||
/// This is used over binary in threshold cryptography settings.
|
||||
///
|
||||
/// # Panics
|
||||
/// If [`GlweDef`] is invalid.
|
||||
///
|
||||
/// # Security
|
||||
/// This key is secret and care should be taken as to which parties
|
||||
/// possess it. Anyone who possesses the returned [`GlweSecretKey`]
|
||||
/// can decrypt any messages encrypted under it.
|
||||
pub fn generate_uniform_glwe_sk(params: &GlweDef) -> GlweSecretKey<u64> {
|
||||
GlweSecretKey::generate_uniform(params)
|
||||
}
|
||||
|
||||
/// Generate a new bootstrapping key, which is used in bootstrapping operations.
|
||||
///
|
||||
/// See also
|
||||
/// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap)
|
||||
/// and [`circuit_bootstrap`](super::evaluation::circuit_bootstrap).
|
||||
///
|
||||
/// # Remarks
|
||||
/// A bootstrapping key is an encryption of an LWE secret key under a different
|
||||
/// GLWE secret key reinterpreted as an LWE key.
|
||||
///
|
||||
/// `lwe` and `glwe` are the LWE and GLWE parameters used when you generated the
|
||||
/// [`LweSecretKey`] and [`GlweSecretKey`], respectively.
|
||||
///
|
||||
/// `radix` specifies the decomposition to use during bootstrapping.
|
||||
///
|
||||
/// You should use the same `lwe`, `glwe`, `radix` values here as when you
|
||||
/// call
|
||||
/// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap).
|
||||
///
|
||||
/// The returned bootstrapping key is not immediately useful outside of serialization.
|
||||
/// You need to FFT transform is first (see [fft_bootstrap_key](super::fft::fft_bootstrap_key)).
|
||||
///
|
||||
/// ## Circuit bootstrapping
|
||||
/// When using [`circuit_bootstrap`](super::evaluation::circuit_bootstrap), `pbs_radix`
|
||||
/// must match this `radix`, `lwe_0` must match `lwe`, and `glwe_2` must match `glwe`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `lwe`, `glwe`, or `radix` are invalid.
|
||||
/// If `glwe_key` isn't valid under `glwe`.
|
||||
/// If `sk` isn't valid under `lwe`.
|
||||
///
|
||||
/// # Security
|
||||
/// The returned key is public and does not compromise semantic security.
|
||||
/// However, anyone who possesses `glwe_key` can easily use the returned
|
||||
/// [`BootstrapKey`] to recover `sk`.
|
||||
pub fn generate_bootstrapping_key(
|
||||
sk: &LweSecretKey<u64>,
|
||||
glwe_key: &GlweSecretKey<u64>,
|
||||
lwe: &LweDef,
|
||||
glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> BootstrapKey<u64> {
|
||||
let mut bsk = BootstrapKey::new(lwe, glwe, radix);
|
||||
|
||||
generate_bootstrap_key(&mut bsk, sk, glwe_key, glwe, radix);
|
||||
|
||||
bsk
|
||||
}
|
||||
|
||||
/// Generate an LWE keyswitch key. LWE keyswitching allows you take an encryption of `m`
|
||||
/// under [LWESecretKey](crate::entities::LweSecretKey) `from_sk` and turn it into an
|
||||
/// encryption of `m` under `to_sk`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `from_lwe` and `to_lwe` are the parameters under which you generated `from_sk` and
|
||||
/// `to_sk`, respectively.
|
||||
///
|
||||
/// `radix` specifies the decomposition to use during LWE keyswitching.
|
||||
///
|
||||
/// TODO: mention keyswitch operation.
|
||||
///
|
||||
/// # Panics
|
||||
/// If the `from_lwe` parameters aren't valid for `from_sk`.
|
||||
/// If the `to_lwe` parameters aren't valid for `to_sk`.
|
||||
/// If `from_lwe`, `to_lwe`, or `radix` parameters are invalid.
|
||||
///
|
||||
/// # Security
|
||||
/// The returned key is public and sharing it does not compromise semantic
|
||||
/// security. However, anyone who possesses `to_lwe` will effectively
|
||||
/// be able to decrypt any message encrypted under `from_lwe` with the
|
||||
/// returned [`LweKeyswitchKey`].
|
||||
pub fn generate_ksk(
|
||||
from_sk: &LweSecretKeyRef<u64>,
|
||||
to_sk: &LweSecretKeyRef<u64>,
|
||||
from_lwe: &LweDef,
|
||||
to_lwe: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> LweKeyswitchKey<u64> {
|
||||
let mut ksk = LweKeyswitchKey::new(from_lwe, to_lwe, radix);
|
||||
|
||||
generate_keyswitch_key_lwe(&mut ksk, from_sk, to_sk, to_lwe, radix);
|
||||
|
||||
ksk
|
||||
}
|
||||
|
||||
/// Generate a set of [`CircuitBootstrappingKeyswitchKeys`] to use during
|
||||
/// [circuit_bootstrap](super::evaluation::circuit_bootstrap) operations.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Internally, [`CircuitBootstrappingKeyswitchKeys`] is a list of
|
||||
/// [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey)s.
|
||||
///
|
||||
/// During [circuit_bootstrap](super::evaluation::circuit_bootstrap) operations,
|
||||
/// these keys are used to convert [`LweCiphertext`](crate::entities::LweCiphertext)s
|
||||
/// encrypted under `from_sk` into [`GlweCiphertext`](crate::entities::GlweCiphertext)s
|
||||
/// encrypted under `to_sk`. These [`GlweCiphertext`](crate::entities::GlweCiphertext)s
|
||||
/// together form a [`GgswCiphertext`](crate::entities::GgswCiphertext).
|
||||
///
|
||||
/// The `from_lwe` and `to_glwe` parameters correspond to those used when you generated
|
||||
/// `from_sk` and `to_sk`, respectively.
|
||||
///
|
||||
/// The `radix` parameter describes the [`RadixDecomposition`] to use during the private
|
||||
/// functional keyswitch operation. When performing a
|
||||
/// [circuit_bootstrap](super::evaluation::circuit_bootstrap), these same parameters should
|
||||
/// be passed as `pfks_radix`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `from_lwe`, `to_glwe`, or `radix` are invalid.
|
||||
/// If `from_lwe` or `to_glwe` parameters don't correspond with `to_sk` or `to_glwe`, respectively.
|
||||
///
|
||||
/// # Security
|
||||
/// The returned [`CircuitBootstrappingKeyswitchKeys`] are public and
|
||||
/// don't in of themselves compromise semantic security. However,
|
||||
/// anyone who possesses `to_sk` can easily recover `from_sk` using this
|
||||
/// information.
|
||||
pub fn generate_cbs_ksk(
|
||||
from_sk: &LweSecretKeyRef<u64>,
|
||||
to_sk: &GlweSecretKeyRef<u64>,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> CircuitBootstrappingKeyswitchKeys<u64> {
|
||||
let mut cbs_ksk = CircuitBootstrappingKeyswitchKeys::new(from_lwe, to_glwe, radix);
|
||||
|
||||
generate_circuit_bootstrapping_pfks_keys(
|
||||
&mut cbs_ksk,
|
||||
from_sk,
|
||||
to_sk,
|
||||
from_lwe,
|
||||
to_glwe,
|
||||
radix,
|
||||
);
|
||||
|
||||
cbs_ksk
|
||||
}
|
||||
}
|
||||
|
||||
/// TFHE functionality related to encryption.
|
||||
pub mod encryption {
|
||||
use crate::{
|
||||
entities::{
|
||||
GgswCiphertext, GgswCiphertextRef, GlweCiphertext, GlweCiphertextRef, GlweSecretKeyRef,
|
||||
LweCiphertext, LweCiphertextRef, LwePublicKeyRef, LweSecretKeyRef, Polynomial,
|
||||
PolynomialRef, TlwePublicEncRandomness,
|
||||
},
|
||||
ops::encryption::{encrypt_ggsw_ciphertext_scalar, trivially_encrypt_lwe_ciphertext},
|
||||
CarryBits, GlweDef, LweDef, PlaintextBits, RadixDecomposition, Torus,
|
||||
};
|
||||
|
||||
/// Create an [`LweCiphertext`] encryption of `val` under
|
||||
/// [LweSecretKey](crate::entities::LweSecretKey) `sk`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk) to
|
||||
/// generate an [`LweSecretKey`](crate::entities::LweSecretKey).
|
||||
///
|
||||
/// `params` should be the same parameters used when generating the secret key.
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` is invalid.
|
||||
/// If `params` don't correspond with `sk`.
|
||||
pub fn encrypt_lwe_secret(
|
||||
val: u64,
|
||||
sk: &LweSecretKeyRef<u64>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> LweCiphertext<u64> {
|
||||
sk.encrypt(val, params, plaintext_bits).0
|
||||
}
|
||||
|
||||
/// Create a tuple containing an [`LweCiphertext`] encryption of `val`
|
||||
/// under [LweSecretKey](crate::entities::LweSecretKey) `sk` and the
|
||||
/// randomness used to generate it.
|
||||
///
|
||||
/// This randomness can be used to produce zero-knowledge proofs.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk)
|
||||
/// to generate an [`LweSecretKey`](crate::entities::LweSecretKey).
|
||||
///
|
||||
/// `params` should be the same parameters used when generating the secret key.
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` is invalid.
|
||||
/// If `params` don't correspond with `sk`.
|
||||
///
|
||||
/// # Security
|
||||
/// Revealing the returned randomness compromises the confidentiality
|
||||
/// of the returned [`LweCiphertext`].
|
||||
pub fn encrypt_lwe_secret_and_return_randomness(
|
||||
val: u64,
|
||||
sk: &LweSecretKeyRef<u64>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> (LweCiphertext<u64>, Torus<u64>) {
|
||||
sk.encrypt(val, params, plaintext_bits)
|
||||
}
|
||||
|
||||
/// Create an [LweCiphertext] encryption of `val` under the secret
|
||||
/// key that pairs with `pk`.
|
||||
///
|
||||
/// This function uses the [`LwePublicKey`](crate::entities::LwePublicKey),
|
||||
/// which may be freely distributed without compromising security.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Use [`generate_lwe_pk`](super::keygen::generate_lwe_pk) to generate
|
||||
/// a [`LwePublicKey`](crate::entities::LwePublicKey).
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` is invalid.
|
||||
/// If `params` doesn't correspond with `pk`.
|
||||
pub fn encrypt_lwe(
|
||||
val: u64,
|
||||
pk: &LwePublicKeyRef<u64>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> LweCiphertext<u64> {
|
||||
pk.encrypt(val, params, plaintext_bits).0
|
||||
}
|
||||
|
||||
/// Create a tuple containing an [`LweCiphertext`] encryption of `val`
|
||||
/// under the [LweSecretKey](crate::entities::LweSecretKey) paired with
|
||||
/// `pk` and the randomness used to generate it.
|
||||
///
|
||||
/// This randomness can be used to produce zero-knowledge proofs.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Use [`generate_binary_lwe_sk`](super::keygen::generate_binary_lwe_sk)
|
||||
/// to generate an [`LweSecretKey`](crate::entities::LweSecretKey).
|
||||
///
|
||||
/// `params` should be the same parameters used when generating the secret key.
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` is invalid.
|
||||
/// If `params` don't correspond with `pk`.
|
||||
///
|
||||
/// # Security
|
||||
/// Revealing the returned randomness compromises the confidentiality
|
||||
/// of the returned [`LweCiphertext`].
|
||||
pub fn encrypt_lwe_and_return_randomness(
|
||||
val: u64,
|
||||
pk: &LwePublicKeyRef<u64>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> (LweCiphertext<u64>, TlwePublicEncRandomness<u64>) {
|
||||
pk.encrypt(val, params, plaintext_bits)
|
||||
}
|
||||
|
||||
/// Create a [`GlweCiphertext`] encryption of `pt` under `sk`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` should be a same as those used when creating `sk`.
|
||||
/// `pt.len()` should equal `sk.dim.polynomial_degree.0`.
|
||||
///
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// polynomial coefficients will contain the message. No coefficient in
|
||||
/// `pt` should exceed `2^plaintext_bits.0`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` is invalid.
|
||||
/// If `params` doesn't correspond with `sk`
|
||||
/// If `pt` doesn't have the same number of coefficients as
|
||||
/// `params.dim.polynomial_degree.0`.
|
||||
pub fn encrypt_glwe(
|
||||
pt: &PolynomialRef<u64>,
|
||||
sk: &GlweSecretKeyRef<u64>,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> GlweCiphertext<u64> {
|
||||
sk.encode_encrypt_glwe(pt, params, plaintext_bits)
|
||||
}
|
||||
|
||||
/// Create a trivial LWE encryption. Trivial encryptions have no noise and are thus
|
||||
/// insecure. However, they are useful for creating public constants in TFHE computations.
|
||||
///
|
||||
/// Trivial encryptions are valid encryptions of `val` under *every* LWE secret key
|
||||
/// using the same `params`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` are the parameters under which to create this encryption. Note that homomorphic
|
||||
/// operations require every operand use the same [`LweDef`] parameters.
|
||||
///
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// will contain the message. `val` should not exceed `2^plaintext_bits.0`. Truncation
|
||||
/// of `val`'s most-significant bits will occur otherwise.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `plaintext_bits > 63`.
|
||||
/// If `params` are invalid.
|
||||
///
|
||||
/// # Security
|
||||
/// Trivial encryptions provide no cryptographic security.
|
||||
pub fn trivial_lwe(
|
||||
val: u64,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> LweCiphertext<u64> {
|
||||
let mut ct = LweCiphertext::new(params);
|
||||
|
||||
trivially_encrypt_lwe_ciphertext(&mut ct, &Torus::encode(val, plaintext_bits), params);
|
||||
|
||||
ct
|
||||
}
|
||||
|
||||
/// Decrypt [LweCiphertext] `ct` encrypted under [LweSecretKey](crate::entities::LweSecretKey)
|
||||
/// `sk`. Decode this decrypted value and return it.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` must correspond with `ct` and `sk`.
|
||||
///
|
||||
/// If a different `sk` is used than the one that produced `ct`, the result will be
|
||||
/// garbage.
|
||||
///
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// will contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally,
|
||||
/// you should use the same `plaintext_bits` that were used during encryption.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` doesn't correspond with either `ct` or `sk`.
|
||||
/// If `params` is invalid.
|
||||
pub fn decrypt_lwe(
|
||||
ct: &LweCiphertextRef<u64>,
|
||||
sk: &LweSecretKeyRef<u64>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> u64 {
|
||||
sk.decrypt(ct, params, plaintext_bits)
|
||||
}
|
||||
|
||||
/// Decrypt [LweCiphertext] `ct` encrypted under [LweSecretKey](crate::entities::LweSecretKey)
|
||||
/// `sk` and decode it, handling carry bits.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` must correspond with `ct` and `sk`.
|
||||
///
|
||||
/// If a different `sk` is used than the one that produced `ct`, the result will be
|
||||
/// garbage.
|
||||
///
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// will contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally,
|
||||
/// you should use the same `plaintext_bits` that were used during encryption.
|
||||
///
|
||||
/// `carry_bits` describes how many of the most-significant bits of the
|
||||
/// [`Torus`] will contain the carry.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` doesn't correspond with either `ct` or `sk`.
|
||||
/// If `params` is invalid.
|
||||
pub fn decrypt_lwe_with_carry(
|
||||
ct: &LweCiphertextRef<u64>,
|
||||
sk: &LweSecretKeyRef<u64>,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
carry_bits: CarryBits,
|
||||
) -> u64 {
|
||||
let decrypted = sk.decrypt_without_decode(ct, params);
|
||||
|
||||
// We manually decode here because the padding bit.
|
||||
let plain_bits = plaintext_bits;
|
||||
|
||||
let round_bit = decrypted
|
||||
.inner()
|
||||
.wrapping_shr(64 - plain_bits.0 - carry_bits.0 - 1)
|
||||
& 0x1;
|
||||
let mask = (0x1 << plain_bits.0) - 1;
|
||||
|
||||
(decrypted
|
||||
.inner()
|
||||
.wrapping_shr(64 - plain_bits.0 - carry_bits.0)
|
||||
+ round_bit)
|
||||
& mask
|
||||
}
|
||||
|
||||
/// Decrypt [GlweCiphertext] `ct` encrypted under [GlweSecretKey](crate::entities::GlweSecretKey)
|
||||
/// `sk`. Decode this decrypted value and return it.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` must correspond with `ct` and `sk`.
|
||||
///
|
||||
/// If a different `sk` is used than the one that produced `ct`, the result will be
|
||||
/// garbage.
|
||||
///
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`] polynomial's
|
||||
/// coefficients contain the message. `val` will not exceed `2^plaintext_bits.0`. Generally,
|
||||
/// you should use the same `plaintext_bits` that were used during encryption.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` doesn't correspond with either `ct` or `sk`.
|
||||
/// If `params` is invalid.
|
||||
pub fn decrypt_glwe(
|
||||
ct: &GlweCiphertextRef<u64>,
|
||||
sk: &GlweSecretKeyRef<u64>,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> Polynomial<u64> {
|
||||
sk.decrypt_decode_glwe(ct, params, plaintext_bits)
|
||||
}
|
||||
|
||||
/// Create a trivial encryption of `pt` as a [GlweCiphertext]. Trivial encryptions contain
|
||||
/// no noise and are thus insecure. However, they are useful as public constants in
|
||||
/// a TFHE computation.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Trivial encryption are valid under *every* [GlweSecretKey](crate::entities::GlweSecretKey)
|
||||
/// using `params`. Note that homomorphic operations require every operand to use the same
|
||||
/// `params` and secret key.
|
||||
///
|
||||
/// `plaintext_bits` describes how many of the most-significant bits of the [`Torus`]
|
||||
/// polynomial coefficients will contain the message. `val` should not exceed
|
||||
/// `2^plaintext_bits.0`. Truncation of `val`'s most-significant bits will occur otherwise.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` are invalid.
|
||||
/// If `plaintext_bits >= 64`.
|
||||
///
|
||||
/// # Security
|
||||
/// Trivial encryptions provide no cryptographic security.
|
||||
pub fn trivial_glwe(
|
||||
pt: &PolynomialRef<u64>,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> GlweCiphertext<u64> {
|
||||
let mut result = GlweCiphertext::new(params);
|
||||
|
||||
for (b_out, b_in) in result
|
||||
.b_mut(params)
|
||||
.coeffs_mut()
|
||||
.iter_mut()
|
||||
.zip(pt.coeffs().iter())
|
||||
{
|
||||
*b_out = Torus::encode(*b_in, plaintext_bits);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Create a [`GgswCiphertext`] encrypting `msg` under
|
||||
/// [GlweSecretKey](crate::entities::GlweSecretKey) `sk`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This encrypts `msg` as a constant coefficient polynomial. While [`GgswCiphertext`]s
|
||||
/// theoretically support encrypting arbitrary polynomial messages, such ciphertexts have
|
||||
/// no known uses in TFHE.
|
||||
///
|
||||
/// Typically, you'll want to set `plaintext_bits` to 1 and encrypt a binary
|
||||
/// `msg`. This allows you to use the [`GgswCiphertext`] as the select input to
|
||||
/// a [`cmux`](super::evaluation::cmux) operation.
|
||||
///
|
||||
/// `params` should match those under which you generated `sk`.
|
||||
///
|
||||
/// `radix` defines how many decompositions to include in the result. Subsequent
|
||||
/// [`cmux`](super::evaluation::cmux) operations using this result must use the
|
||||
/// same `radix`.
|
||||
///
|
||||
/// [`GgswCiphertext`]s are not immediate useful outside of serialization. You must
|
||||
/// first take its Fourier transform using [`fft_ggsw`](super::fft::fft_ggsw) before
|
||||
/// using it in a [`cmux`](super::evaluation::cmux) operation.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` or `radix` are invalid.
|
||||
/// If `params` doesn't correspond to `sk`.
|
||||
/// If `plaintext_bits >= 64`.
|
||||
pub fn encrypt_ggsw(
|
||||
msg: u64,
|
||||
sk: &GlweSecretKeyRef<u64>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> GgswCiphertext<u64> {
|
||||
let mut result = GgswCiphertext::new(params, radix);
|
||||
|
||||
encrypt_ggsw_ciphertext_scalar(&mut result, msg, sk, params, radix, plaintext_bits);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Decrypt a [`GgswCiphertext`] encrypted under `sk`.
|
||||
/// Since GGSW ciphertexts generally contain binary, you should
|
||||
/// usually set `plaintext_bits` to 1.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` should correspond with `ct` and `sk`.
|
||||
/// `radix` should correspond with `ct`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` or `radix` are invalid.
|
||||
/// If `params` or `radix` don't correspond to `ct`
|
||||
/// If `params` don't correspond to `sk`.
|
||||
pub fn decrypt_ggsw(
|
||||
ct: &GgswCiphertextRef<u64>,
|
||||
sk: &GlweSecretKeyRef<u64>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
_plaintext_bits: PlaintextBits,
|
||||
) -> Polynomial<u64> {
|
||||
let mut msg = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
|
||||
crate::ops::encryption::decrypt_ggsw_ciphertext(&mut msg, ct, sk, params, radix);
|
||||
|
||||
msg.map(|x| x.inner())
|
||||
}
|
||||
}
|
||||
|
||||
/// Operations for producing Fourier-transformed versions of entities.
|
||||
pub mod fft {
|
||||
use num::Complex;
|
||||
|
||||
use crate::{
|
||||
entities::{
|
||||
BootstrapKeyFft, BootstrapKeyRef, GgswCiphertextFft, GgswCiphertextRef,
|
||||
GlweCiphertextFft, GlweCiphertextRef,
|
||||
},
|
||||
GlweDef, LweDef, RadixDecomposition,
|
||||
};
|
||||
|
||||
/// Take the fourier transform of a [`GlweCiphertext`](crate::entities::GlweCiphertext).
|
||||
///
|
||||
/// # Remarks
|
||||
/// `params` must be the same parameters that produce `ct`.
|
||||
///
|
||||
/// # Panics
|
||||
/// `params` is invalid.
|
||||
/// `params` doesn't correspond with `ct`.
|
||||
pub fn fft_glwe(
|
||||
ct: &GlweCiphertextRef<u64>,
|
||||
params: &GlweDef,
|
||||
) -> GlweCiphertextFft<Complex<f64>> {
|
||||
let mut fft = GlweCiphertextFft::new(params);
|
||||
|
||||
ct.fft(&mut fft, params);
|
||||
|
||||
fft
|
||||
}
|
||||
|
||||
/// Take the fourier transform of a [`GgswCiphertext`](crate::entities::GgswCiphertext).
|
||||
///
|
||||
/// # Remarks
|
||||
/// `glwe` and `radix` must be the same parameters that produced `ggsw`.
|
||||
///
|
||||
/// For [`GgswCiphertext`](crate::entities::GgswCiphertext)s that result from a
|
||||
/// [`circuit_bootstrap`](super::evaluation::circuit_bootstrap) operation, these
|
||||
/// must match `glwe_1` and `cbs_radix` respectively.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `glwe` and `radix` don't correspond with `ggsw`.
|
||||
/// If `glwe` or `radix` are invalid.
|
||||
pub fn fft_ggsw(
|
||||
ggsw: &GgswCiphertextRef<u64>,
|
||||
glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GgswCiphertextFft<Complex<f64>> {
|
||||
let mut fft = GgswCiphertextFft::new(glwe, radix);
|
||||
|
||||
ggsw.fft(&mut fft, glwe, radix);
|
||||
|
||||
fft
|
||||
}
|
||||
|
||||
/// Take the fourier transform of a [BootstrapKey](crate::entities::BootstrapKey).
|
||||
/// The resulting [`BootstrapKeyFft`] may be used in
|
||||
/// [`univariate_programmable_bootstrap`](super::evaluation::univariate_programmable_bootstrap)
|
||||
/// and [`circuit_bootstrap`](super::evaluation::circuit_bootstrap)
|
||||
/// operations.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `glwe` and `radix` must be the same parameters that produced `bsk`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `glwe` and `radix` don't correspond with `bsk`.
|
||||
/// If `glwe` or `radix` are invalid.
|
||||
pub fn fft_bootstrap_key(
|
||||
bsk: &BootstrapKeyRef<u64>,
|
||||
lwe: &LweDef,
|
||||
glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> BootstrapKeyFft<Complex<f64>> {
|
||||
let mut bsk_fft = BootstrapKeyFft::new(lwe, glwe, radix);
|
||||
|
||||
bsk.fft(&mut bsk_fft, glwe, radix);
|
||||
|
||||
bsk_fft
|
||||
}
|
||||
}
|
||||
|
||||
/// TFHE operations for performing computation.
|
||||
pub mod evaluation {
|
||||
use num::Complex;
|
||||
|
||||
use crate::{
|
||||
entities::{
|
||||
BootstrapKeyFft, BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef,
|
||||
GgswCiphertext, GgswCiphertextFftRef, GlweCiphertext, GlweCiphertextRef, LweCiphertext,
|
||||
LweCiphertextRef, LweKeyswitchKeyRef, UnivariateLookupTableRef,
|
||||
},
|
||||
GlweDef, LweDef, RadixDecomposition,
|
||||
};
|
||||
|
||||
/// Perform a multiplexing operation. When `b_fft` encrypts a zero polynomial,
|
||||
/// the resulting [`GlweCiphertext`] will the same message as `d_0`. When `b_fft`
|
||||
/// encrypts the 1 polynomial, the result will contain the same message as `d_1`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `b_fft`, `d_0`, and `d_1` must all be encrypted under the same
|
||||
/// [`GlweSecretKey`](crate::entities::GlweSecretKey). This implies `params` must
|
||||
/// correspond with all three values.
|
||||
///
|
||||
/// Additionally, `radix` must correspond to `b_fft`.
|
||||
///
|
||||
/// For
|
||||
/// [`GgswCiphertext`] resulting from [`circuit_bootstrap`] operations,
|
||||
/// `radix` must be the same as `cbs_radix` and `params` must be the same as
|
||||
/// `glwe_1`.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `params` doesn't correspond with `b_fft`, `d_0`, `d_1`.
|
||||
/// If `radix` doesn't correspond with `b_fft`.
|
||||
/// If `radix` or `params` are invalid.
|
||||
pub fn cmux(
|
||||
b_fft: &GgswCiphertextFftRef<Complex<f64>>,
|
||||
d_0: &GlweCiphertextRef<u64>,
|
||||
d_1: &GlweCiphertextRef<u64>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlweCiphertext<u64> {
|
||||
let mut result = GlweCiphertext::new(params);
|
||||
|
||||
crate::ops::fft_ops::cmux(&mut result, d_0, d_1, b_fft, params, radix);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Perform a programmable bootstrapping operation. Bootstrapping takes
|
||||
/// `input` and produces a new ciphertext with a fixed noise level, applying
|
||||
/// a univariate function defined by `lut` in the process.
|
||||
///
|
||||
/// This new ciphertext is encrypted under a (usually) different
|
||||
/// [`LweSecretKey`](crate::entities::LweSecretKey) defined by `glwe` interpreted
|
||||
/// as an [`LweDef`].
|
||||
///
|
||||
/// See also [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable).
|
||||
///
|
||||
/// To switch the message back to the original key, you need to perform an LWE
|
||||
/// keyswitch operation. TODO: hotlink
|
||||
///
|
||||
/// # Remarks
|
||||
/// `input` must be valid under the `lwe` parameters.
|
||||
/// `lwe`, `glwe`, and `radix` parameters must be the same as those used when
|
||||
/// first creating the `bsk`.
|
||||
/// `lut` must be valid under `glwe` parameters.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `lwe`, `glwe`, or `radix` parameters are invalid.
|
||||
/// If `input` doesn't correspond to `lwe` parameters.
|
||||
/// If `bsk` doesn't correspond to `lwe`, `glwe`, `radix` parameters.
|
||||
/// If `lut` doesn't correspond to `glwe` parameters.
|
||||
pub fn univariate_programmable_bootstrap(
|
||||
input: &LweCiphertextRef<u64>,
|
||||
lut: &UnivariateLookupTableRef<u64>,
|
||||
bsk: &BootstrapKeyFft<Complex<f64>>,
|
||||
lwe: &LweDef,
|
||||
glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> LweCiphertext<u64> {
|
||||
let mut out = LweCiphertext::new(&glwe.as_lwe_def());
|
||||
|
||||
crate::ops::bootstrapping::programmable_bootstrap(
|
||||
&mut out, input, lut, bsk, lwe, glwe, radix,
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Perform a circuit bootstrapping operation. Circuit bootstrapping takes
|
||||
/// `input` [LweCiphertext] encrypted under a [LweSecretKey](crate::entities::LweSecretKey)
|
||||
/// constructed with `lwe_0` parameters and produces a [GgswCiphertext] encrypted
|
||||
/// under a [GlweSecretKey](crate::entities::GlweSecretKey) constructed with `glwe_1`
|
||||
/// parameters.
|
||||
///
|
||||
/// See also [generate_bootstrapping_key](super::keygen::generate_bootstrapping_key) and
|
||||
/// [generate_cbs_pfks](super::keygen::generate_cbs_ksk) for how to generate the required
|
||||
/// keys.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Internally, circuit bootstrapping occurs in 2 steps. First we perform programmable
|
||||
/// bootstraps (PBS) from `lwe_0` parameters to `glwe_2` reinterpreted as LWE parameters. We
|
||||
/// do this `cbs_radix.count` times using univariate functions that map the message in
|
||||
/// input to its corresponding radix decomposition.
|
||||
///
|
||||
/// For step 2, we use private functional keyswitching (PFKS) to transform the
|
||||
/// `cbs_radix.count` [LweCiphertext]s encrypted under `glwe_2` into
|
||||
/// `cbs_radix.count * glwe_2.size + 1` [GlweCiphertext]s. The PFKS operations multiply each
|
||||
/// [GlevCiphertext](crate::entities::GlevCiphertext) by the corresponding polynomial in
|
||||
/// the `glwe_1` [GlweSecretKey](crate::entities::GlweSecretKey) to create a valid
|
||||
/// [GgswCiphertext].
|
||||
///
|
||||
/// To summarize, we use PBS to turn an `lwe_0` LWE ciphertext into `glwe_2` LWE ciphertexts.
|
||||
/// We then use PFKS to turn the `glwe_2` LWE ciphertexts into `glwe_1` GLWE ciphertexts
|
||||
/// arranged as a valid [GgswCiphertext] encrypting the same value as `input` as a constant
|
||||
/// coefficient polynomial.
|
||||
///
|
||||
/// See [crate::ops::bootstrapping::circuit_bootstrap] for more details.
|
||||
///
|
||||
/// `pbs_radix` parameterizes the bootstrapping operation (step 1).
|
||||
///
|
||||
/// `pfks_radix` parameterizes the PFKS operation (step 2). These should
|
||||
///
|
||||
/// `cbs_radix` parameterizes the final decomposition of the resulting [`GgswCiphertext`]. This
|
||||
/// should match the radix used when creating the
|
||||
/// [CircuitBootstrappingKeyswitchKeys](crate::entities::CircuitBootstrappingKeyswitchKeys)
|
||||
///
|
||||
///
|
||||
/// # Panics
|
||||
/// If `pbs_radix`, `cbs_radix`, `pfksk_radix`, `lwe_0`, `glwe_2`, or `glwe_1` are invalid.
|
||||
/// If `pbs_radix`, `lwe_0`, `glwe_2` don't match the `radix`, `lwe`, `glwe` (respectively) used to generate `bsk`.
|
||||
/// If `pfks_radix`, `glwe_2.as_lwe_def()`, `glwe_1` parameters don't match the `radix`, `from_lwe`, `to_lwe`
|
||||
/// (respectively) used to generate `cbsksk`.
|
||||
pub fn circuit_bootstrap(
|
||||
input: &LweCiphertextRef<u64>,
|
||||
bsk: &BootstrapKeyFftRef<Complex<f64>>,
|
||||
cbsksk: &CircuitBootstrappingKeyswitchKeysRef<u64>,
|
||||
lwe_0: &LweDef,
|
||||
glwe_1: &GlweDef,
|
||||
glwe_2: &GlweDef,
|
||||
pbs_radix: &RadixDecomposition,
|
||||
cbs_radix: &RadixDecomposition,
|
||||
pfks_radix: &RadixDecomposition,
|
||||
) -> GgswCiphertext<u64> {
|
||||
let mut out = GgswCiphertext::new(glwe_1, cbs_radix);
|
||||
|
||||
crate::ops::bootstrapping::circuit_bootstrap(
|
||||
&mut out, input, bsk, cbsksk, lwe_0, glwe_1, glwe_2, pbs_radix, cbs_radix, pfks_radix,
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Perform LWE keyswitching to produce a new [`LweCiphertext`] encrypted
|
||||
/// under a different [`LweSecretKey`](crate::entities::LweSecretKey).
|
||||
///
|
||||
/// # Remarks
|
||||
/// When creating `ksk` with [`generate_ksk`](super::keygen::generate_ksk),
|
||||
/// you pass 2 different [`LweSecretKey`](crate::entities::LweSecretKey)
|
||||
/// values: a `from_sk` and `to_sk`. `ct` should be encrypted under the
|
||||
/// same `from_sk` and ciphertext this function returns will be encrypted
|
||||
/// under `to_sk`.
|
||||
pub fn keyswitch_lwe_to_lwe(
|
||||
ct: &LweCiphertextRef<u64>,
|
||||
ksk: &LweKeyswitchKeyRef<u64>,
|
||||
from_lwe: &LweDef,
|
||||
to_lwe: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> LweCiphertext<u64> {
|
||||
let mut new_ct = LweCiphertext::new(to_lwe);
|
||||
crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe(
|
||||
&mut new_ct,
|
||||
ct,
|
||||
ksk,
|
||||
from_lwe,
|
||||
to_lwe,
|
||||
radix,
|
||||
);
|
||||
|
||||
new_ct
|
||||
}
|
||||
}
|
||||
35
sunscreen_tfhe/src/lib.rs
Normal file
35
sunscreen_tfhe/src/lib.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
//! TFHE low-level library
|
||||
|
||||
#![deny(missing_docs)]
|
||||
#![deny(rustdoc::broken_intra_doc_links)]
|
||||
|
||||
#[macro_use]
|
||||
mod dst;
|
||||
|
||||
/// The entities module contains the main data structures used in the library.
|
||||
pub mod entities;
|
||||
|
||||
/// Higher level operations in TFHE such as programmable bootstrapping, keyswitching, etc.
|
||||
pub mod ops;
|
||||
|
||||
/// Parameters that define a TFHE scheme.
|
||||
pub mod params;
|
||||
pub use params::*;
|
||||
|
||||
/// Math operations on various math primitives such as polynomials.
|
||||
pub mod math;
|
||||
pub use math::*;
|
||||
|
||||
mod macros;
|
||||
|
||||
/// Random number generation.
|
||||
pub mod rand;
|
||||
mod scratch;
|
||||
|
||||
/// A high-level API for interfacing with TFHE. Allocates, computes with and returns
|
||||
/// objects as you would expect from a Rust API.
|
||||
pub mod high_level;
|
||||
|
||||
/// Zero Knowledge proofs for TFHE.
|
||||
#[cfg(feature = "logproof")]
|
||||
pub mod zkp;
|
||||
108
sunscreen_tfhe/src/macros.rs
Normal file
108
sunscreen_tfhe/src/macros.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
macro_rules! impl_binary_op {
|
||||
($op:ident, $type:ty, ($($t_bounds:ty),* $(,)? )) => {
|
||||
paste::paste! {
|
||||
|
||||
// Ex: AddAssign for LweSecretKey
|
||||
impl<S> std::ops::[<$op Assign>] for $type<S>
|
||||
where
|
||||
S: $($t_bounds)*
|
||||
{
|
||||
fn [<$op:lower _assign>](&mut self, rhs: Self) {
|
||||
self.data.iter_mut().zip(rhs.data.iter()).for_each(|(a, b)| {
|
||||
*a = num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a, b);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Ex: Add for LweSecretKey
|
||||
// Calls Add for &LweSecretKeyRef
|
||||
impl<S> std::ops::$op for $type<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn [<$op:lower >](self, rhs: Self) -> Self::Output {
|
||||
std::ops::$op::[< $op:lower >](self.as_ref(), rhs.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
// Ex: WrappingAdd for LweSecretKey
|
||||
// Calls Add for &LweSecretKeyRef
|
||||
impl<S> num::traits::[<Wrapping $op>] for $type<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
fn [<wrapping_ $op:lower>](&self, rhs: &Self) -> Self {
|
||||
std::ops::$op::[< $op:lower >](self.as_ref(), rhs.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
// Ex: Add for &LweSecretKeyRef
|
||||
// Calls AddAssign for LweSecretKey
|
||||
impl<S> std::ops::$op for &[<$type Ref>]<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Output = $type<S>;
|
||||
|
||||
fn [< $op:lower >](self, rhs: Self) -> Self::Output {
|
||||
let mut a = self.to_owned();
|
||||
std::ops::[< $op Assign >]::[< $op:lower _assign>](&mut a, rhs.to_owned());
|
||||
a
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_unary_op {
|
||||
($op:ident, $type:ty) => {
|
||||
paste::paste! {
|
||||
|
||||
// Ex: Neg for LweSecretKey
|
||||
// Calls Neg for &LweSecretKeyRef
|
||||
impl<S> std::ops::$op for $type<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn [<$op:lower>](self) -> Self::Output {
|
||||
std::ops::$op::[<$op:lower>](self.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
// Ex: Neg for &LweSecretKeyRef
|
||||
impl<S> std::ops::$op for &[<$type Ref>]<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Output = $type<S>;
|
||||
|
||||
fn [<$op:lower>](self) -> Self::Output {
|
||||
// We call the wrapping trait instead of using the dot
|
||||
// syntax because the dot syntax can dereference the value
|
||||
// and can cause problems with Deref.
|
||||
let data = self.data.iter().map(|a| num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a)).collect();
|
||||
|
||||
$type { data }
|
||||
}
|
||||
}
|
||||
|
||||
// Ex: WrappingNeg for LweSecretKey
|
||||
// Calls Neg for &LweSecretKeyRef
|
||||
impl<S> num::traits::[<Wrapping $op>] for $type<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
fn [<wrapping_ $op:lower>](&self) -> Self {
|
||||
std::ops::$op::[<$op:lower>](self.as_ref())
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use impl_binary_op;
|
||||
pub(crate) use impl_unary_op;
|
||||
64
sunscreen_tfhe/src/math/basic.rs
Normal file
64
sunscreen_tfhe/src/math/basic.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
/// An integer type that supports rounding division.
|
||||
pub trait RoundedDiv {
|
||||
/// Divides two numbers and rounds the result to the nearest integer.
|
||||
fn div_rounded(&self, divisor: Self) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! div_rounded {
|
||||
($t:ty) => {
|
||||
impl RoundedDiv for $t {
|
||||
#[inline(always)]
|
||||
fn div_rounded(&self, divisor: $t) -> $t {
|
||||
// There are a few ways to do this, but we chose the following
|
||||
// because it allows the entire range of a type to be used. The
|
||||
// other common method is to add half the divisor to the
|
||||
// numerator, but that effectively cuts the size of the possible
|
||||
// inputs in half before overflow.
|
||||
let q = self / divisor;
|
||||
let r = self % divisor;
|
||||
if r >= divisor / 2 {
|
||||
q + 1
|
||||
} else {
|
||||
q
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
div_rounded!(u8);
|
||||
div_rounded!(u16);
|
||||
div_rounded!(u32);
|
||||
div_rounded!(u64);
|
||||
div_rounded!(u128);
|
||||
div_rounded!(usize);
|
||||
div_rounded!(i8);
|
||||
div_rounded!(i16);
|
||||
div_rounded!(i32);
|
||||
div_rounded!(i64);
|
||||
div_rounded!(i128);
|
||||
div_rounded!(isize);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_div_rounded() {
|
||||
for _ in 0..1_000 {
|
||||
let a = thread_rng().next_u64();
|
||||
let mut b = thread_rng().next_u64();
|
||||
|
||||
if b == 0 {
|
||||
b = 1;
|
||||
}
|
||||
|
||||
let expected = ((a as f64) / (b as f64)).round() as u64;
|
||||
let actual = a.div_rounded(b);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
}
|
||||
}
|
||||
184
sunscreen_tfhe/src/math/fft/cyclic/mod.rs
Normal file
184
sunscreen_tfhe/src/math/fft/cyclic/mod.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use std::ops::{Add, Mul, Neg};
|
||||
use std::sync::Arc;
|
||||
|
||||
use num::{complex::Complex, Float};
|
||||
use realfft::{ComplexToReal, FftNum, RealFftPlanner, RealToComplex};
|
||||
use sunscreen_math::{One, Zero as SunscreenZero};
|
||||
|
||||
use crate::{FrequencyTransform, Inverse, Pow, RootOfUnity};
|
||||
|
||||
/// A struct that can perform a real FFT.
|
||||
pub struct RealFft<T>
|
||||
where
|
||||
T: FftNum,
|
||||
{
|
||||
pub(crate) fplan: Arc<dyn RealToComplex<T>>,
|
||||
pub(crate) rplan: Arc<dyn ComplexToReal<T>>,
|
||||
pub(crate) scale: T,
|
||||
}
|
||||
|
||||
impl<T> RealFft<T>
|
||||
where
|
||||
T: FftNum + Float,
|
||||
{
|
||||
/// Create a new [RealFft] with the given size.
|
||||
pub fn new(n: usize) -> Self {
|
||||
let mut plan = RealFftPlanner::<T>::new();
|
||||
let fplan = plan.plan_fft_forward(n);
|
||||
let rplan = plan.plan_fft_inverse(n);
|
||||
|
||||
let scale = T::from(1.0).unwrap() / T::from(n).unwrap();
|
||||
|
||||
Self {
|
||||
fplan,
|
||||
rplan,
|
||||
scale,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FrequencyTransform for RealFft<T>
|
||||
where
|
||||
T: FftNum + Float,
|
||||
{
|
||||
type BaseRepr = T;
|
||||
type FrequencyRepr = Complex<T>;
|
||||
|
||||
fn forward(&self, data: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]) {
|
||||
assert_eq!(data.len() / 2 + 1, output.len());
|
||||
|
||||
self.fplan.process(&mut data.to_owned(), output).unwrap();
|
||||
}
|
||||
|
||||
fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]) {
|
||||
self.rplan.process(&mut data.to_owned(), output).unwrap();
|
||||
|
||||
output.iter_mut().for_each(|x| {
|
||||
*x = *x * self.scale;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct that can perform a NTT using a naive algorithm.
|
||||
#[allow(unused)]
|
||||
pub struct NaiveNtt<T> {
|
||||
twiddle: Vec<T>,
|
||||
inv_twiddle: Vec<T>,
|
||||
n_inv: T,
|
||||
}
|
||||
|
||||
impl<T> NaiveNtt<T>
|
||||
where
|
||||
T: RootOfUnity
|
||||
+ Copy
|
||||
+ SunscreenZero
|
||||
+ One
|
||||
+ Mul<T, Output = T>
|
||||
+ Add<T, Output = T>
|
||||
+ Neg<Output = T>
|
||||
+ Pow<u64>
|
||||
+ From<u64>
|
||||
+ Inverse,
|
||||
{
|
||||
/// Create a new [NaiveNtt] with the given size.
|
||||
#[allow(unused)]
|
||||
pub fn new(n: usize) -> Self {
|
||||
let mut twiddle = vec![];
|
||||
let mut inv_twiddle = vec![];
|
||||
let root = T::nth_root_of_unity(n as u64);
|
||||
let inv_root = root.inverse();
|
||||
|
||||
for i in 0..n as u64 {
|
||||
twiddle.push(root.pow(i));
|
||||
inv_twiddle.push(inv_root.pow(i));
|
||||
}
|
||||
|
||||
Self {
|
||||
twiddle,
|
||||
inv_twiddle,
|
||||
n_inv: T::from(n as u64).inverse(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use num::complex::ComplexFloat;
|
||||
|
||||
use crate::FrequencyTransform;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_roundtrip_real_fft() {
|
||||
let n = 256;
|
||||
|
||||
let fft = RealFft::<f64>::new(n);
|
||||
let input = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut points = vec![Complex::from(0.0); input.len() / 2 + 1];
|
||||
let mut result = vec![0.0; input.len()];
|
||||
|
||||
fft.forward(&input, &mut points);
|
||||
fft.reverse(&points, &mut result);
|
||||
|
||||
for (l, r) in input.iter().zip(result.iter()) {
|
||||
assert!((l - r).abs() < 1e-12);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negacyclic_gives_odd_harmonics() {
|
||||
let n = 512;
|
||||
let nega_len = n / 2;
|
||||
|
||||
let fft = RealFft::<f64>::new(n);
|
||||
let input = (0..n)
|
||||
.enumerate()
|
||||
.map(|(i, x)| {
|
||||
let val = x % nega_len;
|
||||
|
||||
if i < nega_len {
|
||||
val as f64
|
||||
} else {
|
||||
-(val as f64)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut points = vec![Complex::from(0.0); input.len() / 2 + 1];
|
||||
|
||||
fft.forward(&input, &mut points);
|
||||
|
||||
for (i, x) in points.iter().enumerate() {
|
||||
if i % 2 == 0 {
|
||||
assert!(x.re().abs() < 1e-12);
|
||||
assert!(x.im().abs() < 1e-12);
|
||||
} else {
|
||||
assert!(x.re().abs() > 1.0);
|
||||
assert!(x.im().abs() > 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_cyclic_convolution() {
|
||||
let n = 4;
|
||||
|
||||
let fft = RealFft::<f64>::new(n);
|
||||
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut actual = vec![0.0; n];
|
||||
let mut y = vec![Complex::from(0.0); n / 2 + 1];
|
||||
|
||||
fft.forward(&x, &mut y);
|
||||
|
||||
let z = y.iter().map(|y| y * y).collect::<Vec<_>>();
|
||||
|
||||
fft.reverse(&z, &mut actual);
|
||||
|
||||
let expected = [10.0, 12.0, 10.0, 4.0];
|
||||
|
||||
for (e, a) in expected.iter().zip(actual.iter()) {
|
||||
assert!((e - a).abs() < 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
5
sunscreen_tfhe/src/math/fft/mod.rs
Normal file
5
sunscreen_tfhe/src/math/fft/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
/// FFT based operations over real numbers.
|
||||
pub mod cyclic;
|
||||
|
||||
/// FFT based operations over twisted cyclotomics.
|
||||
pub mod negacyclic;
|
||||
176
sunscreen_tfhe/src/math/fft/negacyclic/mod.rs
Normal file
176
sunscreen_tfhe/src/math/fft/negacyclic/mod.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use std::{
|
||||
f64::consts::PI,
|
||||
sync::{Arc, Once},
|
||||
};
|
||||
|
||||
use num::{Complex, Float, One};
|
||||
use realfft::FftNum;
|
||||
use rustfft::{Fft, FftPlanner};
|
||||
|
||||
use crate::{scratch::allocate_scratch, FrequencyTransform};
|
||||
|
||||
static FFT_CACHE_INIT: Once = Once::new();
|
||||
static mut FFT_CACHE: Vec<TwistedFft<f64>> = vec![];
|
||||
|
||||
/// Get a [TwistedFft] for a given log N.
|
||||
pub fn get_fft(log_n: usize) -> &'static TwistedFft<f64> {
|
||||
// Can FFT powers of 2 from N=1 up to 4096.
|
||||
assert!(log_n < 13);
|
||||
|
||||
FFT_CACHE_INIT.call_once(|| {
|
||||
for i in 0..13 {
|
||||
unsafe { FFT_CACHE.push(TwistedFft::new(0x1 << i)) };
|
||||
}
|
||||
});
|
||||
|
||||
unsafe { &FFT_CACHE[log_n] }
|
||||
}
|
||||
|
||||
/// Perform FFT with a twist so points can be used for
|
||||
/// negacyclic convolution.
|
||||
///
|
||||
/// # Remarks
|
||||
/// See `<https://jeremykun.com/2022/12/09/negacyclic-polynomial-multiplication/>` for algorithm.
|
||||
pub struct TwistedFft<T>
|
||||
where
|
||||
T: FftNum,
|
||||
{
|
||||
fwd: Arc<dyn Fft<T>>,
|
||||
rev: Arc<dyn Fft<T>>,
|
||||
|
||||
twist: Vec<Complex<T>>,
|
||||
twist_inv: Vec<Complex<T>>,
|
||||
}
|
||||
|
||||
impl<T> TwistedFft<T>
|
||||
where
|
||||
T: FftNum + Float,
|
||||
{
|
||||
/// Create a new [TwistedFft] with the given size.
|
||||
pub fn new(n: usize) -> Self {
|
||||
assert!(n.is_power_of_two());
|
||||
|
||||
let n_2 = T::from(n * 2).unwrap();
|
||||
let k = n / 2;
|
||||
|
||||
// The true length of the negacyclic sequence is 2N
|
||||
let mut plan = FftPlanner::new();
|
||||
let fwd = plan.plan_fft_forward(k);
|
||||
let rev = plan.plan_fft_inverse(k);
|
||||
|
||||
let two_pi = T::from(PI).unwrap() * T::from(2.0).unwrap();
|
||||
|
||||
let w_2n = (Complex::from(two_pi / n_2) * Complex::i()).exp();
|
||||
|
||||
let twist = (0..k)
|
||||
.map(|x| w_2n.powf(T::from(x).unwrap()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let twist_inv = twist
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|t| t.powf(-T::one()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
debug_assert!(twist.iter().zip(twist_inv.iter()).all(|(a, b)| {
|
||||
let a_a_inv = a * b;
|
||||
let err = a_a_inv - Complex::one();
|
||||
|
||||
err.re.abs() < T::from(1e-12).unwrap() && err.im.abs() < T::from(1e-12).unwrap()
|
||||
}));
|
||||
|
||||
Self {
|
||||
fwd,
|
||||
rev,
|
||||
twist,
|
||||
twist_inv,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FrequencyTransform for TwistedFft<T>
|
||||
where
|
||||
T: FftNum + Float,
|
||||
{
|
||||
type BaseRepr = T;
|
||||
type FrequencyRepr = Complex<T>;
|
||||
|
||||
fn forward(&self, x: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]) {
|
||||
assert_eq!(x.len(), self.fwd.len() * 2);
|
||||
|
||||
let n_div_2 = x.len() / 2;
|
||||
|
||||
for i in 0..n_div_2 {
|
||||
output[i] = Complex::new(x[i], x[i + n_div_2]) * self.twist[i];
|
||||
}
|
||||
|
||||
let mut scratch = allocate_scratch(self.fwd.get_inplace_scratch_len());
|
||||
let scratch_slice = scratch.as_mut_slice();
|
||||
|
||||
self.fwd.process_with_scratch(output, scratch_slice);
|
||||
}
|
||||
|
||||
fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]) {
|
||||
assert_eq!(data.len(), self.rev.len());
|
||||
|
||||
let mut ifft = allocate_scratch(data.len());
|
||||
let ifft_slice = ifft.as_mut_slice();
|
||||
ifft_slice.copy_from_slice(data);
|
||||
|
||||
let mut scratch = allocate_scratch(self.rev.get_inplace_scratch_len());
|
||||
let scratch_slice = scratch.as_mut_slice();
|
||||
|
||||
self.rev.process_with_scratch(ifft_slice, scratch_slice);
|
||||
|
||||
let n_inv = T::one() / T::from(data.len()).unwrap();
|
||||
|
||||
for (i, x) in ifft_slice.iter().enumerate() {
|
||||
let tmp = *x * n_inv * self.twist_inv[i];
|
||||
|
||||
output[i] = tmp.re.round();
|
||||
output[i + data.len()] = tmp.im.round();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_roundtrip_negacyclic_fft() {
|
||||
let n = 8;
|
||||
|
||||
let plan = TwistedFft::<f64>::new(n);
|
||||
|
||||
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut y = vec![Complex::from(0.0); x.len() / 2];
|
||||
let mut actual = vec![0.0; x.len()];
|
||||
|
||||
plan.forward(&x, &mut y);
|
||||
plan.reverse(&y, &mut actual);
|
||||
|
||||
for (l, r) in actual.iter().zip(x.iter()) {
|
||||
assert!((l - r).abs() < 1e-12);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_negacyclic_conv() {
|
||||
let n = 4;
|
||||
|
||||
let plan = TwistedFft::<f64>::new(n);
|
||||
|
||||
let x = (0..n).map(|x| x as f64).collect::<Vec<_>>();
|
||||
let mut y = vec![Complex::from(0.0); x.len() / 2];
|
||||
let mut actual = vec![0.0; x.len()];
|
||||
|
||||
plan.forward(&x, &mut y);
|
||||
|
||||
let z = y.iter().map(|y| y * y).collect::<Vec<_>>();
|
||||
|
||||
plan.reverse(&z, &mut actual);
|
||||
|
||||
assert_eq!(actual, vec![-10.0, -12.0, -8.0, 4.0]);
|
||||
}
|
||||
}
|
||||
359
sunscreen_tfhe/src/math/goldilocks_field.rs
Normal file
359
sunscreen_tfhe/src/math/goldilocks_field.rs
Normal file
@@ -0,0 +1,359 @@
|
||||
use std::ops::{Add, Mul, Neg, Sub};
|
||||
|
||||
use num::traits::{WrappingAdd, WrappingMul, WrappingNeg, WrappingSub};
|
||||
use sunscreen_math::{refify_binary_op, One, Zero};
|
||||
|
||||
use crate::{Inverse, Pow, RootOfUnity};
|
||||
|
||||
/// 2^64 - 2^32 + 1
|
||||
pub const GOLDILOCKS_PRIME: u64 = 0xFFFFFFFF00000001;
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
#[repr(transparent)]
|
||||
/// A value in the Goldilocks field (F_p where p = 2^64 - 2^32 + 1).
|
||||
/// See
|
||||
/// https://cp4space.hatsya.com/2021/09/01/an-efficient-prime-for-number-theoretic-transforms/
|
||||
/// for why this field is so magical.
|
||||
pub struct Fg(u64);
|
||||
|
||||
impl RootOfUnity for Fg {
|
||||
fn nth_root_of_unity(n: u64) -> Self {
|
||||
assert!(
|
||||
n.is_power_of_two(),
|
||||
"Goldilocks prime requires power of 2 < 2^32 for n when computing nth root of unity."
|
||||
);
|
||||
|
||||
// See https://nufhe.readthedocs.io/en/latest/implementation_details.html for where this constant comes
|
||||
// from.
|
||||
const C: Fg = Fg(12037493425763644479);
|
||||
|
||||
let exp = 0x1_0000_0000u64 / n;
|
||||
|
||||
C.pow(exp)
|
||||
}
|
||||
}
|
||||
|
||||
impl Fg {
|
||||
/// Returns `x % GOLDILOCKS_PRIME`
|
||||
pub fn new(x: u64) -> Self {
|
||||
if x > GOLDILOCKS_PRIME {
|
||||
Self(x - GOLDILOCKS_PRIME)
|
||||
} else {
|
||||
Self(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unreduced_add(self, rhs: Self) -> Fg96 {
|
||||
let (c, carry) = self.0.overflowing_add(rhs.0);
|
||||
|
||||
Fg96 {
|
||||
lo: c,
|
||||
hi: carry.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unreduced_sub(self, rhs: Self) -> Fg96 {
|
||||
self.unreduced_add(Fg(GOLDILOCKS_PRIME - rhs.0))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn unreduced_mul(self, rhs: Self) -> Fg159 {
|
||||
let res = self.0 as u128 * rhs.0 as u128;
|
||||
|
||||
res.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Compute `self * b + c` and don't reduce the result.
|
||||
pub fn unreduced_mad(self, b: Self, c: Self) -> Fg159 {
|
||||
let res = self.0 as u128 * b.0 as u128 + c.0 as u128;
|
||||
|
||||
res.into()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
/// Compute `self * b + c`.
|
||||
pub fn mad(self, b: Self, c: Self) -> Self {
|
||||
self.unreduced_mad(b, c).reduce()
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u64> for Fg {
|
||||
fn from(value: u64) -> Self {
|
||||
Self::new(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl Add<&Fg> for &Fg {
|
||||
type Output = Fg;
|
||||
|
||||
#[inline]
|
||||
fn add(self, rhs: &Fg) -> Self::Output {
|
||||
let (res, c) = self.0.overflowing_add(rhs.0);
|
||||
|
||||
// If overflow occurred, then we have 1 carry bit. This gives us
|
||||
// s = 1 * 2^64 + res = (2^32 - 1) + res.
|
||||
//
|
||||
// If overflow didn't occur, but res >= g, then we need to compute
|
||||
// s - g = s + (2^32 + 1). Note that -g (mod 2^64) = 2^32 - 1 (mod 2^64).
|
||||
// Thus, we can add (2^32 + 1) in both cases!
|
||||
if c || res >= GOLDILOCKS_PRIME {
|
||||
Fg(res.wrapping_add(0xFFFFFFFFu64))
|
||||
} else {
|
||||
Fg(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl Mul<&Fg> for &Fg {
|
||||
type Output = Fg;
|
||||
|
||||
#[inline]
|
||||
fn mul(self, rhs: &Fg) -> Self::Output {
|
||||
self.unreduced_mul(*rhs).reduce()
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl Sub<&Fg> for &Fg {
|
||||
type Output = Fg;
|
||||
|
||||
#[inline]
|
||||
fn sub(self, rhs: &Fg) -> Self::Output {
|
||||
let (res, c) = self.0.overflowing_sub(rhs.0);
|
||||
|
||||
let offset = 0u32.wrapping_sub(u32::from(c)) as u64;
|
||||
|
||||
// If we underflow, then we need to add g.
|
||||
// Note that g (mod 2^64) = -(2^32 + 1) (mod 32)
|
||||
// Thus, this is equivalent to subtracting -(2^32 + 1).
|
||||
Fg(res.wrapping_sub(offset))
|
||||
}
|
||||
}
|
||||
|
||||
impl Neg for Fg {
|
||||
type Output = Fg;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self(GOLDILOCKS_PRIME - self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl WrappingAdd for Fg {
|
||||
#[inline]
|
||||
fn wrapping_add(&self, v: &Self) -> Self {
|
||||
self + v
|
||||
}
|
||||
}
|
||||
|
||||
impl WrappingMul for Fg {
|
||||
#[inline]
|
||||
fn wrapping_mul(&self, v: &Self) -> Self {
|
||||
self * v
|
||||
}
|
||||
}
|
||||
|
||||
impl WrappingSub for Fg {
|
||||
#[inline]
|
||||
fn wrapping_sub(&self, v: &Self) -> Self {
|
||||
self - v
|
||||
}
|
||||
}
|
||||
|
||||
impl WrappingNeg for Fg {
|
||||
#[inline]
|
||||
fn wrapping_neg(&self) -> Self {
|
||||
-*self
|
||||
}
|
||||
}
|
||||
|
||||
impl Zero for Fg {
|
||||
#[inline]
|
||||
fn vartime_is_zero(&self) -> bool {
|
||||
self.0 == 0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn zero() -> Self {
|
||||
Fg(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl One for Fg {
|
||||
#[inline]
|
||||
fn one() -> Self {
|
||||
Fg(1)
|
||||
}
|
||||
}
|
||||
|
||||
impl Inverse for Fg {
|
||||
fn inverse(&self) -> Self {
|
||||
<Fg as Pow<u64>>::pow(self, GOLDILOCKS_PRIME - 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
/// A number of the form hi << 64 + C, where B is 32-bit and C = 64-bit.
|
||||
pub struct Fg96 {
|
||||
lo: u64,
|
||||
hi: u64,
|
||||
}
|
||||
|
||||
impl Fg96 {
|
||||
#[inline]
|
||||
/// Reduce the number mod [`GOLDILOCKS_PRIME`].
|
||||
///
|
||||
/// /// # Remarks
|
||||
/// Note that 2^64 = 2^32 - 1 (mod g)
|
||||
///
|
||||
/// Thus, we can rewrite as (2^32 - 1) * B + C.
|
||||
pub fn reduce(self) -> Fg {
|
||||
let prod = (self.hi << 32) - self.hi;
|
||||
let mut res = prod.wrapping_add(self.lo);
|
||||
|
||||
if res < prod || res >= GOLDILOCKS_PRIME {
|
||||
res = res.wrapping_sub(GOLDILOCKS_PRIME);
|
||||
}
|
||||
|
||||
Fg(res)
|
||||
}
|
||||
}
|
||||
|
||||
/// An unreduced value of 159 or fewer bits, where lo is 64-bit, mid is 32-bit and hi is 63-bit.
|
||||
/// We can represent this value as `lo + mid << 64 + hi << 96`.
|
||||
pub struct Fg159 {
|
||||
lo: u64,
|
||||
mid: u64,
|
||||
hi: u64,
|
||||
}
|
||||
|
||||
impl Fg159 {
|
||||
#[inline]
|
||||
/// Reduce the number mod [`GOLDILOCKS_PRIME`].
|
||||
///
|
||||
/// /// # Remarks
|
||||
/// Note that
|
||||
/// * 2^64 = 2^32 - 1 (mod g)
|
||||
/// * 2^96 = -1
|
||||
///
|
||||
/// Thus, we can rewrite as (2^32 - 1) * B + (C - A).
|
||||
pub fn reduce(self) -> Fg {
|
||||
let mut lo2 = self.lo.wrapping_sub(self.hi);
|
||||
if self.hi > self.lo {
|
||||
lo2 = lo2.wrapping_add(GOLDILOCKS_PRIME);
|
||||
}
|
||||
|
||||
let prod = (self.mid << 32) - self.mid;
|
||||
let mut res = lo2.wrapping_add(prod);
|
||||
|
||||
if res < prod || res >= GOLDILOCKS_PRIME {
|
||||
res = res.wrapping_sub(GOLDILOCKS_PRIME);
|
||||
}
|
||||
|
||||
Fg(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for Fg159 {
|
||||
#[inline(always)]
|
||||
fn from(res: u128) -> Self {
|
||||
Fg159 {
|
||||
lo: res as u64,
|
||||
mid: (res >> 64) as u32 as u64,
|
||||
hi: (res >> 96) as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_add_fg() {
|
||||
for _ in 0..1000 {
|
||||
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
|
||||
let c = a + b;
|
||||
let expected = ((a.0 as u128 + b.0 as u128) % GOLDILOCKS_PRIME as u128) as u64;
|
||||
|
||||
assert_eq!(c, Fg(expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_sub_fg() {
|
||||
for _ in 0..1000 {
|
||||
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
|
||||
let c = a - b;
|
||||
let expected = ((a.0 as u128 + (GOLDILOCKS_PRIME - b.0) as u128)
|
||||
% GOLDILOCKS_PRIME as u128) as u64;
|
||||
|
||||
assert_eq!(c, Fg(expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_neg_fg() {
|
||||
for _ in 0..1000 {
|
||||
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
|
||||
let c = -a;
|
||||
let expected = Fg(0) - a;
|
||||
|
||||
assert_eq!(c, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mul_fg() {
|
||||
fn test_case(a: u64, b: u64) {
|
||||
let c = Fg(a) * Fg(b);
|
||||
let expected = ((a as u128 * b as u128) % GOLDILOCKS_PRIME as u128) as u64;
|
||||
|
||||
assert_eq!(c, Fg(expected));
|
||||
}
|
||||
|
||||
test_case(GOLDILOCKS_PRIME - 1, GOLDILOCKS_PRIME - 1);
|
||||
|
||||
for _ in 0..1000 {
|
||||
test_case(
|
||||
thread_rng().next_u64() % GOLDILOCKS_PRIME,
|
||||
thread_rng().next_u64() % GOLDILOCKS_PRIME,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mad_fg() {
|
||||
for _ in 0..1000 {
|
||||
let a = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
let b = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
let c = Fg(thread_rng().next_u64() % GOLDILOCKS_PRIME);
|
||||
|
||||
let acutal = a.mad(b, c);
|
||||
let expected =
|
||||
((a.0 as u128 * b.0 as u128 + c.0 as u128) % GOLDILOCKS_PRIME as u128) as u64;
|
||||
|
||||
assert_eq!(acutal, Fg(expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nth_root_of_unity() {
|
||||
for i in 1..16u64 {
|
||||
let root = Fg::nth_root_of_unity(0x1 << i);
|
||||
|
||||
assert_eq!(root.pow(0x1u64 << i), Fg::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
165
sunscreen_tfhe/src/math/mod.rs
Normal file
165
sunscreen_tfhe/src/math/mod.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use std::ops::{Add, BitAnd, Mul, Shr};
|
||||
|
||||
pub use sunscreen_math::{One, Zero};
|
||||
|
||||
/// FFT based operations.
|
||||
pub mod fft;
|
||||
|
||||
mod goldilocks_field;
|
||||
|
||||
/// Math operations on polynomials.
|
||||
pub mod polynomial;
|
||||
|
||||
/// Operations for performing radix decompositions.
|
||||
pub mod radix;
|
||||
mod torus;
|
||||
pub use torus::*;
|
||||
|
||||
mod basic;
|
||||
pub use basic::*;
|
||||
|
||||
/// Types where the roots of unity in the given field can be found.
|
||||
pub trait RootOfUnity
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// Find the n-th root of unity for the given field.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `w` is an n-th root of unity if `w^n = 1`.
|
||||
///
|
||||
/// # Panics
|
||||
/// Implementers may choose to panic if no n-th root of unity
|
||||
/// exists. You should take care in choosing your field modulus
|
||||
/// to ensure the root does exist.
|
||||
fn nth_root_of_unity(n: u64) -> Self;
|
||||
}
|
||||
|
||||
/// Numbers that have an inverse element.
|
||||
pub trait Inverse
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// Find the inverse of the given number.
|
||||
fn inverse(&self) -> Self;
|
||||
}
|
||||
|
||||
/// Numbers that can be raised to a power.
|
||||
pub trait Pow<T>
|
||||
where
|
||||
Self: Sized,
|
||||
T: Shr<u32, Output = T> + BitAnd<T, Output = T> + One + Eq,
|
||||
{
|
||||
/// Raise the number to the given power.
|
||||
fn pow(&self, exp: T) -> Self;
|
||||
}
|
||||
|
||||
impl<T, U> Pow<T> for U
|
||||
where
|
||||
T: Shr<u32, Output = T> + BitAnd<T, Output = T> + One + Eq + Copy,
|
||||
U: sunscreen_math::One + Copy + Add<U, Output = U> + Mul<U, Output = U>,
|
||||
{
|
||||
fn pow(&self, exp: T) -> U {
|
||||
let mut result = U::one();
|
||||
let mut power = *self;
|
||||
|
||||
for i in 0..64 {
|
||||
if (exp >> i) & T::one() == T::one() {
|
||||
result = result * power
|
||||
}
|
||||
|
||||
power = power * power;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Reinterpret the bits of an unsigned value as signed.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This is different than a cast. For example, UINT_MAX becomes -1
|
||||
/// when interpreted as a 2's complement signed value.
|
||||
pub trait ReinterpretAsSigned {
|
||||
/// The output type of the reinterpretation.
|
||||
type Output: ToF64;
|
||||
|
||||
/// Reinterpret the bits of an unsigned value as signed.
|
||||
fn reinterpret_as_signed(self) -> Self::Output;
|
||||
}
|
||||
|
||||
macro_rules! impl_reinterpret_signed {
|
||||
($ut:ty, $st:ty) => {
|
||||
impl ReinterpretAsSigned for $ut {
|
||||
type Output = $st;
|
||||
|
||||
#[inline(always)]
|
||||
fn reinterpret_as_signed(self) -> Self::Output {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_reinterpret_signed!(u8, i8);
|
||||
impl_reinterpret_signed!(u16, i16);
|
||||
impl_reinterpret_signed!(u32, i32);
|
||||
impl_reinterpret_signed!(u64, i64);
|
||||
impl_reinterpret_signed!(u128, i128);
|
||||
|
||||
/// Reinterpret the bits of a signed value as unsigned.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This is different than a cast. For example, -1 becomes UINT_MAX
|
||||
/// when interpreted as an unsigned integer.
|
||||
pub trait ReinterpretAsUnsigned {
|
||||
/// The output type of the reinterpretation.
|
||||
type Output;
|
||||
|
||||
/// Reinterpret the bits of a signed value as unsigned.
|
||||
fn reinterpret_as_unsigned(self) -> Self::Output;
|
||||
}
|
||||
|
||||
macro_rules! impl_reinterpret_unsigned {
|
||||
($st:ty, $ut:ty) => {
|
||||
impl ReinterpretAsUnsigned for $st {
|
||||
type Output = $ut;
|
||||
|
||||
#[inline(always)]
|
||||
fn reinterpret_as_unsigned(self) -> Self::Output {
|
||||
unsafe { std::mem::transmute(self) }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_reinterpret_unsigned!(i8, u8);
|
||||
impl_reinterpret_unsigned!(i16, u16);
|
||||
impl_reinterpret_unsigned!(i32, u32);
|
||||
impl_reinterpret_unsigned!(i64, u64);
|
||||
impl_reinterpret_unsigned!(i128, u128);
|
||||
|
||||
/// A type that can be converted from and to the fourier domain.
|
||||
pub trait FrequencyTransform {
|
||||
/// Original domain representation.
|
||||
type BaseRepr;
|
||||
|
||||
/// Fourier domain representation.
|
||||
type FrequencyRepr;
|
||||
|
||||
/// Perform a fourier transform.
|
||||
fn forward(&self, data: &[Self::BaseRepr], output: &mut [Self::FrequencyRepr]);
|
||||
|
||||
/// Perform an inverse fourier transform.
|
||||
fn reverse(&self, data: &[Self::FrequencyRepr], output: &mut [Self::BaseRepr]);
|
||||
}
|
||||
|
||||
/// A trait that allows types to specify how many leading zeros or ones are in
|
||||
/// the value.
|
||||
pub trait LeadingBits {
|
||||
/// Count the number of leading zeros in the value.
|
||||
fn leading_zeros(self) -> u32;
|
||||
|
||||
/// Count the number of leading ones in the value.
|
||||
fn leading_ones(self) -> u32;
|
||||
}
|
||||
353
sunscreen_tfhe/src/math/polynomial.rs
Normal file
353
sunscreen_tfhe/src/math/polynomial.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
use std::{
|
||||
num::Wrapping,
|
||||
ops::{Add, AddAssign, Mul, Neg, Sub, SubAssign},
|
||||
};
|
||||
|
||||
use num::traits::MulAdd;
|
||||
|
||||
use crate::{
|
||||
dst::FromMutSlice, entities::PolynomialRef, scratch::allocate_scratch, ToF64, Torus, TorusOps,
|
||||
};
|
||||
|
||||
/// Polynomial subtraction in place. This is equivalent to `a -= b` for each
|
||||
/// coefficient in the polynomial.
|
||||
pub fn polynomial_sub_assign<S>(lhs: &mut PolynomialRef<S>, rhs: &PolynomialRef<S>)
|
||||
where
|
||||
S: SubAssign + Copy,
|
||||
{
|
||||
for (a, b) in lhs
|
||||
.coeffs_mut()
|
||||
.iter_mut()
|
||||
.zip(rhs.coeffs().iter().copied())
|
||||
{
|
||||
*a -= b;
|
||||
}
|
||||
}
|
||||
|
||||
/// Negate a polynomial in place. This is equivalent to `a = -a` for each
|
||||
/// coefficient in the polynomial.
|
||||
pub fn polynomial_negate<S>(c: &mut PolynomialRef<S>)
|
||||
where
|
||||
S: Clone + Copy + Neg<Output = S>,
|
||||
{
|
||||
for c in c.coeffs_mut().iter_mut() {
|
||||
*c = -*c;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c = a * s` where `s` is scalar.
|
||||
pub fn polynomial_scalar_mul<S, T, U>(c: &mut PolynomialRef<S>, a: &PolynomialRef<T>, s: U)
|
||||
where
|
||||
S: Clone,
|
||||
T: Clone + Copy + Mul<U, Output = S>,
|
||||
U: Clone + Copy,
|
||||
{
|
||||
for (c, a) in c.coeffs_mut().iter_mut().zip(a.coeffs().iter()) {
|
||||
*c = *a * s
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c += a * s`, where `s` is scalar.
|
||||
pub fn polynomial_scalar_mad<S, T, U>(c: &mut PolynomialRef<S>, a: &PolynomialRef<T>, s: U)
|
||||
where
|
||||
S: Clone + Copy + Add<S, Output = S>,
|
||||
T: Clone + Copy + MulAdd<U, S, Output = S>,
|
||||
U: Clone + Copy,
|
||||
{
|
||||
for (c, a) in c.coeffs_mut().iter_mut().zip(a.coeffs().iter()) {
|
||||
*c = a.mul_add(s, *c);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c = a + b` where a, b, and c are polynomials.
|
||||
pub fn polynomial_add<S>(c: &mut PolynomialRef<S>, a: &PolynomialRef<S>, b: &PolynomialRef<S>)
|
||||
where
|
||||
S: Clone + Copy + Add<S, Output = S>,
|
||||
{
|
||||
assert_eq!(c.len(), a.len());
|
||||
assert_eq!(c.len(), b.len());
|
||||
|
||||
for (c, (a, b)) in c
|
||||
.as_mut_slice()
|
||||
.iter_mut()
|
||||
.zip(a.as_slice().iter().zip(b.as_slice().iter()))
|
||||
{
|
||||
*c = *a + *b;
|
||||
}
|
||||
}
|
||||
|
||||
/// Polynomial addition in place. This is equivalent to `a += b` for each
|
||||
/// coefficient in the polynomial.
|
||||
pub fn polynomial_add_assign<S>(lhs: &mut PolynomialRef<S>, rhs: &PolynomialRef<S>)
|
||||
where
|
||||
S: AddAssign + Copy,
|
||||
{
|
||||
for (a, b) in lhs
|
||||
.coeffs_mut()
|
||||
.iter_mut()
|
||||
.zip(rhs.coeffs().iter().copied())
|
||||
{
|
||||
*a += b;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c = a - b` where a, b, and c are polynomials.
|
||||
pub fn polynomial_sub<S>(c: &mut PolynomialRef<S>, a: &PolynomialRef<S>, b: &PolynomialRef<S>)
|
||||
where
|
||||
S: Clone + Copy + Sub<S, Output = S>,
|
||||
{
|
||||
assert_eq!(c.len(), a.len());
|
||||
assert_eq!(c.len(), b.len());
|
||||
|
||||
for (c, (a, b)) in c
|
||||
.as_mut_slice()
|
||||
.iter_mut()
|
||||
.zip(a.as_slice().iter().zip(b.as_slice().iter()))
|
||||
{
|
||||
*c = *a - *b;
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c += a \[*\] b` where `a` in Z\[X\]/f and `c, b` in T\[X\]/f, and
|
||||
/// \[*\] is the external product between these rings.
|
||||
pub fn polynomial_external_mad<S>(
|
||||
c: &mut PolynomialRef<Torus<S>>,
|
||||
a: &PolynomialRef<Torus<S>>,
|
||||
b: &PolynomialRef<S>,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
polynomial_mad_impl(c, a, b);
|
||||
}
|
||||
|
||||
/// Compute `c += a * b` where `*` the multiplication of the two polynomials of
|
||||
/// degree (N - 1) modulo (X^N + 1). This is done with the naive algorithm, and
|
||||
/// hence has O(N^2) time.
|
||||
pub fn polynomial_mad<S>(
|
||||
c: &mut PolynomialRef<Wrapping<S>>,
|
||||
a: &PolynomialRef<Wrapping<S>>,
|
||||
b: &PolynomialRef<Wrapping<S>>,
|
||||
) where
|
||||
S: TorusOps,
|
||||
Wrapping<S>: Sub<Wrapping<S>, Output = Wrapping<S>>
|
||||
+ Add<Wrapping<S>, Output = Wrapping<S>>
|
||||
+ Mul<Wrapping<S>, Output = Wrapping<S>>,
|
||||
{
|
||||
polynomial_mad_impl(c, a, b);
|
||||
}
|
||||
|
||||
/// Compute `c += a * b` where `*` the multiplication of the two polynomials of
|
||||
/// degree (N - 1) modulo (X^N + 1). This is done with the naive algorithm, and
|
||||
/// hence has O(N^2) time.
|
||||
fn polynomial_mad_impl<S, T, U>(
|
||||
c: &mut PolynomialRef<S>,
|
||||
a: &PolynomialRef<T>,
|
||||
b: &PolynomialRef<U>,
|
||||
) where
|
||||
U: Clone + Copy + ToF64,
|
||||
T: Mul<U, Output = S> + Clone + Copy + ToF64,
|
||||
S: Sub<S, Output = S> + Add<S, Output = S> + Clone + Copy,
|
||||
{
|
||||
assert!(a.len().is_power_of_two());
|
||||
assert_eq!(a.len(), b.len());
|
||||
assert_eq!(a.len(), c.len());
|
||||
|
||||
let len = a.len();
|
||||
|
||||
// Polynomial's length is a power of 2, so use mask to perform modulus.
|
||||
let mask = len - 1;
|
||||
|
||||
let coeffs = c.coeffs_mut();
|
||||
|
||||
let mut poly_f64 = allocate_scratch::<f64>(a.len());
|
||||
let poly_f64_ref = PolynomialRef::from_mut_slice(poly_f64.as_mut_slice());
|
||||
|
||||
a.map_into(poly_f64_ref, |x| x.to_f64());
|
||||
|
||||
for (i, l) in a.coeffs().iter().copied().enumerate() {
|
||||
for (j, r) in b.coeffs().iter().copied().enumerate() {
|
||||
if i + j >= len {
|
||||
// Reduce mod len and subtract due to negacyclic
|
||||
// polynomials.
|
||||
let index = (i + j) & mask;
|
||||
coeffs[index] = coeffs[index] - l * r;
|
||||
} else {
|
||||
let index = i + j;
|
||||
coeffs[index] = coeffs[index] + l * r;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[derive(BarrettConfig)]
|
||||
#[barrett_config(modulus = "18446744073709551616", num_limbs = 2)]
|
||||
pub struct U64Config;
|
||||
|
||||
pub type Zq64 = Zq<2, BarrettBackend<2, U64Config>>;
|
||||
|
||||
use crate::{
|
||||
entities::{Polynomial, PolynomialFft},
|
||||
normalized_torus_distance, ReinterpretAsSigned, ReinterpretAsUnsigned,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use num::{Complex, Zero};
|
||||
use rand::{thread_rng, RngCore};
|
||||
use sunscreen_math::{
|
||||
poly::Polynomial as BasePolynomial,
|
||||
ring::{BarrettBackend, Zq},
|
||||
BarrettConfig, One,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_multiply_polynomials() {
|
||||
// Compare our negacyclic polynomial implementation against
|
||||
// sunscreen_math computing (a * b) mod (X^N + 1).
|
||||
fn case(a: &PolynomialRef<Torus<u64>>, b: &PolynomialRef<u64>) {
|
||||
let actual = a * b;
|
||||
|
||||
let mut f = vec![<Zq64 as sunscreen_math::Zero>::zero(); a.len() + 1];
|
||||
|
||||
let len = a.len();
|
||||
|
||||
f[0] = Zq64::one();
|
||||
f[a.len()] = Zq64::one();
|
||||
|
||||
let f = BasePolynomial { coeffs: f };
|
||||
|
||||
let a = BasePolynomial {
|
||||
coeffs: a
|
||||
.coeffs()
|
||||
.iter()
|
||||
.map(|x| Zq64::from(x.inner()))
|
||||
.collect::<Vec<_>>(),
|
||||
};
|
||||
|
||||
let b = BasePolynomial {
|
||||
coeffs: b.coeffs().iter().map(|x| Zq64::from(*x)).collect(),
|
||||
};
|
||||
|
||||
let expected = a * b;
|
||||
|
||||
let (_, expected) = expected.vartime_div_rem_restricted_rhs(&f);
|
||||
|
||||
let expected = expected
|
||||
.coeffs
|
||||
.iter()
|
||||
.take(len)
|
||||
.map(|x| Torus::from(x.val.as_words()[0]))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let expected = Polynomial::new(&expected);
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
for _ in 0..50 {
|
||||
let len = thread_rng().next_u64() % 8 + 1;
|
||||
let len = 0x1 << len;
|
||||
|
||||
let a = (0..len)
|
||||
.map(|_| Torus::from(thread_rng().next_u64()))
|
||||
.collect::<Vec<_>>();
|
||||
let a = Polynomial::new(&a);
|
||||
|
||||
let b = (0..len)
|
||||
.map(|_| thread_rng().next_u64())
|
||||
.collect::<Vec<_>>();
|
||||
let b = Polynomial::new(&b);
|
||||
|
||||
case(&a, &b);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_roundtrip_polynomial() {
|
||||
let poly = (0..1024u64).collect::<Vec<_>>();
|
||||
let poly = Polynomial::new(&poly);
|
||||
|
||||
let mut actual = Polynomial::<u64>::zero(poly.len());
|
||||
|
||||
let mut out = PolynomialFft::new(&vec![Complex::zero(); poly.len() / 2]);
|
||||
|
||||
poly.fft(&mut out);
|
||||
out.ifft(&mut actual);
|
||||
|
||||
assert_eq!(poly, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_multiply_polynomials_fft() {
|
||||
for _ in 0..100 {
|
||||
let a = (0..1024u64)
|
||||
.map(|x| Torus::from(x % 0x8000_0000))
|
||||
.collect::<Vec<_>>();
|
||||
let a = Polynomial::new(&a);
|
||||
let b = (0..1024u64).map(|x| x % 16).collect::<Vec<_>>();
|
||||
let b = Polynomial::new(&b);
|
||||
|
||||
let mut expected = Polynomial::<Torus<u64>>::zero(a.len());
|
||||
|
||||
polynomial_external_mad(&mut expected, &a, &b);
|
||||
|
||||
let mut a_fft = PolynomialFft::new(&vec![Complex::zero(); a.len() / 2]);
|
||||
let mut b_fft = a_fft.clone();
|
||||
let mut c_fft = a_fft.clone();
|
||||
|
||||
a.fft(&mut a_fft);
|
||||
b.fft(&mut b_fft);
|
||||
|
||||
c_fft.multiply_add(&a_fft, &b_fft);
|
||||
|
||||
let mut actual = Polynomial::<Torus<u64>>::zero(a.len());
|
||||
c_fft.ifft(&mut actual);
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_approx_multiply_large_polynomials_fft() {
|
||||
let n = 1024;
|
||||
|
||||
for _ in 0..100 {
|
||||
// a is uniform torus elements, b is "small" as we'll encounter
|
||||
// during radix decomposition.
|
||||
let a = (0..n)
|
||||
.map(|_| Torus::from(rand::thread_rng().next_u64()))
|
||||
.collect::<Vec<_>>();
|
||||
let a = Polynomial::new(&a);
|
||||
let b = (0..n)
|
||||
.map(|_| {
|
||||
let signed = (rand::thread_rng().next_u64() % 32).reinterpret_as_signed() - 16;
|
||||
|
||||
signed.reinterpret_as_unsigned()
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let b = Polynomial::new(&b);
|
||||
|
||||
let mut expected = Polynomial::<Torus<u64>>::zero(a.len());
|
||||
|
||||
polynomial_external_mad(&mut expected, &a, &b);
|
||||
|
||||
let mut a_fft = PolynomialFft::new(&vec![Complex::zero(); a.len() / 2]);
|
||||
let mut b_fft = a_fft.clone();
|
||||
let mut c_fft = a_fft.clone();
|
||||
|
||||
a.fft(&mut a_fft);
|
||||
b.fft(&mut b_fft);
|
||||
|
||||
c_fft.multiply_add(&a_fft, &b_fft);
|
||||
|
||||
let mut actual = Polynomial::<Torus<u64>>::zero(a.len());
|
||||
c_fft.ifft(&mut actual);
|
||||
|
||||
for (a, e) in actual.coeffs().iter().zip(expected.coeffs().iter()) {
|
||||
let err = normalized_torus_distance(a, e).abs();
|
||||
assert!(err < 1e-12);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
399
sunscreen_tfhe/src/math/radix.rs
Normal file
399
sunscreen_tfhe/src/math/radix.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
use crate::{
|
||||
entities::{PolynomialIterator, PolynomialRef},
|
||||
math::{Torus, TorusOps},
|
||||
polynomial::polynomial_scalar_mad,
|
||||
RadixCount, RadixDecomposition, RadixLog,
|
||||
};
|
||||
|
||||
// Needed by allow_scratch_ref
|
||||
#[allow(unused_imports)]
|
||||
use crate::dst::FromMutSlice;
|
||||
|
||||
/// An iterator from least to most significant radix decomposition of a value.
|
||||
pub struct ScalarRadixIterator<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
cur: S,
|
||||
level: usize,
|
||||
radix: RadixDecomposition,
|
||||
}
|
||||
|
||||
impl<S: TorusOps> ScalarRadixIterator<S> {
|
||||
/// Creates a new [`ScalarRadixIterator`] for the given [Torus] value.
|
||||
#[inline(always)]
|
||||
pub fn new(val: Torus<S>, radix: &RadixDecomposition) -> Self {
|
||||
Self {
|
||||
cur: round(val, radix),
|
||||
level: 0,
|
||||
radix: *radix,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn get_next_digit<S: TorusOps>(cur: &mut S, radix_log: usize) -> S {
|
||||
let mask = S::from_u64((0x1u64 << radix_log) - 1);
|
||||
|
||||
// Interpreting the digits over [-B/2,B/2) reduces noise by half a bit on average.
|
||||
let mut digit = *cur & mask;
|
||||
*cur = *cur >> radix_log;
|
||||
let carry = digit >> (radix_log - 1);
|
||||
*cur = *cur + carry;
|
||||
digit = digit.wrapping_sub(&(carry << radix_log));
|
||||
|
||||
digit
|
||||
}
|
||||
|
||||
impl<S: TorusOps> Iterator for ScalarRadixIterator<S> {
|
||||
type Item = S;
|
||||
|
||||
#[inline(always)]
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.level == self.radix.count.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let digit = get_next_digit(&mut self.cur, self.radix.radix_log.0);
|
||||
|
||||
self.level += 1;
|
||||
|
||||
Some(digit)
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator from least to most significant radix decomposition of the coefficients
|
||||
/// of a polynomial.
|
||||
pub struct PolynomialRadixIterator<'a, S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
scratch: &'a mut PolynomialRef<S>,
|
||||
level: usize,
|
||||
radix: RadixDecomposition,
|
||||
}
|
||||
|
||||
impl<'a, S> PolynomialRadixIterator<'a, S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
/// Creates a new [`PolynomialRadixIterator`] for the given polynomial.
|
||||
pub fn new(
|
||||
poly: &PolynomialRef<Torus<S>>,
|
||||
scratch: &'a mut PolynomialRef<S>,
|
||||
radix: &RadixDecomposition,
|
||||
) -> Self {
|
||||
assert!(radix.radix_log.0 * radix.count.0 < S::BITS as usize);
|
||||
assert_ne!(radix.radix_log.0 * radix.count.0, 0);
|
||||
|
||||
poly.map_into(scratch, |x| round(*x, radix));
|
||||
|
||||
Self {
|
||||
scratch,
|
||||
level: 0,
|
||||
radix: radix.to_owned(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Writes the next polynomial decomposition to `dst` and returns `Some(())` if there is a next digit.
|
||||
pub fn write_next(&mut self, dst: &mut PolynomialRef<S>) -> Option<()> {
|
||||
if self.level == self.radix.count.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
self.level += 1;
|
||||
|
||||
for (s, r) in self
|
||||
.scratch
|
||||
.coeffs_mut()
|
||||
.iter_mut()
|
||||
.zip(dst.coeffs_mut().iter_mut())
|
||||
{
|
||||
*r = get_next_digit(s, self.radix.radix_log.0);
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Recomposes a polynomial from its `digits` decomposition and adds it to `dst`.
|
||||
///
|
||||
/// # Remarks
|
||||
/// The digits should iterate from least to most significant.
|
||||
pub fn recompose_and_add<S>(
|
||||
dst: &mut PolynomialRef<Torus<S>>,
|
||||
digits: &mut PolynomialIterator<S>,
|
||||
radix: RadixLog,
|
||||
count: RadixCount,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let shift_amount = S::BITS as usize - radix.0 * count.0;
|
||||
let mut cur_radix = S::from_u64(0x1 << shift_amount);
|
||||
let mut actual_count = 0;
|
||||
|
||||
for d in digits {
|
||||
polynomial_scalar_mad(dst, d.as_torus(), cur_radix);
|
||||
|
||||
actual_count += 1;
|
||||
cur_radix = cur_radix << radix.0;
|
||||
}
|
||||
|
||||
assert_eq!(count.0, actual_count);
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Multiply `val` by q / B^(j + 1).
|
||||
pub fn scale_by_decomposition_factor<S: TorusOps>(
|
||||
val: S,
|
||||
j: usize,
|
||||
radix: &RadixDecomposition,
|
||||
) -> S {
|
||||
let shift = S::BITS as usize - radix.radix_log.0 * (j + 1);
|
||||
let factor = S::one() << shift;
|
||||
|
||||
val.wrapping_mul(&factor)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Rounds the given [`Torus`] element and returns the value interpreted as an integer.
|
||||
fn round<S: TorusOps>(x: Torus<S>, radix: &RadixDecomposition) -> S {
|
||||
let shift = S::BITS as usize - radix.radix_log.0 * radix.count.0;
|
||||
let round_bit = (x.inner() >> (shift - 1)) & S::from_u64(0x1);
|
||||
|
||||
(x.inner() >> shift).wrapping_add(&round_bit)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
entities::{Polynomial, PolynomialList},
|
||||
rand::uniform_torus,
|
||||
scratch::allocate_scratch_ref,
|
||||
PlaintextBits, PolynomialDegree,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_round_values() {
|
||||
assert_eq!(
|
||||
round(
|
||||
Torus::from(0x12348FFF_FFFFFFFFu64),
|
||||
&RadixDecomposition {
|
||||
radix_log: RadixLog(4),
|
||||
count: RadixCount(4)
|
||||
}
|
||||
),
|
||||
0x1235
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
round(
|
||||
Torus::from(0x12347FFF_FFFFFFFFu64),
|
||||
&RadixDecomposition {
|
||||
radix_log: RadixLog(4),
|
||||
count: RadixCount(4)
|
||||
}
|
||||
),
|
||||
0x1234
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_decompose() {
|
||||
let x = Polynomial::new(&[Torus::encode(7u64, PlaintextBits(4))]);
|
||||
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<u64>, (PolynomialDegree(x.len())));
|
||||
|
||||
let mut radix_iter = PolynomialRadixIterator::new(
|
||||
&x,
|
||||
scratch,
|
||||
&RadixDecomposition {
|
||||
radix_log: RadixLog(2),
|
||||
count: RadixCount(2),
|
||||
},
|
||||
);
|
||||
|
||||
let mut dst = Polynomial::zero(1);
|
||||
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(1));
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(2));
|
||||
assert_eq!(radix_iter.write_next(&mut dst), None);
|
||||
|
||||
let x = Polynomial::new(&[Torus::encode(1u64, PlaintextBits(1))]);
|
||||
|
||||
let radix_log = 4;
|
||||
let mut radix_iter = PolynomialRadixIterator::new(
|
||||
&x,
|
||||
scratch,
|
||||
&RadixDecomposition {
|
||||
radix_log: RadixLog(radix_log),
|
||||
count: RadixCount(3),
|
||||
},
|
||||
);
|
||||
|
||||
let mut dst = Polynomial::zero(1);
|
||||
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
assert_eq!(dst.coeffs()[0], 0u64);
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
assert_eq!(dst.coeffs()[0], 0u64);
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
|
||||
// 2^(beta - 1) - 2^beta
|
||||
assert_eq!(dst.coeffs()[0], 0u64.wrapping_sub(1 << (radix_log - 1)));
|
||||
assert_eq!(radix_iter.write_next(&mut dst), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_decompose_polynomial() {
|
||||
let x = Polynomial::new(&[
|
||||
// Decomposes to 1 [for 2^1] + 0 [for 2^2], or 1 + 0
|
||||
Torus::encode(1u64, PlaintextBits(4)),
|
||||
// Decomposes to -2 [for 2^1] + 1 [for 2^2], or -2 + 4
|
||||
Torus::encode(2u64, PlaintextBits(4)),
|
||||
// Decomposes to -1 [for 2^1] + 1 [for 2^2], or -1 + 4
|
||||
Torus::encode(3u64, PlaintextBits(4)),
|
||||
// Decomposes to 0 [for 2^1] + 1 [for 2^2], or 0 + 4
|
||||
Torus::encode(4u64, PlaintextBits(4)),
|
||||
]);
|
||||
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<u64>, (PolynomialDegree(x.len())));
|
||||
|
||||
let mut radix_iter = PolynomialRadixIterator::new(
|
||||
&x,
|
||||
scratch,
|
||||
&RadixDecomposition {
|
||||
radix_log: RadixLog(2),
|
||||
count: RadixCount(2),
|
||||
},
|
||||
);
|
||||
|
||||
let mut dst = Polynomial::zero(4);
|
||||
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
assert_eq!(
|
||||
dst.coeffs(),
|
||||
[1u64, 0u64.wrapping_sub(2), 0u64.wrapping_sub(1), 0]
|
||||
);
|
||||
assert_eq!(radix_iter.write_next(&mut dst), Some(()));
|
||||
assert_eq!(dst.coeffs(), [0, 1, 1, 1]);
|
||||
assert_eq!(radix_iter.write_next(&mut dst), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_decompose_recompose() {
|
||||
let d = PolynomialDegree(8);
|
||||
|
||||
for _ in 0..50 {
|
||||
let radix = (thread_rng().next_u32() % 7) + 1;
|
||||
let radix = RadixLog(radix as usize);
|
||||
|
||||
let count = loop {
|
||||
let count = (thread_rng().next_u32() % 7) + 1;
|
||||
|
||||
if (radix.0 * count as usize) < u64::BITS as usize {
|
||||
break count;
|
||||
}
|
||||
};
|
||||
|
||||
let count = RadixCount(count as usize);
|
||||
|
||||
let x = (0..d.0).map(|_| uniform_torus::<u64>()).collect::<Vec<_>>();
|
||||
let x = Polynomial::new(&x);
|
||||
|
||||
let expected = x.map(|c| {
|
||||
let radix_bits = radix.0 * count.0;
|
||||
let lsb = u64::BITS as usize - radix_bits;
|
||||
let round_loc = lsb - 1;
|
||||
|
||||
let round = (c.inner() >> round_loc) & 0x1;
|
||||
let mask = 0xFFFFFFFF_FFFFFFFFu64 << lsb;
|
||||
Torus::from((c.inner() & mask).wrapping_add(round << lsb))
|
||||
});
|
||||
|
||||
let mut digits = PolynomialList::new(d, count.0);
|
||||
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<u64>, (PolynomialDegree(x.len())));
|
||||
|
||||
let mut radix_iter = PolynomialRadixIterator::new(
|
||||
&x,
|
||||
scratch,
|
||||
&RadixDecomposition {
|
||||
radix_log: radix,
|
||||
count,
|
||||
},
|
||||
);
|
||||
|
||||
for d in digits.iter_mut(d) {
|
||||
radix_iter.write_next(d);
|
||||
}
|
||||
|
||||
let mut result = Polynomial::zero(x.len());
|
||||
|
||||
recompose_and_add(&mut result, &mut digits.iter(d), radix, count);
|
||||
|
||||
assert_eq!(expected, result);
|
||||
}
|
||||
}
|
||||
|
||||
fn random_radix() -> RadixDecomposition {
|
||||
let radix = (thread_rng().next_u32() % 7) + 1;
|
||||
let radix = RadixLog(radix as usize);
|
||||
|
||||
let count = loop {
|
||||
let count = (thread_rng().next_u32() % 7) + 1;
|
||||
|
||||
if (radix.0 * count as usize) < u64::BITS as usize {
|
||||
break count;
|
||||
}
|
||||
};
|
||||
|
||||
let count = RadixCount(count as usize);
|
||||
|
||||
RadixDecomposition {
|
||||
radix_log: radix,
|
||||
count,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_decompose_scalar() {
|
||||
for _ in 0..50 {
|
||||
let radix = random_radix();
|
||||
let val = Torus::from(thread_rng().next_u64());
|
||||
/*let radix = RadixDecomposition {
|
||||
count: RadixCount(3),
|
||||
radix_log: RadixLog(4),
|
||||
};
|
||||
let val = Torus::from(0xDEADBEEF_FEEDF00Du64);
|
||||
*/
|
||||
|
||||
let decomp = ScalarRadixIterator::new(val, &radix);
|
||||
let mut digits = vec![];
|
||||
|
||||
for digit in decomp {
|
||||
digits.push(digit);
|
||||
}
|
||||
|
||||
let actual = digits.iter().enumerate().fold(0u64, |s, (i, d)| {
|
||||
let shift_amount = u64::BITS as usize - radix.radix_log.0 * (radix.count.0 - i);
|
||||
let cur_radix = 0x1u64 << shift_amount;
|
||||
|
||||
(cur_radix.wrapping_mul(*d)).wrapping_add(s)
|
||||
});
|
||||
|
||||
// Round shifts our value down to the LSB places, so move it back up to compare
|
||||
// against a torus element.
|
||||
let expected =
|
||||
round(val, &radix) << (u64::BITS as usize - radix.radix_log.0 * (radix.count.0));
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
619
sunscreen_tfhe/src/math/torus.rs
Normal file
619
sunscreen_tfhe/src/math/torus.rs
Normal file
@@ -0,0 +1,619 @@
|
||||
use bytemuck::{Pod as BytemuckPod, Zeroable};
|
||||
use num::traits::{
|
||||
Bounded, MulAdd, Num, WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr,
|
||||
WrappingSub,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{Binary, Debug, LowerHex, UpperHex},
|
||||
num::Wrapping,
|
||||
ops::{Add, AddAssign, BitAnd, Deref, Mul, Neg, Shl, Shr, Sub, SubAssign},
|
||||
};
|
||||
use sunscreen_math::{refify_binary_op, Zero};
|
||||
|
||||
use crate::{
|
||||
math::{ReinterpretAsSigned, ReinterpretAsUnsigned},
|
||||
scratch::Pod,
|
||||
PlaintextBits,
|
||||
};
|
||||
|
||||
/// Number of bits used in the representation of a type.
|
||||
pub trait NumBits {
|
||||
/// The number of bits used in the representation of a type.
|
||||
const BITS: u32;
|
||||
}
|
||||
|
||||
impl NumBits for u32 {
|
||||
const BITS: u32 = u32::BITS;
|
||||
}
|
||||
|
||||
impl NumBits for u64 {
|
||||
const BITS: u32 = u64::BITS;
|
||||
}
|
||||
|
||||
/// Convert a type into a 64-bit floating point number.
|
||||
pub trait ToF64 {
|
||||
/// Approximately convert the value to an f64.
|
||||
fn to_f64(self) -> f64;
|
||||
}
|
||||
|
||||
/// Convert a 64-bit floating point number into a type.
|
||||
pub trait FromF64
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// Approximately convert an f64 into a value.
|
||||
fn from_f64(x: f64) -> Self;
|
||||
}
|
||||
|
||||
/// A type that supports operations on a Torus.
|
||||
pub trait TorusOps:
|
||||
BitAnd<Self, Output = Self>
|
||||
+ WrappingAdd
|
||||
+ WrappingSub
|
||||
+ WrappingMul
|
||||
+ WrappingShl
|
||||
+ WrappingShr
|
||||
+ WrappingNeg
|
||||
+ BitAnd
|
||||
+ ReinterpretAsSigned
|
||||
+ Num
|
||||
+ NumBits
|
||||
+ From<u32>
|
||||
+ TryFrom<u64>
|
||||
+ FromU64
|
||||
+ Clone
|
||||
+ Copy
|
||||
+ Binary
|
||||
+ LowerHex
|
||||
+ UpperHex
|
||||
+ std::fmt::Debug
|
||||
+ Ord
|
||||
+ Zero
|
||||
+ Pod
|
||||
+ BytemuckPod
|
||||
+ Bounded
|
||||
+ ToF64
|
||||
+ FromF64
|
||||
+ ToU64
|
||||
+ NumBits
|
||||
{
|
||||
}
|
||||
|
||||
// Sound since Torus is a transparent wrapper and `S` impl `Pod`
|
||||
unsafe impl<S: TorusOps> Zeroable for Torus<S> {}
|
||||
unsafe impl<S: TorusOps> BytemuckPod for Torus<S> {}
|
||||
|
||||
/// Convert a type into a 64-bit unsigned integer.
|
||||
pub trait ToU64 {
|
||||
/// Convert the value to a 64-bit unsigned integer.
|
||||
fn to_u64(self) -> u64;
|
||||
}
|
||||
|
||||
impl ToU64 for u32 {
|
||||
fn to_u64(self) -> u64 {
|
||||
self as u64
|
||||
}
|
||||
}
|
||||
|
||||
impl ToU64 for u64 {
|
||||
fn to_u64(self) -> u64 {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ToU64 for Torus<T>
|
||||
where
|
||||
T: TorusOps,
|
||||
{
|
||||
fn to_u64(self) -> u64 {
|
||||
self.0.to_u64()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a 64-bit unsigned integer into a type.
|
||||
pub trait FromU64 {
|
||||
/// For the given 64-bit value, take the N most significant bits, where
|
||||
/// N is the bitlength of this type.
|
||||
fn from_u64(val: u64) -> Self;
|
||||
}
|
||||
|
||||
impl FromU64 for u32 {
|
||||
fn from_u64(val: u64) -> Self {
|
||||
(val & 0xFFFFFFFF) as Self
|
||||
}
|
||||
}
|
||||
|
||||
impl FromU64 for u64 {
|
||||
fn from_u64(val: u64) -> Self {
|
||||
val
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_tof64 {
|
||||
($t:ty) => {
|
||||
impl ToF64 for $t {
|
||||
fn to_f64(self) -> f64 {
|
||||
self as f64
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_tof64!(i8);
|
||||
impl_tof64!(i16);
|
||||
impl_tof64!(i32);
|
||||
impl_tof64!(i64);
|
||||
impl_tof64!(i128);
|
||||
impl_tof64!(u8);
|
||||
impl_tof64!(u16);
|
||||
impl_tof64!(u32);
|
||||
impl_tof64!(u64);
|
||||
impl_tof64!(u128);
|
||||
|
||||
impl<T> ToF64 for Wrapping<T>
|
||||
where
|
||||
T: ToF64,
|
||||
{
|
||||
fn to_f64(self) -> f64 {
|
||||
self.0.to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ToF64 for Torus<T>
|
||||
where
|
||||
T: TorusOps,
|
||||
{
|
||||
fn to_f64(self) -> f64 {
|
||||
self.0.to_f64()
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_unsigned_fromf64 {
|
||||
($t:ty,$st:ty) => {
|
||||
impl FromF64 for $t {
|
||||
#[inline(always)]
|
||||
fn from_f64(x: f64) -> $t {
|
||||
let x = x as $st;
|
||||
|
||||
x.reinterpret_as_unsigned()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_unsigned_fromf64!(u128, i128);
|
||||
impl_unsigned_fromf64!(u64, i64);
|
||||
impl_unsigned_fromf64!(u32, i32);
|
||||
impl_unsigned_fromf64!(u16, i16);
|
||||
|
||||
impl<T> FromF64 for Torus<T>
|
||||
where
|
||||
T: TorusOps,
|
||||
{
|
||||
fn from_f64(x: f64) -> Self {
|
||||
Self(T::from_f64(x))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FromF64 for Wrapping<T>
|
||||
where
|
||||
T: FromF64,
|
||||
{
|
||||
fn from_f64(x: f64) -> Self {
|
||||
Wrapping(T::from_f64(x))
|
||||
}
|
||||
}
|
||||
|
||||
impl TorusOps for u64 {}
|
||||
impl TorusOps for u32 {}
|
||||
|
||||
/// A wrapper around a type that supports Torus operations.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct Torus<S: TorusOps = u64>(S);
|
||||
|
||||
/// Compute the distance between two Torus values, normalized to the unit torus.
|
||||
/// The first argument is taken as the reference point.
|
||||
///
|
||||
/// For example, if a is at a normalized position of 0.2, and b is at a
|
||||
/// normalized position of 0.8, then there are two distances between them:
|
||||
/// 0.6 and -0.4. This function will return the shorter of the two distances,
|
||||
/// -0.4.
|
||||
///
|
||||
/// In other words:
|
||||
///
|
||||
/// ```text
|
||||
/// b.normalized_torus() = a.normalized_torus() + normalized_torus_distance(a, b) (mod 1.0)
|
||||
/// ```
|
||||
pub fn normalized_torus_distance<S: TorusOps>(a: &Torus<S>, b: &Torus<S>) -> f64 {
|
||||
let a_minus_b = a - b;
|
||||
let b_minus_a = b - a;
|
||||
|
||||
let modulus = 2_f64.powi(S::BITS as i32);
|
||||
|
||||
let difference = if a_minus_b < b_minus_a {
|
||||
-a_minus_b.to_f64()
|
||||
} else {
|
||||
b_minus_a.to_f64()
|
||||
};
|
||||
|
||||
difference / modulus
|
||||
}
|
||||
|
||||
impl<S: TorusOps> Deref for Torus<S> {
|
||||
type Target = S;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> NumBits for Torus<S> {
|
||||
const BITS: u32 = S::BITS;
|
||||
}
|
||||
|
||||
impl<S: TorusOps> Torus<S> {
|
||||
/// Compute the normalized position of this [Torus] value on the unit torus [0.0, 1.0).
|
||||
pub fn normalized_torus(&self) -> f64 {
|
||||
let modulus: f64 = 2_f64.powi(S::BITS as i32);
|
||||
let val: f64 = self.0.to_f64();
|
||||
|
||||
val / modulus
|
||||
}
|
||||
|
||||
/// Compute the distance between this [Torus] value and another, normalized to
|
||||
/// the unit torus.
|
||||
pub fn normalized_torus_distance(&self, other: &Self) -> f64 {
|
||||
normalized_torus_distance(self, other)
|
||||
}
|
||||
|
||||
/// Encode a value into a [Torus] that supports up to plain_bits of values.
|
||||
/// This encodes the value on the equispaced `2^plaintext_bits` positions on
|
||||
/// a larger torus.
|
||||
pub fn encode(val: S, plain_bits: PlaintextBits) -> Self {
|
||||
assert!(plain_bits.0 < S::BITS);
|
||||
|
||||
let encoded = val.wrapping_shl(S::BITS - plain_bits.0);
|
||||
|
||||
Self(encoded)
|
||||
}
|
||||
|
||||
/// Decode a value from a [Torus] that supports up to plain_bits of values.
|
||||
pub fn decode(&self, plain_bits: PlaintextBits) -> S {
|
||||
assert!(plain_bits.0 < S::BITS);
|
||||
|
||||
let round_bit = self.0.wrapping_shr(S::BITS - plain_bits.0 - 1) & S::from(0x1);
|
||||
let mask = S::from((0x1 << plain_bits.0) - 1);
|
||||
|
||||
(self.0.wrapping_shr(S::BITS - plain_bits.0) + round_bit) & mask
|
||||
}
|
||||
|
||||
/// Scale a Torus element to a different modulus. Assumes that the two moduli are
|
||||
/// powers of 2.
|
||||
pub fn switch_modulus_smaller<T>(&self) -> Torus<T>
|
||||
where
|
||||
T: TorusOps + TryFrom<S>,
|
||||
<T as TryFrom<S>>::Error: Debug,
|
||||
{
|
||||
let shift = (S::BITS - T::BITS) as usize;
|
||||
|
||||
// We don't wrap the bits we don't need
|
||||
let y = self.0 >> shift;
|
||||
|
||||
Torus(T::try_from(y).expect("impossible error on switch_modulus_smaller"))
|
||||
}
|
||||
|
||||
/// Return the underlying type of the Torus.
|
||||
#[inline(always)]
|
||||
pub fn inner(&self) -> S {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> From<S> for Torus<S> {
|
||||
#[inline(always)]
|
||||
fn from(value: S) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> Zero for Torus<S> {
|
||||
fn zero() -> Self {
|
||||
Self(S::from(0))
|
||||
}
|
||||
|
||||
fn vartime_is_zero(&self) -> bool {
|
||||
self.inner() == <S as sunscreen_math::Zero>::zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> Neg for Torus<S> {
|
||||
type Output = Self;
|
||||
|
||||
fn neg(self) -> Self::Output {
|
||||
Self::Output::from(self.0.wrapping_neg())
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> WrappingNeg for Torus<S> {
|
||||
fn wrapping_neg(&self) -> Self {
|
||||
Self::from(self.0.wrapping_neg())
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl<S: TorusOps> Add<&Torus<S>> for &Torus<S> {
|
||||
type Output = Torus<S>;
|
||||
|
||||
fn add(self, rhs: &Torus<S>) -> Self::Output {
|
||||
Self::Output::from(self.0.wrapping_add(&rhs.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> WrappingAdd for Torus<S> {
|
||||
fn wrapping_add(&self, rhs: &Self) -> Self {
|
||||
self + rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl<S: TorusOps> Sub<&Torus<S>> for &Torus<S> {
|
||||
type Output = Torus<S>;
|
||||
|
||||
fn sub(self, rhs: &Torus<S>) -> Self::Output {
|
||||
Self::Output::from(self.0.wrapping_sub(&rhs.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> WrappingSub for Torus<S> {
|
||||
fn wrapping_sub(&self, rhs: &Self) -> Self {
|
||||
self - rhs
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl<S: TorusOps> Mul<&S> for &Torus<S> {
|
||||
type Output = Torus<S>;
|
||||
|
||||
fn mul(self, rhs: &S) -> Self::Output {
|
||||
Self::Output::from(self.wrapping_mul(rhs))
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl<S: TorusOps> BitAnd<&Torus<S>> for &Torus<S> {
|
||||
type Output = Torus<S>;
|
||||
|
||||
fn bitand(self, rhs: &Torus<S>) -> Self::Output {
|
||||
Torus::from(self.0 & rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl<S: TorusOps> Shr<&usize> for &Torus<S> {
|
||||
type Output = Torus<S>;
|
||||
|
||||
fn shr(self, rhs: &usize) -> Self::Output {
|
||||
Torus::from(self.0 >> *rhs)
|
||||
}
|
||||
}
|
||||
|
||||
#[refify_binary_op]
|
||||
impl<S: TorusOps> Shl<&usize> for &Torus<S> {
|
||||
type Output = Torus<S>;
|
||||
|
||||
fn shl(self, rhs: &usize) -> Self::Output {
|
||||
Torus::from(self.0 << *rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> AddAssign<Self> for Torus<S> {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> AddAssign<&Self> for Torus<S> {
|
||||
fn add_assign(&mut self, rhs: &Self) {
|
||||
*self = *self + rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> SubAssign<Self> for Torus<S> {
|
||||
fn sub_assign(&mut self, rhs: Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> SubAssign<&Self> for Torus<S> {
|
||||
fn sub_assign(&mut self, rhs: &Self) {
|
||||
*self = *self - rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> std::iter::Sum for Torus<S> {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.fold(Self::zero(), |acc, x| acc + x)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ReinterpretAsSigned for Torus<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
type Output = <S as ReinterpretAsSigned>::Output;
|
||||
|
||||
#[inline(always)]
|
||||
fn reinterpret_as_signed(self) -> Self::Output {
|
||||
self.0.reinterpret_as_signed()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> num::Zero for Torus<S> {
|
||||
fn zero() -> Self {
|
||||
Self(<S as num::Zero>::zero())
|
||||
}
|
||||
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0.is_zero()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: TorusOps> MulAdd<S, Self> for Torus<S> {
|
||||
type Output = Self;
|
||||
|
||||
#[inline(always)]
|
||||
fn mul_add(self, a: S, b: Self) -> Self::Output {
|
||||
self * a + b
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_negate() {
|
||||
let x = Torus::<u64>::from(0);
|
||||
assert_eq!(-x, Torus::from(0));
|
||||
|
||||
let x = Torus::<u64>::from(1);
|
||||
assert_eq!(-x, Torus::from(u64::MAX));
|
||||
|
||||
let x = Torus::<u64>::from(u64::MAX);
|
||||
assert_eq!(-x, Torus::from(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_encode_decode() {
|
||||
assert_eq!(
|
||||
Torus::<u64>::encode(7, PlaintextBits(4)).0,
|
||||
0x70000000_00000000
|
||||
);
|
||||
|
||||
let x = Torus::<u64>::from(0x70000000_00000000);
|
||||
|
||||
assert_eq!(x.decode(PlaintextBits(4)), 7);
|
||||
|
||||
let x = Torus::<u64>::from(0x7FFFFFFF_FFFFFFFF);
|
||||
|
||||
assert_eq!(x.decode(PlaintextBits(4)), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_decode_off_center() {
|
||||
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.6) as u64);
|
||||
let r = t.decode(PlaintextBits(1));
|
||||
assert_eq!(r, 1);
|
||||
|
||||
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.3) as u64);
|
||||
let r = t.decode(PlaintextBits(1));
|
||||
assert_eq!(r, 1);
|
||||
|
||||
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.8) as u64);
|
||||
let r = t.decode(PlaintextBits(1));
|
||||
assert_eq!(r, 0);
|
||||
|
||||
let t = Torus::<u64>::from(((u64::MAX as f64) * 0.2) as u64);
|
||||
let r = t.decode(PlaintextBits(1));
|
||||
assert_eq!(r, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_normalize() {
|
||||
let x = Torus::<u64>::from(0);
|
||||
assert_eq!(x.normalized_torus(), 0.0);
|
||||
|
||||
let x = Torus::<u64>::from(u64::MAX / 4);
|
||||
assert_eq!(x.normalized_torus(), 0.25);
|
||||
|
||||
let x = Torus::<u64>::from(u64::MAX / 2);
|
||||
assert_eq!(x.normalized_torus(), 0.5);
|
||||
|
||||
let x = Torus::<u64>::from(u64::MAX / 4 * 3);
|
||||
assert_eq!(x.normalized_torus(), 0.75);
|
||||
|
||||
let x = Torus::<u64>::from(u64::MAX / 8 * 7);
|
||||
assert_eq!(x.normalized_torus(), 0.875);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_compute_distance() {
|
||||
let a = Torus::<u64>::from(0);
|
||||
let b = Torus::<u64>::from(u64::MAX / 4);
|
||||
|
||||
assert_eq!(normalized_torus_distance(&a, &b), 0.25);
|
||||
assert_eq!(normalized_torus_distance(&b, &a), -0.25);
|
||||
|
||||
let a = Torus::<u64>::from(u64::MAX / 4);
|
||||
let b = Torus::<u64>::from(u64::MAX / 2);
|
||||
|
||||
assert_eq!(normalized_torus_distance(&a, &b), 0.25);
|
||||
assert_eq!(normalized_torus_distance(&b, &a), -0.25);
|
||||
|
||||
let a = Torus::<u64>::from(0);
|
||||
let b = Torus::<u64>::from(u64::MAX / 4 * 3);
|
||||
|
||||
assert_eq!(normalized_torus_distance(&a, &b), -0.25);
|
||||
assert_eq!(normalized_torus_distance(&b, &a), 0.25);
|
||||
|
||||
let a = Torus::<u64>::from(u64::MAX / 8);
|
||||
let b = Torus::<u64>::from(u64::MAX / 4 * 3);
|
||||
|
||||
assert_eq!(normalized_torus_distance(&a, &b), -0.375);
|
||||
assert_eq!(normalized_torus_distance(&b, &a), 0.375);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalized_relation() {
|
||||
// Tests the relation:
|
||||
// b.normalized_torus() = a.normalized_torus() + normalized_torus_distance(a, b)
|
||||
|
||||
for _ in 0..100 {
|
||||
let a = Torus::<u64>::from(thread_rng().next_u64());
|
||||
let b = Torus::<u64>::from(thread_rng().next_u64());
|
||||
|
||||
let a_norm = a.normalized_torus();
|
||||
let b_norm = b.normalized_torus();
|
||||
|
||||
let dist = normalized_torus_distance(&a, &b);
|
||||
|
||||
let b_norm_from_dist = a_norm + dist;
|
||||
|
||||
let b_norm_from_dist = if b_norm_from_dist < 0.0 {
|
||||
b_norm_from_dist + 1.0
|
||||
} else if b_norm_from_dist >= 1.0 {
|
||||
b_norm_from_dist - 1.0
|
||||
} else {
|
||||
b_norm_from_dist
|
||||
};
|
||||
|
||||
let diff = (b_norm - b_norm_from_dist).abs();
|
||||
|
||||
assert!(
|
||||
diff < 1e-12,
|
||||
"Normalized torus relation test failed: a = {:?}, b = {:?}, a_norm = {}, b_norm = {}, dist: {}, b_norm_from_dist = {}, diff = {}",
|
||||
a,
|
||||
b,
|
||||
a_norm,
|
||||
b_norm,
|
||||
dist,
|
||||
b_norm_from_dist,
|
||||
diff
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_modulus_switch() {
|
||||
let x = Torus::<u64>::from(0x12345678_9ABCDEF0);
|
||||
|
||||
let y = x.switch_modulus_smaller::<u32>();
|
||||
|
||||
assert_eq!(y.0, 0x12345678);
|
||||
}
|
||||
}
|
||||
473
sunscreen_tfhe/src/ops/bootstrapping/blind_rotation.rs
Normal file
473
sunscreen_tfhe/src/ops/bootstrapping/blind_rotation.rs
Normal file
@@ -0,0 +1,473 @@
|
||||
use num::Complex;
|
||||
|
||||
use crate::{
|
||||
dst::FromMutSlice,
|
||||
entities::{BlindRotationShiftFftRef, GgswCiphertext, GlweCiphertextRef, GlweSecretKeyRef},
|
||||
ops::{encryption::encrypt_ggsw_ciphertext_scalar, fft_ops::cmux},
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, PlaintextBits, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
/// Rotate the given ciphertext message polynomial by the given amount as if it
|
||||
/// had been multiplied by a monomial with either positive or negative degree.
|
||||
/// Since this is a negacyclic rotation, a rotation to the left negates the last
|
||||
/// `rotation` coefficients, while a rotation to the right negates the first
|
||||
/// `rotation` coefficients.
|
||||
///
|
||||
/// Mathematically, this is equivalent to multiplying the underlying polynomial
|
||||
/// message by X^{rotation} mod (X^N + 1), where N is the polynomial degree.
|
||||
/// There are some convenient relations to remember over any integer $k$ and
|
||||
/// $i$:
|
||||
///
|
||||
/// - Rotations are modulus 2N: $X^{k*2N} = 1$ and $X^{k*2N ± i} = X^{± i}$.
|
||||
/// - Rotations that are equivalent to a shift by -N are equivalent to negating
|
||||
/// the polynomial: $X^{k*2N + N} = -1$.
|
||||
/// - A negative rotation has an equivalent positive rotation: $X^{-i} = -X^{N - i}$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use sunscreen_tfhe::{
|
||||
/// high_level::{keygen, encryption},
|
||||
/// entities::{GlweCiphertext, Polynomial},
|
||||
/// ops::bootstrapping::rotate_glwe_monomial_negacyclic,
|
||||
/// params::{
|
||||
/// GlweDef,
|
||||
/// GlweSize,
|
||||
/// GlweDimension,
|
||||
/// PlaintextBits,
|
||||
/// PolynomialDegree,
|
||||
/// },
|
||||
/// rand::Stddev,
|
||||
/// };
|
||||
///
|
||||
/// // Define the GLWE parameters
|
||||
/// let params = GlweDef {
|
||||
/// dim: GlweDimension {
|
||||
/// size: GlweSize(1),
|
||||
/// polynomial_degree: PolynomialDegree(8),
|
||||
/// },
|
||||
/// std: Stddev(0.0000000444778278004718),
|
||||
/// };
|
||||
/// let plaintext_bits = PlaintextBits(4);
|
||||
///
|
||||
/// // Generate the GLWE secret key
|
||||
/// let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
///
|
||||
/// // Define and encrypt a message
|
||||
/// let msg = Polynomial::new(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
/// let ct = encryption::encrypt_glwe(&msg, &sk, ¶ms, plaintext_bits);
|
||||
///
|
||||
/// // Rotate the message polynomial by 1 to the right
|
||||
/// let mut rotated_ct = GlweCiphertext::new(¶ms);
|
||||
/// rotate_glwe_monomial_negacyclic(&mut rotated_ct, &ct, 1, ¶ms);
|
||||
///
|
||||
/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, ¶ms, plaintext_bits);
|
||||
///
|
||||
/// assert_eq!(decrypted_msg, Polynomial::new(&[8, 1, 2, 3, 4, 5, 6, 7]));
|
||||
///
|
||||
/// // Rotate the message polynomial by 1 to the left
|
||||
/// let mut rotated_ct = GlweCiphertext::new(¶ms);
|
||||
/// rotate_glwe_monomial_negacyclic(&mut rotated_ct, &ct, -1, ¶ms);
|
||||
///
|
||||
/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, ¶ms, plaintext_bits);
|
||||
///
|
||||
/// // Since this is a negacyclic rotation, the element moved to the end is
|
||||
/// // negated.
|
||||
/// assert_eq!(decrypted_msg, Polynomial::new(&[2, 3, 4, 5, 6, 7, 8, 15]));
|
||||
/// ```
|
||||
pub fn rotate_glwe_monomial_negacyclic<S>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
ct: &GlweCiphertextRef<S>,
|
||||
rotation: isize,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (output_a, output_b) = output.a_b_mut(params);
|
||||
let (ct_a, ct_b) = ct.a_b(params);
|
||||
|
||||
let output_all_coefficients = output_a.chain(std::iter::once(output_b));
|
||||
let ct_all_coefficients = ct_a.chain(std::iter::once(ct_b));
|
||||
|
||||
for (o, a) in output_all_coefficients.zip(ct_all_coefficients) {
|
||||
o.clone_from_ref(a);
|
||||
o.mul_by_monomial_negacyclic(rotation);
|
||||
}
|
||||
}
|
||||
|
||||
/// Rotate the given ciphertext message polynomial by the given amount as if it
|
||||
/// had been multiplied by a monomial. This is equivalent to shifting
|
||||
/// all the coefficients left by `rotation` and negating the last `rotation`
|
||||
/// coefficients. Mathematically, this is equivalent to multiplying by
|
||||
/// X^{-rotation} mod (X^N + 1), or equivalently -X^{N - rotation}.
|
||||
///
|
||||
/// See [`rotate_glwe_monomial_negacyclic`] for the case that handles both
|
||||
/// positive and negative rotations, plus an example.
|
||||
pub fn rotate_glwe_negative_monomial_negacyclic<S>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
ct: &GlweCiphertextRef<S>,
|
||||
rotation: usize,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
rotate_glwe_monomial_negacyclic(output, ct, -(rotation as isize), params)
|
||||
}
|
||||
|
||||
/// Rotate the given ciphertext message polynomial by the given amount as if it
|
||||
/// had been multiplied by a positive monomial. This is equivalent to shifting
|
||||
/// all the coefficients right by `rotation` and negating the first `rotation`
|
||||
/// coefficients. Mathematically, this is equivalent to multiplying by
|
||||
/// X^{rotation} mod (X^N + 1).
|
||||
///
|
||||
/// See [`rotate_glwe_monomial_negacyclic`] for the case that handles both
|
||||
/// positive and negative rotations, plus an example.
|
||||
pub fn rotate_glwe_positive_monomial_negacyclic<S>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
ct: &GlweCiphertextRef<S>,
|
||||
rotation: usize,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
rotate_glwe_monomial_negacyclic(output, ct, rotation as isize, params)
|
||||
}
|
||||
|
||||
/// Rotate the given ciphertext message polynomial by negative encrypted shift.
|
||||
/// In practice this is not often used on its own; bootstrapping performs a
|
||||
/// different procedure.
|
||||
///
|
||||
/// See
|
||||
/// - [`generate_blind_rotation_shift`] for a way to encrypt a rotation amount.
|
||||
/// - [`rotate_glwe_monomial_negacyclic`] for how negacyclic rotation works
|
||||
/// when the rotation amount is public.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use sunscreen_tfhe::{
|
||||
/// high_level::{keygen, encryption},
|
||||
/// entities::{GlweCiphertext, Polynomial},
|
||||
/// ops::bootstrapping::{blind_rotation, generate_blind_rotation_shift},
|
||||
/// params::{
|
||||
/// GlweDef,
|
||||
/// GlweSize,
|
||||
/// GlweDimension,
|
||||
/// PlaintextBits,
|
||||
/// PolynomialDegree,
|
||||
/// RadixDecomposition,
|
||||
/// RadixCount,
|
||||
/// RadixLog,
|
||||
/// },
|
||||
/// rand::Stddev,
|
||||
/// };
|
||||
///
|
||||
/// // Define the GLWE parameters
|
||||
/// let params = GlweDef {
|
||||
/// dim: GlweDimension {
|
||||
/// size: GlweSize(1),
|
||||
/// polynomial_degree: PolynomialDegree(8),
|
||||
/// },
|
||||
/// std: Stddev(0.0000000444778278004718),
|
||||
/// };
|
||||
/// let radix = RadixDecomposition {
|
||||
/// count: RadixCount(3),
|
||||
/// radix_log: RadixLog(4),
|
||||
/// };
|
||||
/// let plaintext_bits = PlaintextBits(4);
|
||||
///
|
||||
/// // Generate the GLWE secret key
|
||||
/// let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
///
|
||||
/// // Define and encrypt a message
|
||||
/// let msg = Polynomial::new(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
/// let ct = encryption::encrypt_glwe(&msg, &sk, ¶ms, plaintext_bits);
|
||||
///
|
||||
/// // Generate a blind rotation amount
|
||||
/// let mut blind_rotation_index = sunscreen_tfhe::entities::BlindRotationShiftFft::new(¶ms, &radix);
|
||||
/// generate_blind_rotation_shift(&mut blind_rotation_index, 1, &sk, ¶ms, &radix, plaintext_bits);
|
||||
///
|
||||
/// // Rotate the message polynomial by the blind rotation amount
|
||||
/// let mut rotated_ct = GlweCiphertext::new(¶ms);
|
||||
/// blind_rotation(&mut rotated_ct, &blind_rotation_index, &ct, ¶ms, &radix);
|
||||
///
|
||||
/// let decrypted_msg = sk.decrypt_decode_glwe(&rotated_ct, ¶ms, plaintext_bits);
|
||||
///
|
||||
/// assert_eq!(decrypted_msg, Polynomial::new(&[2, 3, 4, 5, 6, 7, 8, 15]));
|
||||
/// ```
|
||||
pub fn blind_rotation<S>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
blind_rotation_index: &BlindRotationShiftFftRef<Complex<f64>>,
|
||||
ct: &GlweCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
// Initialize with the unrotated message m
|
||||
output.clone_from_ref(ct);
|
||||
allocate_scratch_ref!(rotated_ct, GlweCiphertextRef<S>, (params.dim));
|
||||
|
||||
for (i, index_select) in blind_rotation_index.rows(params, radix).enumerate() {
|
||||
let rotation = 1 << i;
|
||||
|
||||
rotate_glwe_negative_monomial_negacyclic(rotated_ct, output, rotation, params);
|
||||
|
||||
let tmp = output.to_owned();
|
||||
cmux(output, &tmp, rotated_ct, index_select, params, radix);
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypt an amount to rotate the message polynomial by.
|
||||
///
|
||||
/// This function is mostly provided as a convenience. Bootstrapping will rotate
|
||||
/// a message without encrypting using a cmux tree, so this function is not
|
||||
/// strictly necessary.
|
||||
pub fn generate_blind_rotation_shift<S>(
|
||||
bootstrap_key: &mut BlindRotationShiftFftRef<Complex<f64>>,
|
||||
rotation: usize,
|
||||
sk: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let degree = params.dim.polynomial_degree.0;
|
||||
assert!(rotation < degree);
|
||||
|
||||
for (i, ggsw_fft) in bootstrap_key.rows_mut(params, radix).enumerate() {
|
||||
let bit = ((rotation >> i) & 1) as u64;
|
||||
let mut ct = GgswCiphertext::new(params, radix);
|
||||
|
||||
encrypt_ggsw_ciphertext_scalar(
|
||||
&mut ct,
|
||||
S::from_u64(bit),
|
||||
sk,
|
||||
params,
|
||||
radix,
|
||||
plaintext_bits,
|
||||
);
|
||||
|
||||
ct.fft(ggsw_fft, params, radix);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use blind_rotation::generate_blind_rotation_shift;
|
||||
|
||||
use crate::{
|
||||
entities::{
|
||||
BlindRotationShiftFft, GgswCiphertext, GlweCiphertext, GlweSecretKey, Polynomial,
|
||||
},
|
||||
high_level::{TEST_GLWE_DEF_1, TEST_RADIX},
|
||||
ops::{
|
||||
bootstrapping::{blind_rotation, rotate_glwe_monomial_negacyclic},
|
||||
encryption::decrypt_ggsw_ciphertext,
|
||||
},
|
||||
polynomial::polynomial_external_mad,
|
||||
GlweDef, GlweDimension, GlweSize, PlaintextBits, PolynomialDegree, Torus,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_rotate() {
|
||||
let params = GlweDef {
|
||||
dim: GlweDimension {
|
||||
polynomial_degree: PolynomialDegree(8),
|
||||
size: GlweSize(2),
|
||||
},
|
||||
..TEST_GLWE_DEF_1
|
||||
};
|
||||
let plaintext_bits = PlaintextBits(4);
|
||||
|
||||
let modulus = 1 << plaintext_bits.0;
|
||||
let degree = params.dim.polynomial_degree.0;
|
||||
|
||||
let sk = GlweSecretKey::<u64>::generate_binary(¶ms);
|
||||
|
||||
let msg_coeffs = (0..degree)
|
||||
.map(|i| (i % modulus) as u64)
|
||||
.collect::<Vec<_>>();
|
||||
let msg = Polynomial::new(&msg_coeffs);
|
||||
|
||||
let ct = sk.encode_encrypt_glwe(&msg, ¶ms, plaintext_bits);
|
||||
|
||||
for rotation in (-2i64 * (degree as i64))..=(2i64 * (degree as i64)) {
|
||||
println!("Rotation: {}", rotation);
|
||||
let mut rotation_polynomial = vec![Torus::from(0u64); degree];
|
||||
|
||||
let direction = if rotation < 0 { -1 } else { 1 };
|
||||
let original_rotation = rotation;
|
||||
let rotation = rotation.unsigned_abs() as usize;
|
||||
|
||||
#[allow(clippy::collapsible_else_if)]
|
||||
if direction == 1 {
|
||||
// Positive rotation
|
||||
if rotation == 0 || rotation == 2 * degree {
|
||||
rotation_polynomial[0] = Torus::from(1);
|
||||
} else if rotation < degree {
|
||||
rotation_polynomial[rotation] = Torus::from(1);
|
||||
} else if rotation == degree {
|
||||
rotation_polynomial[0] = -Torus::from(1);
|
||||
} else {
|
||||
rotation_polynomial[rotation % degree] = -Torus::from(1);
|
||||
}
|
||||
} else {
|
||||
// Negative rotation
|
||||
if rotation == 0 || rotation == 2 * degree {
|
||||
rotation_polynomial[0] = Torus::from(1);
|
||||
} else if rotation < degree {
|
||||
rotation_polynomial[(degree - rotation) % degree] = -Torus::from(1);
|
||||
} else if rotation == degree {
|
||||
rotation_polynomial[0] = -Torus::from(1);
|
||||
} else {
|
||||
rotation_polynomial[(2 * degree - rotation) % degree] = Torus::from(1);
|
||||
}
|
||||
}
|
||||
|
||||
let rotation_polynomial = Polynomial::new(&rotation_polynomial);
|
||||
|
||||
let mut expected = Polynomial::<Torus<u64>>::zero(degree);
|
||||
polynomial_external_mad(&mut expected, &rotation_polynomial, &msg);
|
||||
|
||||
// Polynomial multiply doesn't reduce modulo the modulus, so we need to do it manually.
|
||||
let expected = expected.map(|x| x.inner() % (modulus as u64));
|
||||
|
||||
// Perform encrypted rotation
|
||||
let mut output_ct = GlweCiphertext::new(¶ms);
|
||||
|
||||
rotate_glwe_monomial_negacyclic(
|
||||
&mut output_ct,
|
||||
&ct,
|
||||
original_rotation as isize,
|
||||
¶ms,
|
||||
);
|
||||
|
||||
let output_msg = sk.decrypt_decode_glwe(&output_ct, ¶ms, plaintext_bits);
|
||||
|
||||
assert_eq!(output_msg, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rotation_shift_encrypted_properly() {
|
||||
let params = GlweDef {
|
||||
dim: GlweDimension {
|
||||
polynomial_degree: PolynomialDegree(8),
|
||||
size: GlweSize(2),
|
||||
},
|
||||
..TEST_GLWE_DEF_1
|
||||
};
|
||||
let radix = TEST_RADIX;
|
||||
let degree = params.dim.polynomial_degree.0;
|
||||
|
||||
let sk = GlweSecretKey::<u64>::generate_binary(¶ms);
|
||||
|
||||
for rotation in 0..(degree - 1) {
|
||||
let mut ggsw_index = BlindRotationShiftFft::new(¶ms, &radix);
|
||||
generate_blind_rotation_shift(
|
||||
&mut ggsw_index,
|
||||
rotation,
|
||||
&sk,
|
||||
¶ms,
|
||||
&radix,
|
||||
PlaintextBits(4),
|
||||
);
|
||||
|
||||
let mut encrypted_rotation = 0u64;
|
||||
for (i, bit_fft) in ggsw_index.rows(¶ms, &radix).enumerate() {
|
||||
let mut bit = GgswCiphertext::<u64>::new(¶ms, &radix);
|
||||
bit_fft.ifft(&mut bit, ¶ms, &radix);
|
||||
|
||||
let mut pt = Polynomial::zero(degree);
|
||||
decrypt_ggsw_ciphertext(&mut pt, &bit, &sk, ¶ms, &radix);
|
||||
|
||||
encrypted_rotation |= pt.coeffs()[0].inner() << i;
|
||||
}
|
||||
|
||||
assert_eq!(encrypted_rotation, rotation as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_blind_rotate() {
|
||||
let params = GlweDef {
|
||||
dim: GlweDimension {
|
||||
polynomial_degree: PolynomialDegree(8),
|
||||
size: GlweSize(2),
|
||||
},
|
||||
..TEST_GLWE_DEF_1
|
||||
};
|
||||
let radix = TEST_RADIX;
|
||||
let plaintext_bits = PlaintextBits(4);
|
||||
|
||||
let modulus = 1 << plaintext_bits.0;
|
||||
let degree = params.dim.polynomial_degree.0;
|
||||
let num_bits = (degree as u64).ilog2() as usize;
|
||||
|
||||
let sk = GlweSecretKey::<u64>::generate_binary(¶ms);
|
||||
|
||||
let msg_coeffs = (0..degree)
|
||||
.map(|i| (i % modulus) as u64)
|
||||
.collect::<Vec<_>>();
|
||||
let msg = Polynomial::new(&msg_coeffs);
|
||||
|
||||
let ct = sk.encode_encrypt_glwe(&msg, ¶ms, plaintext_bits);
|
||||
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
for rotation in 0..=(degree - 1) {
|
||||
let mut expected = Polynomial::<Torus<u64>>::new(msg.map(|x| Torus::from(*x)).coeffs());
|
||||
|
||||
for i in 0..num_bits {
|
||||
let bit = ((rotation >> i) & 1) as u64;
|
||||
|
||||
// We don't perform this rotation
|
||||
if bit == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut rotation_polynomial = vec![Torus::from(0u64); degree];
|
||||
|
||||
rotation_polynomial[degree - (1 << i)] = -Torus::<u64>::from(1);
|
||||
|
||||
let rotation_polynomial = Polynomial::new(&rotation_polynomial);
|
||||
|
||||
let tmp = expected.map(|x| x.inner());
|
||||
expected = Polynomial::<Torus<u64>>::zero(degree);
|
||||
|
||||
polynomial_external_mad(&mut expected, &rotation_polynomial, &tmp);
|
||||
}
|
||||
|
||||
// Polynomial multiply doesn't reduce modulo the modulus, so we need to do it manually.
|
||||
let expected = expected.map(|x| x.inner() % (modulus as u64));
|
||||
|
||||
// Perform encrypted rotation
|
||||
let mut ggsw_index = BlindRotationShiftFft::new(¶ms, &radix);
|
||||
generate_blind_rotation_shift(
|
||||
&mut ggsw_index,
|
||||
rotation,
|
||||
&sk,
|
||||
¶ms,
|
||||
&radix,
|
||||
plaintext_bits,
|
||||
);
|
||||
let mut output_ct = GlweCiphertext::new(¶ms);
|
||||
blind_rotation(&mut output_ct, &ggsw_index, &ct, ¶ms, &radix);
|
||||
let output_msg = sk.decrypt_decode_glwe(&output_ct, ¶ms, plaintext_bits);
|
||||
|
||||
// Make sure the zero point is rotated the correct amount.
|
||||
assert_eq!(output_msg.coeffs()[(degree - rotation) % degree], 0);
|
||||
|
||||
// Make sure we have moved the element in the rotation position to the zero position.
|
||||
assert_eq!(output_msg.coeffs()[0], msg_coeffs[rotation]);
|
||||
|
||||
assert_eq!(
|
||||
&output_msg, &expected,
|
||||
"CT encrypted message: {:?}, expected message: {:?}",
|
||||
&output_msg, &expected
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
469
sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs
Normal file
469
sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs
Normal file
@@ -0,0 +1,469 @@
|
||||
use num::Complex;
|
||||
|
||||
use crate::{
|
||||
dst::FromMutSlice,
|
||||
entities::{
|
||||
BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, GgswCiphertextRef,
|
||||
LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef,
|
||||
},
|
||||
ops::{
|
||||
bootstrapping::programmable_bootstrap, homomorphisms::rotate,
|
||||
keyswitch::private_functional_keyswitch::private_functional_keyswitch,
|
||||
},
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, LweDef, PlaintextBits, PrivateFunctionalKeyswitchLweCount, RadixDecomposition, Torus,
|
||||
TorusOps,
|
||||
};
|
||||
|
||||
/// Bootstraps a LWE ciphertext to a GGSW ciphertext.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Transform [`LweCiphertextRef`] `input` encrypted under parameters `lwe_0` into
|
||||
/// the [`GgswCiphertextRef`] `output` encrypted under parameters `glwe_1` with
|
||||
/// radix decomposition `cbs_radix`. This resets the noise in `output` in the
|
||||
/// process.
|
||||
///
|
||||
/// [`GgswCiphertext`](crate::entities::GgswCiphertext)s can be used as select
|
||||
/// inputs for [`cmux`](crate::ops::fft_ops::cmux) operations.
|
||||
///
|
||||
/// # Remarks
|
||||
/// The following diagram illustrates how circuit bootstrapping works
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// We perform `cbs_radix.count` programmable bootstrapping (PBS) operations to
|
||||
/// decompose the original message m under radix `2^cbs_radix.radix_log`. These PBS
|
||||
/// operations use a bootstrapping key encrypting the level 0 LWE secret key under
|
||||
/// the level 2 GLWE secret key and internally perform their own radix decomposition
|
||||
/// parameterized by `pbs_radix`. After performing bootstrapping, we now have
|
||||
/// `cbs_radix.count` LWE ciphertexts encrypted under the level 2 GLWE secret key
|
||||
/// (reinterpreted as an LWE key).
|
||||
///
|
||||
/// Next, we take each of these `cbs_radix.count` level 2 LWE ciphertexts and
|
||||
/// perform `glwe_1.dim.size + 1` private functional keyswitching operations (`
|
||||
/// (glwe_1.dim.size + 1) * cbs_radix.count` in total). For the first `glwe_1.dim.
|
||||
/// size` rows of the [`GgswCiphertextRef`] output, this multiplies the radix
|
||||
/// decomposed message by the negative corresponding secret key. For the last
|
||||
/// row, we simply multiply our radix decomposed messages by 1.
|
||||
///
|
||||
/// Recall that [`private_functional_keyswitch`] (PFKS) transforms a list of LWE
|
||||
/// ciphertexts into a [`GlweCiphertext`](crate::entities::GlweCiphertext). In
|
||||
/// our case, this list contains a single
|
||||
/// [`LweCiphertext`](crate::entities::LweCiphertext) for each PFKS operation.
|
||||
/// Each row of the output [`GgswCiphertext`](crate::entities::GgswCiphertext)
|
||||
/// corresponds to a different PFKS key, encapsulated in `cbsksk`.
|
||||
///
|
||||
/// These PFKS operations switch from a key under parameters `glwe_2` (interpreted
|
||||
/// as LWE) to `glwe_1` with [`RadixDecomposition`] `pfks_radix`.
|
||||
///
|
||||
/// # Panics
|
||||
/// * If `bsk` is not valid for bootrapping from parameters `lwe_0` to `glwe_2`
|
||||
/// (reinterpreted as LWE) with radix decomposition `pbs_radix`.
|
||||
/// * If `cbsksk` is not a valid keyswitch key set for switching from `glwe_2`
|
||||
/// (reintrerpreted as LWE) to `glwe_1` with `glwe_1.dim.size` entries and radix
|
||||
/// decomposition `pfks_radix`.
|
||||
/// * If `output` is not the correct length for a GGSW ciphertext under `glwe_1`
|
||||
/// parameters with `cbs_radix` decomposition.
|
||||
/// * If `input` is not a valid LWE ciphertext under `lwe_0` parameters.
|
||||
/// * If `lwe_0`, `glwe_1`, `glwe_2`, `cbs_radix`, `pfks_radix`, `pbs_radix` are
|
||||
/// invalid.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use sunscreen_tfhe::{
|
||||
/// high_level,
|
||||
/// high_level::{keygen, encryption, fft},
|
||||
/// entities::GgswCiphertext,
|
||||
/// ops::bootstrapping::circuit_bootstrap,
|
||||
/// params::{
|
||||
/// GLWE_5_256_80,
|
||||
/// GLWE_1_1024_80,
|
||||
/// LWE_512_80,
|
||||
/// PlaintextBits,
|
||||
/// RadixDecomposition,
|
||||
/// RadixCount,
|
||||
/// RadixLog
|
||||
/// }
|
||||
/// };
|
||||
///
|
||||
/// let pbs_radix = RadixDecomposition {
|
||||
/// count: RadixCount(2),
|
||||
/// radix_log: RadixLog(16),
|
||||
/// };
|
||||
/// let cbs_radix = RadixDecomposition {
|
||||
/// count: RadixCount(2),
|
||||
/// radix_log: RadixLog(5),
|
||||
/// };
|
||||
/// let pfks_radix = RadixDecomposition {
|
||||
/// count: RadixCount(3),
|
||||
/// radix_log: RadixLog(11),
|
||||
/// };
|
||||
///
|
||||
/// let level_2_params = GLWE_5_256_80;
|
||||
/// let level_1_params = GLWE_1_1024_80;
|
||||
/// let level_0_params = LWE_512_80;
|
||||
///
|
||||
/// let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params);
|
||||
/// let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params);
|
||||
/// let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params);
|
||||
///
|
||||
/// let bsk = keygen::generate_bootstrapping_key(
|
||||
/// &sk_0,
|
||||
/// &sk_2,
|
||||
/// &level_0_params,
|
||||
/// &level_2_params,
|
||||
/// &pbs_radix,
|
||||
/// );
|
||||
/// let bsk =
|
||||
/// high_level::fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix);
|
||||
///
|
||||
/// let cbsksk = keygen::generate_cbs_ksk(
|
||||
/// sk_2.to_lwe_secret_key(),
|
||||
/// &sk_1,
|
||||
/// &level_2_params.as_lwe_def(),
|
||||
/// &level_1_params,
|
||||
/// &pfks_radix,
|
||||
/// );
|
||||
///
|
||||
/// let val = 1;
|
||||
/// let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1));
|
||||
///
|
||||
/// let mut ggsw = GgswCiphertext::new(&level_1_params, &cbs_radix);
|
||||
///
|
||||
/// // ggsw will contain `val`
|
||||
/// circuit_bootstrap(
|
||||
/// &mut ggsw,
|
||||
/// &ct,
|
||||
/// &bsk,
|
||||
/// &cbsksk,
|
||||
/// &level_0_params,
|
||||
/// &level_1_params,
|
||||
/// &level_2_params,
|
||||
/// &pbs_radix,
|
||||
/// &cbs_radix,
|
||||
/// &pfks_radix,
|
||||
/// );
|
||||
/// ```
|
||||
pub fn circuit_bootstrap<S: TorusOps>(
|
||||
output: &mut GgswCiphertextRef<S>,
|
||||
input: &LweCiphertextRef<S>,
|
||||
bsk: &BootstrapKeyFftRef<Complex<f64>>,
|
||||
cbsksk: &CircuitBootstrappingKeyswitchKeysRef<S>,
|
||||
lwe_0: &LweDef,
|
||||
glwe_1: &GlweDef,
|
||||
glwe_2: &GlweDef,
|
||||
pbs_radix: &RadixDecomposition,
|
||||
cbs_radix: &RadixDecomposition,
|
||||
pfks_radix: &RadixDecomposition,
|
||||
) {
|
||||
glwe_1.assert_valid();
|
||||
glwe_2.assert_valid();
|
||||
lwe_0.assert_valid();
|
||||
pbs_radix.assert_valid::<S>();
|
||||
cbs_radix.assert_valid::<S>();
|
||||
pfks_radix.assert_valid::<S>();
|
||||
cbsksk.assert_valid(&glwe_2.as_lwe_def(), glwe_1, pfks_radix);
|
||||
bsk.assert_valid(lwe_0, glwe_2, pbs_radix);
|
||||
output.assert_valid(glwe_1, cbs_radix);
|
||||
input.assert_valid(lwe_0);
|
||||
|
||||
// Step 1, for each l in cbs_radix.count, use bootstrapping to base decompose the
|
||||
// plaintext in input. We bootstrap from level 0 -> level 2.
|
||||
allocate_scratch_ref!(
|
||||
level_2_lwes,
|
||||
LweCiphertextListRef<S>,
|
||||
(glwe_2.as_lwe_def().dim, cbs_radix.count.0)
|
||||
);
|
||||
|
||||
level_0_to_level_2(
|
||||
level_2_lwes,
|
||||
input,
|
||||
bsk,
|
||||
lwe_0,
|
||||
glwe_2,
|
||||
pbs_radix,
|
||||
cbs_radix,
|
||||
);
|
||||
|
||||
level_2_to_level1(
|
||||
output,
|
||||
level_2_lwes,
|
||||
cbsksk,
|
||||
glwe_2,
|
||||
glwe_1,
|
||||
pfks_radix,
|
||||
cbs_radix,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[inline(always)]
|
||||
fn level_0_to_level_2<S: TorusOps>(
|
||||
lwes_2: &mut LweCiphertextListRef<S>,
|
||||
input: &LweCiphertextRef<S>,
|
||||
bsk: &BootstrapKeyFftRef<Complex<f64>>,
|
||||
lwe_0: &LweDef,
|
||||
glwe_2: &GlweDef,
|
||||
pbs_radix: &RadixDecomposition,
|
||||
cbs_radix: &RadixDecomposition,
|
||||
) {
|
||||
allocate_scratch_ref!(lut, UnivariateLookupTableRef<S>, (glwe_2.dim));
|
||||
allocate_scratch_ref!(lwe_rotated, LweCiphertextRef<S>, (lwe_0.dim));
|
||||
allocate_scratch_ref!(
|
||||
lwe_bootstrapped,
|
||||
LweCiphertextRef<S>,
|
||||
(glwe_2.as_lwe_def().dim)
|
||||
);
|
||||
|
||||
// Rotate our input by q/4, putting 0 centered on q/4 and 1 centered on
|
||||
// -q/4.
|
||||
rotate(
|
||||
lwe_rotated,
|
||||
input,
|
||||
Torus::encode(S::one(), PlaintextBits(2)),
|
||||
lwe_0,
|
||||
);
|
||||
|
||||
for (i, lwe_2) in lwes_2.ciphertexts_mut(&glwe_2.as_lwe_def()).enumerate() {
|
||||
let cur_level = i + 1;
|
||||
|
||||
// Treat value as a T_{b^l+1} with one extra place for rounding as the last
|
||||
// step.
|
||||
let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level + 1) as u32);
|
||||
|
||||
// Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1}
|
||||
// everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will
|
||||
// give 1. Since we've shifted our input lwe by q/4, a 1 plaintext
|
||||
// value will map to 1 and a 0 will map to -1.
|
||||
let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one();
|
||||
|
||||
lut.fill_with_constant(minus_one, glwe_2, plaintext_bits);
|
||||
|
||||
programmable_bootstrap(
|
||||
lwe_bootstrapped,
|
||||
lwe_rotated,
|
||||
lut,
|
||||
bsk,
|
||||
lwe_0,
|
||||
glwe_2,
|
||||
pbs_radix,
|
||||
);
|
||||
|
||||
// Now we rotate our message containing -1 or 1 by 1 (wrt plaintext_bits).
|
||||
// This will overflow -1 to 0 and cause 1 to wrap to 2.
|
||||
rotate(
|
||||
lwe_2,
|
||||
lwe_bootstrapped,
|
||||
Torus::encode(S::one(), plaintext_bits),
|
||||
&glwe_2.as_lwe_def(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Bootstraps a level 2 GLWE ciphertext to a level 1 GLWE ciphertext.
|
||||
pub fn level_2_to_level1<S: TorusOps>(
|
||||
result: &mut GgswCiphertextRef<S>,
|
||||
lwes_2: &LweCiphertextListRef<S>,
|
||||
cbsksk: &CircuitBootstrappingKeyswitchKeysRef<S>,
|
||||
glwe_2: &GlweDef,
|
||||
glwe_1: &GlweDef,
|
||||
pfks_radix: &RadixDecomposition,
|
||||
cbs_radix: &RadixDecomposition,
|
||||
) {
|
||||
for (glev, pfksk) in result.rows_mut(glwe_1, cbs_radix).zip(cbsksk.keys(
|
||||
&glwe_2.as_lwe_def(),
|
||||
glwe_1,
|
||||
pfks_radix,
|
||||
)) {
|
||||
for (decomp, glwe) in lwes_2
|
||||
.ciphertexts(&glwe_2.as_lwe_def())
|
||||
.zip(glev.glwe_ciphertexts_mut(glwe_1))
|
||||
{
|
||||
private_functional_keyswitch(
|
||||
glwe,
|
||||
&[decomp],
|
||||
pfksk,
|
||||
&glwe_2.as_lwe_def(),
|
||||
glwe_1,
|
||||
pfks_radix,
|
||||
&PrivateFunctionalKeyswitchLweCount(1),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
entities::{GgswCiphertext, LweCiphertextList},
|
||||
high_level::{self, encryption, fft, keygen, TEST_LWE_DEF_1},
|
||||
PlaintextBits, RadixCount, RadixDecomposition, RadixLog, GLWE_1_1024_80, GLWE_5_256_80,
|
||||
LWE_512_80,
|
||||
};
|
||||
|
||||
use super::{circuit_bootstrap, level_0_to_level_2};
|
||||
|
||||
#[test]
|
||||
fn can_level_0_to_level_2() {
|
||||
let pbs_radix = RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(16),
|
||||
};
|
||||
let cbs_radix = RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(5),
|
||||
};
|
||||
|
||||
let glwe_params = GLWE_5_256_80;
|
||||
|
||||
let mut level_2 =
|
||||
LweCiphertextList::<u64>::new(&glwe_params.as_lwe_def(), cbs_radix.count.0);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
let bsk = keygen::generate_bootstrapping_key(
|
||||
&sk,
|
||||
&glwe_sk,
|
||||
&TEST_LWE_DEF_1,
|
||||
&glwe_params,
|
||||
&pbs_radix,
|
||||
);
|
||||
let bsk = fft::fft_bootstrap_key(&bsk, &TEST_LWE_DEF_1, &glwe_params, &pbs_radix);
|
||||
|
||||
let lwe = sk.encrypt(0, &TEST_LWE_DEF_1, PlaintextBits(1)).0;
|
||||
|
||||
level_0_to_level_2(
|
||||
&mut level_2,
|
||||
&lwe,
|
||||
&bsk,
|
||||
&TEST_LWE_DEF_1,
|
||||
&glwe_params,
|
||||
&pbs_radix,
|
||||
&cbs_radix,
|
||||
);
|
||||
|
||||
for (i, lwe_2) in level_2.ciphertexts(&glwe_params.as_lwe_def()).enumerate() {
|
||||
let cur_level = i + 1;
|
||||
|
||||
let bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level) as u32);
|
||||
|
||||
let actual =
|
||||
glwe_sk
|
||||
.to_lwe_secret_key()
|
||||
.decrypt(lwe_2, &glwe_params.as_lwe_def(), bits);
|
||||
|
||||
assert_eq!(actual, 0);
|
||||
}
|
||||
|
||||
let lwe = sk.encrypt(1, &TEST_LWE_DEF_1, PlaintextBits(1)).0;
|
||||
|
||||
level_0_to_level_2(
|
||||
&mut level_2,
|
||||
&lwe,
|
||||
&bsk,
|
||||
&TEST_LWE_DEF_1,
|
||||
&glwe_params,
|
||||
&pbs_radix,
|
||||
&cbs_radix,
|
||||
);
|
||||
|
||||
for (i, lwe_2) in level_2.ciphertexts(&glwe_params.as_lwe_def()).enumerate() {
|
||||
let cur_level = i + 1;
|
||||
|
||||
let bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level) as u32);
|
||||
|
||||
let actual =
|
||||
glwe_sk
|
||||
.to_lwe_secret_key()
|
||||
.decrypt(lwe_2, &glwe_params.as_lwe_def(), bits);
|
||||
|
||||
assert_eq!(actual, 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_circuit_bootstrap() {
|
||||
let pbs_radix = RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(16),
|
||||
};
|
||||
let cbs_radix = RadixDecomposition {
|
||||
count: RadixCount(2),
|
||||
radix_log: RadixLog(5),
|
||||
};
|
||||
let pfks_radix = RadixDecomposition {
|
||||
count: RadixCount(3),
|
||||
radix_log: RadixLog(11),
|
||||
};
|
||||
|
||||
let level_2_params = GLWE_5_256_80;
|
||||
let level_1_params = GLWE_1_1024_80;
|
||||
let level_0_params = LWE_512_80;
|
||||
|
||||
let sk_0 = keygen::generate_binary_lwe_sk(&level_0_params);
|
||||
let sk_1 = keygen::generate_binary_glwe_sk(&level_1_params);
|
||||
let sk_2 = keygen::generate_binary_glwe_sk(&level_2_params);
|
||||
|
||||
let bsk = keygen::generate_bootstrapping_key(
|
||||
&sk_0,
|
||||
&sk_2,
|
||||
&level_0_params,
|
||||
&level_2_params,
|
||||
&pbs_radix,
|
||||
);
|
||||
let bsk =
|
||||
high_level::fft::fft_bootstrap_key(&bsk, &level_0_params, &level_2_params, &pbs_radix);
|
||||
|
||||
let cbsksk = keygen::generate_cbs_ksk(
|
||||
sk_2.to_lwe_secret_key(),
|
||||
&sk_1,
|
||||
&level_2_params.as_lwe_def(),
|
||||
&level_1_params,
|
||||
&pfks_radix,
|
||||
);
|
||||
|
||||
for _ in 0..1 {
|
||||
let val = thread_rng().next_u64() % 2;
|
||||
|
||||
let ct = encryption::encrypt_lwe_secret(val, &sk_0, &level_0_params, PlaintextBits(1));
|
||||
|
||||
let mut actual = GgswCiphertext::new(&level_1_params, &cbs_radix);
|
||||
|
||||
circuit_bootstrap(
|
||||
&mut actual,
|
||||
&ct,
|
||||
&bsk,
|
||||
&cbsksk,
|
||||
&level_0_params,
|
||||
&level_1_params,
|
||||
&level_2_params,
|
||||
&pbs_radix,
|
||||
&cbs_radix,
|
||||
&pfks_radix,
|
||||
);
|
||||
|
||||
let expected =
|
||||
encryption::encrypt_ggsw(val, &sk_1, &level_1_params, &cbs_radix, PlaintextBits(1));
|
||||
|
||||
for (a, e) in actual
|
||||
.rows(&level_1_params, &cbs_radix)
|
||||
.zip(expected.rows(&level_1_params, &cbs_radix))
|
||||
{
|
||||
for (i, (a, e)) in a
|
||||
.glwe_ciphertexts(&level_1_params)
|
||||
.zip(e.glwe_ciphertexts(&level_1_params))
|
||||
.enumerate()
|
||||
{
|
||||
let plaintext_bits = (i + 1) * cbs_radix.radix_log.0;
|
||||
let plaintext_bits = PlaintextBits(plaintext_bits as u32);
|
||||
|
||||
let a = encryption::decrypt_glwe(a, &sk_1, &level_1_params, plaintext_bits);
|
||||
let e = encryption::decrypt_glwe(e, &sk_1, &level_1_params, plaintext_bits);
|
||||
|
||||
assert_eq!(a, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
8
sunscreen_tfhe/src/ops/bootstrapping/mod.rs
Normal file
8
sunscreen_tfhe/src/ops/bootstrapping/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
mod blind_rotation;
|
||||
pub use blind_rotation::*;
|
||||
|
||||
mod circuit_bootstrapping;
|
||||
pub use circuit_bootstrapping::*;
|
||||
|
||||
mod programmable_bootstrapping;
|
||||
pub use programmable_bootstrapping::*;
|
||||
@@ -0,0 +1,824 @@
|
||||
use num::Complex;
|
||||
|
||||
use crate::{
|
||||
dst::FromMutSlice,
|
||||
entities::{
|
||||
BivariateLookupTableRef, BootstrapKeyFftRef, BootstrapKeyRef, GlweCiphertextRef,
|
||||
GlweSecretKeyRef, LweCiphertextRef, LweSecretKeyRef, Polynomial, PolynomialRef,
|
||||
UnivariateLookupTableRef,
|
||||
},
|
||||
ops::{
|
||||
bootstrapping::rotate_glwe_positive_monomial_negacyclic,
|
||||
ciphertext::{add_lwe_inplace, modulus_switch, sample_extract, scalar_mul_ciphertext_mad},
|
||||
encryption::encrypt_ggsw_ciphertext_scalar,
|
||||
fft_ops::cmux,
|
||||
},
|
||||
scratch::allocate_scratch_ref,
|
||||
CarryBits, GlweDef, LweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::rotate_glwe_negative_monomial_negacyclic;
|
||||
|
||||
/// Generate a bootstrap key from a LWE secret key to a GLWE secret key.
|
||||
///
|
||||
/// Mathematically, this key is a list of GGSW ciphertexts, one for each bit of
|
||||
/// the secret key being encrypted.
|
||||
///
|
||||
/// See
|
||||
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap)
|
||||
/// for an example of how to use this key.
|
||||
pub fn generate_bootstrap_key<S>(
|
||||
bootstrap_key: &mut BootstrapKeyRef<S>,
|
||||
sk_to_encrypt: &LweSecretKeyRef<S>,
|
||||
sk: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
sk.assert_valid(params);
|
||||
radix.assert_valid::<S>();
|
||||
|
||||
for (s_i, ggsw) in sk_to_encrypt
|
||||
.s()
|
||||
.iter()
|
||||
.zip(bootstrap_key.rows_mut(params, radix))
|
||||
{
|
||||
encrypt_ggsw_ciphertext_scalar(ggsw, *s_i, sk, params, radix, PlaintextBits(1));
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a negacyclic LUT for bootstrapping. Another name for this structure
|
||||
/// is a test polynomial.
|
||||
///
|
||||
/// The map function passed in must have the following negacyclic property,
|
||||
/// where N is the size of the polynomial:
|
||||
///
|
||||
/// ```text
|
||||
/// map(N + i) = -map(i)
|
||||
/// ```
|
||||
#[allow(dead_code)]
|
||||
fn generate_negacyclic_lut<S, F>(
|
||||
output: &mut Polynomial<Torus<S>>,
|
||||
map: F,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
F: Fn(u64) -> u64,
|
||||
{
|
||||
let p = (1 << plaintext_bits.0) as u64;
|
||||
let n = params.dim.polynomial_degree.0 as u64;
|
||||
|
||||
let stride = 2 * n / p;
|
||||
|
||||
let delta = S::BITS - plaintext_bits.0;
|
||||
|
||||
let c = output.coeffs_mut();
|
||||
|
||||
// Written out this way because when we get to programmable boot strapping,
|
||||
// this will involve replacing p_i with f(p_i)
|
||||
for (j, p_i_unmapped) in (0..=p / 2).enumerate() {
|
||||
let j = j as u64;
|
||||
|
||||
let p_i = map(p_i_unmapped);
|
||||
assert!(p_i < p, "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i);
|
||||
|
||||
let p_i = p_i << delta;
|
||||
|
||||
if j == 0 {
|
||||
for k in 0..(stride / 2) {
|
||||
c[k as usize] = Torus::from(S::from_u64(p_i));
|
||||
}
|
||||
} else if j == p / 2 {
|
||||
for k in (n - (stride / 2))..n {
|
||||
c[k as usize] = Torus::from(S::from_u64(p_i));
|
||||
}
|
||||
} else {
|
||||
for k in (stride / 2 + (j - 1) * stride)..(stride / 2 + j * stride) {
|
||||
c[k as usize] = Torus::from(S::from_u64(p_i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a lookup table (LUT) to be used with bootstrapping. This LUT is
|
||||
/// not negacyclic, and hence must be used with LWE inputs that have at least
|
||||
/// one padding bit.
|
||||
///
|
||||
/// The input `map` is used for generating programmable bootstrapping LUTs. This
|
||||
/// function takes an element in the plaintext space and must produce another
|
||||
/// element in the plaintext space.
|
||||
pub(crate) fn generate_lut<S, F>(
|
||||
output: &mut PolynomialRef<Torus<S>>,
|
||||
map: F,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
F: Fn(u64) -> u64,
|
||||
{
|
||||
let p = (1 << plaintext_bits.0) as usize;
|
||||
let n = params.dim.polynomial_degree.0;
|
||||
|
||||
assert!(n >= p);
|
||||
|
||||
let stride = n / p;
|
||||
|
||||
let delta = S::BITS - plaintext_bits.0;
|
||||
|
||||
let c = output.coeffs_mut();
|
||||
|
||||
for (j, p_i_unmapped) in (0..=p - 1).enumerate() {
|
||||
let p_i = map(p_i_unmapped as u64);
|
||||
|
||||
assert!(p_i < (p as u64), "The map function must produce a value less than p. Map produced the relation ({} -> {})", p_i_unmapped, p_i);
|
||||
|
||||
let p_i = p_i << delta;
|
||||
|
||||
// Insert a stride amount into the LUT
|
||||
c[j * stride..(j + 1) * stride].iter_mut().for_each(|c| {
|
||||
*c = Torus::from(S::from_u64(p_i));
|
||||
});
|
||||
}
|
||||
|
||||
// Negate the first half of p_0 in the LUT in preparation for it to be
|
||||
// rotated.
|
||||
c[0..stride / 2].iter_mut().for_each(|c| {
|
||||
*c = num::traits::WrappingNeg::wrapping_neg(c);
|
||||
});
|
||||
|
||||
c.rotate_left(stride / 2);
|
||||
}
|
||||
|
||||
/// Programmable bootstrapping with a univariate function.
|
||||
///
|
||||
/// The LUT this is a table that maps two inputs into a single output. For
|
||||
/// example, say we want to encode the negation function `f(x) = (x + 1) % 2`
|
||||
/// into a lookup table. We would create a
|
||||
/// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable) that
|
||||
/// implements this function and then execute it on the input ciphertexts.
|
||||
///
|
||||
/// Important note: This function does not perform key switching. The output
|
||||
/// ciphertext will be encrypted under the LWE key extracted from the GLWE
|
||||
/// secret key used for the bootstrapping key. To perform a keyswitch, use
|
||||
/// [`keyswitch_lwe_to_lwe`](crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe)
|
||||
/// after the bootstrapping operation.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use sunscreen_tfhe::{
|
||||
/// high_level::{keygen, encryption, fft},
|
||||
/// entities::{UnivariateLookupTable, LweCiphertext},
|
||||
/// ops::bootstrapping::programmable_bootstrap,
|
||||
/// params::{
|
||||
/// GLWE_1_1024_80,
|
||||
/// LWE_512_80,
|
||||
/// CarryBits,
|
||||
/// PlaintextBits,
|
||||
/// RadixDecomposition,
|
||||
/// RadixCount,
|
||||
/// RadixLog
|
||||
/// },
|
||||
/// };
|
||||
///
|
||||
/// // Parameters defining the scheme we are using
|
||||
/// let lwe_params = LWE_512_80;
|
||||
/// let glwe_params = GLWE_1_1024_80;
|
||||
/// let radix = RadixDecomposition {
|
||||
/// count: RadixCount(3),
|
||||
/// radix_log: RadixLog(4),
|
||||
/// };
|
||||
///
|
||||
/// // We will be showing a binary univariate function. Note that for
|
||||
/// // programmable bootstrapping to work in general, you will need to include at
|
||||
/// // least one padding bit to the input.
|
||||
/// let plaintext_bits = PlaintextBits(1);
|
||||
/// let carry_bits = CarryBits(1);
|
||||
/// let plaintext_bits_carry = PlaintextBits(2);
|
||||
///
|
||||
/// // The univariate function we want to evaluate, encoded as a lookup table.
|
||||
/// let negate = |x| (x + 1) % (1 << plaintext_bits.0);
|
||||
/// let lut = UnivariateLookupTable::trivial_from_fn(
|
||||
/// &negate,
|
||||
/// &glwe_params,
|
||||
/// plaintext_bits,
|
||||
/// );
|
||||
///
|
||||
/// // Generate the secret keys and the bootstrapping key
|
||||
/// let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
|
||||
/// let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
///
|
||||
/// let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, &lwe_params, &glwe_params, &radix);
|
||||
/// let bsk =
|
||||
/// fft::fft_bootstrap_key(&bsk, &lwe_params, &glwe_params, &radix);
|
||||
///
|
||||
/// // Specify the inputs
|
||||
/// let input_plain = 0;
|
||||
///
|
||||
/// // Encrypt the inputs. Note we are adding carry bits to the inputs.
|
||||
/// let input = encryption::encrypt_lwe_secret(
|
||||
/// input_plain,
|
||||
/// &lwe_sk,
|
||||
/// &lwe_params,
|
||||
/// plaintext_bits_carry
|
||||
/// );
|
||||
///
|
||||
/// // Perform the programmable bootstrapping
|
||||
/// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def());
|
||||
/// programmable_bootstrap(
|
||||
/// &mut result,
|
||||
/// &input,
|
||||
/// &lut,
|
||||
/// &bsk,
|
||||
/// &lwe_params,
|
||||
/// &glwe_params,
|
||||
/// &radix,
|
||||
/// );
|
||||
///
|
||||
/// // Check the result matches our plaintext function.
|
||||
/// let decrypted = encryption::decrypt_lwe(
|
||||
/// &result,
|
||||
/// &glwe_sk.to_lwe_secret_key(),
|
||||
/// &glwe_params.as_lwe_def(),
|
||||
/// plaintext_bits,
|
||||
/// );
|
||||
///
|
||||
/// let expected = negate(input_plain);
|
||||
/// assert_eq!(expected, decrypted);
|
||||
/// ```
|
||||
///
|
||||
/// # See also
|
||||
///
|
||||
/// For the bivariate version of programmable bootstrapping, see
|
||||
/// [`programmable_bootstrap_bivariate`](programmable_bootstrap_bivariate) and
|
||||
/// its associated LUT
|
||||
/// [`BivariateLookupTable`](crate::entities::BivariateLookupTable).
|
||||
pub fn programmable_bootstrap<S>(
|
||||
output: &mut LweCiphertextRef<S>,
|
||||
input: &LweCiphertextRef<S>,
|
||||
lut: &UnivariateLookupTableRef<S>,
|
||||
bootstrap_key: &BootstrapKeyFftRef<Complex<f64>>,
|
||||
lwe_params: &LweDef,
|
||||
glwe_params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
// Steps:
|
||||
// 1. Modulus switch the ciphertext to 2N.
|
||||
// 2. Use a cmux tree to blind rotate V using the elements of the bootstrap key (the input LWE secret key bits).
|
||||
// 3. Sample extract.
|
||||
// 4. (Optional, done outside of this method) Key switch to the output LWE
|
||||
// secret key (should be the one extracted from the GLWE key).
|
||||
|
||||
let degree = glwe_params.dim.polynomial_degree.0;
|
||||
let two_n = degree.ilog2() + 1;
|
||||
|
||||
// 1. Modulus switch the ciphertext to 2N.
|
||||
let mut ct = input.to_owned();
|
||||
modulus_switch(&mut ct, S::BITS, two_n, lwe_params);
|
||||
|
||||
let (ct_a, ct_b) = ct.a_b(lwe_params);
|
||||
|
||||
// 2. Use a cmux tree to blind rotate V using the elements of the bootstrap
|
||||
// key (the input LWE secret key bits).
|
||||
|
||||
// Perform V_0 ^ X^{-b}
|
||||
allocate_scratch_ref!(cmux_output, GlweCiphertextRef<S>, (glwe_params.dim));
|
||||
cmux_output.clear();
|
||||
|
||||
rotate_glwe_negative_monomial_negacyclic(
|
||||
cmux_output,
|
||||
lut.glwe(),
|
||||
ct_b.inner().to_u64() as usize,
|
||||
glwe_params,
|
||||
);
|
||||
|
||||
allocate_scratch_ref!(rotated_ct, GlweCiphertextRef<S>, (glwe_params.dim));
|
||||
|
||||
// Perform the cmux tree from the bootstrap key with the relation
|
||||
// V_n = V_{n-1} ^ X^{a_{n-1} s_{n-1}}
|
||||
for (a_i, index_select) in ct_a.iter().zip(bootstrap_key.rows(glwe_params, radix)) {
|
||||
let tmp = cmux_output.to_owned();
|
||||
|
||||
// This operation performs a copy so the rotated_ct doesn't need to be
|
||||
// cleared.
|
||||
rotate_glwe_positive_monomial_negacyclic(
|
||||
rotated_ct,
|
||||
cmux_output,
|
||||
a_i.inner().to_u64() as usize,
|
||||
glwe_params,
|
||||
);
|
||||
|
||||
cmux(
|
||||
cmux_output,
|
||||
&tmp,
|
||||
rotated_ct,
|
||||
index_select,
|
||||
glwe_params,
|
||||
radix,
|
||||
);
|
||||
}
|
||||
|
||||
// 3. Sample extract.
|
||||
sample_extract(output, cmux_output, 0, glwe_params);
|
||||
}
|
||||
|
||||
/// Evaluate a bivariate function on a packed input.
|
||||
fn bivariate_function<F>(map: F, input: u64, plaintext_bits: PlaintextBits) -> u64
|
||||
where
|
||||
F: Fn(u64, u64) -> u64,
|
||||
{
|
||||
let modulus = 1 << plaintext_bits.0;
|
||||
let lhs = (input / modulus) % modulus;
|
||||
let rhs = input % modulus;
|
||||
|
||||
let result = map(lhs, rhs);
|
||||
|
||||
assert!(
|
||||
result < modulus,
|
||||
"The result of the bivariate function must be less than the plaintext modulus"
|
||||
);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Generate a lookup table that takes two inputs and produces a single output.
|
||||
pub(crate) fn generate_bivariate_lut<S, F>(
|
||||
output: &mut PolynomialRef<Torus<S>>,
|
||||
map: F,
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
carry_bits: CarryBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
F: Fn(u64, u64) -> u64,
|
||||
{
|
||||
assert!(
|
||||
plaintext_bits.0 <= carry_bits.0,
|
||||
"The number of plaintext bits must be less than or equal to the number of carry bits"
|
||||
);
|
||||
|
||||
let wrapped_func = |input: u64| bivariate_function(&map, input, plaintext_bits);
|
||||
|
||||
generate_lut(
|
||||
output,
|
||||
wrapped_func,
|
||||
params,
|
||||
PlaintextBits(plaintext_bits.0 + carry_bits.0),
|
||||
);
|
||||
}
|
||||
|
||||
/// Programmable bootstrapping with a bivariate function.
|
||||
///
|
||||
/// The LUT this is a table that maps two inputs into a single output.
|
||||
/// For example, say we want to encode the xor function `f(x, y) = (x + y) % 2`
|
||||
/// into a lookup table. We would create a
|
||||
/// [`BivariateLookupTable`](crate::entities::BivariateLookupTable) that
|
||||
/// implements this function and then execute it on the input ciphertexts.
|
||||
///
|
||||
/// Important note: This function does not perform key switching. The output
|
||||
/// ciphertext will be encrypted under the LWE key extracted from the GLWE
|
||||
/// secret key used for the bootstrapping key. To perform a keyswitch, use
|
||||
/// [`keyswitch_lwe_to_lwe`](crate::ops::keyswitch::lwe_keyswitch::keyswitch_lwe_to_lwe)
|
||||
/// after the bootstrapping operation.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use sunscreen_tfhe::{
|
||||
/// high_level::{keygen, encryption, fft},
|
||||
/// entities::{BivariateLookupTable, LweCiphertext},
|
||||
/// ops::bootstrapping::programmable_bootstrap_bivariate,
|
||||
/// params::{
|
||||
/// GLWE_1_1024_80,
|
||||
/// LWE_512_80,
|
||||
/// CarryBits,
|
||||
/// PlaintextBits,
|
||||
/// RadixDecomposition,
|
||||
/// RadixCount,
|
||||
/// RadixLog
|
||||
/// },
|
||||
/// };
|
||||
///
|
||||
/// // Parameters defining the scheme we are using
|
||||
/// let lwe_params = LWE_512_80;
|
||||
/// let glwe_params = GLWE_1_1024_80;
|
||||
/// let radix = RadixDecomposition {
|
||||
/// count: RadixCount(3),
|
||||
/// radix_log: RadixLog(4),
|
||||
/// };
|
||||
///
|
||||
/// // We will be showing a binary bivariate function, but bivariate
|
||||
/// // bootstrapping can be done on more plaintext bits. Note that the effective
|
||||
/// // number of plaintext bits used is twice the number of plaintext bits
|
||||
/// // specified because the inputs are packed into one ciphertext inside
|
||||
/// // `programmable_bootstrap_bivariate`. The number of carry bits must always
|
||||
/// // be greater than or equal to the number of plaintext bits.
|
||||
/// let plaintext_bits = PlaintextBits(1);
|
||||
/// let plaintext_bits_carry = PlaintextBits(2);
|
||||
/// let carry_bits = CarryBits(1);
|
||||
///
|
||||
/// // The bivariate function we want to evaluate, encoded as a lookup table.
|
||||
/// let xor = |x, y| (x + y) % (1 << plaintext_bits.0);
|
||||
/// let lut = BivariateLookupTable::trivial_from_fn(
|
||||
/// &xor,
|
||||
/// &glwe_params,
|
||||
/// plaintext_bits,
|
||||
/// carry_bits
|
||||
/// );
|
||||
///
|
||||
/// // Generate the secret keys and the bootstrapping key
|
||||
/// let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
|
||||
/// let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
///
|
||||
/// let bsk = keygen::generate_bootstrapping_key(&lwe_sk, &glwe_sk, &lwe_params, &glwe_params, &radix);
|
||||
/// let bsk =
|
||||
/// fft::fft_bootstrap_key(&bsk, &lwe_params, &glwe_params, &radix);
|
||||
///
|
||||
/// // Specify the inputs
|
||||
/// let left_input_plain = 0;
|
||||
/// let right_input_plain = 1;
|
||||
///
|
||||
/// // Encrypt the inputs. Note we are adding carry bits to the inputs.
|
||||
/// let left_input = encryption::encrypt_lwe_secret(
|
||||
/// left_input_plain,
|
||||
/// &lwe_sk,
|
||||
/// &lwe_params,
|
||||
/// plaintext_bits_carry
|
||||
/// );
|
||||
/// let right_input = encryption::encrypt_lwe_secret(
|
||||
/// right_input_plain,
|
||||
/// &lwe_sk,
|
||||
/// &lwe_params,
|
||||
/// plaintext_bits_carry
|
||||
/// );
|
||||
///
|
||||
/// // Perform the programmable bootstrapping
|
||||
/// let mut result = LweCiphertext::new(&glwe_params.as_lwe_def());
|
||||
/// programmable_bootstrap_bivariate(
|
||||
/// &mut result,
|
||||
/// &left_input,
|
||||
/// &right_input,
|
||||
/// &lut,
|
||||
/// &bsk,
|
||||
/// &lwe_params,
|
||||
/// &glwe_params,
|
||||
/// plaintext_bits,
|
||||
/// &radix,
|
||||
/// );
|
||||
///
|
||||
/// // Check the result matches our plaintext function.
|
||||
/// let decrypted = encryption::decrypt_lwe_with_carry(
|
||||
/// &result,
|
||||
/// &glwe_sk.to_lwe_secret_key(),
|
||||
/// &glwe_params.as_lwe_def(),
|
||||
/// plaintext_bits,
|
||||
/// carry_bits
|
||||
/// );
|
||||
///
|
||||
/// let expected = xor(left_input_plain, right_input_plain);
|
||||
/// assert_eq!(expected, decrypted);
|
||||
/// ```
|
||||
///
|
||||
/// # See also
|
||||
///
|
||||
/// For the univariate version of programmable bootstrapping, see
|
||||
/// [`programmable_bootstrap`](programmable_bootstrap) and its associated LUT
|
||||
/// [`UnivariateLookupTable`](crate::entities::UnivariateLookupTable).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn programmable_bootstrap_bivariate<S>(
|
||||
output: &mut LweCiphertextRef<S>,
|
||||
left_input: &LweCiphertextRef<S>,
|
||||
right_input: &LweCiphertextRef<S>,
|
||||
lut: &BivariateLookupTableRef<S>,
|
||||
bootstrap_key: &BootstrapKeyFftRef<Complex<f64>>,
|
||||
lwe_params: &LweDef,
|
||||
glwe_params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
// The general operation for a bivariate PBS is
|
||||
//
|
||||
// 1. Ensure that the number of carry bits is equal to the size of the
|
||||
// message or greater.
|
||||
// 2. Define a LUT where the function takes in one input and decomposes that
|
||||
// input into the higher n plaintext and lower n plaintext bits. The
|
||||
// higher n bits are the left input to the bivariate function, while the
|
||||
// lower n bits are the right input to the bivariate function.
|
||||
// 3. Encrypt the two input ciphertexts using the number of carry bits and
|
||||
// the plaintext bits, with padding.
|
||||
// 4. On the left encrypted input, shift it up by the number of plaintext
|
||||
// bits by multiplying the ciphertext by the plaintext modulus.
|
||||
// 5. Add the left and right encrypted inputs together.
|
||||
// 6. Perform the programmable bootstrapping with this combined input.
|
||||
|
||||
let shift = (1 << plaintext_bits.0) as u64;
|
||||
|
||||
allocate_scratch_ref!(pbs_input, LweCiphertextRef<S>, (lwe_params.dim));
|
||||
pbs_input.clear();
|
||||
|
||||
// (left * modulus) + right to pack the two inputs into a single LWE
|
||||
scalar_mul_ciphertext_mad(pbs_input, &S::from_u64(shift), left_input, lwe_params);
|
||||
add_lwe_inplace(pbs_input, right_input, lwe_params);
|
||||
|
||||
programmable_bootstrap(
|
||||
output,
|
||||
pbs_input,
|
||||
lut.as_univariate(),
|
||||
bootstrap_key,
|
||||
lwe_params,
|
||||
glwe_params,
|
||||
radix,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{
|
||||
entities::{
|
||||
BivariateLookupTable, BootstrapKey, BootstrapKeyFft, LweCiphertext, LweKeyswitchKey,
|
||||
UnivariateLookupTable,
|
||||
},
|
||||
high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX},
|
||||
ops::{
|
||||
encryption::{decrypt_ggsw_ciphertext, encrypt_lwe_ciphertext},
|
||||
keyswitch::lwe_keyswitch_key::generate_keyswitch_key_lwe,
|
||||
},
|
||||
RoundedDiv, GLWE_1_1024_80,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn generate_negacyclic_lut_from_formula<S>(
|
||||
params: &GlweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> Polynomial<Torus<S>>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let mut output = Polynomial::<Torus<S>>::zero(params.dim.polynomial_degree.0);
|
||||
|
||||
let p = (1 << plaintext_bits.0) as u64;
|
||||
let n = params.dim.polynomial_degree.0 as u64;
|
||||
|
||||
let divisor = 2 * n;
|
||||
|
||||
for (j, c) in output.coeffs_mut().iter_mut().enumerate() {
|
||||
let v_i = ((p * (j as u64)).div_rounded(divisor)) % p;
|
||||
let v_i = v_i << (S::BITS - plaintext_bits.0);
|
||||
*c = Torus::from(S::from_u64(v_i));
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_generate_negacyclic_lut() {
|
||||
let p = PlaintextBits(4);
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
|
||||
let mut poly = Polynomial::<Torus<u64>>::zero(params.dim.polynomial_degree.0);
|
||||
generate_negacyclic_lut(&mut poly, |x| x, ¶ms, p);
|
||||
|
||||
let expected = generate_negacyclic_lut_from_formula(¶ms, p);
|
||||
|
||||
assert_eq!(expected, poly);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_generate_bootstrap_key() {
|
||||
let lwe_params = TEST_LWE_DEF_1;
|
||||
let glwe_params = TEST_GLWE_DEF_1;
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(&lwe_params);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
|
||||
let mut bootstrap_key = BootstrapKey::new(&lwe_params, &glwe_params, &radix);
|
||||
generate_bootstrap_key(&mut bootstrap_key, &sk, &glwe_sk, &glwe_params, &radix);
|
||||
|
||||
let mut count = 0;
|
||||
for (s_i, ct) in sk.s().iter().zip(bootstrap_key.rows(&glwe_params, &radix)) {
|
||||
let mut msg = Polynomial::<Torus<u64>>::zero(glwe_params.dim.polynomial_degree.0);
|
||||
decrypt_ggsw_ciphertext(&mut msg, ct, &glwe_sk, &glwe_params, &radix);
|
||||
|
||||
assert_eq!(msg.coeffs()[0].inner(), *s_i);
|
||||
|
||||
count += 1
|
||||
}
|
||||
|
||||
assert_eq!(count, sk.s().len());
|
||||
}
|
||||
|
||||
fn bootstrap_helper(map: impl Fn(u64) -> u64) {
|
||||
let bits = PlaintextBits(3);
|
||||
let lwe = TEST_LWE_DEF_1;
|
||||
let glwe = GLWE_1_1024_80;
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
let original_sk = keygen::generate_binary_lwe_sk(&lwe);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
// We want to switch from the sample extracted key to the new key.
|
||||
let mut ksk = LweKeyswitchKey::<u64>::new(&glwe.as_lwe_def(), &lwe, &radix);
|
||||
generate_keyswitch_key_lwe(
|
||||
&mut ksk,
|
||||
glwe_sk.to_lwe_secret_key(),
|
||||
&original_sk,
|
||||
&lwe,
|
||||
&radix,
|
||||
);
|
||||
|
||||
let mut bsk_nonfft = BootstrapKey::new(&lwe, &glwe, &radix);
|
||||
generate_bootstrap_key(&mut bsk_nonfft, &original_sk, &glwe_sk, &glwe, &radix);
|
||||
|
||||
let mut bsk = BootstrapKeyFft::new(&lwe, &glwe, &radix);
|
||||
bsk_nonfft.fft(&mut bsk, &glwe, &radix);
|
||||
|
||||
// Generate the LUT
|
||||
let lut = UnivariateLookupTable::trivial_from_fn(&map, &glwe, bits);
|
||||
|
||||
let mut failed = Vec::new();
|
||||
for msg in 0..(1 << bits.0) {
|
||||
let mut original_ct = LweCiphertext::new(&lwe);
|
||||
|
||||
// Adding a padding bit
|
||||
let encoded_msg = msg << (64 - bits.0 - 1);
|
||||
encrypt_lwe_ciphertext(
|
||||
&mut original_ct,
|
||||
&original_sk,
|
||||
Torus::from(encoded_msg),
|
||||
&lwe,
|
||||
);
|
||||
|
||||
let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def());
|
||||
|
||||
programmable_bootstrap(&mut new_ct, &original_ct, &lut, &bsk, &lwe, &glwe, &radix);
|
||||
|
||||
let decoded = glwe_sk
|
||||
.to_lwe_secret_key()
|
||||
.decrypt(&new_ct, &glwe.as_lwe_def(), bits);
|
||||
|
||||
let result = map(msg);
|
||||
if result != decoded {
|
||||
failed.push((result, decoded));
|
||||
}
|
||||
}
|
||||
|
||||
if !failed.is_empty() {
|
||||
panic!(
|
||||
"Failed to decrypt the following messages and decrypted values: {:?}",
|
||||
failed
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_bootstrap() {
|
||||
bootstrap_helper(|x| x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_bootstrap_with_map() {
|
||||
bootstrap_helper(|x| (x + 3) % 8);
|
||||
}
|
||||
|
||||
fn bivariate_bootstrap_helper(map: impl Fn(u64, u64) -> u64) {
|
||||
let lwe = TEST_LWE_DEF_1;
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
let _radix = TEST_RADIX;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let carry_bits = CarryBits(1);
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
let original_sk = keygen::generate_binary_lwe_sk(&lwe);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
// We want to switch from the sample extracted key to the new key.
|
||||
let mut ksk = LweKeyswitchKey::<u64>::new(&glwe.as_lwe_def(), &lwe, &radix);
|
||||
generate_keyswitch_key_lwe(
|
||||
&mut ksk,
|
||||
glwe_sk.to_lwe_secret_key(),
|
||||
&original_sk,
|
||||
&lwe,
|
||||
&radix,
|
||||
);
|
||||
|
||||
let mut bsk_nonfft = BootstrapKey::new(&lwe, &glwe, &radix);
|
||||
generate_bootstrap_key(&mut bsk_nonfft, &original_sk, &glwe_sk, &glwe, &radix);
|
||||
|
||||
let mut bsk = BootstrapKeyFft::new(&lwe, &glwe, &radix);
|
||||
bsk_nonfft.fft(&mut bsk, &glwe, &radix);
|
||||
|
||||
// Generate the LUT
|
||||
let lut = BivariateLookupTable::trivial_from_fn(&map, &glwe, bits, carry_bits);
|
||||
|
||||
let mut failed = Vec::new();
|
||||
let mut succeeded = Vec::new();
|
||||
|
||||
for left_msg in 0..(1 << bits.0) {
|
||||
for right_msg in 0..(1 << bits.0) {
|
||||
let mut left_ct = LweCiphertext::new(&lwe);
|
||||
let mut right_ct = LweCiphertext::new(&lwe);
|
||||
|
||||
// Adding a padding bit, hence the - 1
|
||||
let encoded_left_msg = left_msg << (64 - bits.0 - carry_bits.0 - 1);
|
||||
let encoded_right_msg = right_msg << (64 - bits.0 - carry_bits.0 - 1);
|
||||
|
||||
encrypt_lwe_ciphertext(
|
||||
&mut left_ct,
|
||||
&original_sk,
|
||||
Torus::from(encoded_left_msg),
|
||||
&lwe,
|
||||
);
|
||||
|
||||
encrypt_lwe_ciphertext(
|
||||
&mut right_ct,
|
||||
&original_sk,
|
||||
Torus::from(encoded_right_msg),
|
||||
&lwe,
|
||||
);
|
||||
|
||||
let mut new_ct = LweCiphertext::new(&glwe.as_lwe_def());
|
||||
|
||||
programmable_bootstrap_bivariate(
|
||||
&mut new_ct,
|
||||
&left_ct,
|
||||
&right_ct,
|
||||
&lut,
|
||||
&bsk,
|
||||
&lwe,
|
||||
&glwe,
|
||||
bits,
|
||||
&radix,
|
||||
);
|
||||
|
||||
let decrypted = glwe_sk
|
||||
.to_lwe_secret_key()
|
||||
.decrypt_without_decode(&new_ct, &glwe.as_lwe_def());
|
||||
|
||||
// We manually decode here because the
|
||||
let plain_bits = bits;
|
||||
|
||||
let round_bit = decrypted
|
||||
.inner()
|
||||
.wrapping_shr(64 - plain_bits.0 - carry_bits.0 - 1)
|
||||
& 0x1;
|
||||
let mask = (0x1 << plain_bits.0) - 1;
|
||||
|
||||
let decoded = (decrypted
|
||||
.inner()
|
||||
.wrapping_shr(64 - plain_bits.0 - carry_bits.0)
|
||||
+ round_bit)
|
||||
& mask;
|
||||
|
||||
let result = map(left_msg, right_msg);
|
||||
if result != decoded {
|
||||
failed.push(((left_msg, right_msg), result, decoded));
|
||||
} else {
|
||||
succeeded.push(((left_msg, right_msg), result, decoded));
|
||||
}
|
||||
}
|
||||
}
|
||||
if !failed.is_empty() {
|
||||
panic!(
|
||||
"Failed to decrypt the following messages and decrypted values (as ((left input, right_input), expected, decrypted)): {:?}. However, the following messages and decrypted values succeeded: {:?}",
|
||||
failed, succeeded
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn bivariate_test_function(left: u64, right: u64) -> u64 {
|
||||
(left + right) % 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_bootstrap_with_bivariate_map() {
|
||||
bivariate_bootstrap_helper(bivariate_test_function);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_decompose_bivariate_map() {
|
||||
let plaintext_bits = PlaintextBits(2);
|
||||
let modulus = 1 << plaintext_bits.0;
|
||||
|
||||
let map = &bivariate_test_function;
|
||||
|
||||
for left in 0u64..(plaintext_bits.0 as u64) {
|
||||
for right in 0u64..(plaintext_bits.0 as u64) {
|
||||
let left_shifted = left * modulus;
|
||||
let input = left_shifted + right;
|
||||
let result = bivariate_function(map, input, plaintext_bits);
|
||||
|
||||
assert_eq!(result, map(left, right));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
58
sunscreen_tfhe/src/ops/ciphertext/glev_ciphertext_ops.rs
Normal file
58
sunscreen_tfhe/src/ops/ciphertext/glev_ciphertext_ops.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
use crate::{
|
||||
entities::{GlevCiphertextRef, GlweCiphertextRef, Polynomial},
|
||||
radix::{PolynomialRadixIterator, ScalarRadixIterator},
|
||||
GlweDef, TorusOps,
|
||||
};
|
||||
|
||||
use super::{glwe_polynomial_mad, glwe_scalar_mad};
|
||||
|
||||
/// Compute `c += (G^-1 * a) \[*\] b`, where
|
||||
/// * `G^-1 * a`` is the radix decomposition of `a`
|
||||
/// * `b` is a GLEV ciphertext.
|
||||
/// * `c` is a GLWE ciphertext.
|
||||
/// * \[*\] is the external product between a GLEV ciphertext and `l` polynomials
|
||||
///
|
||||
/// # Remarks
|
||||
/// This functions takes a PolynomialRadixIterator to perform the decomposition.
|
||||
pub fn decomposed_polynomial_glev_mad<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
mut a: PolynomialRadixIterator<S>,
|
||||
b: &GlevCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
// a = decomp(a_i)
|
||||
// b = r
|
||||
|
||||
let b_glwe = b.glwe_ciphertexts(params);
|
||||
let mut cur_radix: Polynomial<S> = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
|
||||
// The decomposition of
|
||||
// <Decomp^{beta, l}(gamma), GLEV>
|
||||
// can be performed using
|
||||
// sum_{j = 1}^l gamma_j * C_j
|
||||
// where gamma_j is the polynomial to decompose multiplied by q/B^{j+1}
|
||||
// Note the reverse of the GLWE ciphertexts here! The decomposition iterator
|
||||
// returns the decomposed values in the opposite order.
|
||||
for b in b_glwe.rev() {
|
||||
a.write_next(&mut cur_radix);
|
||||
glwe_polynomial_mad(c, b, &cur_radix, params);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c += (G^-1 * a) \[*\] b`, where
|
||||
/// * `G^-1 * a`` is the radix decomposition of `a`
|
||||
/// * `b` is a GLEV ciphertext.
|
||||
pub fn decomposed_scalar_glev_mad<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
a: ScalarRadixIterator<S>,
|
||||
b: &GlevCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
for (b, a) in b.glwe_ciphertexts(params).rev().zip(a) {
|
||||
glwe_scalar_mad(c, b, a, params);
|
||||
}
|
||||
}
|
||||
549
sunscreen_tfhe/src/ops/ciphertext/glwe_ciphertext_ops.rs
Normal file
549
sunscreen_tfhe/src/ops/ciphertext/glwe_ciphertext_ops.rs
Normal file
@@ -0,0 +1,549 @@
|
||||
use crate::{
|
||||
dst::FromMutSlice,
|
||||
entities::{
|
||||
GgswCiphertextRef, GlweCiphertext, GlweCiphertextRef, LweCiphertextRef, PolynomialRef,
|
||||
},
|
||||
ops::ciphertext::decomposed_polynomial_glev_mad,
|
||||
polynomial::{
|
||||
polynomial_add, polynomial_external_mad, polynomial_negate, polynomial_scalar_mad,
|
||||
polynomial_sub,
|
||||
},
|
||||
radix::PolynomialRadixIterator,
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
/**
|
||||
* Extract a specific coefficient in a message M in a GLWE ciphertext as a LWE
|
||||
* ciphertext under the LWE extracted secret key (extracted from the GLWE secret
|
||||
* key).
|
||||
*
|
||||
* # Arguments
|
||||
*
|
||||
* * `output` - The output LWE ciphertext
|
||||
* * `glwe` - The input GLWE ciphertext
|
||||
* * `h` - The index of the coefficient to extract
|
||||
*
|
||||
* # Remarks
|
||||
* For a GLWE ciphertext of size k and dimension N, the output LWE ciphertext
|
||||
* passed in must have size k*N.
|
||||
*/
|
||||
pub fn sample_extract<S>(
|
||||
output: &mut LweCiphertextRef<S>,
|
||||
glwe: &GlweCiphertextRef<S>,
|
||||
h: usize,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
// We are copying parts of the GLWE ciphertext out according to the following rule:
|
||||
// a_{N*i + j} = a_{i, h - j} for 0 <= i < k, 0 <= j <= h
|
||||
// a_{N*i + j} = -a_{i, h - j + n} for 0 <= i < k, n + 1 <= j < N
|
||||
// b = b_n
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
let N = params.dim.polynomial_degree.0;
|
||||
let k = params.dim.size.0;
|
||||
|
||||
let lwe_size = k * N;
|
||||
|
||||
let (a_lwe, b_lwe) = output.a_b_mut(¶ms.as_lwe_def());
|
||||
|
||||
// Make sure that the correctly sized LWE was passed in.
|
||||
assert_eq!(lwe_size, a_lwe.len());
|
||||
|
||||
let (a_glwe, b_glwe) = glwe.a_b(params);
|
||||
|
||||
for (i, a_gwe_i) in a_glwe.enumerate() {
|
||||
#[allow(non_snake_case)]
|
||||
let Ni = N * i;
|
||||
let a_glwe_i_coeffs = a_gwe_i.coeffs();
|
||||
|
||||
for j in 0..=h {
|
||||
a_lwe[Ni + j] = a_glwe_i_coeffs[h - j];
|
||||
}
|
||||
|
||||
for j in (h + 1)..N {
|
||||
// Note we add N to h first, otherwise h - j might underflow.
|
||||
a_lwe[Ni + j] = num::traits::WrappingNeg::wrapping_neg(&a_glwe_i_coeffs[h + N - j]);
|
||||
}
|
||||
}
|
||||
|
||||
*b_lwe = b_glwe.coeffs()[h];
|
||||
}
|
||||
|
||||
/// Add two GLWE ciphertexts together, storing the result in `c`.
|
||||
pub fn add_glwe_ciphertexts<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
a: &GlweCiphertextRef<S>,
|
||||
b: &GlweCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
let (b_a, b_b) = b.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), a_a.len());
|
||||
assert_eq!(c_a.len(), b_a.len());
|
||||
|
||||
for (c, (a, b)) in c_a.zip(a_a.zip(b_a)) {
|
||||
polynomial_add(c, a, b);
|
||||
}
|
||||
|
||||
polynomial_add(c_b, a_b, b_b);
|
||||
}
|
||||
|
||||
/// Subtract two GLWE ciphertexts together, storing the result in `c`.
|
||||
pub fn sub_glwe_ciphertexts<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
a: &GlweCiphertextRef<S>,
|
||||
b: &GlweCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
let (b_a, b_b) = b.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), a_a.len());
|
||||
assert_eq!(c_a.len(), b_a.len());
|
||||
|
||||
for (c, (a, b)) in c_a.zip(a_a.zip(b_a)) {
|
||||
polynomial_sub(c, a, b);
|
||||
}
|
||||
|
||||
polynomial_sub(c_b, a_b, b_b);
|
||||
}
|
||||
|
||||
/// Homomorphically compute -ct.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This operation is noiseless.
|
||||
pub fn glwe_negate_inplace<S>(ct: &mut GlweCiphertextRef<S>, params: &GlweDef)
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
for a in ct.a_mut(params) {
|
||||
polynomial_negate(a);
|
||||
}
|
||||
|
||||
polynomial_negate(ct.b_mut(params));
|
||||
}
|
||||
|
||||
/// Compute c += a \[*\] b where \[*\] is the external product between a GLWE
|
||||
/// ciphertext and an polynomial in Z\[X\]/(X^N + 1).
|
||||
///
|
||||
/// # Remarks
|
||||
/// For this to produce the correct result, degree(b) must be "small" (ideally
|
||||
/// 0) and the coefficient must be small (i.e. less than the message size).
|
||||
pub fn glwe_polynomial_mad<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
a: &GlweCiphertextRef<S>,
|
||||
b: &PolynomialRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), params.dim.size.0);
|
||||
assert_eq!(a_a.len(), params.dim.size.0);
|
||||
|
||||
for (c, a) in c_a.zip(a_a) {
|
||||
polynomial_external_mad(c, a, b);
|
||||
}
|
||||
|
||||
polynomial_external_mad(c_b, a_b, b);
|
||||
}
|
||||
|
||||
/// Compute `c += a \[*\] b`` where
|
||||
/// * `a` is a GLWE ciphertext
|
||||
/// * `b` is a scalar
|
||||
/// * `\[*\]` is the external product operator GLWE \[*\] Z -> GLWE
|
||||
pub fn glwe_scalar_mad<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
a: &GlweCiphertextRef<S>,
|
||||
b: S,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
for (c, a) in c.a_mut(params).zip(a.a(params)) {
|
||||
polynomial_scalar_mad(c, a, b);
|
||||
}
|
||||
|
||||
polynomial_scalar_mad(c.b_mut(params), a.b(params), b);
|
||||
}
|
||||
|
||||
/// Compute `c += a \[*\] b`` where
|
||||
/// * `a` is a GLWE ciphertext
|
||||
/// * `b` is a GGSW cipheetext
|
||||
/// * `\[*\]` is the external product operator GGSW \[*\] GLWE -> GLWE`
|
||||
pub fn glwe_ggsw_mad<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
a: &GlweCiphertextRef<S>,
|
||||
b: &GgswCiphertextRef<S>,
|
||||
glwe_def: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (a_a, a_b) = a.a_b(glwe_def);
|
||||
let rows = b.rows(glwe_def, radix);
|
||||
|
||||
// Generate an iterator that includes a_a and a_b
|
||||
let a_then_b_glwe_polynomials = a_a.chain(std::iter::once(a_b));
|
||||
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<S>, (glwe_def.dim.polynomial_degree));
|
||||
|
||||
// Performs the external operation
|
||||
//
|
||||
// GGSW ⊡ GLWE = sum_i=0^k <Decomp^{beta, l}(AB_i), C_i>
|
||||
//
|
||||
// where
|
||||
// * `beta` is the decomposition base
|
||||
// * `l` is the decomposition level
|
||||
// * `AB_i` is the i-th polynomial of the GLWE ciphertext with B at the end {A, B}
|
||||
// * `C_i` is the i-th row of the GGSW ciphertext
|
||||
for (a_i, r) in a_then_b_glwe_polynomials.zip(rows) {
|
||||
// For each a polynomial, compute the external product with the GLEV ciphertext
|
||||
// at index i in the GGSW ciphertext.
|
||||
let decomp = PolynomialRadixIterator::new(a_i, scratch, radix);
|
||||
|
||||
decomposed_polynomial_glev_mad(c, decomp, r, glwe_def);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute external product of a GLWE ciphertext and a GGSW ciphertext.
|
||||
/// GGSW ⊡ GLWE -> GLWE
|
||||
pub fn external_product_ggsw_glwe<S>(
|
||||
ggsw: &GgswCiphertextRef<S>,
|
||||
glwe: &GlweCiphertextRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) -> GlweCiphertext<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
// Zero GLWE to sum over.
|
||||
let mut result = GlweCiphertext::new(params);
|
||||
glwe_ggsw_mad(&mut result, glwe, ggsw, params, radix);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
entities::{GgswCiphertext, LweCiphertext, Polynomial},
|
||||
high_level::*,
|
||||
high_level::{keygen, TEST_GLWE_DEF_1},
|
||||
ops::encryption::{
|
||||
decrypt_ggsw_ciphertext, encrypt_ggsw_ciphertext, encrypt_glwe_ciphertext_secret,
|
||||
trivially_encrypt_glwe_ciphertext,
|
||||
},
|
||||
polynomial::polynomial_mad,
|
||||
PlaintextBits, Torus,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
#[test]
|
||||
fn polynomial_iteration_mut() {
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
|
||||
let mut sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
assert_eq!(sk.s(&glwe).count(), glwe.dim.size.0);
|
||||
|
||||
for s_i in sk.s_mut(&glwe) {
|
||||
assert_eq!(s_i.len(), glwe.dim.polynomial_degree.0);
|
||||
|
||||
for s in s_i.coeffs() {
|
||||
assert!(*s == 0 || *s == 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_glwe_ciphertexts() {
|
||||
let bits = PlaintextBits(4);
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
let plaintext = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 4)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let a = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
|
||||
let b = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
|
||||
|
||||
let c = a + b;
|
||||
|
||||
let dec = sk.decrypt_decode_glwe(&c, &glwe, bits);
|
||||
|
||||
for (i, c) in dec.coeffs().iter().enumerate() {
|
||||
assert_eq!(*c, 2 * (i as u64 % 4));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_sub_glwe_ciphertexts() {
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
let plaintext = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 4)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let a = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
|
||||
let b = sk.encode_encrypt_glwe(&plaintext, &glwe, bits);
|
||||
|
||||
let c = a - b;
|
||||
|
||||
let dec = sk.decrypt_decode_glwe(&c, &glwe, bits);
|
||||
|
||||
for c in dec.coeffs() {
|
||||
assert_eq!(*c, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_internal_product_glwe_polynomial() {
|
||||
let bits = PlaintextBits(4);
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
let large_poly = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 4)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let small_poly = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| if x < 1 { 3 } else { 0 })
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
// Do the external product with an encryption of the large polynomial.
|
||||
let a = sk.encode_encrypt_glwe(&large_poly, &glwe, bits);
|
||||
let mut c = GlweCiphertext::new(&glwe);
|
||||
|
||||
glwe_polynomial_mad(&mut c, &a, &small_poly, &glwe);
|
||||
|
||||
let actual = sk.decrypt_decode_glwe(&c, &glwe, bits);
|
||||
|
||||
let mut expected = Polynomial::<u64>::zero(glwe.dim.polynomial_degree.0);
|
||||
|
||||
polynomial_mad(
|
||||
expected.as_wrapping_mut(),
|
||||
large_poly.as_wrapping(),
|
||||
small_poly.as_wrapping(),
|
||||
);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
|
||||
// Now do reverse the large and small polynomials' roles.
|
||||
let a = sk.encode_encrypt_glwe(&small_poly, &glwe, bits);
|
||||
let mut c = GlweCiphertext::new(&glwe);
|
||||
|
||||
glwe_polynomial_mad(&mut c, &a, &large_poly, &glwe);
|
||||
|
||||
let actual = sk.decrypt_decode_glwe(&c, &glwe, bits);
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_external_product_ggsw_glwe() {
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
// Constant polynomial: [1, 0, 0, ...]
|
||||
let ggsw_plaintext_polynomial = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| if x < 1 { 1 } else { 0 })
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let ggsw_plaintext_polynomial_torus = ggsw_plaintext_polynomial.map(|x| Torus::from(*x));
|
||||
|
||||
let mut ggsw_ct = GgswCiphertext::new(&glwe, &radix);
|
||||
encrypt_ggsw_ciphertext(
|
||||
&mut ggsw_ct,
|
||||
&ggsw_plaintext_polynomial,
|
||||
&sk,
|
||||
&glwe,
|
||||
&radix,
|
||||
bits,
|
||||
);
|
||||
|
||||
let mut ggsw_decrypt = Polynomial::zero(glwe.dim.polynomial_degree.0);
|
||||
decrypt_ggsw_ciphertext(&mut ggsw_decrypt, &ggsw_ct, &sk, &glwe, &radix);
|
||||
assert_eq!(
|
||||
ggsw_decrypt, ggsw_plaintext_polynomial_torus,
|
||||
"GGSW decrypt does not match plaintext"
|
||||
);
|
||||
|
||||
// [1, 2, ...]
|
||||
let glwe_plaintext_polynomial = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % (1 << bits.0))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let glwe_plaintext_polynomial_encoded =
|
||||
glwe_plaintext_polynomial.map(|x| Torus::encode(*x, bits));
|
||||
|
||||
// Do the external product with an encryption of the large polynomial.
|
||||
let mut glwe_ct = GlweCiphertext::new(&glwe);
|
||||
encrypt_glwe_ciphertext_secret(
|
||||
&mut glwe_ct,
|
||||
&glwe_plaintext_polynomial_encoded,
|
||||
&sk,
|
||||
&glwe,
|
||||
);
|
||||
|
||||
let glwe_decrypt = sk.decrypt_decode_glwe(&glwe_ct, &glwe, bits);
|
||||
assert_eq!(
|
||||
glwe_decrypt, glwe_plaintext_polynomial,
|
||||
"GLWE decrypt does not match plaintext"
|
||||
);
|
||||
|
||||
let encrypted_result = external_product_ggsw_glwe(&ggsw_ct, &glwe_ct, &glwe, &radix);
|
||||
let result = sk.decrypt_decode_glwe(&encrypted_result, &glwe, bits);
|
||||
|
||||
// expected is the polynomial multiplication of
|
||||
// ggsw_plaintext_polynomial and glwe_plaintext_polynomial
|
||||
let mut expected = Polynomial::<u64>::zero(glwe.dim.polynomial_degree.0);
|
||||
|
||||
polynomial_mad(
|
||||
expected.as_wrapping_mut(),
|
||||
ggsw_plaintext_polynomial.as_wrapping(),
|
||||
glwe_plaintext_polynomial.as_wrapping(),
|
||||
);
|
||||
|
||||
// Reduce modulo 2^BITS
|
||||
let expected = expected.map(|x| x % (1 << bits.0));
|
||||
|
||||
// Verify that the ciphertext multiplication is correct
|
||||
assert_eq!(
|
||||
result, expected,
|
||||
"External product does not match polynomial multiplication"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_add_glwe_with_trivial_glwe() {
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
let delta = 1u64 << (64 - bits.0);
|
||||
|
||||
let large_poly = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 4)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let small_poly = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| if x < 1 { 3 } else { 0 })
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let small_poly_scaled = small_poly.map(|x| Torus::from(x * delta));
|
||||
|
||||
let a = sk.encode_encrypt_glwe(&large_poly, &glwe, bits);
|
||||
|
||||
let mut b = GlweCiphertext::new(&glwe);
|
||||
trivially_encrypt_glwe_ciphertext(&mut b, &small_poly_scaled, &glwe);
|
||||
|
||||
let c = a.as_ref() + b.as_ref();
|
||||
|
||||
let actual = sk.decrypt_decode_glwe(&c, &glwe, bits);
|
||||
let expected = small_poly + large_poly;
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
|
||||
let c2 = b.as_ref() + a.as_ref();
|
||||
|
||||
let actual = sk.decrypt_decode_glwe(&c2, &glwe, bits);
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_extract() {
|
||||
let bits = PlaintextBits(2);
|
||||
let glwe_params = TEST_GLWE_DEF_1;
|
||||
let lwe_params = glwe_params.as_lwe_def();
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
let extracted_lwe_sk = sk.to_lwe_secret_key();
|
||||
|
||||
let large_poly = Polynomial::new(
|
||||
&(0..glwe_params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 4)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let glwe = sk.encode_encrypt_glwe(&large_poly, &glwe_params, bits);
|
||||
|
||||
for h in 0..glwe_params.dim.polynomial_degree.0 {
|
||||
let mut lwe = LweCiphertext::new(&lwe_params);
|
||||
sample_extract(&mut lwe, &glwe, h, &glwe_params);
|
||||
|
||||
let lwe_msg = extracted_lwe_sk.decrypt(&lwe, &lwe_params, bits);
|
||||
|
||||
let expected = large_poly.coeffs()[h];
|
||||
|
||||
assert_eq!(expected, lwe_msg);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_glwe_scalar_mad() {
|
||||
for _ in 0..20 {
|
||||
let sk = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
|
||||
|
||||
let plaintext_bits = (thread_rng().next_u64()) % 8 + 1;
|
||||
let plaintext_bits = PlaintextBits(plaintext_bits as u32);
|
||||
|
||||
let scalar = thread_rng().next_u64() % 64;
|
||||
|
||||
let pt = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0)
|
||||
.map(|_| thread_rng().next_u64() % plaintext_bits.0 as u64)
|
||||
.collect::<Vec<_>>();
|
||||
let pt = Polynomial::new(&pt);
|
||||
|
||||
let ct = sk.encode_encrypt_glwe(&pt, &TEST_GLWE_DEF_1, plaintext_bits);
|
||||
|
||||
let mut result = GlweCiphertext::new(&TEST_GLWE_DEF_1);
|
||||
|
||||
glwe_scalar_mad(&mut result, &ct, scalar, &TEST_GLWE_DEF_1);
|
||||
|
||||
let actual = sk.decrypt_decode_glwe(&result, &TEST_GLWE_DEF_1, plaintext_bits);
|
||||
|
||||
for (pt, actual) in pt.coeffs().iter().zip(actual.coeffs()) {
|
||||
let expected = pt.wrapping_mul(scalar) % (0x1u64 << plaintext_bits.0);
|
||||
|
||||
assert_eq!(expected, *actual);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
42
sunscreen_tfhe/src/ops/ciphertext/lev_ciphertext_ops.rs
Normal file
42
sunscreen_tfhe/src/ops/ciphertext/lev_ciphertext_ops.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use crate::{
|
||||
entities::{LevCiphertextRef, LweCiphertextRef, Polynomial},
|
||||
radix::PolynomialRadixIterator,
|
||||
LweDef, TorusOps,
|
||||
};
|
||||
|
||||
use super::scalar_mul_ciphertext_mad;
|
||||
|
||||
/// Compute `c += (G^-1 * a) \[*\] b`, where
|
||||
/// * `G^-1 * a`` is the radix decomposition of `a`
|
||||
/// * `b` is a LEV ciphertext.
|
||||
/// * `c` is a LWE ciphertext.
|
||||
/// * \[*\] is the external product between a LEV ciphertext and the decomposed
|
||||
/// LWE ciphertext.
|
||||
///
|
||||
/// # Remarks
|
||||
/// This functions takes a PolynomialRadixIterator to perform the decomposition.
|
||||
pub fn decomposed_scalar_lev_mad<S>(
|
||||
c: &mut LweCiphertextRef<S>,
|
||||
mut a: PolynomialRadixIterator<S>,
|
||||
b: &LevCiphertextRef<S>,
|
||||
params: &LweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let b_lwe = b.lwe_ciphertexts(params);
|
||||
let mut cur_radix: Polynomial<S> = Polynomial::zero(1);
|
||||
|
||||
// The decomposition of
|
||||
// <Decomp^{beta, l}(gamma), GLEV>
|
||||
// can be performed using
|
||||
// sum_{j = 1}^l gamma_j * C_j
|
||||
// where gamma_j is the polynomial to decompose multiplied by q/B^{j+1}
|
||||
// Note the reverse of the GLWE ciphertexts here! The decomposition iterator
|
||||
// returns the decomposed values in the opposite order.
|
||||
for b in b_lwe.rev() {
|
||||
a.write_next(&mut cur_radix);
|
||||
let radix = cur_radix.coeffs()[0];
|
||||
|
||||
scalar_mul_ciphertext_mad(c, &radix, b, params);
|
||||
}
|
||||
}
|
||||
93
sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs
Normal file
93
sunscreen_tfhe/src/ops/ciphertext/lwe_ciphertext_ops.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use crate::{entities::LweCiphertextRef, LweDef, RoundedDiv, Torus, TorusOps};
|
||||
|
||||
/// Add the coefficients of a to the coefficients of c in place.
|
||||
pub fn add_lwe_inplace<S>(c: &mut LweCiphertextRef<S>, a: &LweCiphertextRef<S>, params: &LweDef)
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), a_a.len());
|
||||
|
||||
for (c, a) in c_a.iter_mut().zip(a_a.iter()) {
|
||||
*c = num::traits::WrappingAdd::wrapping_add(c, a);
|
||||
}
|
||||
|
||||
*c_b = num::traits::WrappingAdd::wrapping_add(c_b, a_b);
|
||||
}
|
||||
|
||||
/// Subtract one LWE ciphertext from another, storing the result in the provided
|
||||
/// output variable. Mostly meant to be used reduce the number of allocations
|
||||
/// and with functions like [allocate_scratch_ref].
|
||||
pub(crate) fn sub_lwe_ciphertexts<S>(
|
||||
c: &mut LweCiphertextRef<S>,
|
||||
a: &LweCiphertextRef<S>,
|
||||
b: &LweCiphertextRef<S>,
|
||||
params: &LweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
let (b_a, b_b) = b.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), a_a.len());
|
||||
assert_eq!(c_a.len(), b_a.len());
|
||||
|
||||
for (c, (a, b)) in c_a.iter_mut().zip(a_a.iter().zip(b_a.iter())) {
|
||||
*c = num::traits::WrappingSub::wrapping_sub(a, b);
|
||||
}
|
||||
|
||||
*c_b = num::traits::WrappingSub::wrapping_sub(a_b, b_b);
|
||||
}
|
||||
|
||||
/// Multiplies an LWE ciphertext by a scalar, storing the result in the provided
|
||||
/// output variable. Mostly meant to be used reduce the number of allocations
|
||||
/// and with functions like [allocate_scratch_ref].
|
||||
pub(crate) fn scalar_mul_ciphertext_mad<S>(
|
||||
c: &mut LweCiphertextRef<S>,
|
||||
scalar: &S,
|
||||
a: &LweCiphertextRef<S>,
|
||||
params: &LweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), a_a.len());
|
||||
|
||||
for (c, a) in c_a.iter_mut().zip(a_a.iter()) {
|
||||
*c += a * scalar;
|
||||
}
|
||||
|
||||
*c_b += a_b * scalar;
|
||||
}
|
||||
|
||||
/// Perform modulus switching on a ciphertext. We are assuming that moduli are
|
||||
/// both powers of two, and that the original number of bits is greater than the
|
||||
/// new number of bits.
|
||||
pub fn modulus_switch<S>(
|
||||
ct: &mut LweCiphertextRef<S>,
|
||||
original_bits: u32,
|
||||
new_bits: u32,
|
||||
params: &LweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (c_a, c_b) = ct.a_b_mut(params);
|
||||
|
||||
// We specifically want to zero out the MSBs instead of shifting them back
|
||||
// around.
|
||||
for a in c_a {
|
||||
let c = a.inner().to_u64() as u128;
|
||||
let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128);
|
||||
*a = Torus::from(S::from_u64(res as u64));
|
||||
}
|
||||
|
||||
let c = c_b.inner().to_u64() as u128;
|
||||
let res = (c * (1 << new_bits)).div_rounded(1 << original_bits as u128);
|
||||
|
||||
*c_b = Torus::from(S::from_u64(res as u64));
|
||||
}
|
||||
11
sunscreen_tfhe/src/ops/ciphertext/mod.rs
Normal file
11
sunscreen_tfhe/src/ops/ciphertext/mod.rs
Normal file
@@ -0,0 +1,11 @@
|
||||
mod lwe_ciphertext_ops;
|
||||
pub use lwe_ciphertext_ops::*;
|
||||
|
||||
mod lev_ciphertext_ops;
|
||||
pub use lev_ciphertext_ops::*;
|
||||
|
||||
mod glev_ciphertext_ops;
|
||||
pub use glev_ciphertext_ops::*;
|
||||
|
||||
mod glwe_ciphertext_ops;
|
||||
pub use glwe_ciphertext_ops::*;
|
||||
403
sunscreen_tfhe/src/ops/encryption/ggsw_encryption.rs
Normal file
403
sunscreen_tfhe/src/ops/encryption/ggsw_encryption.rs
Normal file
@@ -0,0 +1,403 @@
|
||||
use num::Zero;
|
||||
|
||||
use crate::{
|
||||
entities::{GgswCiphertextRef, GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef},
|
||||
polynomial::{polynomial_external_mad, polynomial_scalar_mad, polynomial_scalar_mul},
|
||||
GlweDef, PlaintextBits, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
use super::{
|
||||
decrypt_glwe_ciphertext, encrypt_glwe_ciphertext_secret,
|
||||
trivially_encrypt_glwe_with_sk_argument,
|
||||
};
|
||||
|
||||
/// Perform a ggsw encryption. This is generic in case a trivial GGSW encryption
|
||||
/// is wanted (for example, for testing purposes).
|
||||
pub(crate) fn encrypt_ggsw_ciphertext_generic<S>(
|
||||
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
|
||||
msg: &PolynomialRef<S>,
|
||||
glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
encrypt: impl Fn(
|
||||
&mut GlweCiphertextRef<S>,
|
||||
&PolynomialRef<Torus<S>>,
|
||||
&GlweSecretKeyRef<S>,
|
||||
&GlweDef,
|
||||
),
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let max_val = S::from_u64(0x1 << plaintext_bits.0);
|
||||
assert!(msg.coeffs().iter().all(|x| *x < max_val));
|
||||
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
let polynomial_degree = params.dim.polynomial_degree.0;
|
||||
let glwe_size = params.dim.size.0;
|
||||
|
||||
// k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption
|
||||
// of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of
|
||||
// M/B^{j+1}
|
||||
for (i, row) in ggsw_ciphertext.rows_mut(params, radix).enumerate() {
|
||||
let mut m_times_s = Polynomial::<Torus<S>>::zero(polynomial_degree);
|
||||
|
||||
let m_times_s = if i < glwe_size {
|
||||
// The message is composed of the negated secret key and the message
|
||||
// for all but the last row.
|
||||
let s = glwe_secret_key.s(params).nth(i).unwrap();
|
||||
polynomial_external_mad(&mut m_times_s, msg.as_torus(), s);
|
||||
|
||||
// Negate the product.
|
||||
for c in m_times_s.coeffs_mut().iter_mut() {
|
||||
// Have to call the trait directly because deref is implemented on Torus
|
||||
*c = num::traits::WrappingNeg::wrapping_neg(c);
|
||||
}
|
||||
|
||||
&m_times_s
|
||||
} else {
|
||||
// Last row isn't multiplied by secret key.
|
||||
msg.as_torus()
|
||||
};
|
||||
|
||||
for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() {
|
||||
let mut scaled_msg = Polynomial::zero(polynomial_degree);
|
||||
|
||||
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
|
||||
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
|
||||
let decomp_factor =
|
||||
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
|
||||
|
||||
polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor);
|
||||
|
||||
encrypt(col, &scaled_msg, glwe_secret_key, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypt a GGSW ciphertext with a given message polynomial and secret key.
|
||||
pub fn encrypt_ggsw_ciphertext<S>(
|
||||
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
|
||||
msg: &PolynomialRef<S>,
|
||||
glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
encrypt_ggsw_ciphertext_generic(
|
||||
ggsw_ciphertext,
|
||||
msg,
|
||||
glwe_secret_key,
|
||||
params,
|
||||
radix,
|
||||
plaintext_bits,
|
||||
encrypt_glwe_ciphertext_secret,
|
||||
);
|
||||
}
|
||||
|
||||
/// Encrypt a GGSW ciphertext with a given message polynomial and secret key.
|
||||
/// This is a trivial encryption that doesn't use the secret key and is not
|
||||
/// secure.
|
||||
pub fn trivially_encrypt_ggsw_ciphertext<S>(
|
||||
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
|
||||
msg: &PolynomialRef<S>,
|
||||
glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
encrypt_ggsw_ciphertext_generic(
|
||||
ggsw_ciphertext,
|
||||
msg,
|
||||
glwe_secret_key,
|
||||
params,
|
||||
radix,
|
||||
plaintext_bits,
|
||||
trivially_encrypt_glwe_with_sk_argument,
|
||||
);
|
||||
}
|
||||
|
||||
/// Encrypt scalar (i.e. degree 0 polynomial) msg as a GGSW ciphertext.
|
||||
pub fn encrypt_ggsw_ciphertext_scalar<S>(
|
||||
ggsw_ciphertext: &mut GgswCiphertextRef<S>,
|
||||
msg: S,
|
||||
glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
glwe_def: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let max_val = S::from_u64(0x1 << plaintext_bits.0);
|
||||
assert!(msg < max_val);
|
||||
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
let polynomial_degree = glwe_def.dim.polynomial_degree.0;
|
||||
let glwe_size = glwe_def.dim.size.0;
|
||||
|
||||
// k + 1 rows with l columns of glwe ciphertexts. Element (i,j) is a glwe encryption
|
||||
// of -M/B^{i+1} * s_j, except for j=k+1, where it's simply an encryption of
|
||||
// M/B^{j+1}
|
||||
for (i, row) in ggsw_ciphertext.rows_mut(glwe_def, radix).enumerate() {
|
||||
let mut m_times_s = Polynomial::<Torus<S>>::zero(polynomial_degree);
|
||||
let m_times_s = if i < glwe_size {
|
||||
let s = glwe_secret_key.s(glwe_def).nth(i).unwrap();
|
||||
polynomial_scalar_mad(&mut m_times_s, s.as_torus(), msg);
|
||||
&m_times_s
|
||||
} else {
|
||||
// Last row isn't multiplied by secret key.
|
||||
m_times_s.clear();
|
||||
m_times_s.coeffs_mut()[0] = Torus::from(msg);
|
||||
&m_times_s
|
||||
};
|
||||
|
||||
for (j, col) in row.glwe_ciphertexts_mut(glwe_def).enumerate() {
|
||||
let mut scaled_msg = Polynomial::zero(polynomial_degree);
|
||||
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
|
||||
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
|
||||
let decomp_factor =
|
||||
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
|
||||
|
||||
if i < glwe_size {
|
||||
let decomp_factor = decomp_factor.wrapping_neg();
|
||||
|
||||
polynomial_scalar_mul(&mut scaled_msg, m_times_s, decomp_factor);
|
||||
} else {
|
||||
scaled_msg.coeffs_mut()[0] = Torus::from(msg.wrapping_mul(&decomp_factor));
|
||||
|
||||
for c in scaled_msg.coeffs_mut().iter_mut().skip(1) {
|
||||
*c = Torus::zero();
|
||||
}
|
||||
}
|
||||
|
||||
encrypt_glwe_ciphertext_secret(col, &scaled_msg, glwe_secret_key, glwe_def);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decrypt_glwe_in_ggsw<S>(
|
||||
msg: &mut PolynomialRef<Torus<S>>,
|
||||
ggsw_ciphertext: &GgswCiphertextRef<S>,
|
||||
glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
row: usize,
|
||||
column: usize,
|
||||
) -> Option<()>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
|
||||
// To decrypt a GGSW ciphertext, it suffices to decrypt the first GLWE ciphertext in
|
||||
// the last row and divide by its decomposition factor.
|
||||
let glev = ggsw_ciphertext.rows(params, radix).nth(row)?;
|
||||
let glwe = glev.glwe_ciphertexts(params).nth(column)?;
|
||||
|
||||
// Decrypt that specific GLWE ciphertext, which should have a message of
|
||||
// q / beta ^ {column + 1} * SM, where SM is the message times the secret
|
||||
// every row but the last (-SM) and M for the last row.
|
||||
decrypt_glwe_ciphertext(msg, glwe, glwe_secret_key, params);
|
||||
|
||||
let mask = (0x1 << decomposition_radix_log) - 1;
|
||||
|
||||
for c in msg.coeffs_mut() {
|
||||
let val = c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1));
|
||||
let r = (c.inner() >> (S::BITS as usize - decomposition_radix_log * (column + 1) - 1))
|
||||
& S::from_u64(0x1);
|
||||
|
||||
*c = Torus::from((val + r) & S::from_u64(mask));
|
||||
}
|
||||
|
||||
Some(())
|
||||
}
|
||||
|
||||
/// Decrypt a GGSW ciphertext with a given secret key.
|
||||
pub fn decrypt_ggsw_ciphertext<S>(
|
||||
msg: &mut PolynomialRef<Torus<S>>,
|
||||
ggsw_ciphertext: &GgswCiphertextRef<S>,
|
||||
glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let row = params.dim.size.0;
|
||||
|
||||
decrypt_glwe_in_ggsw(msg, ggsw_ciphertext, glwe_secret_key, params, radix, row, 0).unwrap();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{entities::GgswCiphertext, high_level::TEST_GLWE_DEF_1, high_level::*};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_encrypt_decrypt_gsw_const_coeff() {
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
let radix = &TEST_RADIX;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let msg = 1;
|
||||
|
||||
let ct = encryption::encrypt_ggsw(msg, &sk, ¶ms, radix, bits);
|
||||
let pt = encryption::decrypt_ggsw(&ct, &sk, ¶ms, radix, bits);
|
||||
|
||||
assert_eq!(pt.coeffs()[0], msg);
|
||||
|
||||
for c in pt.coeffs().iter().skip(1) {
|
||||
assert_eq!(*c, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that each of the rows in the GGSW ciphertext is a GLWE ciphertext that encodes the
|
||||
/// appropriate message (usually the decomposed message times the secret key)
|
||||
#[test]
|
||||
fn can_decrypt_all_elements_ggsw() {
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
let radix = TEST_RADIX;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let coeffs = (0..params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let msg = Polynomial::new(&coeffs);
|
||||
|
||||
let mut ct = GgswCiphertext::new(¶ms, &radix);
|
||||
encrypt_ggsw_ciphertext(&mut ct, &msg, &sk, ¶ms, &radix, bits);
|
||||
|
||||
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
decrypt_ggsw_ciphertext(&mut pt, &ct, &sk, ¶ms, &radix);
|
||||
let pt = pt.map(|x| x.inner());
|
||||
|
||||
// Ensure that the basic decryption works.
|
||||
assert_eq!(pt, msg);
|
||||
|
||||
let n_rows = ct.rows(¶ms, &radix).len();
|
||||
let n_cols = ct
|
||||
.rows(¶ms, &radix)
|
||||
.next()
|
||||
.unwrap()
|
||||
.glwe_ciphertexts(¶ms)
|
||||
.len();
|
||||
|
||||
// Beta
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
|
||||
for i in 0..n_rows {
|
||||
let mut m_times_s = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
let m_times_s = if i < params.dim.size.0 {
|
||||
// The message is composed of the negated secret key and the message
|
||||
// for all but the last row.
|
||||
let s = sk.s(¶ms).nth(i).unwrap();
|
||||
polynomial_external_mad(&mut m_times_s, msg.as_torus(), s);
|
||||
|
||||
// Negate the product.
|
||||
for c in m_times_s.coeffs_mut().iter_mut() {
|
||||
// Have to call the trait directly because deref is implemented on Torus
|
||||
*c = num::traits::WrappingNeg::wrapping_neg(c);
|
||||
}
|
||||
|
||||
&m_times_s
|
||||
} else {
|
||||
// Last row isn't multiplied by secret key.
|
||||
msg.as_torus()
|
||||
};
|
||||
|
||||
for j in 0..n_cols {
|
||||
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
let mut msg = m_times_s.to_owned();
|
||||
|
||||
let mask = (0x1 << decomposition_radix_log) - 1;
|
||||
|
||||
for c in msg.coeffs_mut() {
|
||||
*c = Torus::from(c.inner() & mask);
|
||||
}
|
||||
|
||||
decrypt_glwe_in_ggsw(&mut pt, &ct, &sk, ¶ms, &radix, i, j).unwrap();
|
||||
|
||||
assert_eq!(pt, msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_trivially_decrypy_ggsw() {
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
let radix = TEST_RADIX;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let coeffs = (0..params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let msg = Polynomial::new(&coeffs);
|
||||
|
||||
let mut ct = GgswCiphertext::new(¶ms, &radix);
|
||||
trivially_encrypt_ggsw_ciphertext(&mut ct, &msg, &sk, ¶ms, &radix, bits);
|
||||
|
||||
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
decrypt_ggsw_ciphertext(&mut pt, &ct, &sk, ¶ms, &radix);
|
||||
let pt = pt.map(|x| x.inner());
|
||||
|
||||
// Ensure that the basic decryption works.
|
||||
assert_eq!(pt, msg);
|
||||
|
||||
let n_rows = ct.rows(¶ms, &radix).len();
|
||||
let n_cols = ct
|
||||
.rows(¶ms, &radix)
|
||||
.next()
|
||||
.unwrap()
|
||||
.glwe_ciphertexts(¶ms)
|
||||
.len();
|
||||
|
||||
// Beta
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
|
||||
for i in 0..n_rows {
|
||||
let mut m_times_s = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
let m_times_s = if i < params.dim.size.0 {
|
||||
// The message is composed of the negated secret key and the message
|
||||
// for all but the last row.
|
||||
let s = sk.s(¶ms).nth(i).unwrap();
|
||||
polynomial_external_mad(&mut m_times_s, msg.as_torus(), s);
|
||||
|
||||
// Negate the product.
|
||||
for c in m_times_s.coeffs_mut().iter_mut() {
|
||||
// Have to call the trait directly because deref is implemented on Torus
|
||||
*c = num::traits::WrappingNeg::wrapping_neg(c);
|
||||
}
|
||||
|
||||
&m_times_s
|
||||
} else {
|
||||
// Last row isn't multiplied by secret key.
|
||||
msg.as_torus()
|
||||
};
|
||||
|
||||
for j in 0..n_cols {
|
||||
let mut pt = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
let mut msg = m_times_s.to_owned();
|
||||
|
||||
let mask = (0x1 << decomposition_radix_log) - 1;
|
||||
|
||||
for c in msg.coeffs_mut() {
|
||||
*c = Torus::from(c.inner() & mask);
|
||||
}
|
||||
|
||||
decrypt_glwe_in_ggsw(&mut pt, &ct, &sk, ¶ms, &radix, i, j).unwrap();
|
||||
|
||||
assert_eq!(pt, msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
184
sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs
Normal file
184
sunscreen_tfhe/src/ops/encryption/glwe_encryption.rs
Normal file
@@ -0,0 +1,184 @@
|
||||
use num::Zero;
|
||||
|
||||
use crate::{
|
||||
entities::{GlweCiphertextRef, GlweSecretKeyRef, Polynomial, PolynomialRef},
|
||||
polynomial::{polynomial_add_assign, polynomial_external_mad, polynomial_sub_assign},
|
||||
rand::{normal_torus, uniform_torus},
|
||||
GlweDef, Torus, TorusOps,
|
||||
};
|
||||
|
||||
pub(crate) fn trivially_encrypt_glwe_with_sk_argument<S>(
|
||||
glwe_ciphertext: &mut GlweCiphertextRef<S>,
|
||||
msg: &PolynomialRef<Torus<S>>,
|
||||
_glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
trivially_encrypt_glwe_ciphertext(glwe_ciphertext, msg, params);
|
||||
}
|
||||
|
||||
/// Encrypt `msg` into a into the given GLWE ciphertext `c` using the secret key `sk.`
|
||||
pub fn encrypt_glwe_ciphertext_secret_generic<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
msg: &PolynomialRef<Torus<S>>,
|
||||
sk: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let mut tmp = Polynomial::zero(params.dim.polynomial_degree.0);
|
||||
|
||||
let (a, b) = c.a_b_mut(params);
|
||||
|
||||
// tmp = A_i * S_i
|
||||
for (a_i, s_i) in a.zip(sk.s(params)) {
|
||||
// Fill a_i with uniform data
|
||||
for c in a_i.coeffs_mut() {
|
||||
*c = uniform_torus();
|
||||
}
|
||||
|
||||
polynomial_external_mad(&mut tmp, a_i, s_i);
|
||||
}
|
||||
|
||||
// b = A * S
|
||||
polynomial_add_assign(b, &tmp);
|
||||
|
||||
// b = A * S + m
|
||||
polynomial_add_assign(b, msg);
|
||||
|
||||
let e = Polynomial::new(
|
||||
&(0..msg.len())
|
||||
.map(|_| normal_torus::<S>(params.std))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
// b = A * S + m + e
|
||||
polynomial_add_assign(b, &e);
|
||||
}
|
||||
|
||||
/// Encrypt `msg` into a into the given GLWE ciphertext `c` using the secret key `sk.`
|
||||
pub fn encrypt_glwe_ciphertext_secret<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
msg: &PolynomialRef<Torus<S>>,
|
||||
sk: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
encrypt_glwe_ciphertext_secret_generic(c, msg, sk, params)
|
||||
}
|
||||
|
||||
/// Generate a trivial GLWE encryption. Note that the caller will need to scale
|
||||
/// the message appropriately; a factor like delta is not automatically applied.
|
||||
pub fn trivially_encrypt_glwe_ciphertext<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
msg: &PolynomialRef<Torus<S>>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (a, b) = c.a_b_mut(params);
|
||||
|
||||
// tmp = A_i * S_i
|
||||
for a_i in a {
|
||||
// Fill a_i with zero data
|
||||
for c in a_i.coeffs_mut() {
|
||||
*c = Torus::zero();
|
||||
}
|
||||
}
|
||||
|
||||
// b = m
|
||||
b.clone_from_ref(msg);
|
||||
}
|
||||
|
||||
/// Decrypt GLWE ciphertext `ct` into `msg` using secret key `sk`.
|
||||
pub fn decrypt_glwe_ciphertext<S>(
|
||||
msg: &mut PolynomialRef<Torus<S>>,
|
||||
ct: &GlweCiphertextRef<S>,
|
||||
sk: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (a, b) = ct.a_b(params);
|
||||
|
||||
let mut tmp = Polynomial::zero(b.len());
|
||||
|
||||
// msg = b
|
||||
msg.clone_from_ref(b);
|
||||
|
||||
// tmp = A_i * S_i
|
||||
for (a_i, s_i) in a.zip(sk.s(params)) {
|
||||
polynomial_external_mad(&mut tmp, a_i, s_i);
|
||||
}
|
||||
|
||||
// msg = b - A * S = m + e
|
||||
polynomial_sub_assign(msg, &tmp);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{high_level::*, PlaintextBits};
|
||||
|
||||
use super::*;
|
||||
|
||||
// Encryption
|
||||
|
||||
#[test]
|
||||
fn can_encrypt_decrypt() {
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let plaintext = Polynomial::new(
|
||||
&(0..params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let ct = encryption::encrypt_glwe(&plaintext, &sk, ¶ms, bits);
|
||||
let dec = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(dec, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_encrypt_decrypt_uniform() {
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let plaintext = Polynomial::new(
|
||||
&(0..params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let ct = encryption::encrypt_glwe(&plaintext, &sk, ¶ms, bits);
|
||||
let dec = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(dec, plaintext);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trivial_glwe_decrypts() {
|
||||
let params = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let plaintext = Polynomial::new(
|
||||
&(0..params.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let ct = encryption::trivial_glwe(&plaintext, ¶ms, bits);
|
||||
let dec = encryption::decrypt_glwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(dec, plaintext);
|
||||
}
|
||||
}
|
||||
114
sunscreen_tfhe/src/ops/encryption/lwe_encryption.rs
Normal file
114
sunscreen_tfhe/src/ops/encryption/lwe_encryption.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
entities::{LweCiphertextRef, LweSecretKeyRef},
|
||||
math::{Torus, TorusOps},
|
||||
rand::{normal_torus, uniform_torus},
|
||||
LweDef, PlaintextBits,
|
||||
};
|
||||
|
||||
/// Generate a trivial GLWE encryption. Note that the caller will need to scale
|
||||
/// the message appropriately; a factor like delta is not automatically applied.
|
||||
pub fn trivially_encrypt_lwe_ciphertext<S>(
|
||||
c: &mut LweCiphertextRef<S>,
|
||||
msg: &Torus<S>,
|
||||
params: &LweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (a, b) = c.a_b_mut(params);
|
||||
|
||||
// tmp = A_i * S_i
|
||||
for a_i in a {
|
||||
*a_i = Torus::zero();
|
||||
}
|
||||
|
||||
// b = m
|
||||
*b = *msg;
|
||||
}
|
||||
|
||||
/// Encrypts the given message under sk, writing the ciphertext to ct. Returns the
|
||||
/// randomness used to generate the ciphertext.
|
||||
pub fn encrypt_lwe_ciphertext<S>(
|
||||
ct: &mut LweCiphertextRef<S>,
|
||||
sk: &LweSecretKeyRef<S>,
|
||||
msg: Torus<S>,
|
||||
params: &LweDef,
|
||||
) -> Torus<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (a, b) = ct.a_b_mut(params);
|
||||
|
||||
for (a_i, d_i) in a.iter_mut().zip(sk.as_slice().iter()) {
|
||||
*a_i = uniform_torus::<S>();
|
||||
*b += *a_i * d_i;
|
||||
}
|
||||
|
||||
let e = normal_torus(params.std);
|
||||
*b += msg + e;
|
||||
|
||||
e
|
||||
}
|
||||
|
||||
/// Encrypts the given message under sk, writing the ciphertext to ct. Returns the
|
||||
/// randomness used to generate the ciphertext.
|
||||
pub fn encode_and_encrypt_lwe_ciphertext<S>(
|
||||
ct: &mut LweCiphertextRef<S>,
|
||||
sk: &LweSecretKeyRef<S>,
|
||||
msg: S,
|
||||
params: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> Torus<S>
|
||||
where
|
||||
S: TorusOps,
|
||||
{
|
||||
let msg = Torus::<S>::encode(msg, plaintext_bits);
|
||||
|
||||
encrypt_lwe_ciphertext(ct, sk, msg, params)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{high_level::*, PlaintextBits};
|
||||
|
||||
#[test]
|
||||
fn can_encrypt_decrypt() {
|
||||
let params = TEST_LWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let ct = encryption::encrypt_lwe_secret(4, &sk, ¶ms, bits);
|
||||
let pt = encryption::decrypt_lwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(pt, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_encrypt_decrypt_uniform() {
|
||||
let params = TEST_LWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_uniform_lwe_sk(¶ms);
|
||||
|
||||
let ct = encryption::encrypt_lwe_secret(4, &sk, ¶ms, bits);
|
||||
let pt = encryption::decrypt_lwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(pt, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_trivially_decrypt() {
|
||||
let params = TEST_LWE_DEF_1;
|
||||
let bits = PlaintextBits(4);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let ct = encryption::trivial_lwe(4, ¶ms, bits);
|
||||
let pt = encryption::decrypt_lwe(&ct, &sk, ¶ms, bits);
|
||||
|
||||
assert_eq!(pt, 4);
|
||||
}
|
||||
}
|
||||
8
sunscreen_tfhe/src/ops/encryption/mod.rs
Normal file
8
sunscreen_tfhe/src/ops/encryption/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
mod lwe_encryption;
|
||||
pub use lwe_encryption::*;
|
||||
|
||||
mod glwe_encryption;
|
||||
pub use glwe_encryption::*;
|
||||
|
||||
mod ggsw_encryption;
|
||||
pub use ggsw_encryption::*;
|
||||
285
sunscreen_tfhe/src/ops/fft_ops.rs
Normal file
285
sunscreen_tfhe/src/ops/fft_ops.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
use num::Complex;
|
||||
|
||||
use crate::{
|
||||
dst::{FromMutSlice, OverlaySize},
|
||||
entities::{
|
||||
GgswCiphertextFftRef, GlevCiphertextFftRef, GlweCiphertextFftRef, GlweCiphertextRef,
|
||||
PolynomialFftRef, PolynomialRef,
|
||||
},
|
||||
ops::ciphertext::{add_glwe_ciphertexts, sub_glwe_ciphertexts},
|
||||
radix::PolynomialRadixIterator,
|
||||
scratch::{allocate_scratch, allocate_scratch_ref},
|
||||
GlweDef, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
/// Compute `c += a \[*] b`` where
|
||||
/// * `a` is a GLWE ciphertext
|
||||
/// * `b` is a GGSW cipheetext
|
||||
/// * `\[*\]` is the external product operator GGSW \[*\] GLWE -> GLWE`
|
||||
pub fn glwe_ggsw_mad<S>(
|
||||
c_fft: &mut GlweCiphertextFftRef<Complex<f64>>,
|
||||
a: &GlweCiphertextRef<S>,
|
||||
b_fft: &GgswCiphertextFftRef<Complex<f64>>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
let rows = b_fft.rows(params, radix);
|
||||
|
||||
// Generate an iterator that includes a_a and a_b
|
||||
let a_then_b_glwe_polynomials = a_a.chain(std::iter::once(a_b));
|
||||
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<S>, (params.dim.polynomial_degree));
|
||||
|
||||
// Performs the external operation
|
||||
//
|
||||
// GGSW ⊡ GLWE = sum_i=0^k <Decomp^{beta, l}(AB_i), C_i>
|
||||
//
|
||||
// where
|
||||
// * `beta` is the decomposition base
|
||||
// * `l` is the decomposition level
|
||||
// * `AB_i` is the i-th polynomial of the GLWE ciphertext with B at the end {A, B}
|
||||
// * `C_i` is the i-th row of the GGSW ciphertext
|
||||
for (a_i, r) in a_then_b_glwe_polynomials.zip(rows) {
|
||||
// For each a polynomial, compute the external product with the GLEV ciphertext
|
||||
// at index i in the GGSW ciphertext.
|
||||
let decomp = PolynomialRadixIterator::new(a_i, scratch, radix);
|
||||
|
||||
decomposed_polynomial_glev_mad(c_fft, decomp, r, params);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute `c += (G^-1 * a) \[*\] b`, where
|
||||
/// * `G^-1 * a`` is the radix decomposition of `a`
|
||||
/// * `b` is a GLEV ciphertext.
|
||||
/// * `c` is a GLWE ciphertext.
|
||||
/// * \[*\] is the external product between a GLEV ciphertext and `l` polynomials
|
||||
///
|
||||
/// # Remarks
|
||||
/// This functions takes a PolynomialRadixIterator to perform the decomposition.
|
||||
pub fn decomposed_polynomial_glev_mad<S>(
|
||||
c: &mut GlweCiphertextFftRef<Complex<f64>>,
|
||||
mut a: PolynomialRadixIterator<S>,
|
||||
b: &GlevCiphertextFftRef<Complex<f64>>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let b_glwe = b.glwe_ciphertexts(params);
|
||||
|
||||
let mut cur_radix = allocate_scratch::<S>(params.dim.polynomial_degree.0);
|
||||
let cur_radix = PolynomialRef::from_mut_slice(cur_radix.as_mut_slice());
|
||||
|
||||
let mut decomp_fft = allocate_scratch(PolynomialFftRef::<Complex<f64>>::size(
|
||||
params.dim.polynomial_degree,
|
||||
));
|
||||
let decomp_fft = PolynomialFftRef::from_mut_slice(decomp_fft.as_mut_slice());
|
||||
|
||||
// The decomposition of
|
||||
// <Decomp^{beta, l}(gamma), GLEV>
|
||||
// can be performed using
|
||||
// sum_{j = 1}^l gamma_j * C_j
|
||||
// where gamma_j is the polynomial to decompose multiplied by q/B^{j+1}
|
||||
// Note the reverse of the GLWE ciphertexts here! The decomposition iterator
|
||||
// returns the decomposed values in the opposite order.
|
||||
for b in b_glwe.rev() {
|
||||
a.write_next(cur_radix);
|
||||
cur_radix.fft(decomp_fft);
|
||||
|
||||
glwe_polynomial_mad(c, b, decomp_fft, params);
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute c += a \[*\] b where \[*\] is the external product
|
||||
/// between a GLWE ciphertext and an polynomial in Z\[X\]/(X^N + 1).
|
||||
///
|
||||
/// # Remarks
|
||||
/// For this to produce the correct result, degree(b) must be "small"
|
||||
/// (ideally 0) and the coefficient must be small (i.e. less than the
|
||||
/// message size).
|
||||
pub fn glwe_polynomial_mad(
|
||||
c: &mut GlweCiphertextFftRef<Complex<f64>>,
|
||||
a: &GlweCiphertextFftRef<Complex<f64>>,
|
||||
b: &PolynomialFftRef<Complex<f64>>,
|
||||
params: &GlweDef,
|
||||
) {
|
||||
let (c_a, c_b) = c.a_b_mut(params);
|
||||
let (a_a, a_b) = a.a_b(params);
|
||||
|
||||
assert_eq!(c_a.len(), params.dim.size.0);
|
||||
assert_eq!(a_a.len(), params.dim.size.0);
|
||||
|
||||
for (c, a) in c_a.zip(a_a) {
|
||||
c.multiply_add(a, b);
|
||||
}
|
||||
|
||||
c_b.multiply_add(a_b, b);
|
||||
}
|
||||
|
||||
/// Performs a CMUX operation, which enables one of two GLWE ciphertexts
|
||||
/// to be selected from an encrypted boolean GGSW ciphertext. The result
|
||||
/// is stored in `c`.
|
||||
///
|
||||
/// Conceptually, this can be seen as the following operation in Rust:
|
||||
///
|
||||
/// ```text
|
||||
/// let c = if b_fft { d_1 } else { d_0 }
|
||||
/// ```
|
||||
///
|
||||
/// where the output `c` is a different encryption than either of the initial
|
||||
/// inputs. Note that this will result in higher noise than in the original
|
||||
/// ciphertexts.
|
||||
pub fn cmux<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
d_0: &GlweCiphertextRef<S>,
|
||||
d_1: &GlweCiphertextRef<S>,
|
||||
b_fft: &GgswCiphertextFftRef<Complex<f64>>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
allocate_scratch_ref!(diff, GlweCiphertextRef<S>, (params.dim));
|
||||
|
||||
sub_glwe_ciphertexts(diff, d_1, d_0, params);
|
||||
|
||||
allocate_scratch_ref!(prod_fft, GlweCiphertextFftRef<Complex<f64>>, (params.dim));
|
||||
|
||||
prod_fft.clear();
|
||||
|
||||
glwe_ggsw_mad(prod_fft, diff, b_fft, params, radix);
|
||||
|
||||
allocate_scratch_ref!(prod, GlweCiphertextRef<S>, (params.dim));
|
||||
|
||||
prod_fft.ifft(prod, params);
|
||||
|
||||
add_glwe_ciphertexts(c, prod, d_0, params);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
entities::{GgswCiphertextFft, GlweCiphertext, GlweCiphertextFft, Polynomial},
|
||||
high_level::*,
|
||||
PlaintextBits, Torus,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_fft_external_product_glwe_ggsw() {
|
||||
let glwe_params = TEST_GLWE_DEF_1;
|
||||
let radix = TEST_RADIX;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
|
||||
for _ in 0..100 {
|
||||
let sel = thread_rng().next_u64() % 2;
|
||||
|
||||
let ggsw = encryption::encrypt_ggsw(sel, &sk, &glwe_params, &radix, bits);
|
||||
|
||||
let glwe_pt = (0..glwe_params.dim.polynomial_degree.0)
|
||||
.map(|_| thread_rng().next_u64() % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let glwe_pt = Polynomial::new(&glwe_pt);
|
||||
|
||||
let glwe = encryption::encrypt_glwe(&glwe_pt, &sk, &glwe_params, bits);
|
||||
|
||||
let mut ggsw_fft = GgswCiphertextFft::new(&glwe_params, &radix);
|
||||
let mut res_fft = GlweCiphertextFft::new(&glwe_params);
|
||||
let mut res = GlweCiphertext::<u64>::new(&glwe_params);
|
||||
|
||||
ggsw.fft(&mut ggsw_fft, &glwe_params, &radix);
|
||||
|
||||
glwe_ggsw_mad(&mut res_fft, &glwe, &ggsw_fft, &glwe_params, &radix);
|
||||
|
||||
res_fft.ifft(&mut res, &glwe_params);
|
||||
|
||||
let actual = encryption::decrypt_glwe(&res, &sk, &glwe_params, bits);
|
||||
|
||||
if sel == 1 {
|
||||
assert_eq!(actual, glwe_pt);
|
||||
} else {
|
||||
assert_eq!(
|
||||
actual,
|
||||
Polynomial::zero(glwe_params.dim.polynomial_degree.0)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_cmux_fft() {
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
let sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
let radix = TEST_RADIX;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
for _ in 0..100 {
|
||||
let sel = thread_rng().next_u64() % 2;
|
||||
|
||||
let sel_ct = encryption::encrypt_ggsw(sel, &sk, &glwe, &radix, bits);
|
||||
|
||||
let a = (0..glwe.dim.polynomial_degree.0)
|
||||
.map(|_| thread_rng().next_u64() % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let a = Polynomial::new(&a);
|
||||
|
||||
let a_ct = encryption::encrypt_glwe(&a, &sk, &glwe, bits);
|
||||
|
||||
let b = (0..glwe.dim.polynomial_degree.0)
|
||||
.map(|_| thread_rng().next_u64() % 2)
|
||||
.collect::<Vec<_>>();
|
||||
let b = Polynomial::new(&b);
|
||||
|
||||
let b_ct = encryption::encrypt_glwe(&b, &sk, &glwe, bits);
|
||||
|
||||
let sel_fft = fft::fft_ggsw(&sel_ct, &glwe, &radix);
|
||||
|
||||
let mut res_ct = GlweCiphertext::new(&glwe);
|
||||
|
||||
cmux(&mut res_ct, &a_ct, &b_ct, &sel_fft, &glwe, &radix);
|
||||
|
||||
let actual = encryption::decrypt_glwe(&res_ct, &sk, &glwe, bits);
|
||||
|
||||
if sel == 1 {
|
||||
assert_eq!(actual, b);
|
||||
} else {
|
||||
assert_eq!(actual, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cmux_trivial_ciphertexts_yields_nontrivial() {
|
||||
let sk = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
|
||||
|
||||
let plaintext_bits = crate::PlaintextBits(1);
|
||||
|
||||
let a = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Polynomial<_>>();
|
||||
let a = encryption::trivial_glwe(&a, &TEST_GLWE_DEF_1, plaintext_bits);
|
||||
let b = (0..TEST_GLWE_DEF_1.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| (x + 1) % 2)
|
||||
.collect::<Polynomial<_>>();
|
||||
let b = encryption::trivial_glwe(&b, &TEST_GLWE_DEF_1, plaintext_bits);
|
||||
|
||||
let sel = encryption::encrypt_ggsw(1, &sk, &TEST_GLWE_DEF_1, &TEST_RADIX, plaintext_bits);
|
||||
|
||||
let sel = fft::fft_ggsw(&sel, &TEST_GLWE_DEF_1, &TEST_RADIX);
|
||||
|
||||
let res = evaluation::cmux(&sel, &a, &b, &TEST_GLWE_DEF_1, &TEST_RADIX);
|
||||
|
||||
for a in res.a(&TEST_GLWE_DEF_1) {
|
||||
let zero = Polynomial::<Torus<u64>>::zero(TEST_GLWE_DEF_1.dim.polynomial_degree.0);
|
||||
|
||||
assert_ne!(a.to_owned(), zero);
|
||||
}
|
||||
}
|
||||
}
|
||||
67
sunscreen_tfhe/src/ops/homomorphisms/lwe.rs
Normal file
67
sunscreen_tfhe/src/ops/homomorphisms/lwe.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use crate::{entities::LweCiphertextRef, LweDef, Torus, TorusOps};
|
||||
|
||||
/// Add `amount` to each torus element (mod q) in the ciphertext.
|
||||
/// This shifts where messages lie on the torus and adds no noise.
|
||||
///
|
||||
/// # Remark
|
||||
/// Suppose we have plaintexts 0 and 1 that lie centered at 0 and q/2 respectively.
|
||||
/// If we rotate by q/4, then the 0 lies centered at q/4 and 1 lies at 3q/4 == -q/4.
|
||||
pub fn rotate<S: TorusOps>(
|
||||
output: &mut LweCiphertextRef<S>,
|
||||
input: &LweCiphertextRef<S>,
|
||||
amount: Torus<S>,
|
||||
lwe: &LweDef,
|
||||
) {
|
||||
output.assert_valid(lwe);
|
||||
input.assert_valid(lwe);
|
||||
|
||||
output.a_mut(lwe).clone_from_slice(input.a(lwe));
|
||||
*output.b_mut(lwe) = input.b(lwe) + amount;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
entities::LweCiphertext,
|
||||
high_level::{keygen, TEST_LWE_DEF_1},
|
||||
PlaintextBits, Torus,
|
||||
};
|
||||
|
||||
use super::rotate;
|
||||
|
||||
#[test]
|
||||
fn can_rotate() {
|
||||
let lwe_params = TEST_LWE_DEF_1;
|
||||
|
||||
for _ in 0..100 {
|
||||
let sk = keygen::generate_binary_lwe_sk(&lwe_params);
|
||||
let val = sk.encrypt(0, &lwe_params, PlaintextBits(1)).0;
|
||||
|
||||
let mut res = LweCiphertext::new(&lwe_params);
|
||||
|
||||
rotate(
|
||||
&mut res,
|
||||
&val,
|
||||
Torus::encode(1, PlaintextBits(2)),
|
||||
&lwe_params,
|
||||
);
|
||||
|
||||
let t = sk.decrypt_without_decode(&res, &lwe_params);
|
||||
|
||||
assert!(t.inner() < 0x1u64 << 63);
|
||||
|
||||
let val = sk.encrypt(1, &lwe_params, PlaintextBits(1)).0;
|
||||
|
||||
rotate(
|
||||
&mut res,
|
||||
&val,
|
||||
Torus::encode(1, PlaintextBits(2)),
|
||||
&lwe_params,
|
||||
);
|
||||
|
||||
let t = sk.decrypt_without_decode(&res, &lwe_params);
|
||||
|
||||
assert!(t.inner() > 0x1u64 << 63);
|
||||
}
|
||||
}
|
||||
}
|
||||
2
sunscreen_tfhe/src/ops/homomorphisms/mod.rs
Normal file
2
sunscreen_tfhe/src/ops/homomorphisms/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
mod lwe;
|
||||
pub use lwe::*;
|
||||
107
sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch.rs
Normal file
107
sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use crate::{
|
||||
dst::FromMutSlice,
|
||||
entities::{GlweCiphertext, GlweCiphertextRef, GlweKeyswitchKeyRef, PolynomialRef},
|
||||
ops::{
|
||||
ciphertext::{decomposed_polynomial_glev_mad, sub_glwe_ciphertexts},
|
||||
encryption::trivially_encrypt_glwe_ciphertext,
|
||||
},
|
||||
radix::PolynomialRadixIterator,
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
/// Switches a ciphertext under the original key to a ciphertext under the new
|
||||
/// key using a keyswitch key.
|
||||
///
|
||||
/// # Remark
|
||||
///
|
||||
/// This performs the following operation:
|
||||
///
|
||||
/// ```text
|
||||
/// switched_ciphertext = trivial_encrypt(ciphertext_b) - sum_i(<decomp(ciphertext_a_i), glev_i>)
|
||||
/// ```
|
||||
///
|
||||
/// where `trivial_encrypt` is the encryption of the body of the original
|
||||
/// ciphertext.
|
||||
pub fn keyswitch_glwe_to_glwe<S>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
ciphertext_under_original_key: &GlweCiphertextRef<S>,
|
||||
keyswitch_key: &GlweKeyswitchKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(params);
|
||||
|
||||
let keyswitch_glevs = keyswitch_key.rows(params, radix);
|
||||
|
||||
let mut a_i_decomp_sum = GlweCiphertext::new(params);
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<S>, (params.dim.polynomial_degree));
|
||||
|
||||
// sum_i(<decomp(ciphertext_a_i), glev_i>)
|
||||
for (a_i, glev_i) in ciphertext_a.zip(keyswitch_glevs) {
|
||||
let decomp = PolynomialRadixIterator::new(a_i, scratch, radix);
|
||||
|
||||
decomposed_polynomial_glev_mad(&mut a_i_decomp_sum, decomp, glev_i, params);
|
||||
}
|
||||
|
||||
// trivial_encrypt(ciphertext_b)
|
||||
let mut trivial_b = GlweCiphertext::new(params);
|
||||
trivially_encrypt_glwe_ciphertext(&mut trivial_b, ciphertext_b, params);
|
||||
|
||||
// output = trivial_encrypt(ciphertext_b) - sum_i(<decomp(ciphertext_a_i), glev_i>)
|
||||
sub_glwe_ciphertexts(output, &trivial_b, &a_i_decomp_sum, params);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{
|
||||
entities::{GlweCiphertext, GlweKeyswitchKey, Polynomial},
|
||||
high_level::*,
|
||||
ops::keyswitch::{
|
||||
glwe_keyswitch::keyswitch_glwe_to_glwe, glwe_keyswitch_key::generate_keyswitch_key_glwe,
|
||||
},
|
||||
PlaintextBits,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn keyswitch_glwe() {
|
||||
let glwe = TEST_GLWE_DEF_1;
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let original_sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
let new_sk = keygen::generate_binary_glwe_sk(&glwe);
|
||||
|
||||
let mut ksk = GlweKeyswitchKey::<u64>::new(&TEST_GLWE_DEF_1, &TEST_RADIX);
|
||||
generate_keyswitch_key_glwe(
|
||||
&mut ksk,
|
||||
&original_sk,
|
||||
&new_sk,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
);
|
||||
|
||||
let msg = Polynomial::new(
|
||||
&(0..glwe.dim.polynomial_degree.0 as u64)
|
||||
.map(|x| x % 2)
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
|
||||
let original_ct = original_sk.encode_encrypt_glwe(&msg, &glwe, bits);
|
||||
|
||||
let mut new_ct = GlweCiphertext::new(&glwe);
|
||||
keyswitch_glwe_to_glwe(
|
||||
&mut new_ct,
|
||||
&original_ct,
|
||||
&ksk,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
);
|
||||
|
||||
let new_decrypted = new_sk.decrypt_decode_glwe(&new_ct, &glwe, bits);
|
||||
|
||||
assert_eq!(new_decrypted.coeffs(), msg.coeffs());
|
||||
}
|
||||
}
|
||||
150
sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch_key.rs
Normal file
150
sunscreen_tfhe/src/ops/keyswitch/glwe_keyswitch_key.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use crate::{
|
||||
entities::{
|
||||
GlweCiphertextRef, GlweKeyswitchKeyRef, GlweSecretKeyRef, Polynomial, PolynomialRef,
|
||||
},
|
||||
ops::encryption::encrypt_glwe_ciphertext_secret_generic,
|
||||
polynomial::polynomial_scalar_mul,
|
||||
GlweDef, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
fn encrypt_glwe_ciphertext_secret_with_keyswitch_noise<S>(
|
||||
c: &mut GlweCiphertextRef<S>,
|
||||
msg: &PolynomialRef<Torus<S>>,
|
||||
sk: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
encrypt_glwe_ciphertext_secret_generic(c, msg, sk, params);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a keyswitch key from the original key to the new key. The resulting
|
||||
* keyswitch key is encrypted under the new key. This function is generic over
|
||||
* the encrypt function should there be a need.
|
||||
*
|
||||
* The specific operation on each GLev ciphertext row inside a keyswitch key is
|
||||
*
|
||||
* ```text
|
||||
* KSK_i = (GLWE_{s', })
|
||||
* ```
|
||||
*/
|
||||
fn encrypt_keyswitch_key_generic<S>(
|
||||
keyswitch_key: &mut GlweKeyswitchKeyRef<S>,
|
||||
original_glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
new_glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
encrypt: impl Fn(
|
||||
&mut GlweCiphertextRef<S>,
|
||||
&PolynomialRef<Torus<S>>,
|
||||
&GlweSecretKeyRef<S>,
|
||||
&GlweDef,
|
||||
),
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
let polynomial_degree = params.dim.polynomial_degree.0;
|
||||
|
||||
for (i, row) in keyswitch_key.rows_mut(params, radix).enumerate() {
|
||||
let s = original_glwe_secret_key
|
||||
.s(params)
|
||||
.nth(i)
|
||||
.unwrap()
|
||||
.map(|x| Torus::from(*x));
|
||||
|
||||
for (j, col) in row.glwe_ciphertexts_mut(params).enumerate() {
|
||||
let mut scaled_original_key = Polynomial::zero(polynomial_degree);
|
||||
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
|
||||
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
|
||||
let decomp_factor =
|
||||
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
|
||||
|
||||
polynomial_scalar_mul(&mut scaled_original_key, &s, decomp_factor);
|
||||
|
||||
encrypt(col, &scaled_original_key, new_glwe_secret_key, params);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a keyswitch key from the original key to the new key.
|
||||
/// For use with
|
||||
/// [`keyswitch_glwe_to_glwe`](crate::ops::keyswitch::glwe_keyswitch::keyswitch_glwe_to_glwe).
|
||||
pub fn generate_keyswitch_key_glwe<S>(
|
||||
keyswitch_key: &mut GlweKeyswitchKeyRef<S>,
|
||||
original_glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
new_glwe_secret_key: &GlweSecretKeyRef<S>,
|
||||
params: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
encrypt_keyswitch_key_generic(
|
||||
keyswitch_key,
|
||||
original_glwe_secret_key,
|
||||
new_glwe_secret_key,
|
||||
params,
|
||||
radix,
|
||||
encrypt_glwe_ciphertext_secret_with_keyswitch_noise,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use crate::{
|
||||
dst::FromSlice,
|
||||
entities::{GlweKeyswitchKey, GlweKeyswitchKeyRef},
|
||||
high_level::{TEST_GLWE_DEF_1, TEST_RADIX},
|
||||
Torus,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_generate_keyswitch_key_glwe() {
|
||||
// Size of the arrary should be (k + 1) * l * k * poly_degree as we are
|
||||
// GLev encrypting all S_i with a given radix count and base.
|
||||
let ksk = GlweKeyswitchKey::<u64>::new(&TEST_GLWE_DEF_1, &TEST_RADIX);
|
||||
|
||||
let k = TEST_GLWE_DEF_1.dim.size.0;
|
||||
let l = TEST_RADIX.count.0;
|
||||
let poly_degree = TEST_GLWE_DEF_1.dim.polynomial_degree.0;
|
||||
assert_eq!(ksk.as_slice().len(), (k + 1) * l * k * poly_degree);
|
||||
|
||||
// Generate fake data to iterate through.
|
||||
let ksk_data = (0..ksk.as_slice().len())
|
||||
.map(|x| Torus::from(x as u64))
|
||||
.collect::<Vec<Torus<u64>>>();
|
||||
|
||||
let ksk = GlweKeyswitchKeyRef::<u64>::from_slice(&ksk_data);
|
||||
|
||||
// Check that the data is correct.
|
||||
let glwe_size = TEST_GLWE_DEF_1.dim.polynomial_degree.0 * (TEST_GLWE_DEF_1.dim.size.0 + 1);
|
||||
let mut count = 0;
|
||||
for row in ksk.rows(&TEST_GLWE_DEF_1, &TEST_RADIX) {
|
||||
for glwe in row.glwe_ciphertexts(&TEST_GLWE_DEF_1) {
|
||||
let (a, b) = glwe.a_b(&TEST_GLWE_DEF_1);
|
||||
|
||||
let expected_vector = (count..(glwe_size + count))
|
||||
.map(|x| Torus::from(x as u64))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let (a_expected, b_expected) = expected_vector
|
||||
.split_at(TEST_GLWE_DEF_1.dim.size.0 * TEST_GLWE_DEF_1.dim.polynomial_degree.0);
|
||||
|
||||
let a_i_expected = a_expected
|
||||
.chunks(TEST_GLWE_DEF_1.dim.polynomial_degree.0)
|
||||
.map(|x| x.to_vec())
|
||||
.collect::<Vec<Vec<Torus<u64>>>>();
|
||||
|
||||
for (a_j, a_j_expected) in a.zip(a_i_expected) {
|
||||
assert_eq!(a_j.coeffs(), a_j_expected);
|
||||
}
|
||||
|
||||
assert_eq!(b.coeffs(), b_expected);
|
||||
|
||||
count += glwe_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
91
sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch.rs
Normal file
91
sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use crate::{
|
||||
dst::{FromMutSlice, FromSlice},
|
||||
entities::{LweCiphertext, LweCiphertextRef, LweKeyswitchKeyRef, PolynomialRef},
|
||||
ops::{
|
||||
ciphertext::{decomposed_scalar_lev_mad, sub_lwe_ciphertexts},
|
||||
encryption::trivially_encrypt_lwe_ciphertext,
|
||||
},
|
||||
radix::PolynomialRadixIterator,
|
||||
scratch::allocate_scratch_ref,
|
||||
LweDef, PolynomialDegree, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
/// Switches a ciphertext under the original key to a ciphertext under the new
|
||||
/// key using a keyswitch key.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// * output: the output ciphertext
|
||||
/// * ciphertext_under_original_key: the input ciphertext
|
||||
/// * keyswitch_key: the keyswitch key
|
||||
/// * old_params: the parameters of the original ciphertext
|
||||
/// * new_params: the parameters of the output ciphertext
|
||||
pub fn keyswitch_lwe_to_lwe<S>(
|
||||
output: &mut LweCiphertextRef<S>,
|
||||
ciphertext_under_original_key: &LweCiphertextRef<S>,
|
||||
keyswitch_key: &LweKeyswitchKeyRef<S>,
|
||||
old_params: &LweDef,
|
||||
new_params: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
keyswitch_key.assert_valid(old_params, new_params, radix);
|
||||
|
||||
let (ciphertext_a, ciphertext_b) = ciphertext_under_original_key.a_b(old_params);
|
||||
|
||||
let keyswitch_levs = keyswitch_key.rows(new_params, radix);
|
||||
|
||||
let mut a_i_decomp_sum = LweCiphertext::new(new_params);
|
||||
|
||||
allocate_scratch_ref!(scratch, PolynomialRef<S>, (PolynomialDegree(1)));
|
||||
|
||||
// sum_i(<decomp(ciphertext_a_i), lev_i>)
|
||||
for (a_i, lev_i) in ciphertext_a.iter().zip(keyswitch_levs) {
|
||||
let decomp =
|
||||
PolynomialRadixIterator::new(PolynomialRef::from_slice(&[*a_i]), scratch, radix);
|
||||
|
||||
decomposed_scalar_lev_mad(&mut a_i_decomp_sum, decomp, lev_i, new_params);
|
||||
}
|
||||
|
||||
// trivial_encrypt(ciphertext_b)
|
||||
let mut trivial_b = LweCiphertext::new(new_params);
|
||||
trivially_encrypt_lwe_ciphertext(&mut trivial_b, ciphertext_b, new_params);
|
||||
|
||||
// output = trivial_encrypt(ciphertext_b) - sum_i(<decomp(ciphertext_a_i), glev_i>)
|
||||
sub_lwe_ciphertexts(output, &trivial_b, &a_i_decomp_sum, new_params);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{high_level::*, PlaintextBits};
|
||||
|
||||
#[test]
|
||||
fn keyswitch_lwe() {
|
||||
let bits = PlaintextBits(4);
|
||||
let from_lwe = TEST_LWE_DEF_1;
|
||||
let to_lwe = TEST_LWE_DEF_2;
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
for _ in 0..50 {
|
||||
let original_sk = keygen::generate_binary_lwe_sk(&from_lwe);
|
||||
let new_sk = keygen::generate_binary_lwe_sk(&to_lwe);
|
||||
|
||||
let ksk = keygen::generate_ksk(&original_sk, &new_sk, &from_lwe, &to_lwe, &radix);
|
||||
|
||||
let msg = thread_rng().next_u64() % (1 << bits.0);
|
||||
|
||||
let original_ct = original_sk.encrypt(msg, &from_lwe, bits).0;
|
||||
|
||||
let new_ct =
|
||||
evaluation::keyswitch_lwe_to_lwe(&original_ct, &ksk, &from_lwe, &to_lwe, &radix);
|
||||
|
||||
let new_decrypted = new_sk.decrypt(&new_ct, &to_lwe, bits);
|
||||
|
||||
assert_eq!(new_decrypted, msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
41
sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch_key.rs
Normal file
41
sunscreen_tfhe/src/ops/keyswitch/lwe_keyswitch_key.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use crate::{
|
||||
entities::{LweKeyswitchKeyRef, LweSecretKeyRef},
|
||||
ops::encryption::encrypt_lwe_ciphertext,
|
||||
LweDef, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
/// Generates a keyswitch key from an original LWE key to a new LWE key. The
|
||||
/// resulting keyswitch key is encrypted under the new key.
|
||||
///
|
||||
/// Arguments:
|
||||
///
|
||||
/// * keyswitch_key: the resulting keyswitch key
|
||||
/// * original_lwe_secret_key: the original LWE secret key
|
||||
/// * new_lwe_secret_key: the new LWE secret key
|
||||
/// * new_params: the parameters of the new LWE secret key
|
||||
pub fn generate_keyswitch_key_lwe<S>(
|
||||
keyswitch_key: &mut LweKeyswitchKeyRef<S>,
|
||||
original_lwe_secret_key: &LweSecretKeyRef<S>,
|
||||
new_lwe_secret_key: &LweSecretKeyRef<S>,
|
||||
new_params: &LweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
{
|
||||
let decomposition_radix_log = radix.radix_log.0;
|
||||
|
||||
for (i, row) in keyswitch_key.rows_mut(new_params, radix).enumerate() {
|
||||
let s_i = original_lwe_secret_key.s()[i];
|
||||
|
||||
for (j, col) in row.lwe_ciphertexts_mut(new_params).enumerate() {
|
||||
// The factor is q / B^{i+1}. Since B is a power of 2, this is equivalent to
|
||||
// multiplying by 2^{log2(q) - log2(B) * (i + 1)}
|
||||
let decomp_factor =
|
||||
S::from_u64(0x1 << (S::BITS as usize - decomposition_radix_log * (j + 1)));
|
||||
|
||||
let msg = decomp_factor * s_i;
|
||||
|
||||
encrypt_lwe_ciphertext(col, new_lwe_secret_key, Torus::from(msg), new_params);
|
||||
}
|
||||
}
|
||||
}
|
||||
17
sunscreen_tfhe/src/ops/keyswitch/mod.rs
Normal file
17
sunscreen_tfhe/src/ops/keyswitch/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
/// Methods for performing a private functional keyswitch (PFKS)
|
||||
pub mod private_functional_keyswitch;
|
||||
|
||||
/// Methods for performing a public functional keyswitch (PuFKS)
|
||||
pub mod public_functional_keyswitch;
|
||||
|
||||
/// Generate LWE keyswitch keys.
|
||||
pub mod lwe_keyswitch_key;
|
||||
|
||||
/// Generate GLWE keyswitch keys.
|
||||
pub mod glwe_keyswitch_key;
|
||||
|
||||
/// Methods for performing a LWE keyswitch.
|
||||
pub mod lwe_keyswitch;
|
||||
|
||||
/// Methods for performing a GLWE keyswitch.
|
||||
pub mod glwe_keyswitch;
|
||||
350
sunscreen_tfhe/src/ops/keyswitch/private_functional_keyswitch.rs
Normal file
350
sunscreen_tfhe/src/ops/keyswitch/private_functional_keyswitch.rs
Normal file
@@ -0,0 +1,350 @@
|
||||
use sunscreen_math::Zero;
|
||||
|
||||
use crate::{
|
||||
dst::FromMutSlice,
|
||||
entities::{
|
||||
CircuitBootstrappingKeyswitchKeysRef, GlweCiphertextRef, GlweSecretKeyRef,
|
||||
LweCiphertextRef, LweSecretKeyRef, PolynomialRef, PrivateFunctionalKeyswitchKeyRef,
|
||||
},
|
||||
ops::{
|
||||
ciphertext::{decomposed_scalar_glev_mad, glwe_negate_inplace},
|
||||
encryption::encrypt_glwe_ciphertext_secret,
|
||||
},
|
||||
radix::{scale_by_decomposition_factor, ScalarRadixIterator},
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, LweDef, PrivateFunctionalKeyswitchLweCount, RadixDecomposition, Torus, TorusOps,
|
||||
};
|
||||
|
||||
/// Initialize `output`, a
|
||||
/// [`PrivateFunctionalKeyswitchKey`](crate::entities::PrivateFunctionalKeyswitchKey),
|
||||
/// under the given scheme parameters for the given secret mapping `map`.
|
||||
/// Conceptually, this map transforms a list of torus plaintexts into a
|
||||
/// polynomial plaintext.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `map` must be an R-Lipschitzian morphism `T_q^p -> T_q[X]` where `p = lwe_count`.
|
||||
///
|
||||
/// The first parameter in `map` is the output
|
||||
/// [`Polynomial`](crate::entities::Polynomial) of the morphism. This parameter
|
||||
/// is initialized to 0.
|
||||
///
|
||||
/// The second argument is a [`&[Torus]`](crate::math::Torus) of length `lwe_count`.
|
||||
///
|
||||
/// # Security
|
||||
/// To prevent side channels, `map` must run in constant time.
|
||||
///
|
||||
/// # Panics
|
||||
/// * If `output` is not valid for the given `from_lwe`, `to_glwe`, `radix`, `lwe_count`.
|
||||
/// * If any of `from_lwe`, `to_glwe`, `radix`, `lwe_count` are invalid.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn generate_private_functional_keyswitch_key<S, F>(
|
||||
output: &mut PrivateFunctionalKeyswitchKeyRef<S>,
|
||||
from_key: &LweSecretKeyRef<S>,
|
||||
to_key: &GlweSecretKeyRef<S>,
|
||||
map: F,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
lwe_count: &PrivateFunctionalKeyswitchLweCount,
|
||||
) where
|
||||
S: TorusOps,
|
||||
F: Fn(&mut PolynomialRef<Torus<S>>, &[Torus<S>]),
|
||||
{
|
||||
output.assert_valid(from_lwe, to_glwe, radix, lwe_count);
|
||||
radix.assert_valid::<S>();
|
||||
from_key.assert_valid(from_lwe);
|
||||
to_key.assert_valid(to_glwe);
|
||||
to_glwe.assert_valid();
|
||||
from_lwe.assert_valid();
|
||||
lwe_count.assert_valid();
|
||||
|
||||
allocate_scratch_ref!(
|
||||
pt_poly,
|
||||
PolynomialRef<Torus<S>>,
|
||||
(to_glwe.dim.polynomial_degree)
|
||||
);
|
||||
allocate_scratch_ref!(pt_touri, [Torus<S>], lwe_count.0);
|
||||
|
||||
let mut glevs = output.glevs_mut(to_glwe, radix);
|
||||
let minus_one = <S as Zero>::zero().wrapping_sub(&<S as num::One>::one());
|
||||
|
||||
for z in 0..lwe_count.0 {
|
||||
for s_i in from_key.s().iter().chain([minus_one].iter()) {
|
||||
let glev = glevs.next().unwrap();
|
||||
|
||||
for (j, glwe) in glev.glwe_ciphertexts_mut(to_glwe).enumerate() {
|
||||
let scaled_s_i = scale_by_decomposition_factor(*s_i, j, radix);
|
||||
|
||||
pt_poly.clear();
|
||||
pt_touri.iter_mut().for_each(|x| *x = Torus::zero());
|
||||
|
||||
pt_touri[z] = Torus::from(scaled_s_i);
|
||||
|
||||
map(pt_poly, pt_touri);
|
||||
|
||||
encrypt_glwe_ciphertext_secret(glwe, pt_poly, to_key, to_glwe);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a private functional keyswitch. See
|
||||
/// [`module`](crate::ops::keyswitch::private_functional_keyswitch) documentation for more
|
||||
/// details.
|
||||
pub fn private_functional_keyswitch<S: TorusOps>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
inputs: &[&LweCiphertextRef<S>],
|
||||
pfksk: &PrivateFunctionalKeyswitchKeyRef<S>,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
lwe_count: &PrivateFunctionalKeyswitchLweCount,
|
||||
) {
|
||||
output.assert_valid(to_glwe);
|
||||
pfksk.assert_valid(from_lwe, to_glwe, radix, lwe_count);
|
||||
from_lwe.assert_valid();
|
||||
to_glwe.assert_valid();
|
||||
radix.assert_valid::<S>();
|
||||
lwe_count.assert_valid();
|
||||
|
||||
assert_eq!(lwe_count.0, inputs.len());
|
||||
|
||||
let mut ksk_glevs = pfksk.glevs(to_glwe, radix);
|
||||
|
||||
for input in inputs.iter() {
|
||||
for i in 0..from_lwe.dim.0 + 1 {
|
||||
// Treating the z'th ciphertext as slice of length n + 1 allows us to iterate
|
||||
// over a || b.
|
||||
let ab = input.as_slice();
|
||||
|
||||
let glev = ksk_glevs.next().unwrap();
|
||||
let decomp = ScalarRadixIterator::new(ab[i], radix);
|
||||
|
||||
decomposed_scalar_glev_mad(output, decomp, glev, to_glwe);
|
||||
}
|
||||
}
|
||||
|
||||
// Return minus output.
|
||||
glwe_negate_inplace(output, to_glwe);
|
||||
}
|
||||
|
||||
/// Generate the keys for a private functional keyswitch.
|
||||
pub fn generate_circuit_bootstrapping_pfks_keys<S: TorusOps>(
|
||||
output: &mut CircuitBootstrappingKeyswitchKeysRef<S>,
|
||||
from_key: &LweSecretKeyRef<S>,
|
||||
to_key: &GlweSecretKeyRef<S>,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
output.assert_valid(from_lwe, to_glwe, radix);
|
||||
from_key.assert_valid(from_lwe);
|
||||
to_glwe.assert_valid();
|
||||
to_key.assert_valid(to_glwe);
|
||||
radix.assert_valid::<S>();
|
||||
from_lwe.assert_valid();
|
||||
|
||||
// Fill in k pfks keys that multiply each of the "a" GLEVs by the corresponding
|
||||
// polynomial in the GLWE secret key.
|
||||
for (pfksk, s) in output
|
||||
.keys_mut(from_lwe, to_glwe, radix)
|
||||
.zip(to_key.s(to_glwe))
|
||||
.take(to_glwe.dim.size.0)
|
||||
{
|
||||
let map = |poly: &mut PolynomialRef<Torus<S>>, x: &[Torus<S>]| {
|
||||
for (c, a) in poly.coeffs_mut().iter_mut().zip(s.coeffs().iter()) {
|
||||
*c = -x[0] * a;
|
||||
}
|
||||
};
|
||||
|
||||
generate_private_functional_keyswitch_key(
|
||||
pfksk,
|
||||
from_key,
|
||||
to_key,
|
||||
map,
|
||||
from_lwe,
|
||||
to_glwe,
|
||||
radix,
|
||||
&PrivateFunctionalKeyswitchLweCount(1),
|
||||
);
|
||||
}
|
||||
|
||||
// Now fill in the "b" GLEV.
|
||||
// TODO: We could compute this row with public key switching. Is it worth it?
|
||||
let b = output
|
||||
.keys_mut(from_lwe, to_glwe, radix)
|
||||
.nth(to_glwe.dim.size.0)
|
||||
.unwrap();
|
||||
|
||||
let map = |poly: &mut PolynomialRef<Torus<S>>, x: &[Torus<S>]| {
|
||||
poly.clear();
|
||||
poly.coeffs_mut()[0] = x[0];
|
||||
};
|
||||
|
||||
generate_private_functional_keyswitch_key(
|
||||
b,
|
||||
from_key,
|
||||
to_key,
|
||||
map,
|
||||
from_lwe,
|
||||
to_glwe,
|
||||
radix,
|
||||
&PrivateFunctionalKeyswitchLweCount(1),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
entities::{GlweCiphertext, PrivateFunctionalKeyswitchKey},
|
||||
high_level::{keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX},
|
||||
PlaintextBits, PrivateFunctionalKeyswitchLweCount,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_create_private_functional_keyswitch_key() {
|
||||
for _ in 0..5 {
|
||||
let lwe_count =
|
||||
PrivateFunctionalKeyswitchLweCount((thread_rng().next_u64() as usize % 8) + 1);
|
||||
|
||||
let lwe_key = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
|
||||
let glwe_key = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
|
||||
|
||||
let mut pfks_key = PrivateFunctionalKeyswitchKey::<u64>::new(
|
||||
&TEST_LWE_DEF_1,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
&lwe_count,
|
||||
);
|
||||
|
||||
fn map<S: TorusOps>(poly: &mut PolynomialRef<Torus<S>>, inputs: &[Torus<S>]) {
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
poly.coeffs_mut()[i] = *input;
|
||||
}
|
||||
}
|
||||
|
||||
generate_private_functional_keyswitch_key(
|
||||
&mut pfks_key,
|
||||
&lwe_key,
|
||||
&glwe_key,
|
||||
map,
|
||||
&TEST_LWE_DEF_1,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
&lwe_count,
|
||||
);
|
||||
|
||||
let mut glevs = pfks_key.glevs(&TEST_GLWE_DEF_1, &TEST_RADIX);
|
||||
|
||||
let minus_one = u64::MAX;
|
||||
|
||||
for z in 0..lwe_count.0 {
|
||||
for s_i in lwe_key.s().iter().chain([minus_one].iter()) {
|
||||
let glev = glevs.next().unwrap();
|
||||
|
||||
for (j, glwe) in glev.glwe_ciphertexts(&TEST_GLWE_DEF_1).enumerate() {
|
||||
let plaintext_bits = (j + 1) * TEST_RADIX.radix_log.0;
|
||||
let plaintext_bits = PlaintextBits(plaintext_bits as u32);
|
||||
|
||||
let pt =
|
||||
glwe_key.decrypt_decode_glwe(glwe, &TEST_GLWE_DEF_1, plaintext_bits);
|
||||
|
||||
for i in 0..pt.coeffs().len() {
|
||||
if *s_i == minus_one && i == z {
|
||||
let expected = (0x1u64 << ((j + 1) * TEST_RADIX.radix_log.0)) - 1;
|
||||
let actual = pt.coeffs()[i];
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
} else if i == z {
|
||||
assert_eq!(pt.coeffs()[i], *s_i);
|
||||
} else {
|
||||
assert_eq!(pt.coeffs()[i], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_private_functional_keyswitch() {
|
||||
for _ in 0..5 {
|
||||
let lwe_count =
|
||||
PrivateFunctionalKeyswitchLweCount((thread_rng().next_u64() as usize % 8) + 1);
|
||||
|
||||
let lwe_key = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
|
||||
let glwe_key = keygen::generate_binary_glwe_sk(&TEST_GLWE_DEF_1);
|
||||
|
||||
let mut pfks_key = PrivateFunctionalKeyswitchKey::<u64>::new(
|
||||
&TEST_LWE_DEF_1,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
&lwe_count,
|
||||
);
|
||||
|
||||
fn map<S: TorusOps>(poly: &mut PolynomialRef<Torus<S>>, inputs: &[Torus<S>]) {
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
poly.coeffs_mut()[i] = *input;
|
||||
}
|
||||
}
|
||||
|
||||
generate_private_functional_keyswitch_key(
|
||||
&mut pfks_key,
|
||||
&lwe_key,
|
||||
&glwe_key,
|
||||
map,
|
||||
&TEST_LWE_DEF_1,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
&lwe_count,
|
||||
);
|
||||
|
||||
let plaintext_bits = PlaintextBits(4);
|
||||
|
||||
let pts = (0..lwe_count.0)
|
||||
.map(|_x| thread_rng().next_u64() % (0x1u64 << plaintext_bits.0))
|
||||
.collect::<Vec<u64>>();
|
||||
let lwe_cts = pts
|
||||
.iter()
|
||||
.map(|x| lwe_key.encrypt(*x, &TEST_LWE_DEF_1, plaintext_bits))
|
||||
.collect::<Vec<_>>();
|
||||
let mut lwe_ct_refs: Vec<&LweCiphertextRef<_>> = vec![];
|
||||
|
||||
for ct in lwe_cts.iter() {
|
||||
lwe_ct_refs.push(&ct.0);
|
||||
}
|
||||
|
||||
let mut result = GlweCiphertext::new(&TEST_GLWE_DEF_1);
|
||||
|
||||
private_functional_keyswitch(
|
||||
&mut result,
|
||||
&lwe_ct_refs,
|
||||
&pfks_key,
|
||||
&TEST_LWE_DEF_1,
|
||||
&TEST_GLWE_DEF_1,
|
||||
&TEST_RADIX,
|
||||
&lwe_count,
|
||||
);
|
||||
|
||||
let actual = glwe_key.decrypt_decode_glwe(&result, &TEST_GLWE_DEF_1, plaintext_bits);
|
||||
|
||||
for (i, (c, pt)) in actual
|
||||
.coeffs()
|
||||
.iter()
|
||||
.zip(pts.iter().cycle().take(actual.coeffs().len()))
|
||||
.enumerate()
|
||||
{
|
||||
if i < lwe_count.0 {
|
||||
assert_eq!(*c, *pt);
|
||||
} else {
|
||||
assert_eq!(*c, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
264
sunscreen_tfhe/src/ops/keyswitch/public_functional_keyswitch.rs
Normal file
264
sunscreen_tfhe/src/ops/keyswitch/public_functional_keyswitch.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use num::Complex;
|
||||
|
||||
use crate::dst::FromMutSlice;
|
||||
use crate::entities::{
|
||||
GlevCiphertextFftRef, GlweCiphertextRef, LweCiphertextRef, LweSecretKeyRef, PolynomialRef,
|
||||
};
|
||||
use crate::ops::ciphertext::glwe_negate_inplace;
|
||||
use crate::ops::encryption::encrypt_glwe_ciphertext_secret;
|
||||
use crate::ops::fft_ops::decomposed_polynomial_glev_mad;
|
||||
use crate::polynomial::polynomial_add_assign;
|
||||
use crate::radix::PolynomialRadixIterator;
|
||||
use crate::scratch::allocate_scratch;
|
||||
use crate::Torus;
|
||||
use crate::{
|
||||
entities::{GlweCiphertextFftRef, GlweSecretKeyRef, PublicFunctionalKeyswitchKeyRef},
|
||||
radix::scale_by_decomposition_factor,
|
||||
scratch::allocate_scratch_ref,
|
||||
GlweDef, LweDef, RadixDecomposition, TorusOps,
|
||||
};
|
||||
|
||||
/// Generate a public functional keyswitch key, which is used to transform a
|
||||
/// list of LWE ciphertexts into a GLWE ciphertext while applying a provided
|
||||
/// function that converts the scalars in the LWE ciphertexts to the polynomial
|
||||
/// message space in the GLWE ciphertext.
|
||||
///
|
||||
/// See
|
||||
/// [`public_functional_keyswitch`](crate::ops::keyswitch::public_functional_keyswitch)
|
||||
/// for more details.
|
||||
pub fn generate_public_functional_keyswitch_key<S: TorusOps>(
|
||||
output: &mut PublicFunctionalKeyswitchKeyRef<S>,
|
||||
from_sk: &LweSecretKeyRef<S>,
|
||||
to_sk: &GlweSecretKeyRef<S>,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) {
|
||||
from_sk.assert_valid(from_lwe);
|
||||
to_sk.assert_valid(to_glwe);
|
||||
output.assert_valid(from_lwe, to_glwe, radix);
|
||||
|
||||
allocate_scratch_ref!(pt, PolynomialRef<Torus<S>>, (to_glwe.dim.polynomial_degree));
|
||||
pt.clear();
|
||||
|
||||
for (s_i, glev) in from_sk.s().iter().zip(output.glevs_mut(to_glwe, radix)) {
|
||||
for (j, glwe_ct) in (0..radix.count.0).zip(glev.glwe_ciphertexts_mut(to_glwe)) {
|
||||
let x = scale_by_decomposition_factor(*s_i, j, radix);
|
||||
|
||||
pt.coeffs_mut()[0] = Torus::from(x);
|
||||
|
||||
encrypt_glwe_ciphertext_secret(glwe_ct, pt, to_sk, to_glwe);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a public functional keyswitch, where a list of LWE ciphertexts are
|
||||
/// transformed into a GLWE ciphertext while applying a provided function.
|
||||
/// Conceptually, this map transforms a list of torus plaintexts into a
|
||||
/// polynomial plaintext.
|
||||
///
|
||||
/// This operation is called "public" because the function F is public; a
|
||||
/// variant of this operation called
|
||||
/// [`private_functional_keyswitch`](crate::ops::keyswitch::private_functional_keyswitch)
|
||||
/// is also available, where the function F is secret encoded in the keyswitch
|
||||
/// key.
|
||||
///
|
||||
/// # Remarks
|
||||
/// `map` must be an R-Lipschitzian morphism `T_q^p -> T_q[X]` where `p = lwe_count`.
|
||||
///
|
||||
/// The first parameter in `map` is the output
|
||||
/// [`Polynomial`](crate::entities::Polynomial) of the morphism. This parameter
|
||||
/// is initialized to 0.
|
||||
///
|
||||
/// The second argument is a [`&[Torus]`](crate::math::Torus) of length `lwe_count`.
|
||||
pub fn public_functional_keyswitch<S, F>(
|
||||
output: &mut GlweCiphertextRef<S>,
|
||||
inputs: &[&LweCiphertextRef<S>],
|
||||
pufksk: &PublicFunctionalKeyswitchKeyRef<S>,
|
||||
f: F,
|
||||
from_lwe: &LweDef,
|
||||
to_glwe: &GlweDef,
|
||||
radix: &RadixDecomposition,
|
||||
) where
|
||||
S: TorusOps,
|
||||
F: Fn(&mut PolynomialRef<Torus<S>>, &[Torus<S>]),
|
||||
{
|
||||
pufksk.assert_valid(from_lwe, to_glwe, radix);
|
||||
output.assert_valid(to_glwe);
|
||||
|
||||
for i in inputs {
|
||||
i.assert_valid(from_lwe);
|
||||
}
|
||||
|
||||
assert!(inputs.len() <= to_glwe.dim.polynomial_degree.0);
|
||||
|
||||
allocate_scratch_ref!(
|
||||
poly,
|
||||
PolynomialRef<Torus<S>>,
|
||||
(to_glwe.dim.polynomial_degree)
|
||||
);
|
||||
output.clear();
|
||||
allocate_scratch_ref!(
|
||||
decomp_scratch,
|
||||
PolynomialRef<S>,
|
||||
(to_glwe.dim.polynomial_degree)
|
||||
);
|
||||
let mut a_buf = allocate_scratch::<Torus<S>>(inputs.len());
|
||||
let lwe_vals = a_buf.as_mut_slice();
|
||||
allocate_scratch_ref!(
|
||||
glev_fft,
|
||||
GlevCiphertextFftRef<Complex<f64>>,
|
||||
(to_glwe.dim, radix.count)
|
||||
);
|
||||
allocate_scratch_ref!(
|
||||
output_fft,
|
||||
GlweCiphertextFftRef<Complex<f64>>,
|
||||
(to_glwe.dim)
|
||||
);
|
||||
|
||||
output_fft.clear();
|
||||
|
||||
// Compute all the a terms
|
||||
for (i, row) in pufksk.glevs(to_glwe, radix).enumerate() {
|
||||
for (j, a_i) in inputs.iter().map(|x| x.a(from_lwe)[i]).enumerate() {
|
||||
lwe_vals[j] = a_i;
|
||||
}
|
||||
|
||||
row.fft(glev_fft, to_glwe);
|
||||
|
||||
f(poly, lwe_vals);
|
||||
|
||||
let decomp = PolynomialRadixIterator::new(poly, decomp_scratch, radix);
|
||||
|
||||
decomposed_polynomial_glev_mad(output_fft, decomp, glev_fft, to_glwe);
|
||||
}
|
||||
|
||||
// Compute the b term
|
||||
for (j, b) in inputs.iter().map(|x| x.b(from_lwe)).enumerate() {
|
||||
lwe_vals[j] = *b;
|
||||
}
|
||||
|
||||
f(poly, lwe_vals);
|
||||
|
||||
output_fft.ifft(output, to_glwe);
|
||||
|
||||
// Compute (0, b) - output
|
||||
glwe_negate_inplace(output, to_glwe);
|
||||
polynomial_add_assign(output.b_mut(to_glwe), poly);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
entities::{GlweCiphertext, PublicFunctionalKeyswitchKey},
|
||||
high_level::{encryption, keygen, TEST_GLWE_DEF_1, TEST_LWE_DEF_1, TEST_RADIX},
|
||||
PlaintextBits,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_generate_public_functional_keyswitch_key() {
|
||||
let lwe_params = TEST_LWE_DEF_1;
|
||||
let glwe_params = TEST_GLWE_DEF_1;
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
let mut pufksk = PublicFunctionalKeyswitchKey::new(&lwe_params, &glwe_params, &TEST_RADIX);
|
||||
|
||||
let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
|
||||
generate_public_functional_keyswitch_key(
|
||||
&mut pufksk,
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&lwe_params,
|
||||
&glwe_params,
|
||||
&radix,
|
||||
);
|
||||
|
||||
for (s_i, glev) in lwe_sk.s().iter().zip(pufksk.glevs(&glwe_params, &radix)) {
|
||||
for (j, glwe_ct) in (0..radix.count.0).zip(glev.glwe_ciphertexts(&glwe_params)) {
|
||||
let plaintext_bits = ((j + 1) * radix.radix_log.0) as u32;
|
||||
let plaintext_bits = PlaintextBits(plaintext_bits);
|
||||
|
||||
let pt = encryption::decrypt_glwe(glwe_ct, &glwe_sk, &glwe_params, plaintext_bits);
|
||||
|
||||
assert_eq!(pt.coeffs()[0], *s_i);
|
||||
|
||||
for x in pt.coeffs().iter().skip(1) {
|
||||
assert_eq!(*x, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_public_functional_keyswitch() {
|
||||
let lwe_params = TEST_LWE_DEF_1;
|
||||
let glwe_params = TEST_GLWE_DEF_1;
|
||||
let radix = TEST_RADIX;
|
||||
|
||||
let mut pufksk = PublicFunctionalKeyswitchKey::new(&lwe_params, &glwe_params, &TEST_RADIX);
|
||||
|
||||
let lwe_sk = keygen::generate_binary_lwe_sk(&lwe_params);
|
||||
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);
|
||||
|
||||
let plaintext_bits = PlaintextBits(4);
|
||||
|
||||
generate_public_functional_keyswitch_key(
|
||||
&mut pufksk,
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&lwe_params,
|
||||
&glwe_params,
|
||||
&radix,
|
||||
);
|
||||
|
||||
for _ in 0..10 {
|
||||
let lwe_count = thread_rng().next_u64() as usize % glwe_params.dim.polynomial_degree.0;
|
||||
|
||||
let pts = (0..lwe_count)
|
||||
.map(|_| thread_rng().next_u64() % (0x1 << plaintext_bits.0))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let lwes = pts
|
||||
.iter()
|
||||
.map(|x| encryption::encrypt_lwe_secret(*x, &lwe_sk, &lwe_params, plaintext_bits))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut lwe_refs: Vec<&LweCiphertextRef<u64>> = vec![];
|
||||
|
||||
for x in lwes.iter() {
|
||||
lwe_refs.push(x);
|
||||
}
|
||||
|
||||
let mut output = GlweCiphertext::new(&glwe_params);
|
||||
|
||||
fn map<S: TorusOps>(poly: &mut PolynomialRef<Torus<S>>, tori: &[Torus<S>]) {
|
||||
for (c, t) in poly.coeffs_mut().iter_mut().zip(tori.iter()) {
|
||||
*c = *t;
|
||||
}
|
||||
}
|
||||
|
||||
public_functional_keyswitch(
|
||||
&mut output,
|
||||
&lwe_refs,
|
||||
&pufksk,
|
||||
map,
|
||||
&lwe_params,
|
||||
&glwe_params,
|
||||
&radix,
|
||||
);
|
||||
|
||||
let actual = encryption::decrypt_glwe(&output, &glwe_sk, &glwe_params, plaintext_bits);
|
||||
|
||||
for (a, e) in actual.coeffs().iter().zip(pts.iter()) {
|
||||
assert_eq!(a, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
19
sunscreen_tfhe/src/ops/mod.rs
Normal file
19
sunscreen_tfhe/src/ops/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
/// Ciphertext operations where one of the operands is in FFT form.
|
||||
pub mod fft_ops;
|
||||
|
||||
/// Methods for key switching a ciphertext from one key to another, potentially
|
||||
/// switching the parameters at the same time.
|
||||
pub mod keyswitch;
|
||||
|
||||
/// Methods for bootstrapping an LWE ciphertext from one key to another, while
|
||||
/// refreshing the noise in the ciphertext.
|
||||
pub mod bootstrapping;
|
||||
|
||||
/// Methods for homomorphic operations on ciphertexts.
|
||||
pub mod homomorphisms;
|
||||
|
||||
/// Methods for operating on different ciphertext types.
|
||||
pub mod ciphertext;
|
||||
|
||||
/// Methods for encrypting and decrypting to various ciphertext types.
|
||||
pub mod encryption;
|
||||
310
sunscreen_tfhe/src/params.rs
Normal file
310
sunscreen_tfhe/src/params.rs
Normal file
@@ -0,0 +1,310 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::rand::Stddev;
|
||||
use crate::TorusOps;
|
||||
|
||||
use sunscreen_math::security::{lwe_std_to_security_level, SecurityLevelResults};
|
||||
|
||||
trait SecurityLevel {
|
||||
fn security_level(&self) -> f64;
|
||||
|
||||
fn assert_security_level(&self, specified_security_level: usize) {
|
||||
// Note that the underlying approximation is accurate up to 4 bits, so
|
||||
// that is what we use here.
|
||||
let tolerance = 4.0;
|
||||
|
||||
let security_level = self.security_level();
|
||||
let security_difference = (security_level - specified_security_level as f64).abs();
|
||||
|
||||
assert!(
|
||||
security_difference <= tolerance,
|
||||
"Security level mismatch: expected {}, got {}",
|
||||
specified_security_level,
|
||||
security_level
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn generic_security_level(dimension: usize, std: f64) -> f64 {
|
||||
match lwe_std_to_security_level(dimension, std) {
|
||||
SecurityLevelResults::Level(level) => level,
|
||||
SecurityLevelResults::BelowStandardDeviationBound => {
|
||||
panic!("Standard deviation is too small for the given dimension")
|
||||
}
|
||||
SecurityLevelResults::AboveStandardDeviationBound => {
|
||||
panic!("Standard deviation is too large for the given dimension")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of torus elements in the LWE lattice.
|
||||
pub struct LweDimension(pub usize);
|
||||
|
||||
impl LweDimension {
|
||||
/// Asserts this LWE problem is well-formed.
|
||||
pub fn assert_valid(&self) {
|
||||
assert_ne!(self.0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The degree of the modulus polynomial `(x^N+1)` in a GLWE instance.
|
||||
///
|
||||
/// # Remarks
|
||||
/// GLWE encryption uses polynomials in `Z_q\[X\]/(X^N + 1)` where N is a power
|
||||
/// of two. I.e. negacyclic polynomials modulo (X^N + 1) where the coefficients
|
||||
/// are integers mod `q`.
|
||||
pub struct PolynomialDegree(pub usize);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of polynomials in a GLWE instance.
|
||||
pub struct GlweSize(pub usize);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of plaintext bits to encode into a message.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Packing too much data into messages can result in incorrect decryptions due
|
||||
/// to noise.
|
||||
///
|
||||
/// For binary, set this to one.
|
||||
pub struct PlaintextBits(pub u32);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of padding bits to include in an LWE ciphertext.
|
||||
pub struct CarryBits(pub u32);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of digits to decompose a value into.
|
||||
pub struct RadixCount(pub usize);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of bits in a digit output during base decomposition.
|
||||
pub struct RadixLog(pub usize);
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The number of [`LweCiphertext`](crate::entities::LweCiphertext)s that get
|
||||
/// mapped into a [`GlweCiphertext`](crate::entities::GlweCiphertext) during
|
||||
/// private functional keyswitching.
|
||||
pub struct PrivateFunctionalKeyswitchLweCount(pub usize);
|
||||
|
||||
impl PrivateFunctionalKeyswitchLweCount {
|
||||
#[inline(always)]
|
||||
/// Assert this [`PrivateFunctionalKeyswitchLweCount`] is valid.
|
||||
pub fn assert_valid(&self) {
|
||||
assert_ne!(self.0, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
/// The parameters defining how to do approximately perform base decomposition. I.e.
|
||||
/// decompose values into digits.
|
||||
///
|
||||
/// # Validity
|
||||
/// For [`RadixDecomposition`] parameters to be valid:
|
||||
/// * `count` must be greater than zero.
|
||||
/// * `radix_log` must be greater than zero.
|
||||
/// * `count * radix_log` must be less than equal to the number of bits in the
|
||||
/// [`Torus`](crate::Torus) element used in the operation(s) performing radix
|
||||
/// decomposition.
|
||||
///
|
||||
/// Calling [`assert_valid`](Self::assert_valid) will panic if the parameters are invalid.
|
||||
pub struct RadixDecomposition {
|
||||
/// The number of digits to decompose a value into.
|
||||
pub count: RadixCount,
|
||||
|
||||
/// The number of bits in a digit output during base decomposition.
|
||||
pub radix_log: RadixLog,
|
||||
}
|
||||
|
||||
impl RadixDecomposition {
|
||||
#[inline(always)]
|
||||
/// Panics if these [`RadixDecomposition`] parameters are invalid.
|
||||
pub fn assert_valid<S: TorusOps>(&self) {
|
||||
assert!(self.count.0 > 0);
|
||||
assert!(self.radix_log.0 > 0);
|
||||
assert!(self.count.0 * self.radix_log.0 <= S::BITS as usize);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
/// A [`PolynomialDegree`] and [`GlweSize`] in a GLWE instance.
|
||||
pub struct GlweDimension {
|
||||
/// The degree of the polynomial in a GLWE instance.
|
||||
pub polynomial_degree: PolynomialDegree,
|
||||
|
||||
/// The number of polynomials in a GLWE instance.
|
||||
pub size: GlweSize,
|
||||
}
|
||||
|
||||
impl GlweDimension {
|
||||
/// Reinterpret this GLWE problem instance as an LWE problem instance, returning
|
||||
/// the [`LweDimension`].
|
||||
pub fn as_lwe_dimension(&self) -> LweDimension {
|
||||
LweDimension(self.polynomial_degree.0 * self.size.0)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Assert these GLWE parameters are valid.
|
||||
pub fn assert_valid(&self) {
|
||||
assert!(self.polynomial_degree.0.is_power_of_two());
|
||||
assert!(self.polynomial_degree.0 > 0);
|
||||
assert!(self.size.0 > 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
/// Parameters that define an LWE problem instance.
|
||||
///
|
||||
/// # Security
|
||||
/// `dim` and `std` must be properly set to attain the desired security level. Improper
|
||||
/// values will result in an insecure scheme.
|
||||
pub struct LweDef {
|
||||
/// The dimension of the LWE lattice.
|
||||
pub dim: LweDimension,
|
||||
|
||||
/// The standard deviation of the noise in the LWE lattice.
|
||||
pub std: Stddev,
|
||||
}
|
||||
|
||||
impl LweDef {
|
||||
#[inline(always)]
|
||||
/// Asserts this LWE problem is well-formed.
|
||||
pub fn assert_valid(&self) {
|
||||
self.dim.assert_valid();
|
||||
}
|
||||
}
|
||||
|
||||
impl SecurityLevel for LweDef {
|
||||
fn security_level(&self) -> f64 {
|
||||
generic_security_level(self.dim.0, self.std.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
/// Parameters that define a GLWE problem instance.
|
||||
///
|
||||
/// # Security
|
||||
/// `dim` and `std` must be properly set to attain the desired security level. Improper
|
||||
/// values will result in an insecure scheme.
|
||||
pub struct GlweDef {
|
||||
/// The dimension of the GLWE instance.
|
||||
pub dim: GlweDimension,
|
||||
|
||||
/// The standard deviation of the noise in the GLWE instance.
|
||||
pub std: Stddev,
|
||||
}
|
||||
|
||||
impl GlweDef {
|
||||
/// Reinterpret this GLWE instance as an LWE instance of the same lattice dimension.
|
||||
pub fn as_lwe_def(&self) -> LweDef {
|
||||
LweDef {
|
||||
dim: self.dim.as_lwe_dimension(),
|
||||
std: self.std,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
/// Assert the GLWE instance is valid.
|
||||
pub fn assert_valid(&self) {
|
||||
self.dim.assert_valid();
|
||||
}
|
||||
}
|
||||
|
||||
impl SecurityLevel for GlweDef {
|
||||
fn security_level(&self) -> f64 {
|
||||
self.as_lwe_def().security_level()
|
||||
}
|
||||
}
|
||||
|
||||
/// 128-bit secure parameters for an LWE instance with a dimension of 512.
|
||||
pub const LWE_512_128: LweDef = LweDef {
|
||||
dim: LweDimension(512),
|
||||
std: Stddev(0.0004899836456140595),
|
||||
};
|
||||
|
||||
/// 128-bit secure parameters for a GLWE instance with 1 polynomial of degree 1024.
|
||||
pub const GLWE_1_1024_128: GlweDef = GlweDef {
|
||||
dim: GlweDimension {
|
||||
size: GlweSize(1),
|
||||
polynomial_degree: PolynomialDegree(1024),
|
||||
},
|
||||
std: Stddev(0.0000000444778278004718),
|
||||
};
|
||||
|
||||
/// 128-bit secure parameters for a GLWE instance with 1 polynomial of degree 2048.
|
||||
pub const GLWE_1_2048_128: GlweDef = GlweDef {
|
||||
dim: GlweDimension {
|
||||
size: GlweSize(1),
|
||||
polynomial_degree: PolynomialDegree(2048),
|
||||
},
|
||||
std: Stddev(0.00000000000000034667670193445625),
|
||||
};
|
||||
|
||||
/// 80-bit secure parameters for an LWE instance with a dimension of 512.
|
||||
pub const LWE_512_80: LweDef = LweDef {
|
||||
dim: LweDimension(512),
|
||||
std: Stddev(0.000001842343446823844),
|
||||
};
|
||||
|
||||
/// 80-bit secure parameters for a GLWE instance with 5 polynomials of degree 256.
|
||||
pub const GLWE_5_256_80: GlweDef = GlweDef {
|
||||
dim: GlweDimension {
|
||||
size: GlweSize(5),
|
||||
polynomial_degree: PolynomialDegree(256),
|
||||
},
|
||||
std: Stddev(0.000000000000002106764669572764),
|
||||
};
|
||||
|
||||
/// 80-bit secure parameters for a GLWE instance with 1 polynomial of degree 1024.
|
||||
pub const GLWE_1_1024_80: GlweDef = GlweDef {
|
||||
dim: GlweDimension {
|
||||
size: GlweSize(1),
|
||||
polynomial_degree: PolynomialDegree(1024),
|
||||
},
|
||||
std: Stddev(0.0000000000010900242107812643),
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use sunscreen_math::security::lwe_security_level_to_std;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn check_security_levels() {
|
||||
let actual_lwe_std = lwe_security_level_to_std(512, 128.0);
|
||||
println!("LWE 512 128: {}", actual_lwe_std);
|
||||
LWE_512_128.assert_security_level(128);
|
||||
|
||||
let actual_glwe_std = lwe_security_level_to_std(1024, 128.0);
|
||||
println!("GLWE 1 1024 128: {}", actual_glwe_std);
|
||||
GLWE_1_1024_128.assert_security_level(128);
|
||||
|
||||
let actual_glwe_std = lwe_security_level_to_std(2048, 128.0);
|
||||
println!("GLWE 1 2048 128: {}", actual_glwe_std);
|
||||
GLWE_1_2048_128.assert_security_level(128);
|
||||
|
||||
let actual_lwe_std = lwe_security_level_to_std(512, 80.0);
|
||||
println!("LWE 512 80: {}", actual_lwe_std);
|
||||
LWE_512_80.assert_security_level(80);
|
||||
|
||||
let actual_glwe_std = lwe_security_level_to_std(256 * 5, 80.0);
|
||||
println!("GLWE 5 256 80: {}", actual_glwe_std);
|
||||
GLWE_5_256_80.assert_security_level(80);
|
||||
|
||||
let actual_glwe_std = lwe_security_level_to_std(1024, 80.0);
|
||||
println!("GLWE 1 1024 80: {}", actual_glwe_std);
|
||||
GLWE_1_1024_80.assert_security_level(80);
|
||||
}
|
||||
}
|
||||
97
sunscreen_tfhe/src/rand.rs
Normal file
97
sunscreen_tfhe/src/rand.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use rand::{thread_rng, Rng, RngCore};
|
||||
use rand_distr::Normal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::math::{Torus, TorusOps};
|
||||
|
||||
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
/// The standard deviation of a Gaussian distribution normalized over the torus
|
||||
/// `T_q`.
|
||||
pub struct Stddev(pub f64);
|
||||
|
||||
/// Sample a random torus element from the a normal distribution
|
||||
/// with a mean of 0 and the given stddev
|
||||
pub fn normal_torus<S: TorusOps>(std: Stddev) -> Torus<S> {
|
||||
let dist = Normal::new(0., std.0).unwrap();
|
||||
|
||||
let e_0 = thread_rng().sample(dist);
|
||||
let q = (S::BITS as f64).exp2();
|
||||
|
||||
let e = f64::round(e_0 * q) as i64;
|
||||
let e: u64 = unsafe { std::mem::transmute(e) };
|
||||
|
||||
Torus::from(S::from_u64(e))
|
||||
}
|
||||
|
||||
/// Generate a random torus element uniformly
|
||||
pub fn uniform_torus<S: TorusOps>() -> Torus<S> {
|
||||
Torus::from(S::from_u64(thread_rng().next_u64()))
|
||||
}
|
||||
|
||||
/// Generate a random binary torus element
|
||||
pub fn binary<S: TorusOps>() -> S {
|
||||
S::from_u64(thread_rng().next_u64() % 2)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::mem::transmute_copy;
|
||||
|
||||
use crate::math::ToF64;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_produce_random_torus() {
|
||||
pub fn case<S, I>()
|
||||
where
|
||||
S: TorusOps,
|
||||
I: ToF64 + Copy + Debug,
|
||||
<S as TryFrom<u64>>::Error: Debug,
|
||||
{
|
||||
let q = (S::BITS as f64).exp2();
|
||||
let n: i32 = 100_000;
|
||||
|
||||
let dev = Stddev(0.000_448_516_698_238_696_5);
|
||||
|
||||
let data = (0..n)
|
||||
.map(|_| {
|
||||
let t = normal_torus::<S>(dev).inner();
|
||||
unsafe { transmute_copy::<S, I>(&t) }
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Reinterpreting the torus points as i64 values should give a mean of approximately zero.
|
||||
let mean = data
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| x.to_f64())
|
||||
.fold(0., |s, x| s + x)
|
||||
/ (q * n as f64);
|
||||
|
||||
assert!(mean < 1e-5);
|
||||
|
||||
// Scale the integer values back to the range [0, 1) and compute the stddev
|
||||
let measured_std = data
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|x| {
|
||||
let val = (x.to_f64() / q) - mean;
|
||||
|
||||
val * val
|
||||
})
|
||||
.fold(0f64, |s, x| s + x)
|
||||
/ n as f64;
|
||||
|
||||
let measured_std = measured_std.sqrt();
|
||||
|
||||
assert!((measured_std - dev.0).abs() < 0.00001);
|
||||
}
|
||||
|
||||
case::<u32, i32>();
|
||||
case::<u64, i64>();
|
||||
}
|
||||
}
|
||||
427
sunscreen_tfhe/src/scratch.rs
Normal file
427
sunscreen_tfhe/src/scratch.rs
Normal file
@@ -0,0 +1,427 @@
|
||||
use linked_list::{Cursor, LinkedList};
|
||||
use num::{Complex, Float};
|
||||
use rustfft::FftNum;
|
||||
use std::{
|
||||
alloc::Layout,
|
||||
cell::RefCell,
|
||||
marker::PhantomData,
|
||||
mem::{size_of, transmute},
|
||||
};
|
||||
|
||||
use crate::{Torus, TorusOps};
|
||||
|
||||
thread_local! {
|
||||
static SCRATCH: RefCell<Option<Scratch>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
macro_rules! allocate_scratch_ref {
|
||||
($out_ident:ident,$ref_t:ident<$t_arg:ty>, ($($args:expr),*)) => {
|
||||
let mut tmp = crate::scratch::allocate_scratch(<$ref_t<$t_arg> as crate::dst::OverlaySize>::size(($($args),*)));
|
||||
let $out_ident = <$ref_t<$t_arg>>::from_mut_slice(tmp.as_mut_slice());
|
||||
};
|
||||
($out_ident:ident,[$ref_t:ident<$t_arg:ty>], $len:expr) => {
|
||||
let mut tmp = crate::scratch::allocate_scratch::<$ref_t<$t_arg>>(<[$ref_t<$t_arg>] as crate::dst::OverlaySize>::size($len));
|
||||
let $out_ident = tmp.as_mut_slice();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) use allocate_scratch_ref;
|
||||
|
||||
/// Indicates this is a "Plain Old Data" type. For `T` qualify as such,
|
||||
/// all bit patterns must be considered a properly initialized instance of
|
||||
/// `T`.
|
||||
///
|
||||
/// # Safety
|
||||
/// Implementing this trait on types that don't meet the above requirements
|
||||
/// may result in undefined behavior.
|
||||
pub unsafe trait Pod {}
|
||||
|
||||
unsafe impl Pod for u8 {}
|
||||
unsafe impl Pod for u16 {}
|
||||
unsafe impl Pod for u32 {}
|
||||
unsafe impl Pod for u64 {}
|
||||
unsafe impl Pod for u128 {}
|
||||
unsafe impl Pod for i8 {}
|
||||
unsafe impl Pod for i16 {}
|
||||
unsafe impl Pod for i32 {}
|
||||
unsafe impl Pod for i64 {}
|
||||
unsafe impl Pod for i128 {}
|
||||
unsafe impl Pod for f32 {}
|
||||
unsafe impl Pod for f64 {}
|
||||
unsafe impl<T> Pod for Complex<T> where T: Float + FftNum {}
|
||||
unsafe impl<S> Pod for Torus<S> where S: TorusOps {}
|
||||
|
||||
/// Allocate a scratch buffer in a cache efficient manner. Freed scratch
|
||||
/// buffers are reused in subsequent allocations.
|
||||
///
|
||||
/// # Remarks
|
||||
/// The returned [`ScratchBuffer`] will be aligned to `align_of::<T>()` and
|
||||
/// have a length equal to count.
|
||||
///
|
||||
/// # Panics
|
||||
/// If `T` is a zero-sized type (e.g. `()`).
|
||||
pub fn allocate_scratch<T>(count: usize) -> ScratchBuffer<'static, T>
|
||||
where
|
||||
T: Pod,
|
||||
{
|
||||
SCRATCH.with(|s| {
|
||||
if s.borrow().is_none() {
|
||||
let new_scratch = Scratch::new();
|
||||
*s.borrow_mut() = Some(new_scratch);
|
||||
}
|
||||
|
||||
(*s.borrow_mut()).as_mut().unwrap().allocate::<T>(count)
|
||||
})
|
||||
}
|
||||
|
||||
/// An "allocator" designed for allocating scratch memory.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Internally, this data structure is a [`LinkedList`] of
|
||||
/// [`Vec`]s treated as a stack.
|
||||
///
|
||||
/// [`Scratch`] is designed to provide a cache locality for
|
||||
/// temporary buffers by reusing allocations.
|
||||
///
|
||||
/// Upon allocation, we update the "top" of the stack to be the
|
||||
/// furthest consecutive free buffer from the actual top.
|
||||
///
|
||||
/// The only way to use this type is through the `thread_local`
|
||||
/// singleton, as this is the only way to guarantee soundness.
|
||||
/// The references doled out by allocate have `'static``
|
||||
/// lifetimes. This is needed so you can have mutable references
|
||||
/// to different allocations at the same time.
|
||||
struct Scratch {
|
||||
// Only accessed through the cursor, so compiler thinks it's unused.
|
||||
#[allow(unused)]
|
||||
stack: Box<LinkedList<Allocation>>,
|
||||
top: *mut Cursor<'static, Allocation>,
|
||||
}
|
||||
|
||||
impl Drop for Scratch {
|
||||
fn drop(&mut self) {
|
||||
let top = unsafe { Box::from_raw(self.top) };
|
||||
|
||||
std::mem::drop(top);
|
||||
}
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
/// We require this to be private. [`allocate_scratch`] should be
|
||||
/// the only way to use scratch memory, which will allocate memory
|
||||
/// using a thread_local allocator.
|
||||
fn new() -> Self {
|
||||
let mut list = Box::new(LinkedList::new());
|
||||
|
||||
let cursor = Box::new(list.cursor());
|
||||
let top = unsafe { transmute(Box::into_raw(cursor)) };
|
||||
|
||||
Self { stack: list, top }
|
||||
}
|
||||
|
||||
/// Allocate a buffer matching the given specification.
|
||||
fn allocate<T>(&mut self, count: usize) -> ScratchBuffer<'static, T>
|
||||
where
|
||||
T: Pod,
|
||||
{
|
||||
assert_ne!(size_of::<T>(), 0);
|
||||
|
||||
let top = unsafe { &mut *self.top };
|
||||
|
||||
// Push the top as far down until we hit the bottom or an allocation
|
||||
// currently in use.
|
||||
loop {
|
||||
let prev = top.peek_prev();
|
||||
|
||||
if let Some(x) = prev {
|
||||
if x.is_free {
|
||||
top.prev().unwrap();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
let layout = Layout::array::<T>(count).unwrap();
|
||||
let req_len = layout.size() + layout.align();
|
||||
|
||||
let allocation = match top.peek_next() {
|
||||
Some(d) => {
|
||||
assert!(d.is_free);
|
||||
|
||||
// Resize the allocation if needed.
|
||||
if d.data.len() < req_len {
|
||||
d.data.resize(req_len, 0u8);
|
||||
}
|
||||
|
||||
d.requested_len = count;
|
||||
top.next().unwrap()
|
||||
}
|
||||
None => {
|
||||
let data = vec![0u8; req_len];
|
||||
|
||||
let allocation = Allocation {
|
||||
requested_len: count,
|
||||
is_free: false,
|
||||
data,
|
||||
};
|
||||
|
||||
top.insert(allocation);
|
||||
top.next().unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
allocation.is_free = false;
|
||||
|
||||
ScratchBuffer {
|
||||
allocation: allocation as *mut Allocation,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Allocation {
|
||||
requested_len: usize,
|
||||
data: Vec<u8>,
|
||||
is_free: bool,
|
||||
}
|
||||
|
||||
pub struct ScratchBuffer<'a, T> {
|
||||
allocation: *mut Allocation,
|
||||
_phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl<'a, T> ScratchBuffer<'a, T> {
|
||||
#[allow(unused)]
|
||||
/// Get a slice to the underlying data.
|
||||
///
|
||||
/// # Remarks
|
||||
/// While not extremely expensive, this operation does require capturing
|
||||
/// an aligned slice of data in an underlying allocation. As such,
|
||||
/// you should avoid repeated calls.
|
||||
pub fn as_slice(&self) -> &[T] {
|
||||
let count = unsafe { (*self.allocation).requested_len };
|
||||
let (_, slice, _) = unsafe { (*self.allocation).data.align_to::<T>() };
|
||||
unsafe { transmute(&slice[0..count]) }
|
||||
}
|
||||
|
||||
/// Get a mutable slice to the underlying data.
|
||||
///
|
||||
/// # Remarks
|
||||
/// While not extremely expensive, this operation does require capturing
|
||||
/// an aligned slice of data in an underlying allocation. As such,
|
||||
/// you should avoid repeated calls.
|
||||
pub fn as_mut_slice(&mut self) -> &mut [T] {
|
||||
let count = unsafe { (*self.allocation).requested_len };
|
||||
let (_pre, slice, _post) = unsafe { (*self.allocation).data.align_to_mut::<T>() };
|
||||
unsafe { transmute(&mut slice[0..count]) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for ScratchBuffer<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { (*self.allocation).is_free = true };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::mem::align_of;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_allocate() {
|
||||
let mut scratch = Scratch::new();
|
||||
|
||||
let mut d = scratch.allocate::<u64>(64);
|
||||
|
||||
let d = d.as_mut_slice();
|
||||
assert_eq!(d.len(), 64);
|
||||
|
||||
for (i, d_i) in d.iter_mut().enumerate() {
|
||||
*d_i = i as u64;
|
||||
}
|
||||
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn buffers_get_reused() {
|
||||
let mut scratch = Scratch::new();
|
||||
|
||||
let b = scratch.allocate::<u64>(64);
|
||||
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
let b_slice = b.as_slice();
|
||||
assert_eq!(b_slice.len(), 64);
|
||||
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
|
||||
let first_ptr = b_slice.as_ptr();
|
||||
|
||||
std::mem::drop(b);
|
||||
|
||||
let mut b = scratch.allocate::<u64>(64);
|
||||
let b_slice = b.as_mut_slice();
|
||||
assert_eq!(first_ptr, b_slice.as_ptr());
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
assert_eq!(b_slice.len(), 64);
|
||||
|
||||
for (i, b_i) in b_slice.iter_mut().enumerate() {
|
||||
*b_i = i as u64;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn reallocate_on_bigger_request() {
|
||||
let mut scratch = Scratch::new();
|
||||
|
||||
let mut b = scratch.allocate::<u64>(64);
|
||||
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
let b_slice = b.as_mut_slice();
|
||||
assert_eq!(b_slice.len(), 64);
|
||||
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
|
||||
let first_ptr = b_slice.as_ptr();
|
||||
|
||||
for (i, b_i) in b_slice.iter_mut().enumerate() {
|
||||
*b_i = i as u64;
|
||||
}
|
||||
|
||||
std::mem::drop(b);
|
||||
|
||||
let mut b = scratch.allocate::<u64>(16384);
|
||||
let b = b.as_mut_slice();
|
||||
assert_ne!(first_ptr, b.as_ptr());
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
assert_eq!(b.len(), 16384);
|
||||
|
||||
for (i, b_i) in b.iter_mut().enumerate() {
|
||||
*b_i = i as u64;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocate_two_buffers() {
|
||||
let mut scratch = Scratch::new();
|
||||
|
||||
let mut a = scratch.allocate::<u64>(12);
|
||||
let mut b = scratch.allocate::<u64>(12);
|
||||
|
||||
let a = a.as_mut_slice();
|
||||
let b = b.as_mut_slice();
|
||||
|
||||
assert_eq!(a.len(), 12);
|
||||
assert_eq!(b.len(), 12);
|
||||
assert_eq!(scratch.stack.len(), 2);
|
||||
assert_ne!(a.as_mut_ptr(), b.as_mut_ptr());
|
||||
|
||||
for i in 0..a.len() {
|
||||
a[i] = i as u64;
|
||||
b[i] = i as u64;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn align_16() {
|
||||
let mut scratch = Scratch::new();
|
||||
let mut buffers = (0..10)
|
||||
.map(|_| scratch.allocate::<u128>(10))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(scratch.stack.len(), 10);
|
||||
|
||||
for b in buffers.iter_mut() {
|
||||
let b = b.as_mut_slice();
|
||||
assert_eq!(b.len(), 10);
|
||||
|
||||
for (i, b_i) in b.iter_mut().enumerate() {
|
||||
*b_i = i as u128;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn align_65536() {
|
||||
// Chose an alignment larger than any reasonable OS's page size
|
||||
// to try to force the alignment algorithm into play.
|
||||
#[repr(C, align(65536))]
|
||||
#[derive(Copy, Clone)]
|
||||
struct Foo {
|
||||
x: u32,
|
||||
}
|
||||
|
||||
unsafe impl Pod for Foo {}
|
||||
|
||||
let mut scratch = Scratch::new();
|
||||
let mut b = scratch.allocate::<Foo>(15);
|
||||
|
||||
let b_slice = b.as_mut_slice();
|
||||
assert_eq!(b_slice.len(), 15);
|
||||
|
||||
for b in b_slice {
|
||||
b.x = 22;
|
||||
let ptr = b as *mut Foo;
|
||||
|
||||
// Check that each item is memory aligned to the proper
|
||||
// location.
|
||||
assert_eq!(ptr.align_offset(align_of::<Foo>()), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stack_coalesces_correctly() {
|
||||
let mut scratch = Scratch::new();
|
||||
let a = scratch.allocate::<u64>(16);
|
||||
let mut b: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
let b_ptr = b.as_mut_slice().as_mut_ptr();
|
||||
|
||||
let c: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
let d: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
|
||||
std::mem::drop(b);
|
||||
assert_eq!(scratch.stack.len(), 4);
|
||||
|
||||
// We can't reuse b's buffer until c, d, e get dropped.
|
||||
let mut e: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
assert_ne!(b_ptr, e.as_mut_slice().as_mut_ptr());
|
||||
|
||||
assert_eq!(scratch.stack.len(), 5);
|
||||
|
||||
std::mem::drop(c);
|
||||
std::mem::drop(d);
|
||||
std::mem::drop(e);
|
||||
|
||||
// Now we can reuse b's buffer.
|
||||
let mut f = scratch.allocate::<u64>(16);
|
||||
assert_eq!(f.as_mut_slice().as_mut_ptr(), b_ptr);
|
||||
assert_eq!(scratch.stack.len(), 5);
|
||||
|
||||
std::mem::drop(a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_size_allocations() {
|
||||
let mut scratch = Scratch::new();
|
||||
let a = scratch.allocate::<u64>(2);
|
||||
let b = scratch.allocate::<u64>(0);
|
||||
|
||||
let a_slice = a.as_slice();
|
||||
let b_slice = b.as_slice();
|
||||
|
||||
assert_eq!(scratch.stack.len(), 2);
|
||||
assert_eq!(a_slice.len(), 2);
|
||||
assert_eq!(b_slice.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zst_allocations_should_panic() {
|
||||
struct Foo {}
|
||||
unsafe impl Pod for Foo {}
|
||||
|
||||
let mut scratch = Scratch::new();
|
||||
let _ = scratch.allocate::<Foo>(0x1 << 48);
|
||||
}
|
||||
}
|
||||
986
sunscreen_tfhe/src/zkp.rs
Normal file
986
sunscreen_tfhe/src/zkp.rs
Normal file
@@ -0,0 +1,986 @@
|
||||
use logproof::{
|
||||
crypto::CryptoHash,
|
||||
linear_algebra::{Matrix, PolynomialMatrix},
|
||||
math::ModSwitch,
|
||||
rings::ZqRistretto,
|
||||
Bounds, LogProofProverKnowledge, LogProofVerifierKnowledge as VerifierKnowledge,
|
||||
};
|
||||
use sunscreen_math::{
|
||||
poly::Polynomial,
|
||||
ring::{BarrettBackend, Ring, RingModulus, Zq},
|
||||
BarrettConfig, One, Zero,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
entities::{LweCiphertext, LwePublicKey, LweSecretKey, TlwePublicEncRandomness},
|
||||
math::{Torus, TorusOps},
|
||||
LweDef, PlaintextBits,
|
||||
};
|
||||
|
||||
/// Proof statements for the SDLP proof system when applied to TFHE.
|
||||
#[derive(Debug)]
|
||||
pub enum ProofStatement<'a, 'b, S: TorusOps + TorusZq> {
|
||||
/// A private key encryption statement.
|
||||
PrivateKeyEncryption {
|
||||
/// The message ID being encrypted.
|
||||
message_id: usize,
|
||||
|
||||
/// The ciphertext being encrypted.
|
||||
ciphertext: &'a LweCiphertext<S>,
|
||||
},
|
||||
|
||||
/// A public key encryption statement.
|
||||
PublicKeyEncryption {
|
||||
/// The message ID being encrypted.
|
||||
message_id: usize,
|
||||
|
||||
/// The encrypted data under the provided public key.
|
||||
ciphertext: &'a LweCiphertext<S>,
|
||||
|
||||
/// The public key being used to encrypt the message.
|
||||
public_key: &'b LwePublicKey<S>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Witness information for the SDLP proof system when applied to TFHE.
|
||||
/// This is the private information used when generating a proof.
|
||||
#[derive(Debug)]
|
||||
pub enum Witness<'a, 'b, S: TorusOps + TorusZq> {
|
||||
/// A private key encryption witness.
|
||||
PrivateKeyEncryption {
|
||||
/// The randomness used in the encryption.
|
||||
randomness: Torus<S>,
|
||||
|
||||
/// The private key used in the encryption.
|
||||
private_key: &'a LweSecretKey<S>,
|
||||
},
|
||||
|
||||
/// A public key encryption witness.
|
||||
PublicKeyEncryption {
|
||||
/// The randomness used in the encryption.
|
||||
randomness: &'b TlwePublicEncRandomness<S>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Generate LogProofProverKnowledge for the SDLP proof system.
|
||||
pub fn generate_tfhe_sdlp_prover_knowledge<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
messages: &[Torus<S>],
|
||||
witness: &[Witness<S>],
|
||||
lwe: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> LogProofProverKnowledge<S::Zq> {
|
||||
let vk = generate_tfhe_sdlp_verifier_knowledge(statements, lwe, plaintext_bits);
|
||||
|
||||
let s = compute_s(statements, witness, messages, lwe);
|
||||
|
||||
LogProofProverKnowledge { vk, s }
|
||||
}
|
||||
|
||||
/// Modulus for u32.
|
||||
#[derive(BarrettConfig)]
|
||||
#[barrett_config(modulus = "4294967296", num_limbs = 1)]
|
||||
pub struct U32Config;
|
||||
|
||||
/// Field for u32.
|
||||
pub type Zq32 = Zq<1, BarrettBackend<1, U32Config>>;
|
||||
|
||||
/// Modulus for u64.
|
||||
#[derive(BarrettConfig)]
|
||||
#[barrett_config(modulus = "18446744073709551616", num_limbs = 2)]
|
||||
pub struct U64Config;
|
||||
|
||||
/// Field for u64.
|
||||
pub type Zq64 = Zq<2, BarrettBackend<2, U64Config>>;
|
||||
|
||||
/// Torus properties on a discrete ring `Z_q`.
|
||||
pub trait TorusZq
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
/// The discrete ring `Z_q`.
|
||||
type Zq: Ring
|
||||
+ From<Self>
|
||||
+ CryptoHash
|
||||
+ ModSwitch<ZqRistretto>
|
||||
+ RingModulus<4>
|
||||
+ Ord
|
||||
+ From<u32>;
|
||||
}
|
||||
|
||||
impl TorusZq for u32 {
|
||||
type Zq = Zq32;
|
||||
}
|
||||
|
||||
impl TorusZq for u64 {
|
||||
type Zq = Zq64;
|
||||
}
|
||||
|
||||
/// Computes the public information needed to prove and verify public and private key
|
||||
/// TLWE encryptions.
|
||||
///
|
||||
/// # Remarks
|
||||
/// Using only private key encryption results in significantly faster runtime since we
|
||||
/// can characterize `Z_q[X]/f` with `f = X + 1`.
|
||||
///
|
||||
/// # Details
|
||||
/// Let `N` be the TLWE lattice dimension.
|
||||
/// let `M` be the number of messages encrypted in in the ciphertexts
|
||||
/// Note that multiple ciphertexts can encrypt the same message and SDLP can prove they contain
|
||||
/// the same message.
|
||||
///
|
||||
/// ## A matrix's structure
|
||||
/// SDLP proves a linear relation `AS=T` where
|
||||
/// * `A in Z_q[X]^{m x k}/f`
|
||||
/// * `S in Z_q[X]^{k x n}/f`
|
||||
/// * `T in Z_q[X]^{m x n}/f`
|
||||
/// * `f = X^D + 1`
|
||||
///
|
||||
/// For proving encryptions of TLWE ciphertexts, we have:
|
||||
/// * `m` is the number of proof statements.
|
||||
/// * `n = 1`.
|
||||
/// * `k = num_messages + num_public_keys * (N + 1) + num_private_keys * N + num_private_encs + num_public_encs * (N + 1)`.
|
||||
/// * See [quotient ring modulus `f`](#quotient-ring-modulus-f).
|
||||
///
|
||||
/// The matrix `A` is arranged into rows, one per proof statement whose column structure depends
|
||||
/// on whether the statement is for public or private key encryption.
|
||||
///
|
||||
/// ### Private key statements
|
||||
/// Each private key encryption statement row has the following column arrangement:
|
||||
/// ```ignore
|
||||
/// 1 at msg_idx pub_key (N/A) a at pvt_key_id * N 1 at pvt_stmt_idx e_idx (N/A)
|
||||
/// V V V V V
|
||||
/// [ 0 .. 1 .. 0 0 .. 0 0 .. a_0, a_1, .. a_N .. 0 0 .. 1 .. 0 0 .. 0]
|
||||
/// ```
|
||||
/// ### Public key statements
|
||||
/// Recall that a TLWE ciphertext consists of (a_1, ... a_N, b), where `a_i, b` in the discrete
|
||||
/// torus `T`.
|
||||
///
|
||||
/// Furthermore, the public key `P = (p_1, ... p_N)` where `p_i` is a secret key encryption of
|
||||
/// zero.
|
||||
///
|
||||
/// Each public key encryption statement row has the following column arrangement:
|
||||
/// ```ignore
|
||||
/// 1 at msg_idx p at pub_key_id * N a (N/A) b (N/A) e_idx
|
||||
/// V V V V V
|
||||
/// [ 0 .. 1 .. 0 0 .. p_0, p_1, .. p_N .. 0 0 .. 0 0 .. 0 0 .. 1 .. 0 ]
|
||||
/// ```
|
||||
/// For these entries, we reinterpret each `p_i` as a polynomial `mod (X^N + 1)`.
|
||||
///
|
||||
/// ## T's structure
|
||||
/// While SDLP support T as a matrix, we only need a vector of `m` rows.
|
||||
///
|
||||
/// The structure of each row in T depends on whether said the corresponding statement describes
|
||||
/// a public or private key encryption.
|
||||
///
|
||||
/// ### Private key statements
|
||||
/// The row's polynomial is simply `b * X^0`.
|
||||
///
|
||||
/// ### Public key statements
|
||||
/// The rows's polynomial is `a_0 * X^0, a_1 * X^1, ... a_N * X^{N - 1} b * X^N`
|
||||
///
|
||||
/// ## Quotient ring modulus `f`
|
||||
/// When `statements` contains only secret key encryption statements, `f = X + 1`. Otherwise,
|
||||
/// `f = X^{N + 1} + 1`.
|
||||
pub fn generate_tfhe_sdlp_verifier_knowledge<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
lwe: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> VerifierKnowledge<S::Zq> {
|
||||
// If we need to prove any public encryption statements, then `f = X^{lwe_dimension + 1} + 1`.
|
||||
// If not, we can use the more efficient X + 1.
|
||||
let mut f_coeffs = vec![S::Zq::from(<S as Zero>::zero()); f_degree(statements, lwe) + 1];
|
||||
f_coeffs[0] = S::Zq::from(S::one());
|
||||
|
||||
let last_coeff = f_coeffs.len() - 1;
|
||||
f_coeffs[last_coeff] = S::Zq::from(S::one());
|
||||
|
||||
let f = Polynomial::new(&f_coeffs);
|
||||
|
||||
let a = compute_a(statements, lwe, plaintext_bits);
|
||||
let t = compute_t(statements, lwe);
|
||||
let bounds = compute_bounds(statements, lwe, plaintext_bits);
|
||||
|
||||
VerifierKnowledge::new(a, t, f, bounds)
|
||||
}
|
||||
|
||||
fn compute_bounds<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
lwe: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> Matrix<Bounds> {
|
||||
let (_, cols) = proof_matrix_dim(statements, lwe.dim.0);
|
||||
let offsets = compute_a_column_offsets(statements, lwe);
|
||||
|
||||
let mut bounds = Matrix::<Bounds>::new(cols, 1);
|
||||
|
||||
let num_messages = num_messages(statements);
|
||||
let num_coeffs = f_degree(statements, lwe);
|
||||
let lwe_dimension = lwe.dim.0;
|
||||
|
||||
// Bounds for messages
|
||||
for i in 0..num_messages {
|
||||
let mut b = Bounds(vec![0; num_coeffs]);
|
||||
b.0[0] = 0x1u64 << plaintext_bits.0;
|
||||
debug_assert_eq!(bounds[(i, 0)].0, &[]);
|
||||
bounds[(i, 0)] = b;
|
||||
}
|
||||
|
||||
// Public r and e
|
||||
for i in 0..num_public(statements) {
|
||||
for j in 0..lwe_dimension {
|
||||
let mut b = Bounds(vec![0; num_coeffs]);
|
||||
|
||||
// Values of r are binary
|
||||
b.0[0] = 0x1u64 << plaintext_bits.0;
|
||||
debug_assert_eq!(
|
||||
bounds[(offsets.public_keys + i * lwe_dimension + j, 0)].0,
|
||||
&[]
|
||||
);
|
||||
bounds[(offsets.public_keys + i * lwe_dimension + j, 0)] = b;
|
||||
}
|
||||
|
||||
// e is normal distributed over the torus.
|
||||
// TODO: This bound is too high. Get a tighter bound.
|
||||
let b = Bounds(vec![0x1u64 << (60 - plaintext_bits.0); num_coeffs]);
|
||||
|
||||
debug_assert_eq!(bounds[(offsets.public_e + i, 0)].0, &[]);
|
||||
bounds[(offsets.public_e + i, 0)] = b;
|
||||
}
|
||||
|
||||
// Private s and e
|
||||
for i in 0..num_private(statements) {
|
||||
for j in 0..lwe_dimension {
|
||||
let mut b = Bounds(vec![0; num_coeffs]);
|
||||
|
||||
// Values of s are binary
|
||||
b.0[0] = 0x1u64 << plaintext_bits.0;
|
||||
debug_assert_eq!(
|
||||
bounds[(offsets.private_a + j + i * lwe_dimension, 0)].0,
|
||||
&[]
|
||||
);
|
||||
bounds[(offsets.private_a + j + i * lwe_dimension, 0)] = b;
|
||||
}
|
||||
|
||||
let mut b = Bounds(vec![0; num_coeffs]);
|
||||
|
||||
// e is normal distributed over the torus.
|
||||
// TODO: This bound is too high. Get a tighter bound.
|
||||
b.0[0] = 0x1u64 << (62 - plaintext_bits.0);
|
||||
debug_assert_eq!(bounds[(offsets.private_e + i, 0)].0, &[]);
|
||||
bounds[(offsets.private_e + i, 0)] = b;
|
||||
}
|
||||
|
||||
bounds
|
||||
}
|
||||
|
||||
fn f_degree<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>], lwe: &LweDef) -> usize {
|
||||
let lwe_dimension = lwe.dim.0;
|
||||
|
||||
if num_public(statements) > 0 {
|
||||
lwe_dimension + 1
|
||||
} else {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
fn encoding_factor<S: TorusZq + TorusOps>(plaintext_bits: PlaintextBits) -> S::Zq {
|
||||
let x = S::from_u64(0x1u64 << (S::BITS - plaintext_bits.0));
|
||||
S::Zq::from(x)
|
||||
}
|
||||
|
||||
struct IdxOffsets {
|
||||
public_keys: usize,
|
||||
public_e: usize,
|
||||
private_a: usize,
|
||||
private_e: usize,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn compute_a_column_offsets<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
lwe: &LweDef,
|
||||
) -> IdxOffsets {
|
||||
let lwe_dimension = lwe.dim.0;
|
||||
let public_keys = num_messages(statements);
|
||||
let public_e = public_keys + num_public(statements) * lwe_dimension;
|
||||
let private_a = public_e + num_public(statements);
|
||||
let private_e = private_a + num_private(statements) * lwe_dimension;
|
||||
|
||||
IdxOffsets {
|
||||
public_keys,
|
||||
public_e,
|
||||
private_a,
|
||||
private_e,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_a<S: TorusZq + TorusOps>(
|
||||
statements: &[ProofStatement<S>],
|
||||
lwe: &LweDef,
|
||||
plaintext_bits: PlaintextBits,
|
||||
) -> Matrix<Polynomial<S::Zq>> {
|
||||
let lwe_dimension = lwe.dim.0;
|
||||
let offsets = compute_a_column_offsets(statements, lwe);
|
||||
|
||||
let (rows, cols) = proof_matrix_dim(statements, lwe_dimension);
|
||||
|
||||
let mut a = Matrix::<Polynomial<S::Zq>>::new(rows, cols);
|
||||
|
||||
assert!(plaintext_bits.0 > 0);
|
||||
|
||||
let msg_encode = encoding_factor::<S>(plaintext_bits);
|
||||
|
||||
let mut cur_public = 0;
|
||||
let mut cur_private = 0;
|
||||
|
||||
for (i, s) in statements.iter().enumerate() {
|
||||
let (is_public, public_key, message_id, ciphertext) = match s {
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext,
|
||||
message_id,
|
||||
} => (false, None, *message_id, *ciphertext),
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
public_key,
|
||||
ciphertext,
|
||||
message_id,
|
||||
} => (true, Some(public_key), *message_id, *ciphertext),
|
||||
};
|
||||
|
||||
// Insert the message coefficient. For private encryptions, this goes in
|
||||
// the constant coefficient. For public, it goes in the d-1 coefficient.
|
||||
let coeffs = if !is_public {
|
||||
vec![msg_encode.clone()]
|
||||
} else {
|
||||
let mut coeffs = vec![S::Zq::from(<S as Zero>::zero()); lwe_dimension + 1];
|
||||
coeffs[lwe_dimension] = msg_encode.clone();
|
||||
coeffs
|
||||
};
|
||||
|
||||
debug_assert_eq!(a[(i, message_id)], Polynomial::zero());
|
||||
a[(i, message_id)] = Polynomial::new(&coeffs);
|
||||
|
||||
if is_public {
|
||||
// Push the public key
|
||||
let pk = public_key.unwrap();
|
||||
|
||||
for (j, z) in pk.enc_zeros(lwe).enumerate() {
|
||||
let (z_a, z_b) = z.a_b(lwe);
|
||||
let mut coeffs = z_a
|
||||
.iter()
|
||||
.map(|x| S::Zq::from(x.inner()))
|
||||
.collect::<Vec<_>>();
|
||||
coeffs.push(S::Zq::from(z_b.inner()));
|
||||
|
||||
let public_key_idx = offsets.public_keys + cur_public * lwe_dimension + j;
|
||||
|
||||
debug_assert_eq!(a[(i, public_key_idx)], Polynomial::zero());
|
||||
a[(i, public_key_idx)] = Polynomial::new(&coeffs);
|
||||
}
|
||||
|
||||
// Push the randomness
|
||||
debug_assert_eq!(a[(i, offsets.public_e + cur_public)], Polynomial::zero());
|
||||
a[(i, offsets.public_e + cur_public)] = Polynomial::one();
|
||||
cur_public += 1;
|
||||
} else {
|
||||
let (c_a, _c_b) = ciphertext.a_b(lwe);
|
||||
|
||||
// Place the a values of the cipertext in the matrix.
|
||||
for (j, a_j) in c_a.iter().enumerate() {
|
||||
let private_key_idx = offsets.private_a + j + cur_private * lwe_dimension;
|
||||
|
||||
debug_assert_eq!(a[(i, private_key_idx)], Polynomial::zero());
|
||||
a[(i, private_key_idx)] = Polynomial::new(&[S::Zq::from(a_j.inner())]);
|
||||
}
|
||||
|
||||
debug_assert_eq!(a[(i, offsets.private_e)], Polynomial::zero());
|
||||
|
||||
a[(i, offsets.private_e + i)] = Polynomial::one();
|
||||
cur_private += 1;
|
||||
}
|
||||
}
|
||||
|
||||
a
|
||||
}
|
||||
|
||||
fn compute_t<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
lwe: &LweDef,
|
||||
) -> PolynomialMatrix<S::Zq> {
|
||||
let mut t = PolynomialMatrix::new(statements.len(), 1);
|
||||
|
||||
for (i, s) in statements.iter().enumerate() {
|
||||
match s {
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: c,
|
||||
message_id: _,
|
||||
} => {
|
||||
let (_, c_b) = c.a_b(lwe);
|
||||
|
||||
debug_assert_eq!(t[(i, 0)], Polynomial::zero());
|
||||
t[(i, 0)] = Polynomial::new(&[S::Zq::from(c_b.inner())]);
|
||||
}
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
public_key: _,
|
||||
message_id: _,
|
||||
ciphertext: c,
|
||||
} => {
|
||||
let (c_a, c_b) = c.a_b(lwe);
|
||||
|
||||
let mut coeffs = c_a
|
||||
.iter()
|
||||
.map(|x| S::Zq::from(x.inner()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
coeffs.push(S::Zq::from(c_b.inner()));
|
||||
|
||||
debug_assert_eq!(t[(i, 0)], Polynomial::zero());
|
||||
t[(i, 0)] = Polynomial::new(&coeffs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn compute_s<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
witness: &[Witness<S>],
|
||||
messages: &[Torus<S>],
|
||||
lwe: &LweDef,
|
||||
) -> PolynomialMatrix<S::Zq> {
|
||||
assert_eq!(statements.len(), witness.len());
|
||||
|
||||
let lwe_dimension = lwe.dim.0;
|
||||
|
||||
// If the a matrix is rows x cols, the S witness must have 'cols' rows.
|
||||
let (_, cols) = proof_matrix_dim(statements, lwe_dimension);
|
||||
|
||||
// Column offsets in the A matrix become row offsets in the S witness.
|
||||
let offsets = compute_a_column_offsets(statements, lwe);
|
||||
|
||||
let mut s = PolynomialMatrix::new(cols, 1);
|
||||
|
||||
// Put the messages into the witness
|
||||
for (i, m) in messages.iter().take(num_messages(statements)).enumerate() {
|
||||
debug_assert_eq!(s[(i, 0)], Polynomial::zero());
|
||||
s[(i, 0)] = Polynomial::new(&[S::Zq::from(m.inner())]);
|
||||
}
|
||||
|
||||
let public_entries = witness.iter().filter_map(|x| match x {
|
||||
Witness::PublicKeyEncryption { randomness } => Some(randomness),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
// Put public 'r' and 'e' randomness into the witness
|
||||
for (i, rnd) in public_entries.enumerate() {
|
||||
for (j, r) in rnd.r.iter().enumerate() {
|
||||
let public_key_index = offsets.public_keys + j + i * lwe_dimension;
|
||||
|
||||
debug_assert_eq!(s[(public_key_index, 0)], Polynomial::zero());
|
||||
s[(public_key_index, 0)] = Polynomial::new(&[S::Zq::from(*r)]);
|
||||
}
|
||||
|
||||
let mut coeffs = rnd
|
||||
.e
|
||||
.a_b(lwe)
|
||||
.0
|
||||
.iter()
|
||||
.map(|x| S::Zq::from(x.inner()))
|
||||
.collect::<Vec<_>>();
|
||||
coeffs.push(S::Zq::from(rnd.e.a_b(lwe).1.inner()));
|
||||
|
||||
let public_randomness_index = offsets.public_e + i;
|
||||
|
||||
debug_assert_eq!(s[(public_randomness_index, 0)], Polynomial::zero());
|
||||
s[(public_randomness_index, 0)] = Polynomial::new(&coeffs);
|
||||
}
|
||||
|
||||
let private_keys = witness.iter().filter_map(|x| match x {
|
||||
Witness::PrivateKeyEncryption {
|
||||
randomness,
|
||||
private_key,
|
||||
} => Some((private_key, randomness)),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
// Put the secret keys into the witness.
|
||||
for (i, (sk, e)) in private_keys.enumerate() {
|
||||
for (j, sk) in sk.s().iter().enumerate() {
|
||||
let private_key_index = offsets.private_a + j + i * lwe_dimension;
|
||||
|
||||
debug_assert_eq!(s[(private_key_index, 0)], Polynomial::zero());
|
||||
s[(private_key_index, 0)] = Polynomial::new(&[S::Zq::from(*sk)]);
|
||||
}
|
||||
|
||||
let private_randomness_index = offsets.private_e + i;
|
||||
|
||||
debug_assert_eq!(s[(private_randomness_index, 0)], Polynomial::zero());
|
||||
s[(private_randomness_index, 0)] = Polynomial::new(&[S::Zq::from(e.inner())]);
|
||||
}
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn num_private<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>]) -> usize {
|
||||
// Public key statements require 2 rows while private key statements require only one.
|
||||
statements
|
||||
.iter()
|
||||
.filter(|x| {
|
||||
matches!(
|
||||
x,
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: _,
|
||||
message_id: _
|
||||
}
|
||||
)
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn num_public<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>]) -> usize {
|
||||
statements.len() - num_private(statements)
|
||||
}
|
||||
|
||||
/// Returns a tuple of the `(rows, cols)` in SDLP's `A` matrix.
|
||||
#[inline]
|
||||
fn proof_matrix_dim<S: TorusOps + TorusZq>(
|
||||
statements: &[ProofStatement<S>],
|
||||
lwe_dimension: usize,
|
||||
) -> (usize, usize) {
|
||||
let num_rows = statements.len();
|
||||
|
||||
let num_cols = num_messages(statements) // message terms
|
||||
+ num_public(statements) * lwe_dimension // public key terms
|
||||
+ num_private(statements) * lwe_dimension // a terms
|
||||
+ num_rows; // Public + private e term '1' coeffs
|
||||
|
||||
(num_rows, num_cols)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn num_messages<S: TorusOps + TorusZq>(statements: &[ProofStatement<S>]) -> usize {
|
||||
statements.iter().fold(0usize, |max, x| {
|
||||
let message_id = match x {
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
public_key: _,
|
||||
ciphertext: _,
|
||||
message_id,
|
||||
} => message_id,
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: _,
|
||||
message_id,
|
||||
} => message_id,
|
||||
};
|
||||
|
||||
usize::max(max, *message_id)
|
||||
}) + 1
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use logproof::{InnerProductVerifierKnowledge, LogProof, LogProofGenerators};
|
||||
use merlin::Transcript;
|
||||
use rand::{thread_rng, RngCore};
|
||||
|
||||
use crate::{
|
||||
high_level::*,
|
||||
zkp::{num_private, num_public},
|
||||
LweDef, LweDimension, LWE_512_80,
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn can_compute_a_column_offsets() {
|
||||
let lwe = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
let lwe_dimension = lwe.dim.0;
|
||||
|
||||
let ct = LweCiphertext::<u64>::zero(&lwe);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(&lwe);
|
||||
let pk = keygen::generate_lwe_pk(&sk, &lwe);
|
||||
|
||||
let statements = [
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
message_id: 1,
|
||||
},
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
message_id: 0,
|
||||
},
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
public_key: &pk,
|
||||
message_id: 0,
|
||||
},
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
message_id: 1,
|
||||
},
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
message_id: 0,
|
||||
},
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
public_key: &pk,
|
||||
message_id: 1,
|
||||
},
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
public_key: &pk,
|
||||
message_id: 1,
|
||||
},
|
||||
];
|
||||
|
||||
let idx = compute_a_column_offsets(&statements, &lwe);
|
||||
|
||||
let num_messages = num_messages(&statements);
|
||||
let num_private = num_private(&statements);
|
||||
let num_public = num_public(&statements);
|
||||
|
||||
assert_eq!(num_messages, 2);
|
||||
assert_eq!(num_private, 4);
|
||||
assert_eq!(num_public, 3);
|
||||
|
||||
// We have 2 messages that precede this.
|
||||
assert_eq!(idx.public_keys, num_messages);
|
||||
// We have 3 different public encryptions, so there are 3 public key
|
||||
// column entries. That 2 use the same key is immaterial, as they
|
||||
// must match different randomness in the witness.
|
||||
assert_eq!(idx.public_e, idx.public_keys + num_public * lwe_dimension);
|
||||
// We have 3 public key encryptions that contribute to this offset.
|
||||
// Each 'e' is a polynomial of degree `lwe_dimension`.
|
||||
assert_eq!(idx.private_a, idx.public_e + num_public);
|
||||
// Finally, each secret_a entry takes `lwe_dimension` columns.
|
||||
assert_eq!(idx.private_e, idx.private_a + num_private * lwe_dimension);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn num_messages_works() {
|
||||
let lwe = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(&lwe);
|
||||
let pk = keygen::generate_lwe_pk(&sk, &lwe);
|
||||
|
||||
let ct = encryption::trivial_lwe(0, &lwe, PlaintextBits(1));
|
||||
|
||||
let num_messages = num_messages(&[
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
message_id: 0,
|
||||
},
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
ciphertext: &ct,
|
||||
message_id: 1,
|
||||
},
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
public_key: &pk,
|
||||
ciphertext: &ct,
|
||||
message_id: 0,
|
||||
},
|
||||
]);
|
||||
|
||||
assert_eq!(num_messages, 2);
|
||||
}
|
||||
|
||||
fn prove_and_verify<S: TorusOps + TorusZq>(pk: &LogProofProverKnowledge<S::Zq>) {
|
||||
let gen: LogProofGenerators = LogProofGenerators::new(pk.vk.l() as usize);
|
||||
let u = InnerProductVerifierKnowledge::get_u();
|
||||
let mut p_t = Transcript::new(b"test");
|
||||
|
||||
let proof = LogProof::create(&mut p_t, pk, &gen.g, &gen.h, &u);
|
||||
|
||||
let mut v_t = Transcript::new(b"test");
|
||||
|
||||
proof.verify(&mut v_t, &pk.vk, &gen.g, &gen.h, &u).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_secret_key() {
|
||||
let params = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let (ct, rng) =
|
||||
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, PlaintextBits(1));
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(
|
||||
&[ProofStatement::PrivateKeyEncryption {
|
||||
message_id: 0,
|
||||
ciphertext: &ct,
|
||||
}],
|
||||
&[Torus::from(1)],
|
||||
&[Witness::PrivateKeyEncryption {
|
||||
randomness: rng,
|
||||
private_key: &sk,
|
||||
}],
|
||||
¶ms,
|
||||
PlaintextBits(1),
|
||||
);
|
||||
|
||||
prove_and_verify::<u64>(&pk);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_secret_key() {
|
||||
let params = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
|
||||
let (ct0, rng0) =
|
||||
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits);
|
||||
let (ct1, rng1) =
|
||||
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits);
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(
|
||||
&[
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
message_id: 0,
|
||||
ciphertext: &ct0,
|
||||
},
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
message_id: 1,
|
||||
ciphertext: &ct1,
|
||||
},
|
||||
],
|
||||
&[Torus::from(1), Torus::from(1)],
|
||||
&[
|
||||
Witness::PrivateKeyEncryption {
|
||||
randomness: rng0,
|
||||
private_key: &sk,
|
||||
},
|
||||
Witness::PrivateKeyEncryption {
|
||||
randomness: rng1,
|
||||
private_key: &sk,
|
||||
},
|
||||
],
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
|
||||
prove_and_verify::<u64>(&pk);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_public_key() {
|
||||
let params = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let pk = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
let (ct, rng) = encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits);
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(
|
||||
&[ProofStatement::PublicKeyEncryption {
|
||||
message_id: 0,
|
||||
public_key: &pk,
|
||||
ciphertext: &ct,
|
||||
}],
|
||||
&[Torus::from(1)],
|
||||
&[Witness::PublicKeyEncryption { randomness: &rng }],
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
|
||||
prove_and_verify::<u64>(&pk);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_public_key() {
|
||||
let params = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let pk = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
let (ct0, rng0) = encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits);
|
||||
let (ct1, rng1) = encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits);
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(
|
||||
&[
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
message_id: 0,
|
||||
public_key: &pk,
|
||||
ciphertext: &ct0,
|
||||
},
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
message_id: 1,
|
||||
public_key: &pk,
|
||||
ciphertext: &ct1,
|
||||
},
|
||||
],
|
||||
&[Torus::from(1), Torus::from(1)],
|
||||
&[
|
||||
Witness::PublicKeyEncryption { randomness: &rng0 },
|
||||
Witness::PublicKeyEncryption { randomness: &rng1 },
|
||||
],
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
|
||||
prove_and_verify::<u64>(&pk);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn one_public_one_private() {
|
||||
let params = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let pk = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
let (ct_priv, rng_priv) =
|
||||
encryption::encrypt_lwe_secret_and_return_randomness(1, &sk, ¶ms, bits);
|
||||
let (ct_pub, rng_pub) =
|
||||
encryption::encrypt_lwe_and_return_randomness(1, &pk, ¶ms, bits);
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(
|
||||
&[
|
||||
ProofStatement::PrivateKeyEncryption {
|
||||
message_id: 0,
|
||||
ciphertext: &ct_priv,
|
||||
},
|
||||
ProofStatement::PublicKeyEncryption {
|
||||
message_id: 0,
|
||||
public_key: &pk,
|
||||
ciphertext: &ct_pub,
|
||||
},
|
||||
],
|
||||
&[Torus::from(1)],
|
||||
&[
|
||||
Witness::PrivateKeyEncryption {
|
||||
randomness: rng_priv,
|
||||
private_key: &sk,
|
||||
},
|
||||
Witness::PublicKeyEncryption {
|
||||
randomness: &rng_pub,
|
||||
},
|
||||
],
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
|
||||
prove_and_verify::<u64>(&pk);
|
||||
}
|
||||
|
||||
#[ignore]
|
||||
#[test]
|
||||
fn complex_examples() {
|
||||
let params = LweDef {
|
||||
dim: LweDimension(4),
|
||||
..LWE_512_80
|
||||
};
|
||||
let bits = PlaintextBits(1);
|
||||
|
||||
let case = || {
|
||||
let sk = keygen::generate_binary_lwe_sk(¶ms);
|
||||
let pk = keygen::generate_lwe_pk(&sk, ¶ms);
|
||||
|
||||
let num_messages = thread_rng().next_u64() as usize % 7 + 1;
|
||||
let num_secret_encryptions = thread_rng().next_u64() as usize % 8;
|
||||
let num_public_encryptions = thread_rng().next_u64() as usize % 8;
|
||||
|
||||
let messages = (0..num_messages)
|
||||
.map(|_| thread_rng().next_u64() % 2)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Skip trivial cases. I don't think SDLP allows it and it's boring..
|
||||
if num_public_encryptions == 0 && num_secret_encryptions == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut statements = vec![];
|
||||
let mut witnesses = vec![];
|
||||
let mut private_info = vec![];
|
||||
let mut public_info = vec![];
|
||||
|
||||
for _ in 0..num_secret_encryptions {
|
||||
let msg_id = thread_rng().next_u64() as usize % num_messages;
|
||||
|
||||
let (ct, noise) = encryption::encrypt_lwe_secret_and_return_randomness(
|
||||
messages[msg_id],
|
||||
&sk,
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
private_info.push((ct, noise, msg_id));
|
||||
}
|
||||
|
||||
for (ct, noise, msg_id) in private_info.iter() {
|
||||
statements.push(ProofStatement::PrivateKeyEncryption {
|
||||
message_id: *msg_id,
|
||||
ciphertext: ct,
|
||||
});
|
||||
witnesses.push(Witness::PrivateKeyEncryption {
|
||||
randomness: *noise,
|
||||
private_key: &sk,
|
||||
});
|
||||
}
|
||||
|
||||
for _ in 0..num_public_encryptions {
|
||||
let msg_id = thread_rng().next_u64() as usize % num_messages;
|
||||
|
||||
let (ct, noise) = encryption::encrypt_lwe_and_return_randomness(
|
||||
messages[msg_id],
|
||||
&pk,
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
public_info.push((ct, noise, msg_id));
|
||||
}
|
||||
|
||||
for (ct, noise, msg_id) in public_info.iter() {
|
||||
statements.push(ProofStatement::PublicKeyEncryption {
|
||||
message_id: *msg_id,
|
||||
ciphertext: ct,
|
||||
public_key: &pk,
|
||||
});
|
||||
witnesses.push(Witness::PublicKeyEncryption { randomness: noise });
|
||||
}
|
||||
|
||||
let messages = messages.iter().map(|x| Torus::from(*x)).collect::<Vec<_>>();
|
||||
|
||||
let pk = generate_tfhe_sdlp_prover_knowledge(
|
||||
&statements,
|
||||
&messages,
|
||||
&witnesses,
|
||||
¶ms,
|
||||
bits,
|
||||
);
|
||||
|
||||
prove_and_verify::<u64>(&pk);
|
||||
};
|
||||
|
||||
for _ in 0..5 {
|
||||
case();
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user