add alignment

This commit is contained in:
Samir Menon
2022-04-12 09:21:10 +02:00
parent f40bfc6239
commit 94b40205af
15 changed files with 550 additions and 65 deletions

View File

@@ -12,6 +12,7 @@ default = []
[dependencies]
spiral-rs = { path = "../spiral-rs" }
rand = { version = "0.8.5" }
wasm-bindgen = "0.2.74"
# The `console_error_panic_hook` crate provides better debugging of panics by

View File

@@ -1,5 +1,5 @@
mod utils;
use rand::{rngs::ThreadRng, thread_rng};
use spiral_rs::{client::*, discrete_gaussian::*, params::*, util::*};
use wasm_bindgen::prelude::*;
@@ -19,17 +19,19 @@ macro_rules! console_log {
// Avoids a lifetime in the return signature of bound Rust functions
#[wasm_bindgen]
pub struct WrappedClient {
client: Client<'static>,
client: Client<'static, ThreadRng>,
}
// Unsafe global with a static lifetime
// Accessed unsafely only once, at load / setup
static mut PARAMS: Params = get_empty_params();
static mut RNG: Option<ThreadRng> = None;
// Very simply test to ensure random generation is not obviously biased.
fn dg_seems_okay() {
let params = get_test_params();
let mut dg = DiscreteGaussian::init(&params);
let mut rng = thread_rng();
let mut dg = DiscreteGaussian::init(&params, &mut rng);
let mut v = Vec::new();
let trials = 10000;
let mut sum = 0;
@@ -72,7 +74,8 @@ pub fn initialize(json_params: Option<String>) -> WrappedClient {
// this minimal unsafe operation is need to initialize state
unsafe {
PARAMS = params_from_json(&cfg);
client = Client::init(&PARAMS);
RNG = Some(thread_rng());
client = Client::init(&PARAMS, RNG.as_mut().unwrap());
}
WrappedClient { client }

View File

@@ -1,3 +1,3 @@
[build]
# target = "x86_64-unknown-linux-gnu"
# rustflags = ["-C", "target-feature=+avx2"]
target = "x86_64-unknown-linux-gnu"
rustflags = ["-C", "target-feature=+avx2"]

315
spiral-rs/Cargo.lock generated
View File

@@ -2,6 +2,41 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "ahash"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47"
dependencies = [
"getrandom",
"once_cell",
"version_check",
]
[[package]]
name = "arrayvec"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd9fd44efafa8690358b7408d253adf110036b88f55672a933f01d616ad9b1b9"
dependencies = [
"nodrop",
]
[[package]]
name = "atty"
version = "0.2.14"
@@ -19,6 +54,21 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "backtrace"
version = "0.3.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
]
[[package]]
name = "base64"
version = "0.13.0"
@@ -27,9 +77,9 @@ checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bitflags"
version = "1.3.2"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693"
[[package]]
name = "bstr"
@@ -49,6 +99,12 @@ version = "3.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4a45a46ab1f2412e53d3a0ade76ffad2025804294569aae387231a0cd6e0899"
[[package]]
name = "bytemuck"
version = "1.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdead85bdec19c194affaeeb670c0e41fe23de31459efd1c174d049269cf02cc"
[[package]]
name = "bytes"
version = "1.1.0"
@@ -103,6 +159,15 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
[[package]]
name = "cpp_demangle"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eeaa953eaad386a53111e47172c2fedba671e5684c8dd601a5f474f4f118710f"
dependencies = [
"cfg-if",
]
[[package]]
name = "criterion"
version = "0.3.5"
@@ -205,6 +270,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "debugid"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6ee87af31d84ef885378aebca32be3d682b0e0dc119d5b4860a2c5bb5046730"
dependencies = [
"uuid",
]
[[package]]
name = "either"
version = "1.6.1"
@@ -321,6 +395,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "gimli"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78cc372d058dcf6d5ecd98510e7fbc9e5aec4d21de70f65fea8fecebcd881bd4"
[[package]]
name = "h2"
version = "0.3.12"
@@ -453,6 +533,24 @@ dependencies = [
"hashbrown",
]
[[package]]
name = "inferno"
version = "0.10.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3886428c6400486522cf44b8626e7b94ad794c14390290f2a274dcf728a58f"
dependencies = [
"ahash",
"atty",
"indexmap",
"itoa 1.0.1",
"lazy_static",
"log",
"num-format",
"quick-xml",
"rgb",
"str_stack",
]
[[package]]
name = "instant"
version = "0.1.12"
@@ -510,6 +608,16 @@ version = "0.2.122"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec647867e2bf0772e28c8bcde4f0d19a9216916e890543b5a03ed8ef27b8f259"
[[package]]
name = "lock_api"
version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327fa5b6a6940e4699ec49a9beae1ea4845c6bab9314e4f84ac68742139d8c53"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.14"
@@ -531,6 +639,15 @@ version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a"
[[package]]
name = "memmap2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "057a3db23999c867821a7a59feb06a578fcb03685e983dff90daf9e7d24ac08f"
dependencies = [
"libc",
]
[[package]]
name = "memoffset"
version = "0.6.5"
@@ -546,6 +663,16 @@ version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d"
[[package]]
name = "miniz_oxide"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b"
dependencies = [
"adler",
"autocfg",
]
[[package]]
name = "mio"
version = "0.8.2"
@@ -587,6 +714,25 @@ dependencies = [
"tempfile",
]
[[package]]
name = "nix"
version = "0.20.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5e06129fb611568ef4e868c14b326274959aa70ff7776e9d55323531c374945"
dependencies = [
"bitflags",
"cc",
"cfg-if",
"libc",
"memoffset",
]
[[package]]
name = "nodrop"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "ntapi"
version = "0.3.7"
@@ -596,6 +742,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-format"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bafe4179722c2894288ee77a9f044f02811c86af699344c498b0840c698a2465"
dependencies = [
"arrayvec",
"itoa 0.4.8",
]
[[package]]
name = "num-traits"
version = "0.2.14"
@@ -615,6 +771,15 @@ dependencies = [
"libc",
]
[[package]]
name = "object"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.10.0"
@@ -660,6 +825,31 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "parking_lot"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99"
dependencies = [
"instant",
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d76e8e1493bcac0d2766c42737f34458f1c8c50c0d23bcb24ea953affb273216"
dependencies = [
"cfg-if",
"instant",
"libc",
"redox_syscall",
"smallvec",
"winapi",
]
[[package]]
name = "percent-encoding"
version = "2.1.0"
@@ -712,6 +902,25 @@ dependencies = [
"plotters-backend",
]
[[package]]
name = "pprof"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d78fcdebc1569625891b4fefed7ece660af53082529d03d9c6e8d01b3880ab92"
dependencies = [
"backtrace",
"criterion",
"inferno",
"lazy_static",
"libc",
"log",
"nix",
"parking_lot",
"symbolic-demangle",
"tempfile",
"thiserror",
]
[[package]]
name = "ppv-lite86"
version = "0.2.16"
@@ -727,6 +936,15 @@ dependencies = [
"unicode-xid",
]
[[package]]
name = "quick-xml"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8533f14c8382aaad0d592c812ac3b826162128b65662331e1127b45c3d18536b"
dependencies = [
"memchr",
]
[[package]]
name = "quote"
version = "1.0.15"
@@ -866,6 +1084,21 @@ dependencies = [
"winreg",
]
[[package]]
name = "rgb"
version = "0.8.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e74fdc210d8f24a7dbfedc13b04ba5764f5232754ccebfdf5fff1bad791ccbc6"
dependencies = [
"bytemuck",
]
[[package]]
name = "rustc-demangle"
version = "0.1.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342"
[[package]]
name = "rustc_version"
version = "0.4.0"
@@ -908,9 +1141,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "security-framework"
version = "2.6.1"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dc14f172faf8a0194a3aded622712b0de276821addc574fa54fc0a1167e10dc"
checksum = "23a2ac85147a3a11d77ecf1bc7166ec0b92febfa4461c37944e180f319ece467"
dependencies = [
"bitflags",
"core-foundation",
@@ -991,6 +1224,12 @@ version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9def91fd1e018fe007022791f865d0ccc9b3a0d5001e01aabb8b40e46000afb5"
[[package]]
name = "smallvec"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
[[package]]
name = "socket2"
version = "0.4.4"
@@ -1007,11 +1246,47 @@ version = "0.1.0"
dependencies = [
"criterion",
"getrandom",
"pprof",
"rand",
"reqwest",
"serde_json",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "str_stack"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb"
[[package]]
name = "symbolic-common"
version = "8.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6aac7b803adc9ee75344af7681969f76d4b38e4723c6eaacf3b28f5f1d87ff"
dependencies = [
"debugid",
"memmap2",
"stable_deref_trait",
"uuid",
]
[[package]]
name = "symbolic-demangle"
version = "8.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8143ea5aa546f86c64f9b9aafdd14223ffad4ecd2d58575c63c21335909c99a7"
dependencies = [
"cpp_demangle",
"rustc-demangle",
"symbolic-common",
]
[[package]]
name = "syn"
version = "1.0.86"
@@ -1046,6 +1321,26 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "thiserror"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
@@ -1182,12 +1477,24 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "uuid"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc5cf98d8186244414c848017f0e2676b3fcb46807f6668a97dfe67359a3c4b7"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.3.2"

View File

@@ -11,6 +11,7 @@ serde_json = "1.0"
[dev-dependencies]
criterion = "0.3"
pprof = { version = "0.4", features = ["flamegraph", "criterion"] }
[[bench]]
name = "ntt"

View File

@@ -1,4 +1,6 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use pprof::criterion::{Output, PProfProfiler};
use spiral_rs::client::*;
use spiral_rs::poly::*;
use spiral_rs::server::*;
@@ -9,9 +11,9 @@ fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("sample-size");
group
.sample_size(10)
.measurement_time(Duration::from_secs(10));
.measurement_time(Duration::from_secs(30));
let params = get_short_keygen_params();
let params = get_expansion_testing_params();
let v_neg1 = params.get_v_neg1();
let mut seeded_rng = get_seeded_rng();
let mut client = Client::init(&params, &mut seeded_rng);
@@ -29,6 +31,7 @@ fn criterion_benchmark(c: &mut Criterion) {
let v_w_left = public_params.v_expansion_left.unwrap();
let v_w_right = public_params.v_expansion_right.unwrap();
// note: the benchmark on AVX2 is 545ms for the c++ impl
group.bench_function("coeff exp", |b| {
b.iter(|| {
coefficient_expansion(
@@ -46,5 +49,10 @@ fn criterion_benchmark(c: &mut Criterion) {
group.finish();
}
criterion_group!(benches, criterion_benchmark);
// criterion_group!(benches, criterion_benchmark);
criterion_group! {
name = benches;
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = criterion_benchmark
}
criterion_main!(benches);

View File

@@ -0,0 +1,75 @@
use std::{alloc::{alloc_zeroed, dealloc, Layout}, slice::{from_raw_parts, from_raw_parts_mut}, ops::{Index, IndexMut}, mem::size_of};
const ALIGN_SIMD: usize = 64; // enough to support AVX-512
pub type AlignedMemory64 = AlignedMemory<ALIGN_SIMD>;
pub struct AlignedMemory<const ALIGN: usize> {
p: *mut u64,
sz_u64: usize,
layout: Layout
}
impl<const ALIGN: usize> AlignedMemory<{ALIGN}> {
pub fn new(sz_u64: usize) -> Self {
let sz_bytes = sz_u64 * size_of::<u64>();
let layout = Layout::from_size_align(sz_bytes, ALIGN).unwrap();
let ptr;
unsafe {
ptr = alloc_zeroed(layout);
}
Self {
p: ptr as *mut u64,
sz_u64,
layout
}
}
pub fn as_slice(&self) -> &[u64] {
unsafe {
from_raw_parts(self.p, self.sz_u64)
}
}
pub fn as_mut_slice(&mut self) -> &mut [u64] {
unsafe {
from_raw_parts_mut(self.p, self.sz_u64)
}
}
pub fn len(&self) -> usize {
self.sz_u64
}
}
impl<const ALIGN: usize> Drop for AlignedMemory<{ALIGN}> {
fn drop(&mut self) {
unsafe {
dealloc(self.p as *mut u8, self.layout);
}
}
}
impl<const ALIGN: usize> Index<usize> for AlignedMemory<{ALIGN}> {
type Output = u64;
fn index(&self, index: usize) -> &Self::Output {
&self.as_slice()[index]
}
}
impl<const ALIGN: usize> IndexMut<usize> for AlignedMemory<{ALIGN}> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.as_mut_slice()[index]
}
}
impl<const ALIGN: usize> Clone for AlignedMemory<{ALIGN}> {
fn clone(&self) -> Self {
let mut out = Self::new(self.sz_u64);
out.as_mut_slice().copy_from_slice(self.as_slice());
out
}
}

View File

@@ -1,8 +1,7 @@
use crate::{
arith::*, discrete_gaussian::*, gadget::*, number_theory::*, params::*, poly::*, util::*,
};
use rand::rngs::StdRng;
use rand::{thread_rng, Rng};
use rand::{Rng};
use std::iter::once;
fn serialize_polymatrix(vec: &mut Vec<u8>, a: &PolyMatrixRaw) {
@@ -500,6 +499,8 @@ impl<'a, TRng: Rng> Client<'a, TRng> {
#[cfg(test)]
mod test {
use rand::thread_rng;
use super::*;
fn assert_first8(m: &[u64], gold: [u64; 8]) {
@@ -517,7 +518,7 @@ mod test {
let mut rng = thread_rng();
let client = Client::init(&params, &mut rng);
assert_eq!(client.stop_round, 6);
assert_eq!(client.stop_round, 5);
assert_eq!(client.g, 10);
assert_eq!(*client.params, params);
}
@@ -531,7 +532,7 @@ mod test {
let public_params = client.generate_keys();
assert_first8(
&public_params.v_conversion.unwrap()[0].data,
public_params.v_conversion.unwrap()[0].data.as_slice(),
[
253586619, 247235120, 141892996, 163163429, 15531298, 200914775, 125109567,
75889562,
@@ -539,7 +540,7 @@ mod test {
);
assert_first8(
&client.sk_gsw.data,
client.sk_gsw.data.as_slice(),
[1, 5, 0, 3, 1, 3, 66974689739603967, 3],
);
}

View File

@@ -1,7 +1,6 @@
use rand::distributions::WeightedIndex;
use rand::prelude::Distribution;
use rand::Rng;
use rand::{rngs::ThreadRng, thread_rng};
use crate::params::*;
use crate::poly::*;
@@ -53,6 +52,8 @@ impl<'a, T: Rng> DiscreteGaussian<'a, T> {
#[cfg(test)]
mod test {
use rand::thread_rng;
use super::*;
use crate::util::*;

View File

@@ -1,5 +1,3 @@
use std::primitive;
use crate::{params::*, poly::*};
pub fn get_bits_per(params: &Params, dim: usize) -> usize {
@@ -33,16 +31,17 @@ pub fn build_gadget(params: &Params, rows: usize, cols: usize) -> PolyMatrixRaw
g
}
pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
let params = inp.params;
pub fn gadget_invert_rdim<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>, rdim: usize) {
assert_eq!(out.cols, inp.cols);
let num_elems = mx / inp.rows;
let params = inp.params;
let mx = out.rows;
let num_elems = mx / rdim;
let bits_per = get_bits_per(params, num_elems);
let mask = (1u64 << bits_per) - 1;
let mut out = PolyMatrixRaw::zero(params, mx, inp.cols);
for i in 0..inp.cols {
for j in 0..inp.rows {
for j in 0..rdim {
for z in 0..params.poly_len {
let val = inp.get_poly(j, i)[z];
for k in 0..num_elems {
@@ -53,11 +52,20 @@ pub fn gadget_invert<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a
None => 0,
};
out.get_poly_mut(j + k * inp.rows, i)[z] = piece;
out.get_poly_mut(j + k * rdim, i)[z] = piece;
}
}
}
}
}
pub fn gadget_invert<'a>(out: &mut PolyMatrixRaw<'a>, inp: &PolyMatrixRaw<'a>) {
gadget_invert_rdim(out, inp, inp.rows);
}
pub fn gadget_invert_alloc<'a>(mx: usize, inp: &PolyMatrixRaw<'a>) -> PolyMatrixRaw<'a> {
let mut out = PolyMatrixRaw::zero(inp.params, mx, inp.cols);
gadget_invert(&mut out, inp);
out
}
@@ -74,7 +82,7 @@ mod test {
mat.get_poly_mut(0, 0)[37] = 3;
mat.get_poly_mut(1, 0)[37] = 6;
let log_q = params.modulus_log2 as usize;
let result = gadget_invert(2 * log_q, &mat);
let result = gadget_invert_alloc(2 * log_q, &mat);
assert_eq!(result.get_poly(0, 0)[37], 1);
assert_eq!(result.get_poly(2, 0)[37], 1);

View File

@@ -2,6 +2,7 @@ pub mod arith;
pub mod discrete_gaussian;
pub mod number_theory;
pub mod util;
pub mod aligned_memory;
pub mod gadget;
pub mod ntt;

View File

@@ -156,8 +156,8 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
// Use AVX2 here
let p_x = &mut op[j] as *mut u64;
let p_y = &mut op[j + t] as *mut u64;
let x = _mm256_loadu_si256(p_x as *const __m256i);
let y = _mm256_loadu_si256(p_y as *const __m256i);
let x = _mm256_load_si256(p_x as *const __m256i);
let y = _mm256_load_si256(p_y as *const __m256i);
let cmp_val = _mm256_set1_epi64x(two_times_modulus_small as i64);
let gt_mask = _mm256_cmpgt_epi64(x, cmp_val);
@@ -181,8 +181,8 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
let q_final_inverted = _mm256_sub_epi64(cmp_val, q_final);
let new_y = _mm256_add_epi64(curr_x, q_final_inverted);
_mm256_storeu_si256(p_x as *mut __m256i, new_x);
_mm256_storeu_si256(p_y as *mut __m256i, new_y);
_mm256_store_si256(p_x as *mut __m256i, new_x);
_mm256_store_si256(p_y as *mut __m256i, new_y);
}
}
}
@@ -194,7 +194,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
let p_x = &mut operand[i] as *mut u64;
let cmp_val1 = _mm256_set1_epi64x(two_times_modulus_small as i64);
let mut x = _mm256_loadu_si256(p_x as *const __m256i);
let mut x = _mm256_load_si256(p_x as *const __m256i);
let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
x = _mm256_sub_epi64(x, to_subtract);
@@ -203,7 +203,7 @@ pub fn ntt_forward(params: &Params, operand_overall: &mut [u64]) {
gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
x = _mm256_sub_epi64(x, to_subtract);
_mm256_storeu_si256(p_x as *mut __m256i, x);
_mm256_store_si256(p_x as *mut __m256i, x);
}
}
}
@@ -301,8 +301,8 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
// Use AVX2 here
let p_x = &mut op[j] as *mut u64;
let p_y = &mut op[j + t] as *mut u64;
let x = _mm256_loadu_si256(p_x as *const __m256i);
let y = _mm256_loadu_si256(p_y as *const __m256i);
let x = _mm256_load_si256(p_x as *const __m256i);
let y = _mm256_load_si256(p_y as *const __m256i);
let modulus_vec = _mm256_set1_epi64x(modulus as i64);
let two_times_modulus_vec =
@@ -331,8 +331,8 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
let h_tmp_times_modulus = _mm256_mul_epu32(h_tmp, modulus_vec);
let new_y = _mm256_sub_epi64(w_times_t_tmp, h_tmp_times_modulus);
_mm256_storeu_si256(p_x as *mut __m256i, new_x);
_mm256_storeu_si256(p_y as *mut __m256i, new_y);
_mm256_store_si256(p_x as *mut __m256i, new_x);
_mm256_store_si256(p_y as *mut __m256i, new_y);
}
}
}
@@ -343,13 +343,31 @@ pub fn ntt_inverse(params: &Params, operand_overall: &mut [u64]) {
operand[i] -= ((operand[i] >= two_times_modulus) as u64) * two_times_modulus;
operand[i] -= ((operand[i] >= modulus) as u64) * modulus;
}
// for i in (0..n).step_by(4) {
// unsafe {
// let p_x = &mut operand[i] as *mut u64;
// let cmp_val1 = _mm256_set1_epi64x(two_times_modulus as i64);
// let mut x = _mm256_load_si256(p_x as *const __m256i);
// let mut gt_mask = _mm256_cmpgt_epi64(x, cmp_val1);
// let mut to_subtract = _mm256_and_si256(gt_mask, cmp_val1);
// x = _mm256_sub_epi64(x, to_subtract);
// let cmp_val2 = _mm256_set1_epi64x(modulus as i64);
// gt_mask = _mm256_cmpgt_epi64(x, cmp_val2);
// to_subtract = _mm256_and_si256(gt_mask, cmp_val2);
// x = _mm256_sub_epi64(x, to_subtract);
// _mm256_store_si256(p_x as *mut __m256i, x);
// }
// }
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::util::*;
use crate::{util::*, aligned_memory::AlignedMemory64};
use rand::Rng;
fn get_params() -> Params {
@@ -382,7 +400,7 @@ mod test {
#[test]
fn ntt_forward_correct() {
let params = get_params();
let mut v1 = vec![0; 2 * 2048];
let mut v1 = AlignedMemory64::new(2 * 2048);
v1[0] = 100;
v1[2048] = 100;
ntt_forward(&params, v1.as_mut_slice());
@@ -393,7 +411,10 @@ mod test {
#[test]
fn ntt_inverse_correct() {
let params = get_params();
let mut v1 = vec![100; 2 * 2048];
let mut v1 = AlignedMemory64::new(2 * 2048);
for i in 0..v1.len() {
v1[i] = 100;
}
ntt_inverse(&params, v1.as_mut_slice());
assert_eq!(v1[0], 100);
assert_eq!(v1[2048], 100);
@@ -404,7 +425,7 @@ mod test {
#[test]
fn ntt_correct() {
let params = get_params();
let mut v1 = vec![0; params.crt_count * params.poly_len];
let mut v1 = AlignedMemory64::new(params.crt_count * params.poly_len);
let mut rng = rand::thread_rng();
for i in 0..params.crt_count {
for j in 0..params.poly_len {

View File

@@ -6,10 +6,10 @@ use rand::Rng;
use std::cell::RefCell;
use std::ops::{Add, Mul, Neg};
use crate::{arith::*, discrete_gaussian::*, ntt::*, params::*, util::*};
use crate::{arith::*, discrete_gaussian::*, ntt::*, params::*, util::*, aligned_memory::*};
const SCRATCH_SPACE: usize = 8192;
thread_local!(static SCRATCH: RefCell<Vec<u64>> = RefCell::new(vec![0u64; SCRATCH_SPACE]));
thread_local!(static SCRATCH: RefCell<AlignedMemory64> = RefCell::new(AlignedMemory64::new(SCRATCH_SPACE)));
pub trait PolyMatrix<'a> {
fn is_ntt(&self) -> bool;
@@ -59,14 +59,14 @@ pub struct PolyMatrixRaw<'a> {
pub params: &'a Params,
pub rows: usize,
pub cols: usize,
pub data: Vec<u64>,
pub data: AlignedMemory64,
}
pub struct PolyMatrixNTT<'a> {
pub params: &'a Params,
pub rows: usize,
pub cols: usize,
pub data: Vec<u64>,
pub data: AlignedMemory64,
}
impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
@@ -93,7 +93,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
}
fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
let num_coeffs = rows * cols * params.poly_len;
let data: Vec<u64> = vec![0; num_coeffs];
let data = AlignedMemory64::new(num_coeffs);
PolyMatrixRaw {
params,
rows,
@@ -143,7 +143,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixRaw<'a> {
impl<'a> PolyMatrixRaw<'a> {
pub fn identity(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixRaw<'a> {
let num_coeffs = rows * cols * params.poly_len;
let mut data: Vec<u64> = vec![0; num_coeffs];
let mut data = AlignedMemory::new(num_coeffs);
for r in 0..rows {
let c = r;
let idx = r * cols * params.poly_len + c * params.poly_len;
@@ -227,7 +227,7 @@ impl<'a> PolyMatrix<'a> for PolyMatrixNTT<'a> {
}
fn zero(params: &'a Params, rows: usize, cols: usize) -> PolyMatrixNTT<'a> {
let num_coeffs = rows * cols * params.poly_len * params.crt_count;
let data: Vec<u64> = vec![0; num_coeffs];
let data = AlignedMemory::new(num_coeffs);
PolyMatrixNTT {
params,
rows,
@@ -339,14 +339,14 @@ pub fn multiply_add_poly_avx(params: &Params, res: &mut [u64], a: &[u64], b: &[u
let p_x = &a[c * params.poly_len + i] as *const u64;
let p_y = &b[c * params.poly_len + i] as *const u64;
let p_z = &mut res[c * params.poly_len + i] as *mut u64;
let x = _mm256_loadu_si256(p_x as *const __m256i);
let y = _mm256_loadu_si256(p_y as *const __m256i);
let z = _mm256_loadu_si256(p_z as *const __m256i);
let x = _mm256_load_si256(p_x as *const __m256i);
let y = _mm256_load_si256(p_y as *const __m256i);
let z = _mm256_load_si256(p_z as *const __m256i);
let product = _mm256_mul_epu32(x, y);
let out = _mm256_add_epi64(z, product);
_mm256_storeu_si256(p_z as *mut __m256i, out);
_mm256_store_si256(p_z as *mut __m256i, out);
}
}
}
@@ -511,6 +511,21 @@ pub fn to_ntt(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
}
}
pub fn to_ntt_no_reduce(a: &mut PolyMatrixNTT, b: &PolyMatrixRaw) {
let params = a.params;
for r in 0..a.rows {
for c in 0..a.cols {
let pol_src = b.get_poly(r, c);
let pol_dst = a.get_poly_mut(r, c);
for n in 0..params.crt_count {
let idx = n * params.poly_len;
pol_dst[idx..idx + params.poly_len].copy_from_slice(pol_src);
}
ntt_forward(params, pol_dst);
}
}
}
pub fn to_ntt_alloc<'a>(b: &PolyMatrixRaw<'a>) -> PolyMatrixNTT<'a> {
let mut a = PolyMatrixNTT::zero(b.params, b.rows, b.cols);
to_ntt(&mut a, b);

View File

@@ -1,5 +1,5 @@
use crate::arith;
use crate::gadget::gadget_invert;
use crate::gadget::*;
use crate::params::*;
use crate::poly::*;
@@ -15,6 +15,16 @@ pub fn coefficient_expansion(
) {
let poly_len = params.poly_len;
let mut ct = PolyMatrixRaw::zero(params, 2, 1);
let mut ct_auto = PolyMatrixRaw::zero(params, 2, 1);
let mut ct_auto_1 = PolyMatrixRaw::zero(params, 1, 1);
let mut ct_auto_1_ntt = PolyMatrixNTT::zero(params, 1, 1);
let mut ginv_ct_left = PolyMatrixRaw::zero(params, params.t_exp_left, 1);
let mut ginv_ct_left_ntt = PolyMatrixNTT::zero(params, params.t_exp_left, 1);
let mut ginv_ct_right = PolyMatrixRaw::zero(params, params.t_exp_right, 1);
let mut ginv_ct_right_ntt = PolyMatrixNTT::zero(params, params.t_exp_right, 1);
let mut w_times_ginv_ct = PolyMatrixNTT::zero(params, 2, 1);
for r in 0..g {
let num_in = 1 << r;
let num_out = 2 * num_in;
@@ -30,23 +40,36 @@ pub fn coefficient_expansion(
continue;
}
let (w, gadget_dim) = match i % 2 {
0 => (&v_w_left[r], params.t_exp_left),
1 | _ => (&v_w_right[r], params.t_exp_right),
let (w, _gadget_dim, gi_ct, gi_ct_ntt) = match i % 2 {
0 => (&v_w_left[r], params.t_exp_left, &mut ginv_ct_left, &mut ginv_ct_left_ntt),
1 | _ => (&v_w_right[r], params.t_exp_right, &mut ginv_ct_right, &mut ginv_ct_right_ntt),
};
// let (w, gadget_dim) = match i % 2 {
// 0 => (&v_w_left[r], params.t_exp_left),
// 1 | _ => (&v_w_right[r], params.t_exp_right),
// };
if i < num_in {
let (src, dest) = v.split_at_mut(num_in);
scalar_multiply(&mut dest[i], neg1, &src[i]);
}
let ct = from_ntt_alloc(&v[i]);
let ct_auto = automorph_alloc(&ct, t);
let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
let ginv_ct = gadget_invert(gadget_dim, &ct_auto_0);
let ginv_ct_ntt = ginv_ct.ntt();
let w_times_ginv_ct = w * &ginv_ct_ntt;
// let ct = from_ntt_alloc(&v[i]);
// let ct_auto = automorph_alloc(&ct, t);
// let ct_auto_0 = ct_auto.submatrix(0, 0, 1, 1);
// let ct_auto_1_ntt = ct_auto.submatrix(1, 0, 1, 1).ntt();
// let ginv_ct = gadget_invert_alloc(gadget_dim, &ct_auto_0);
// let ginv_ct_ntt = ginv_ct.ntt();
// let w_times_ginv_ct = w * &ginv_ct_ntt;
from_ntt(&mut ct, &v[i]);
automorph(&mut ct_auto, &ct, t);
gadget_invert_rdim(gi_ct, &ct_auto, 1);
to_ntt_no_reduce(gi_ct_ntt, &gi_ct);
ct_auto_1.data.as_mut_slice().copy_from_slice(ct_auto.get_poly(1, 0));
to_ntt(&mut ct_auto_1_ntt, &ct_auto_1);
multiply(&mut w_times_ginv_ct, w, &gi_ct_ntt);
let mut idx = 0;
for j in 0..2 {
@@ -71,7 +94,7 @@ mod test {
use super::*;
fn get_params() -> Params {
get_short_keygen_params()
get_expansion_testing_params()
}
#[test]

View File

@@ -52,6 +52,26 @@ pub fn get_short_keygen_params() -> Params {
)
}
pub fn get_expansion_testing_params() -> Params {
let cfg = r#"
{'n': 2,
'nu_1': 9,
'nu_2': 6,
'p': 256,
'q_prime_bits': 20,
's_e': 87.62938774292914,
't_GSW': 8,
't_conv': 4,
't_exp': 8,
't_exp_right': 56,
'instances': 1,
'db_item_size': 256 }
"#;
let cfg = cfg.replace("'", "\"");
let b = params_from_json(&cfg);
b
}
pub fn get_seed() -> [u8; 32] {
[
1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,