mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
chore(ntt): bring concrete-ntt in the repo as tfhe-ntt
This commit is contained in:
1
tfhe-ntt/.gitignore
vendored
Normal file
1
tfhe-ntt/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
benchmarks_parameters/
|
||||
36
tfhe-ntt/Cargo.toml
Normal file
36
tfhe-ntt/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "tfhe-ntt"
|
||||
version = "0.3.0"
|
||||
edition = "2021"
|
||||
description = "tfhe-ntt is a pure Rust high performance number theoretic transform library."
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/zama-ai/tfhe-rs"
|
||||
license = "BSD-3-Clause-Clear"
|
||||
homepage = "https://zama.ai/"
|
||||
keywords = ["ntt"]
|
||||
rust-version = "1.67"
|
||||
|
||||
|
||||
[dependencies]
|
||||
aligned-vec = { workspace = true }
|
||||
bytemuck = { workspace = true }
|
||||
pulp = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["pulp/std", "aligned-vec/std"]
|
||||
nightly = ["pulp/nightly"]
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.4"
|
||||
rand = "0.8"
|
||||
serde = "1.0.163"
|
||||
serde_json = "1.0.96"
|
||||
|
||||
[[bench]]
|
||||
name = "ntt"
|
||||
harness = false
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
all-features = true
|
||||
rustdoc-args = ["--html-in-header", "katex-header.html", "--cfg", "docsrs"]
|
||||
33
tfhe-ntt/LICENSE
Normal file
33
tfhe-ntt/LICENSE
Normal file
@@ -0,0 +1,33 @@
|
||||
BSD 3-Clause Clear License
|
||||
|
||||
Copyright © 2023 ZAMA.
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice, this
|
||||
list of conditions and the following disclaimer in the documentation and/or other
|
||||
materials provided with the distribution.
|
||||
|
||||
3. Neither the name of ZAMA nor the names of its contributors may be used to endorse
|
||||
or promote products derived from this software without specific prior written permission.
|
||||
|
||||
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*.
|
||||
THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||||
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
|
||||
ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
|
||||
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
|
||||
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
|
||||
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive,
|
||||
free and non-commercial license on all patents filed in its name relating to the open-source
|
||||
code (the "Patents") for the sole purpose of evaluation, development, research, prototyping
|
||||
and experimentation.
|
||||
64
tfhe-ntt/README.md
Normal file
64
tfhe-ntt/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
tfhe-ntt is a pure Rust high performance Number Theoretic Transform library that processes
|
||||
vectors of sizes that are powers of two.
|
||||
|
||||
This library provides three kinds of NTT:
|
||||
- The prime NTT computes the transform in a field $\mathbb{Z}/p \mathbb{Z}$
|
||||
with $p$ prime, allowing for arithmetic operations on the polynomial modulo $p$.
|
||||
- The native NTT internally computes the transform of the first kind with
|
||||
several primes, allowing the simulation of arithmetic modulo the product of
|
||||
those primes, and truncates the
|
||||
result when the inverse transform is desired. The truncated result is guaranteed to be as if
|
||||
the computations were performed with wrapping arithmetic, as long as the full integer result
|
||||
would have be smaller than half the product of the primes, in absolute value. It is guaranteed
|
||||
to be suitable for multiplying two polynomials with arbitrary coefficients, and returns the
|
||||
result in wrapping arithmetic.
|
||||
- The native binary NTT is similar to the native NTT, but is optimized for the case where one
|
||||
of the operands of the multiplication has coefficients in $\lbrace 0, 1 \rbrace$.
|
||||
|
||||
# Rust requirements
|
||||
tfhe-ntt requires a Rust version >= 1.67.0.
|
||||
|
||||
# Features
|
||||
|
||||
- `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
|
||||
- `nightly`: This enables unstable Rust features to further speed up the NTT, by enabling
|
||||
AVX512 instructions on CPUs that support them. This feature requires a nightly Rust
|
||||
toolchain.
|
||||
|
||||
# Example
|
||||
|
||||
```rust
|
||||
use tfhe_ntt::prime32::Plan;
|
||||
|
||||
const N: usize = 32;
|
||||
let p = 1062862849;
|
||||
let plan = Plan::try_new(N, p).unwrap();
|
||||
|
||||
let data = [
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
|
||||
25, 26, 27, 28, 29, 30, 31,
|
||||
];
|
||||
|
||||
let mut transformed_fwd = data;
|
||||
plan.fwd(&mut transformed_fwd);
|
||||
|
||||
let mut transformed_inv = transformed_fwd;
|
||||
plan.inv(&mut transformed_inv);
|
||||
|
||||
for (&actual, expected) in transformed_inv.iter().zip(data.iter().map(|x| x * N as u32)) {
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
```
|
||||
|
||||
More examples can be found in the `examples` directory.
|
||||
|
||||
- `mul_poly_prime.rs`: Negacyclic polynomial multiplication with a prime modulus.
|
||||
Run the example with `cargo run --example mul_poly_prime`.
|
||||
- `mul_poly_native.rs`: Negacyclic polynomial multiplication with a native modulus (`2^32`, `2^64`, or `2^128`).
|
||||
Run the example with `cargo run --example mul_poly_native`.
|
||||
|
||||
# Benchmarks
|
||||
|
||||
Benchmarks can be executed with `cargo bench`. If a nightly toolchain is
|
||||
available, then AVX512 acceleration can be enabled by passing the
|
||||
`--features=nightly` flag.
|
||||
3
tfhe-ntt/benches/lib.rs
Normal file
3
tfhe-ntt/benches/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
mod ntt;
|
||||
238
tfhe-ntt/benches/ntt.rs
Normal file
238
tfhe-ntt/benches/ntt.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
use serde::Serialize;
|
||||
use std::{fs, path::PathBuf};
|
||||
|
||||
use criterion::*;
|
||||
use tfhe_ntt::{prime::largest_prime_in_arithmetic_progression64, *};
|
||||
|
||||
#[derive(Serialize)]
|
||||
enum PrimeModulus {
|
||||
// 32 bits section
|
||||
FitsIn30Bits,
|
||||
FitsIn31Bits,
|
||||
FitsIn32Bits,
|
||||
Native32,
|
||||
// 64 bits section
|
||||
FitsIn50Bits,
|
||||
FitsIn51Bits,
|
||||
FitsIn52Bits,
|
||||
FitsIn62Bits,
|
||||
FitsIn63Bits,
|
||||
FitsIn64Bits,
|
||||
Native64,
|
||||
// 128 bits section
|
||||
Native128,
|
||||
}
|
||||
|
||||
impl PrimeModulus {
|
||||
fn from_u64(p: u64) -> Self {
|
||||
if p < 1 << 30 {
|
||||
Self::FitsIn30Bits
|
||||
} else if p < 1 << 31 {
|
||||
Self::FitsIn31Bits
|
||||
} else if p < 1 << 32 {
|
||||
Self::FitsIn32Bits
|
||||
} else if p < 1 << 50 {
|
||||
Self::FitsIn50Bits
|
||||
} else if p < 1 << 51 {
|
||||
Self::FitsIn51Bits
|
||||
} else if p < 1 << 52 {
|
||||
Self::FitsIn52Bits
|
||||
} else if p < 1 << 62 {
|
||||
Self::FitsIn62Bits
|
||||
} else if p < 1 << 63 {
|
||||
Self::FitsIn63Bits
|
||||
} else {
|
||||
Self::FitsIn64Bits
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct BenchmarkParametersRecord {
|
||||
display_name: String,
|
||||
polynomial_size: usize,
|
||||
prime_modulus: PrimeModulus,
|
||||
// If this field value is set to 0 means that the number is not a prime.
|
||||
prime_number: u64,
|
||||
}
|
||||
|
||||
/// Writes benchmarks parameters to disk in JSON format.
|
||||
fn write_to_json(
|
||||
bench_id: &str,
|
||||
display_name: impl Into<String>,
|
||||
polynomial_size: usize,
|
||||
prime_modulus: PrimeModulus,
|
||||
prime_number: u64,
|
||||
) {
|
||||
let record = BenchmarkParametersRecord {
|
||||
display_name: display_name.into(),
|
||||
polynomial_size,
|
||||
prime_modulus,
|
||||
prime_number,
|
||||
};
|
||||
|
||||
let mut params_directory = ["benchmarks_parameters", bench_id]
|
||||
.iter()
|
||||
.collect::<PathBuf>();
|
||||
fs::create_dir_all(¶ms_directory).unwrap();
|
||||
params_directory.push("parameters.json");
|
||||
|
||||
fs::write(params_directory, serde_json::to_string(&record).unwrap()).unwrap();
|
||||
}
|
||||
|
||||
fn criterion_bench(c: &mut Criterion) {
|
||||
let ns = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768];
|
||||
for n in ns {
|
||||
let mut data = vec![0; n];
|
||||
for p in [
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30).unwrap(),
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap(),
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 31, 1 << 32).unwrap(),
|
||||
] {
|
||||
let p_u64 = p;
|
||||
let p = p as u32;
|
||||
let plan = prime32::Plan::try_new(n, p).unwrap();
|
||||
let bench_id = format!("fwd-32-{p}-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.fwd(&mut data));
|
||||
});
|
||||
write_to_json(&bench_id, "fwd-32", n, PrimeModulus::from_u64(p_u64), p_u64);
|
||||
|
||||
let bench_id = format!("inv-32-{p}-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.inv(&mut data));
|
||||
});
|
||||
write_to_json(&bench_id, "inv-32", n, PrimeModulus::from_u64(p_u64), p_u64);
|
||||
}
|
||||
}
|
||||
|
||||
for n in ns {
|
||||
let mut data = vec![0; n];
|
||||
for p in [
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 49, 1 << 50).unwrap(),
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap(),
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62).unwrap(),
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap(),
|
||||
prime64::Solinas::P,
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 63, u64::MAX).unwrap(),
|
||||
] {
|
||||
let plan = prime64::Plan::try_new(n, p).unwrap();
|
||||
let bench_id = format!("fwd-64-{p}-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.fwd(&mut data));
|
||||
});
|
||||
write_to_json(&bench_id, "fwd-64", n, PrimeModulus::from_u64(p), p);
|
||||
|
||||
let bench_id = format!("inv-64-{p}-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.inv(&mut data));
|
||||
});
|
||||
write_to_json(&bench_id, "inv-64", n, PrimeModulus::from_u64(p), p);
|
||||
}
|
||||
}
|
||||
|
||||
for n in ns {
|
||||
let mut prod = vec![0; n];
|
||||
let lhs = vec![0; n];
|
||||
let rhs = vec![0; n];
|
||||
|
||||
let plan = native32::Plan32::try_new(n).unwrap();
|
||||
let bench_id = format!("native32-32-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "native32-32", n, PrimeModulus::Native32, 0);
|
||||
|
||||
let plan = native_binary32::Plan32::try_new(n).unwrap();
|
||||
let bench_id = format!("nativebinary32-32-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "nativebinary32-32", n, PrimeModulus::Native32, 0);
|
||||
|
||||
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
|
||||
{
|
||||
if let Some(plan) = native32::Plan52::try_new(n) {
|
||||
let bench_id = format!("native32-52-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "native32-52", n, PrimeModulus::Native32, 0);
|
||||
}
|
||||
if let Some(plan) = native_binary32::Plan52::try_new(n) {
|
||||
let bench_id = format!("nativebinary32-52-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "nativebinary32-52", n, PrimeModulus::Native32, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for n in ns {
|
||||
let mut prod = vec![0; n];
|
||||
let lhs = vec![0; n];
|
||||
let rhs = vec![0; n];
|
||||
|
||||
let plan = native64::Plan32::try_new(n).unwrap();
|
||||
let bench_id = format!("native64-32-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "native64-32", n, PrimeModulus::Native64, 0);
|
||||
|
||||
let plan = native_binary64::Plan32::try_new(n).unwrap();
|
||||
let bench_id = format!("nativebinary64-32-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "nativebinary64-32", n, PrimeModulus::Native64, 0);
|
||||
|
||||
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
|
||||
{
|
||||
if let Some(plan) = native64::Plan52::try_new(n) {
|
||||
let bench_id = format!("native64-52-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "native64-52", n, PrimeModulus::Native64, 0);
|
||||
}
|
||||
if let Some(plan) = native_binary64::Plan52::try_new(n) {
|
||||
let bench_id = format!("nativebinary64-52-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "nativebinary64-52", n, PrimeModulus::Native64, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for n in ns {
|
||||
let mut prod = vec![0; n];
|
||||
let lhs = vec![0; n];
|
||||
let rhs = vec![0; n];
|
||||
|
||||
let plan = native128::Plan32::try_new(n).unwrap();
|
||||
let bench_id = format!("native128-32-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(&bench_id, "native128-32", n, PrimeModulus::Native128, 0);
|
||||
|
||||
let plan = native_binary128::Plan32::try_new(n).unwrap();
|
||||
let bench_id = format!("nativebinary128-32-{n}");
|
||||
c.bench_function(&bench_id, |b| {
|
||||
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
|
||||
});
|
||||
write_to_json(
|
||||
&bench_id,
|
||||
"nativebinary128-32",
|
||||
n,
|
||||
PrimeModulus::Native128,
|
||||
0,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_bench);
|
||||
criterion_main!(benches);
|
||||
38
tfhe-ntt/examples/mul_poly_native.rs
Normal file
38
tfhe-ntt/examples/mul_poly_native.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use rand::random;
|
||||
use tfhe_ntt::native32::Plan32;
|
||||
|
||||
fn main() {
|
||||
// define suitable polynomial size. Power of two polynomial sizes up to `2^16` are supported.
|
||||
let polynomial_size = 1024;
|
||||
|
||||
let lhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>()).collect();
|
||||
let rhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>()).collect();
|
||||
|
||||
// method 1: schoolbook algorithm
|
||||
let add = |x: u32, y: u32| x.wrapping_add(y);
|
||||
let sub = |x: u32, y: u32| x.wrapping_sub(y);
|
||||
let mul = |x: u32, y: u32| x.wrapping_mul(y);
|
||||
|
||||
let mut full_convolution = vec![0; 2 * polynomial_size];
|
||||
for i in 0..polynomial_size {
|
||||
for j in 0..polynomial_size {
|
||||
full_convolution[i + j] = add(full_convolution[i + j], mul(lhs_poly[i], rhs_poly[j]));
|
||||
}
|
||||
}
|
||||
|
||||
let mut negacyclic_convolution = vec![0; polynomial_size];
|
||||
for i in 0..polynomial_size {
|
||||
negacyclic_convolution[i] = sub(full_convolution[i], full_convolution[polynomial_size + i]);
|
||||
}
|
||||
|
||||
// method 2: NTT
|
||||
let plan = Plan32::try_new(polynomial_size).unwrap();
|
||||
let mut product_poly = vec![0; polynomial_size];
|
||||
|
||||
// convert to NTT domain
|
||||
plan.negacyclic_polymul(&mut product_poly, &lhs_poly, &rhs_poly);
|
||||
|
||||
// check that method 1 and method 2 give the same result
|
||||
assert_eq!(product_poly, negacyclic_convolution);
|
||||
println!("Success!");
|
||||
}
|
||||
49
tfhe-ntt/examples/mul_poly_prime.rs
Normal file
49
tfhe-ntt/examples/mul_poly_prime.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use rand::random;
|
||||
use tfhe_ntt::prime32::Plan;
|
||||
|
||||
fn main() {
|
||||
// define suitable NTT prime and polynomial size
|
||||
let p: u32 = 1073479681;
|
||||
let polynomial_size = 1024;
|
||||
|
||||
// unwrapping is fine here because we know roots of unity exist for the combination
|
||||
// `(polynomial_size, p)`
|
||||
let lhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>() % p).collect();
|
||||
let rhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>() % p).collect();
|
||||
|
||||
// method 1: schoolbook algorithm
|
||||
let add = |x: u32, y: u32| ((x as u64 + y as u64) % p as u64) as u32;
|
||||
let sub = |x: u32, y: u32| add(x, p - y);
|
||||
let mul = |x: u32, y: u32| ((x as u64 * y as u64) % p as u64) as u32;
|
||||
|
||||
let mut full_convolution = vec![0; 2 * polynomial_size];
|
||||
for i in 0..polynomial_size {
|
||||
for j in 0..polynomial_size {
|
||||
full_convolution[i + j] = add(full_convolution[i + j], mul(lhs_poly[i], rhs_poly[j]));
|
||||
}
|
||||
}
|
||||
|
||||
let mut negacyclic_convolution = vec![0; polynomial_size];
|
||||
for i in 0..polynomial_size {
|
||||
negacyclic_convolution[i] = sub(full_convolution[i], full_convolution[polynomial_size + i]);
|
||||
}
|
||||
|
||||
// method 2: NTT
|
||||
let plan = Plan::try_new(polynomial_size, p).unwrap();
|
||||
let mut lhs_ntt = lhs_poly;
|
||||
let mut rhs_ntt = rhs_poly;
|
||||
|
||||
// convert to NTT domain
|
||||
plan.fwd(&mut lhs_ntt);
|
||||
plan.fwd(&mut rhs_ntt);
|
||||
|
||||
// perform elementwise multiplication and normalize (result is stored in `lhs_ntt`)
|
||||
plan.mul_assign_normalize(&mut lhs_ntt, &rhs_ntt);
|
||||
|
||||
// convert back to standard domain
|
||||
plan.inv(&mut lhs_ntt);
|
||||
|
||||
// check that method 1 and method 2 give the same result
|
||||
assert_eq!(lhs_ntt, negacyclic_convolution);
|
||||
println!("Success!");
|
||||
}
|
||||
15
tfhe-ntt/katex-header.html
Normal file
15
tfhe-ntt/katex-header.html
Normal file
@@ -0,0 +1,15 @@
|
||||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.15.3/dist/katex.min.css" integrity="sha384-KiWOvVjnN8qwAZbuQyWDIbfCLFhLXNETzBQjA/92pIowpC0d2O3nppDGQVgwd2nB" crossorigin="anonymous">
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@0.15.3/dist/katex.min.js" integrity="sha384-0fdwu/T/EQMsQlrHCCHoH10pkPLlKA1jL5dFyUOvB3lfeT2540/2g6YgSi2BL14p" crossorigin="anonymous"></script>
|
||||
<script src="https://cdn.jsdelivr.net/npm/katex@0.15.3/dist/contrib/auto-render.min.js" integrity="sha384-+XBljXPPiv+OzfbB3cVmLHf4hdUFHlWNZN5spNQ7rmHTXpd7WvJum6fIACpNNfIR" crossorigin="anonymous"></script>
|
||||
<script>
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
renderMathInElement(document.body, {
|
||||
delimiters: [
|
||||
{left: "$$", right: "$$", display: true},
|
||||
{left: "\\(", right: "\\)", display: false},
|
||||
{left: "$", right: "$", display: false},
|
||||
{left: "\\[", right: "\\]", display: true}
|
||||
]
|
||||
});
|
||||
});
|
||||
</script>
|
||||
5
tfhe-ntt/rustfmt.toml
Normal file
5
tfhe-ntt/rustfmt.toml
Normal file
@@ -0,0 +1,5 @@
|
||||
unstable_features = true
|
||||
imports_granularity="Crate"
|
||||
format_code_in_doc_comments = true
|
||||
wrap_comments = true
|
||||
comment_width = 100
|
||||
196
tfhe-ntt/src/fastdiv.rs
Normal file
196
tfhe-ntt/src/fastdiv.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use crate::u256;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mul128_u32(lowbits: u64, d: u32) -> u32 {
|
||||
((lowbits as u128 * d as u128) >> 64) as u32
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mul128_u64(lowbits: u128, d: u64) -> u64 {
|
||||
let mut bottom_half = (lowbits & 0xFFFF_FFFF_FFFF_FFFF) * d as u128;
|
||||
bottom_half >>= 64;
|
||||
let top_half = (lowbits >> 64) * d as u128;
|
||||
let both_halves = bottom_half + top_half;
|
||||
(both_halves >> 64) as u64
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mul256_u128(lowbits: u256, d: u128) -> u128 {
|
||||
lowbits.mul_u256_u128(d).1
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) const fn mul256_u64(lowbits: u256, d: u64) -> u64 {
|
||||
lowbits.mul_u256_u64(d).1
|
||||
}
|
||||
|
||||
/// Divisor representing a 32bit denominator.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Div32 {
|
||||
pub double_reciprocal: u128,
|
||||
pub single_reciprocal: u64,
|
||||
pub divisor: u32,
|
||||
}
|
||||
|
||||
/// Divisor representing a 64bit denominator.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Div64 {
|
||||
pub double_reciprocal: u256,
|
||||
pub single_reciprocal: u128,
|
||||
pub divisor: u64,
|
||||
}
|
||||
|
||||
impl Div32 {
|
||||
/// Returns the division structure holding the given divisor.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the divisor is zero or one.
|
||||
pub const fn new(divisor: u32) -> Self {
|
||||
assert!(divisor > 1);
|
||||
let single_reciprocal = (u64::MAX / divisor as u64) + 1;
|
||||
let double_reciprocal = (u128::MAX / divisor as u128) + 1;
|
||||
|
||||
Self {
|
||||
double_reciprocal,
|
||||
single_reciprocal,
|
||||
divisor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the quotient of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn div(n: u32, d: Self) -> u32 {
|
||||
mul128_u32(d.single_reciprocal, n)
|
||||
}
|
||||
|
||||
/// Returns the remainder of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn rem(n: u32, d: Self) -> u32 {
|
||||
let low_bits = d.single_reciprocal.wrapping_mul(n as u64);
|
||||
mul128_u32(low_bits, d.divisor)
|
||||
}
|
||||
|
||||
/// Returns the quotient of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn div_u64(n: u64, d: Self) -> u64 {
|
||||
mul128_u64(d.double_reciprocal, n)
|
||||
}
|
||||
|
||||
/// Returns the remainder of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn rem_u64(n: u64, d: Self) -> u32 {
|
||||
let low_bits = d.double_reciprocal.wrapping_mul(n as u128);
|
||||
mul128_u64(low_bits, d.divisor as u64) as u32
|
||||
}
|
||||
|
||||
/// Returns the internal divisor as an integer.
|
||||
#[inline(always)]
|
||||
pub const fn divisor(&self) -> u32 {
|
||||
self.divisor
|
||||
}
|
||||
}
|
||||
|
||||
impl Div64 {
|
||||
/// Returns the division structure holding the given divisor.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the divisor is zero or one.
|
||||
pub const fn new(divisor: u64) -> Self {
|
||||
assert!(divisor > 1);
|
||||
let single_reciprocal = ((u128::MAX) / divisor as u128) + 1;
|
||||
let double_reciprocal = u256::MAX
|
||||
.div_rem_u256_u64(divisor)
|
||||
.0
|
||||
.overflowing_add(u256 {
|
||||
x0: 1,
|
||||
x1: 0,
|
||||
x2: 0,
|
||||
x3: 0,
|
||||
})
|
||||
.0;
|
||||
|
||||
Self {
|
||||
double_reciprocal,
|
||||
single_reciprocal,
|
||||
divisor,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the quotient of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn div(n: u64, d: Self) -> u64 {
|
||||
mul128_u64(d.single_reciprocal, n)
|
||||
}
|
||||
|
||||
/// Returns the remainder of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn rem(n: u64, d: Self) -> u64 {
|
||||
let low_bits = d.single_reciprocal.wrapping_mul(n as u128);
|
||||
mul128_u64(low_bits, d.divisor)
|
||||
}
|
||||
|
||||
/// Returns the quotient of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn div_u128(n: u128, d: Self) -> u128 {
|
||||
mul256_u128(d.double_reciprocal, n)
|
||||
}
|
||||
|
||||
/// Returns the remainder of the division of `n` by `d`.
|
||||
#[inline(always)]
|
||||
pub const fn rem_u128(n: u128, d: Self) -> u64 {
|
||||
let low_bits = d.double_reciprocal.wrapping_mul_u256_u128(n);
|
||||
mul256_u64(low_bits, d.divisor)
|
||||
}
|
||||
|
||||
/// Returns the internal divisor as an integer.
|
||||
#[inline(always)]
|
||||
pub const fn divisor(&self) -> u64 {
|
||||
self.divisor
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use rand::random;
|
||||
|
||||
#[test]
|
||||
fn test_div64() {
|
||||
for _ in 0..1000 {
|
||||
let divisor = loop {
|
||||
let d = random();
|
||||
if d > 1 {
|
||||
break d;
|
||||
}
|
||||
};
|
||||
|
||||
let div = Div64::new(divisor);
|
||||
let n = random();
|
||||
let m = random();
|
||||
assert_eq!(Div64::div(m, div), m / divisor);
|
||||
assert_eq!(Div64::rem(m, div), m % divisor);
|
||||
assert_eq!(Div64::div_u128(n, div), n / divisor as u128);
|
||||
assert_eq!(Div64::rem_u128(n, div) as u128, n % divisor as u128);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div32() {
|
||||
for _ in 0..1000 {
|
||||
let divisor = loop {
|
||||
let d = random();
|
||||
if d > 1 {
|
||||
break d;
|
||||
}
|
||||
};
|
||||
|
||||
let div = Div32::new(divisor);
|
||||
let n = random();
|
||||
let m = random();
|
||||
assert_eq!(Div32::div(m, div), m / divisor);
|
||||
assert_eq!(Div32::rem(m, div), m % divisor);
|
||||
assert_eq!(Div32::div_u64(n, div), n / divisor as u64);
|
||||
assert_eq!(Div32::rem_u64(n, div) as u64, n % divisor as u64);
|
||||
}
|
||||
}
|
||||
}
|
||||
908
tfhe-ntt/src/lib.rs
Normal file
908
tfhe-ntt/src/lib.rs
Normal file
@@ -0,0 +1,908 @@
|
||||
//! tfhe-ntt is a pure Rust high performance number theoretic transform library that processes
|
||||
//! vectors of sizes that are powers of two.
|
||||
//!
|
||||
//! This library provides three kinds of NTT:
|
||||
//! - The prime NTT computes the transform in a field $\mathbb{Z}/p\mathbb{Z}$ with $p$ prime,
|
||||
//! allowing for arithmetic operations on the polynomial modulo $p$.
|
||||
//! - The native NTT internally computes the transform of the first kind with several primes,
|
||||
//! allowing the simulation of arithmetic modulo the product of those primes, and truncates the
|
||||
//! result when the inverse transform is desired. The truncated result is guaranteed to be as if
|
||||
//! the computations were performed with wrapping arithmetic, as long as the full integer result
|
||||
//! would have been smaller than half the product of the primes, in absolute value. It is
|
||||
//! guaranteed to be suitable for multiplying two polynomials with arbitrary coefficients, and
|
||||
//! returns the result in wrapping arithmetic.
|
||||
//! - The native binary NTT is similar to the native NTT, but is optimized for the case where one of
|
||||
//! the operands of the multiplication has coefficients in $\lbrace 0, 1 \rbrace$.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
|
||||
//! - `nightly`: This enables unstable Rust features to further speed up the NTT, by enabling AVX512
|
||||
//! instructions on CPUs that support them. This feature requires a nightly Rust toolchain.
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```
|
||||
//! use tfhe_ntt::prime32::Plan;
|
||||
//!
|
||||
//! const N: usize = 32;
|
||||
//! let p = 1062862849;
|
||||
//! let plan = Plan::try_new(N, p).unwrap();
|
||||
//!
|
||||
//! let data = [
|
||||
//! 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
|
||||
//! 25, 26, 27, 28, 29, 30, 31,
|
||||
//! ];
|
||||
//!
|
||||
//! let mut transformed_fwd = data;
|
||||
//! plan.fwd(&mut transformed_fwd);
|
||||
//!
|
||||
//! let mut transformed_inv = transformed_fwd;
|
||||
//! plan.inv(&mut transformed_inv);
|
||||
//!
|
||||
//! for (&actual, expected) in transformed_inv
|
||||
//! .iter()
|
||||
//! .zip(data.iter().map(|x| x * N as u32))
|
||||
//! {
|
||||
//! assert_eq!(expected, actual);
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
#![cfg_attr(
|
||||
all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")),
|
||||
feature(avx512_target_feature, stdarch_x86_avx512)
|
||||
)]
|
||||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![allow(clippy::too_many_arguments, clippy::let_unit_value)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
/// Implementation notes:
|
||||
///
|
||||
/// we use `NullaryFnOnce` instead of a closure because we need the `#[inline(always)]`
|
||||
/// annotation, which doesn't always work with closures for some reason.
|
||||
///
|
||||
/// Shoup modular multiplication
|
||||
/// <https://pdfs.semanticscholar.org/e000/fa109f1b2a6a3e52e04462bac4b7d58140c9.pdf>
|
||||
///
|
||||
/// Lemire modular reduction
|
||||
/// <https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/>
|
||||
///
|
||||
/// Barrett reduction
|
||||
/// <https://arxiv.org/pdf/2103.16400.pdf> Algorithm 8
|
||||
///
|
||||
/// Chinese remainder theorem solution:
|
||||
/// The art of computer programming (Donald E. Knuth), section 4.3.2
|
||||
#[allow(dead_code)]
|
||||
fn implementation_notes() {}
|
||||
|
||||
use u256_impl::u256;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
#[doc(hidden)]
|
||||
pub mod prime;
|
||||
mod roots;
|
||||
mod u256_impl;
|
||||
|
||||
/// Fast division by a constant divisor.
|
||||
pub mod fastdiv;
|
||||
/// 32bit negacyclic NTT for a prime modulus.
|
||||
pub mod prime32;
|
||||
/// 64bit negacyclic NTT for a prime modulus.
|
||||
pub mod prime64;
|
||||
|
||||
/// Negacyclic NTT for multiplying two polynomials with values less than `2^128`.
|
||||
pub mod native128;
|
||||
/// Negacyclic NTT for multiplying two polynomials with values less than `2^32`.
|
||||
pub mod native32;
|
||||
/// Negacyclic NTT for multiplying two polynomials with values less than `2^64`.
|
||||
pub mod native64;
|
||||
|
||||
/// Negacyclic NTT for multiplying a polynomial with values less than `2^128` with a binary
|
||||
/// polynomial.
|
||||
pub mod native_binary128;
|
||||
/// Negacyclic NTT for multiplying a polynomial with values less than `2^32` with a binary
|
||||
/// polynomial.
|
||||
pub mod native_binary32;
|
||||
/// Negacyclic NTT for multiplying a polynomial with values less than `2^64` with a binary
|
||||
/// polynomial.
|
||||
pub mod native_binary64;
|
||||
|
||||
pub mod product;
|
||||
|
||||
// Fn arguments are (simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
trait Butterfly<S: Copy, V: Copy>: Copy + Fn(S, V, V, V, V, V, V, V) -> (V, V) {}
|
||||
impl<F: Copy + Fn(S, V, V, V, V, V, V, V) -> (V, V), S: Copy, V: Copy> Butterfly<S, V> for F {}
|
||||
|
||||
#[inline]
|
||||
fn bit_rev(nbits: u32, i: usize) -> usize {
|
||||
i.reverse_bits() >> (usize::BITS - nbits)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(transparent)]
|
||||
struct V3(pulp::x86::V3);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[repr(transparent)]
|
||||
struct V4(pulp::x86::V4);
|
||||
|
||||
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
|
||||
pulp::simd_type! {
|
||||
struct V4IFma {
|
||||
pub sse: "sse",
|
||||
pub sse2: "sse2",
|
||||
pub fxsr: "fxsr",
|
||||
pub sse3: "sse3",
|
||||
pub ssse3: "ssse3",
|
||||
pub sse4_1: "sse4.1",
|
||||
pub sse4_2: "sse4.2",
|
||||
pub popcnt: "popcnt",
|
||||
pub avx: "avx",
|
||||
pub avx2: "avx2",
|
||||
pub bmi1: "bmi1",
|
||||
pub bmi2: "bmi2",
|
||||
pub fma: "fma",
|
||||
pub lzcnt: "lzcnt",
|
||||
pub avx512f: "avx512f",
|
||||
pub avx512bw: "avx512bw",
|
||||
pub avx512cd: "avx512cd",
|
||||
pub avx512dq: "avx512dq",
|
||||
pub avx512vl: "avx512vl",
|
||||
pub avx512ifma: "avx512ifma",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl V4 {
|
||||
#[inline]
|
||||
pub fn try_new() -> Option<Self> {
|
||||
pulp::x86::V4::try_new().map(Self)
|
||||
}
|
||||
|
||||
/// Returns separately two vectors containing the low 64 bits of the result,
|
||||
/// and the high 64 bits of the result.
|
||||
#[inline(always)]
|
||||
pub fn widening_mul_u64x8(self, a: u64x8, b: u64x8) -> (u64x8, u64x8) {
|
||||
// https://stackoverflow.com/a/28827013
|
||||
let avx = self.avx512f;
|
||||
let x = cast(a);
|
||||
let y = cast(b);
|
||||
|
||||
let lo_mask = avx._mm512_set1_epi64(0x0000_0000_FFFF_FFFFu64 as _);
|
||||
let x_hi = avx._mm512_shuffle_epi32::<0b1011_0001>(x);
|
||||
let y_hi = avx._mm512_shuffle_epi32::<0b1011_0001>(y);
|
||||
|
||||
let z_lo_lo = avx._mm512_mul_epu32(x, y);
|
||||
let z_lo_hi = avx._mm512_mul_epu32(x, y_hi);
|
||||
let z_hi_lo = avx._mm512_mul_epu32(x_hi, y);
|
||||
let z_hi_hi = avx._mm512_mul_epu32(x_hi, y_hi);
|
||||
|
||||
let z_lo_lo_shift = avx._mm512_srli_epi64::<32>(z_lo_lo);
|
||||
|
||||
let sum_tmp = avx._mm512_add_epi64(z_lo_hi, z_lo_lo_shift);
|
||||
let sum_lo = avx._mm512_and_si512(sum_tmp, lo_mask);
|
||||
let sum_mid = avx._mm512_srli_epi64::<32>(sum_tmp);
|
||||
|
||||
let sum_mid2 = avx._mm512_add_epi64(z_hi_lo, sum_lo);
|
||||
let sum_mid2_hi = avx._mm512_srli_epi64::<32>(sum_mid2);
|
||||
let sum_hi = avx._mm512_add_epi64(z_hi_hi, sum_mid);
|
||||
|
||||
let prod_hi = avx._mm512_add_epi64(sum_hi, sum_mid2_hi);
|
||||
let prod_lo = avx._mm512_add_epi64(
|
||||
avx._mm512_slli_epi64::<32>(avx._mm512_add_epi64(z_lo_hi, z_hi_lo)),
|
||||
z_lo_lo,
|
||||
);
|
||||
|
||||
(cast(prod_lo), cast(prod_hi))
|
||||
}
|
||||
|
||||
/// Multiplies the low 32 bits of each 64 bit integer and returns the 64 bit result.
|
||||
#[inline(always)]
|
||||
pub fn mul_low_32_bits_u64x8(self, a: u64x8, b: u64x8) -> u64x8 {
|
||||
pulp::cast(self.avx512f._mm512_mul_epu32(pulp::cast(a), pulp::cast(b)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl V4IFma {
|
||||
/// Returns separately two vectors containing the low 52 bits of the result,
|
||||
/// and the high 52 bits of the result.
|
||||
#[inline(always)]
|
||||
pub fn widening_mul_u52x8(self, a: u64x8, b: u64x8) -> (u64x8, u64x8) {
|
||||
let a = cast(a);
|
||||
let b = cast(b);
|
||||
let zero = cast(self.splat_u64x8(0));
|
||||
(
|
||||
cast(self.avx512ifma._mm512_madd52lo_epu64(zero, a, b)),
|
||||
cast(self.avx512ifma._mm512_madd52hi_epu64(zero, a, b)),
|
||||
)
|
||||
}
|
||||
|
||||
/// (a * b + c) mod 2^52 for each 52 bit integer in a, b, and c.
|
||||
#[inline(always)]
|
||||
pub fn wrapping_mul_add_u52x8(self, a: u64x8, b: u64x8, c: u64x8) -> u64x8 {
|
||||
self.and_u64x8(
|
||||
cast(
|
||||
self.avx512ifma
|
||||
._mm512_madd52lo_epu64(cast(c), cast(a), cast(b)),
|
||||
),
|
||||
self.splat_u64x8((1u64 << 52) - 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
trait SupersetOfV4: Copy {
|
||||
fn get_v4(self) -> V4;
|
||||
fn vectorize(self, f: impl pulp::NullaryFnOnce);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl SupersetOfV4 for V4 {
|
||||
#[inline(always)]
|
||||
fn get_v4(self) -> V4 {
|
||||
self
|
||||
}
|
||||
#[inline(always)]
|
||||
fn vectorize(self, f: impl pulp::NullaryFnOnce) {
|
||||
self.0.vectorize(f);
|
||||
}
|
||||
}
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl SupersetOfV4 for V4IFma {
|
||||
#[inline(always)]
|
||||
fn get_v4(self) -> V4 {
|
||||
*self
|
||||
}
|
||||
#[inline(always)]
|
||||
fn vectorize(self, f: impl pulp::NullaryFnOnce) {
|
||||
self.vectorize(f);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
impl V3 {
|
||||
#[inline]
|
||||
pub fn try_new() -> Option<Self> {
|
||||
pulp::x86::V3::try_new().map(Self)
|
||||
}
|
||||
|
||||
/// Returns separately two vectors containing the low 64 bits of the result,
|
||||
/// and the high 64 bits of the result.
|
||||
#[inline(always)]
|
||||
pub fn widening_mul_u64x4(self, a: u64x4, b: u64x4) -> (u64x4, u64x4) {
|
||||
// https://stackoverflow.com/a/28827013
|
||||
let avx = self.avx;
|
||||
let avx2 = self.avx2;
|
||||
let x = cast(a);
|
||||
let y = cast(b);
|
||||
let lo_mask = avx._mm256_set1_epi64x(0x0000_0000_FFFF_FFFFu64 as _);
|
||||
let x_hi = avx2._mm256_shuffle_epi32::<0b10110001>(x);
|
||||
let y_hi = avx2._mm256_shuffle_epi32::<0b10110001>(y);
|
||||
|
||||
let z_lo_lo = avx2._mm256_mul_epu32(x, y);
|
||||
let z_lo_hi = avx2._mm256_mul_epu32(x, y_hi);
|
||||
let z_hi_lo = avx2._mm256_mul_epu32(x_hi, y);
|
||||
let z_hi_hi = avx2._mm256_mul_epu32(x_hi, y_hi);
|
||||
|
||||
let z_lo_lo_shift = avx2._mm256_srli_epi64::<32>(z_lo_lo);
|
||||
|
||||
let sum_tmp = avx2._mm256_add_epi64(z_lo_hi, z_lo_lo_shift);
|
||||
let sum_lo = avx2._mm256_and_si256(sum_tmp, lo_mask);
|
||||
let sum_mid = avx2._mm256_srli_epi64::<32>(sum_tmp);
|
||||
|
||||
let sum_mid2 = avx2._mm256_add_epi64(z_hi_lo, sum_lo);
|
||||
let sum_mid2_hi = avx2._mm256_srli_epi64::<32>(sum_mid2);
|
||||
let sum_hi = avx2._mm256_add_epi64(z_hi_hi, sum_mid);
|
||||
|
||||
let prod_hi = avx2._mm256_add_epi64(sum_hi, sum_mid2_hi);
|
||||
let prod_lo = avx2._mm256_add_epi64(
|
||||
avx2._mm256_slli_epi64::<32>(avx2._mm256_add_epi64(z_lo_hi, z_hi_lo)),
|
||||
z_lo_lo,
|
||||
);
|
||||
|
||||
(cast(prod_lo), cast(prod_hi))
|
||||
}
|
||||
|
||||
/// Multiplies the low 32 bits of each 64 bit integer and returns the 64 bit result.
|
||||
#[inline(always)]
|
||||
pub fn mul_low_32_bits_u64x4(self, a: u64x4, b: u64x4) -> u64x4 {
|
||||
pulp::cast(self.avx2._mm256_mul_epu32(pulp::cast(a), pulp::cast(b)))
|
||||
}
|
||||
|
||||
// (a * b mod 2^32) mod 2^64 for each element in a and b.
|
||||
#[inline(always)]
|
||||
pub fn wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(self, a: u64x4, b: u64x4) -> u64x4 {
|
||||
let a = cast(a);
|
||||
let b = cast(b);
|
||||
let avx2 = self.avx2;
|
||||
let x_hi = avx2._mm256_shuffle_epi32::<0b10110001>(a);
|
||||
let z_lo_lo = avx2._mm256_mul_epu32(a, b);
|
||||
let z_hi_lo = avx2._mm256_mul_epu32(x_hi, b);
|
||||
cast(avx2._mm256_add_epi64(avx2._mm256_slli_epi64::<32>(z_hi_lo), z_lo_lo))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl core::ops::Deref for V4 {
|
||||
type Target = pulp::x86::V4;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl core::ops::Deref for V4IFma {
|
||||
type Target = V4;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
let Self {
|
||||
sse,
|
||||
sse2,
|
||||
fxsr,
|
||||
sse3,
|
||||
ssse3,
|
||||
sse4_1,
|
||||
sse4_2,
|
||||
popcnt,
|
||||
avx,
|
||||
avx2,
|
||||
bmi1,
|
||||
bmi2,
|
||||
fma,
|
||||
lzcnt,
|
||||
avx512f,
|
||||
avx512bw,
|
||||
avx512cd,
|
||||
avx512dq,
|
||||
avx512vl,
|
||||
avx512ifma: _,
|
||||
} = *self;
|
||||
let simd_ref = (pulp::x86::V4 {
|
||||
sse,
|
||||
sse2,
|
||||
fxsr,
|
||||
sse3,
|
||||
ssse3,
|
||||
sse4_1,
|
||||
sse4_2,
|
||||
popcnt,
|
||||
avx,
|
||||
avx2,
|
||||
bmi1,
|
||||
bmi2,
|
||||
fma,
|
||||
lzcnt,
|
||||
avx512f,
|
||||
avx512bw,
|
||||
avx512cd,
|
||||
avx512dq,
|
||||
avx512vl,
|
||||
})
|
||||
.to_ref();
|
||||
|
||||
// SAFETY
|
||||
// `pulp::x86::V4` and `crate::V4` have the same layout, since the latter is
|
||||
// #[repr(transparent)].
|
||||
unsafe { &*(simd_ref as *const pulp::x86::V4 as *const V4) }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
impl core::ops::Deref for V3 {
|
||||
type Target = pulp::x86::V3;
|
||||
|
||||
#[inline]
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
// the magic constants are such that
|
||||
// for all x < 2^64
|
||||
// x / P_i == ((x * P_i_MAGIC) >> 64) >> P_i_MAGIC_SHIFT
|
||||
//
|
||||
// this can be used to implement the modulo operation in constant time to avoid side channel
|
||||
// attacks, can also speed up the operation x % P_i, since the compiler doesn't manage to vectorize
|
||||
// it on its own.
|
||||
//
|
||||
// how to:
|
||||
// run `cargo test generate_primes -- --nocapture`
|
||||
//
|
||||
// copy paste the generated primes in this function
|
||||
// ```
|
||||
// pub fn codegen(x: u64) -> u64 {
|
||||
// x / $PRIME
|
||||
// }
|
||||
// ```
|
||||
//
|
||||
// look at the generated assembly for codegen
|
||||
// extract primes that satisfy the desired property
|
||||
//
|
||||
// asm should look like this on x86_64
|
||||
// ```
|
||||
// mov rax, rdi
|
||||
// movabs rcx, P_MAGIC (as i64 signed value)
|
||||
// mul rcx
|
||||
// mov rax, rdx
|
||||
// shr rax, P_MAGIC_SHIFT
|
||||
// ret
|
||||
// ```
|
||||
#[allow(dead_code)]
|
||||
pub(crate) mod primes32 {
|
||||
use crate::{
|
||||
fastdiv::{Div32, Div64},
|
||||
prime::exp_mod32,
|
||||
};
|
||||
|
||||
pub const P0: u32 = 0b0011_1111_0101_1010_0000_0000_0000_0001;
|
||||
pub const P1: u32 = 0b0011_1111_0101_1101_0000_0000_0000_0001;
|
||||
pub const P2: u32 = 0b0011_1111_0111_0110_0000_0000_0000_0001;
|
||||
pub const P3: u32 = 0b0011_1111_1000_0010_0000_0000_0000_0001;
|
||||
pub const P4: u32 = 0b0011_1111_1010_1100_0000_0000_0000_0001;
|
||||
pub const P5: u32 = 0b0011_1111_1010_1111_0000_0000_0000_0001;
|
||||
pub const P6: u32 = 0b0011_1111_1011_0001_0000_0000_0000_0001;
|
||||
pub const P7: u32 = 0b0011_1111_1011_1011_0000_0000_0000_0001;
|
||||
pub const P8: u32 = 0b0011_1111_1101_1110_0000_0000_0000_0001;
|
||||
pub const P9: u32 = 0b0011_1111_1111_1100_0000_0000_0000_0001;
|
||||
|
||||
pub const P0_MAGIC: u64 = 9317778228489988551;
|
||||
pub const P1_MAGIC: u64 = 4658027473943558643;
|
||||
pub const P2_MAGIC: u64 = 1162714878353869247;
|
||||
pub const P3_MAGIC: u64 = 4647426722536610861;
|
||||
pub const P4_MAGIC: u64 = 9270903515973367219;
|
||||
pub const P5_MAGIC: u64 = 2317299382174935855;
|
||||
pub const P6_MAGIC: u64 = 9268060552616330319;
|
||||
pub const P7_MAGIC: u64 = 2315594963384859737;
|
||||
pub const P8_MAGIC: u64 = 9242552129100825291;
|
||||
pub const P9_MAGIC: u64 = 576601523622774689;
|
||||
|
||||
pub const P0_MAGIC_SHIFT: u32 = 29;
|
||||
pub const P1_MAGIC_SHIFT: u32 = 28;
|
||||
pub const P2_MAGIC_SHIFT: u32 = 26;
|
||||
pub const P3_MAGIC_SHIFT: u32 = 28;
|
||||
pub const P4_MAGIC_SHIFT: u32 = 29;
|
||||
pub const P5_MAGIC_SHIFT: u32 = 27;
|
||||
pub const P6_MAGIC_SHIFT: u32 = 29;
|
||||
pub const P7_MAGIC_SHIFT: u32 = 27;
|
||||
pub const P8_MAGIC_SHIFT: u32 = 29;
|
||||
pub const P9_MAGIC_SHIFT: u32 = 25;
|
||||
|
||||
const fn mul_mod(modulus: u32, a: u32, b: u32) -> u32 {
|
||||
let wide = a as u64 * b as u64;
|
||||
(wide % modulus as u64) as u32
|
||||
}
|
||||
|
||||
const fn inv_mod(modulus: u32, x: u32) -> u32 {
|
||||
exp_mod32(Div32::new(modulus), x, modulus - 2)
|
||||
}
|
||||
|
||||
const fn shoup(modulus: u32, w: u32) -> u32 {
|
||||
(((w as u64) << 32) / modulus as u64) as u32
|
||||
}
|
||||
|
||||
const fn mul_mod64(modulus: u64, a: u64, b: u64) -> u64 {
|
||||
let wide = a as u128 * b as u128;
|
||||
(wide % modulus as u128) as u64
|
||||
}
|
||||
|
||||
const fn exp_mod64(modulus: u64, base: u64, pow: u64) -> u64 {
|
||||
crate::prime::exp_mod64(Div64::new(modulus), base, pow)
|
||||
}
|
||||
|
||||
const fn shoup64(modulus: u64, w: u64) -> u64 {
|
||||
(((w as u128) << 64) / modulus as u128) as u64
|
||||
}
|
||||
|
||||
pub const P0_INV_MOD_P1: u32 = inv_mod(P1, P0);
|
||||
pub const P0_INV_MOD_P1_SHOUP: u32 = shoup(P1, P0_INV_MOD_P1);
|
||||
pub const P01_INV_MOD_P2: u32 = inv_mod(P2, mul_mod(P2, P0, P1));
|
||||
pub const P01_INV_MOD_P2_SHOUP: u32 = shoup(P2, P01_INV_MOD_P2);
|
||||
pub const P012_INV_MOD_P3: u32 = inv_mod(P3, mul_mod(P3, mul_mod(P3, P0, P1), P2));
|
||||
pub const P012_INV_MOD_P3_SHOUP: u32 = shoup(P3, P012_INV_MOD_P3);
|
||||
pub const P0123_INV_MOD_P4: u32 =
|
||||
inv_mod(P4, mul_mod(P4, mul_mod(P4, mul_mod(P4, P0, P1), P2), P3));
|
||||
pub const P0123_INV_MOD_P4_SHOUP: u32 = shoup(P4, P0123_INV_MOD_P4);
|
||||
|
||||
pub const P0_MOD_P2_SHOUP: u32 = shoup(P2, P0);
|
||||
pub const P0_MOD_P3_SHOUP: u32 = shoup(P3, P0);
|
||||
pub const P1_MOD_P3_SHOUP: u32 = shoup(P3, P1);
|
||||
pub const P0_MOD_P4_SHOUP: u32 = shoup(P4, P0);
|
||||
pub const P1_MOD_P4_SHOUP: u32 = shoup(P4, P1);
|
||||
pub const P2_MOD_P4_SHOUP: u32 = shoup(P4, P2);
|
||||
|
||||
pub const P1_INV_MOD_P2: u32 = inv_mod(P2, P1);
|
||||
pub const P1_INV_MOD_P2_SHOUP: u32 = shoup(P2, P1_INV_MOD_P2);
|
||||
pub const P3_INV_MOD_P4: u32 = inv_mod(P4, P3);
|
||||
pub const P3_INV_MOD_P4_SHOUP: u32 = shoup(P4, P3_INV_MOD_P4);
|
||||
pub const P12: u64 = P1 as u64 * P2 as u64;
|
||||
pub const P34: u64 = P3 as u64 * P4 as u64;
|
||||
pub const P0_INV_MOD_P12: u64 =
|
||||
exp_mod64(P12, P0 as u64, (P1 as u64 - 1) * (P2 as u64 - 1) - 1);
|
||||
pub const P0_INV_MOD_P12_SHOUP: u64 = shoup64(P12, P0_INV_MOD_P12);
|
||||
pub const P0_MOD_P34_SHOUP: u64 = shoup64(P34, P0 as u64);
|
||||
pub const P012_INV_MOD_P34: u64 = exp_mod64(
|
||||
P34,
|
||||
mul_mod64(P34, P0 as u64, P12),
|
||||
(P3 as u64 - 1) * (P4 as u64 - 1) - 1,
|
||||
);
|
||||
pub const P012_INV_MOD_P34_SHOUP: u64 = shoup64(P34, P012_INV_MOD_P34);
|
||||
|
||||
pub const P2_INV_MOD_P3: u32 = inv_mod(P3, P2);
|
||||
pub const P2_INV_MOD_P3_SHOUP: u32 = shoup(P3, P2_INV_MOD_P3);
|
||||
pub const P4_INV_MOD_P5: u32 = inv_mod(P5, P4);
|
||||
pub const P4_INV_MOD_P5_SHOUP: u32 = shoup(P5, P4_INV_MOD_P5);
|
||||
pub const P6_INV_MOD_P7: u32 = inv_mod(P7, P6);
|
||||
pub const P6_INV_MOD_P7_SHOUP: u32 = shoup(P7, P6_INV_MOD_P7);
|
||||
pub const P8_INV_MOD_P9: u32 = inv_mod(P9, P8);
|
||||
pub const P8_INV_MOD_P9_SHOUP: u32 = shoup(P9, P8_INV_MOD_P9);
|
||||
|
||||
pub const P01: u64 = P0 as u64 * P1 as u64;
|
||||
pub const P23: u64 = P2 as u64 * P3 as u64;
|
||||
pub const P45: u64 = P4 as u64 * P5 as u64;
|
||||
pub const P67: u64 = P6 as u64 * P7 as u64;
|
||||
pub const P89: u64 = P8 as u64 * P9 as u64;
|
||||
|
||||
pub const P01_MOD_P45_SHOUP: u64 = shoup64(P45, P01);
|
||||
pub const P01_MOD_P67_SHOUP: u64 = shoup64(P67, P01);
|
||||
pub const P01_MOD_P89_SHOUP: u64 = shoup64(P89, P01);
|
||||
|
||||
pub const P23_MOD_P67_SHOUP: u64 = shoup64(P67, P23);
|
||||
pub const P23_MOD_P89_SHOUP: u64 = shoup64(P89, P23);
|
||||
|
||||
pub const P45_MOD_P89_SHOUP: u64 = shoup64(P89, P45);
|
||||
|
||||
pub const P01_INV_MOD_P23: u64 = exp_mod64(P23, P01, (P2 as u64 - 1) * (P3 as u64 - 1) - 1);
|
||||
pub const P01_INV_MOD_P23_SHOUP: u64 = shoup64(P23, P01_INV_MOD_P23);
|
||||
pub const P0123_INV_MOD_P45: u64 = exp_mod64(
|
||||
P45,
|
||||
mul_mod64(P45, P01, P23),
|
||||
(P4 as u64 - 1) * (P5 as u64 - 1) - 1,
|
||||
);
|
||||
pub const P0123_INV_MOD_P45_SHOUP: u64 = shoup64(P45, P0123_INV_MOD_P45);
|
||||
pub const P012345_INV_MOD_P67: u64 = exp_mod64(
|
||||
P67,
|
||||
mul_mod64(P67, mul_mod64(P67, P01, P23), P45),
|
||||
(P6 as u64 - 1) * (P7 as u64 - 1) - 1,
|
||||
);
|
||||
pub const P012345_INV_MOD_P67_SHOUP: u64 = shoup64(P67, P012345_INV_MOD_P67);
|
||||
pub const P01234567_INV_MOD_P89: u64 = exp_mod64(
|
||||
P89,
|
||||
mul_mod64(P89, mul_mod64(P89, mul_mod64(P89, P01, P23), P45), P67),
|
||||
(P8 as u64 - 1) * (P9 as u64 - 1) - 1,
|
||||
);
|
||||
pub const P01234567_INV_MOD_P89_SHOUP: u64 = shoup64(P89, P01234567_INV_MOD_P89);
|
||||
|
||||
pub const P0123: u128 = u128::wrapping_mul(P01 as u128, P23 as u128);
|
||||
pub const P012345: u128 = u128::wrapping_mul(P0123, P45 as u128);
|
||||
pub const P01234567: u128 = u128::wrapping_mul(P012345, P67 as u128);
|
||||
pub const P0123456789: u128 = u128::wrapping_mul(P01234567, P89 as u128);
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) mod primes52 {
|
||||
use crate::fastdiv::Div64;
|
||||
|
||||
pub const P0: u64 = 0b0011_1111_1111_1111_1111_1111_1110_0111_0111_0000_0000_0000_0001;
|
||||
pub const P1: u64 = 0b0011_1111_1111_1111_1111_1111_1110_1011_1001_0000_0000_0000_0001;
|
||||
pub const P2: u64 = 0b0011_1111_1111_1111_1111_1111_1110_1100_1000_0000_0000_0000_0001;
|
||||
pub const P3: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1000_1011_0000_0000_0000_0001;
|
||||
pub const P4: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1011_1000_0000_0000_0000_0001;
|
||||
pub const P5: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1100_0111_0000_0000_0000_0001;
|
||||
|
||||
pub const P0_MAGIC: u64 = 9223372247845040859;
|
||||
pub const P1_MAGIC: u64 = 4611686106205779591;
|
||||
pub const P2_MAGIC: u64 = 4611686102179247601;
|
||||
pub const P3_MAGIC: u64 = 2305843024917166187;
|
||||
pub const P4_MAGIC: u64 = 4611686037754736721;
|
||||
pub const P5_MAGIC: u64 = 4611686033728204851;
|
||||
|
||||
pub const P0_MAGIC_SHIFT: u32 = 49;
|
||||
pub const P1_MAGIC_SHIFT: u32 = 48;
|
||||
pub const P2_MAGIC_SHIFT: u32 = 48;
|
||||
pub const P3_MAGIC_SHIFT: u32 = 47;
|
||||
pub const P4_MAGIC_SHIFT: u32 = 48;
|
||||
pub const P5_MAGIC_SHIFT: u32 = 48;
|
||||
|
||||
const fn mul_mod(modulus: u64, a: u64, b: u64) -> u64 {
|
||||
let wide = a as u128 * b as u128;
|
||||
(wide % modulus as u128) as u64
|
||||
}
|
||||
|
||||
const fn inv_mod(modulus: u64, x: u64) -> u64 {
|
||||
crate::prime::exp_mod64(Div64::new(modulus), x, modulus - 2)
|
||||
}
|
||||
|
||||
const fn shoup(modulus: u64, w: u64) -> u64 {
|
||||
(((w as u128) << 52) / modulus as u128) as u64
|
||||
}
|
||||
|
||||
pub const P0_INV_MOD_P1: u64 = inv_mod(P1, P0);
|
||||
pub const P0_INV_MOD_P1_SHOUP: u64 = shoup(P1, P0_INV_MOD_P1);
|
||||
|
||||
pub const P01_INV_MOD_P2: u64 = inv_mod(P2, mul_mod(P2, P0, P1));
|
||||
pub const P01_INV_MOD_P2_SHOUP: u64 = shoup(P2, P01_INV_MOD_P2);
|
||||
pub const P012_INV_MOD_P3: u64 = inv_mod(P3, mul_mod(P3, mul_mod(P3, P0, P1), P2));
|
||||
pub const P012_INV_MOD_P3_SHOUP: u64 = shoup(P3, P012_INV_MOD_P3);
|
||||
pub const P0123_INV_MOD_P4: u64 =
|
||||
inv_mod(P4, mul_mod(P4, mul_mod(P4, mul_mod(P4, P0, P1), P2), P3));
|
||||
pub const P0123_INV_MOD_P4_SHOUP: u64 = shoup(P4, P0123_INV_MOD_P4);
|
||||
|
||||
pub const P0_MOD_P2_SHOUP: u64 = shoup(P2, P0);
|
||||
pub const P0_MOD_P3_SHOUP: u64 = shoup(P3, P0);
|
||||
pub const P1_MOD_P3_SHOUP: u64 = shoup(P3, P1);
|
||||
pub const P0_MOD_P4_SHOUP: u64 = shoup(P4, P0);
|
||||
pub const P1_MOD_P4_SHOUP: u64 = shoup(P4, P1);
|
||||
pub const P2_MOD_P4_SHOUP: u64 = shoup(P4, P2);
|
||||
}
|
||||
|
||||
macro_rules! izip {
|
||||
(@ __closure @ ($a:expr)) => { |a| (a,) };
|
||||
(@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, g) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) };
|
||||
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) };
|
||||
|
||||
( $first:expr $(,)?) => {
|
||||
{
|
||||
::core::iter::IntoIterator::into_iter($first)
|
||||
}
|
||||
};
|
||||
( $first:expr, $($rest:expr),+ $(,)?) => {
|
||||
{
|
||||
::core::iter::IntoIterator::into_iter($first)
|
||||
$(.zip($rest))*
|
||||
.map(crate::izip!(@ __closure @ ($first, $($rest),*)))
|
||||
}
|
||||
};
|
||||
}
|
||||
pub(crate) use izip;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::prime::largest_prime_in_arithmetic_progression64;
|
||||
use rand::random;
|
||||
|
||||
#[test]
|
||||
fn test_barrett32() {
|
||||
let p =
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
|
||||
|
||||
let big_q: u32 = p.ilog2() + 1;
|
||||
let big_l: u32 = big_q + 31;
|
||||
let k: u32 = ((1u128 << big_l) / p as u128).try_into().unwrap();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let a = random::<u32>() % p;
|
||||
let b = random::<u32>() % p;
|
||||
|
||||
let d = a as u64 * b as u64;
|
||||
// Q < 31
|
||||
// d < 2^(2Q)
|
||||
// (d >> (Q-1)) < 2^(Q+1) -> c1 fits in u32
|
||||
let c1 = (d >> (big_q - 1)) as u32;
|
||||
// c2 < 2^(Q+33)
|
||||
let c3 = ((c1 as u64 * k as u64) >> 32) as u32;
|
||||
let c = (d as u32).wrapping_sub(p.wrapping_mul(c3));
|
||||
let c = if c >= p { c - p } else { c };
|
||||
assert_eq!(c as u64, d % p as u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_barrett52() {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap();
|
||||
|
||||
let big_q: u32 = p.ilog2() + 1;
|
||||
let big_l: u32 = big_q + 51;
|
||||
let k: u64 = ((1u128 << big_l) / p as u128).try_into().unwrap();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let a = random::<u64>() % p;
|
||||
let b = random::<u64>() % p;
|
||||
|
||||
let d = a as u128 * b as u128;
|
||||
// Q < 51
|
||||
// d < 2^(2Q)
|
||||
// (d >> (Q-1)) < 2^(Q+1) -> c1 fits in u64
|
||||
let c1 = (d >> (big_q - 1)) as u64;
|
||||
// c2 < 2^(Q+53)
|
||||
let c3 = ((c1 as u128 * k as u128) >> 52) as u64;
|
||||
let c = (d as u64).wrapping_sub(p.wrapping_mul(c3));
|
||||
let c = if c >= p { c - p } else { c };
|
||||
assert_eq!(c as u128, d % p as u128);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_barrett64() {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap();
|
||||
|
||||
let big_q: u32 = p.ilog2() + 1;
|
||||
let big_l: u32 = big_q + 63;
|
||||
let k: u64 = ((1u128 << big_l) / p as u128).try_into().unwrap();
|
||||
|
||||
for _ in 0..10000 {
|
||||
let a = random::<u64>() % p;
|
||||
let b = random::<u64>() % p;
|
||||
|
||||
let d = a as u128 * b as u128;
|
||||
// Q < 63
|
||||
// d < 2^(2Q)
|
||||
// (d >> (Q-1)) < 2^(Q+1) -> c1 fits in u64
|
||||
let c1 = (d >> (big_q - 1)) as u64;
|
||||
// c2 < 2^(Q+65)
|
||||
let c3 = ((c1 as u128 * k as u128) >> 64) as u64;
|
||||
let c = (d as u64).wrapping_sub(p.wrapping_mul(c3));
|
||||
let c = if c >= p { c - p } else { c };
|
||||
assert_eq!(c as u128, d % p as u128);
|
||||
}
|
||||
}
|
||||
|
||||
// primes should be of the form x * LARGEST_POLYNOMIAL_SIZE(2^16) + 1
|
||||
// primes should be < 2^30 or < 2^50, for NTT efficiency
|
||||
// primes should satisfy the magic property documented above the primes32 module
|
||||
// primes should be as large as possible
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn generate_primes() {
|
||||
let mut p = 1u64 << 30;
|
||||
for _ in 0..100 {
|
||||
p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, p - 1).unwrap();
|
||||
println!("{p:#034b}");
|
||||
}
|
||||
|
||||
let mut p = 1u64 << 50;
|
||||
for _ in 0..100 {
|
||||
p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, p - 1).unwrap();
|
||||
println!("{p:#054b}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(test)]
|
||||
mod x86_tests {
|
||||
use super::*;
|
||||
use rand::random as rnd;
|
||||
|
||||
#[test]
|
||||
fn test_widening_mul() {
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
let a = u64x4(rnd(), rnd(), rnd(), rnd());
|
||||
let b = u64x4(rnd(), rnd(), rnd(), rnd());
|
||||
let (lo, hi) = simd.widening_mul_u64x4(a, b);
|
||||
assert_eq!(
|
||||
lo,
|
||||
u64x4(
|
||||
u64::wrapping_mul(a.0, b.0),
|
||||
u64::wrapping_mul(a.1, b.1),
|
||||
u64::wrapping_mul(a.2, b.2),
|
||||
u64::wrapping_mul(a.3, b.3),
|
||||
),
|
||||
);
|
||||
assert_eq!(
|
||||
hi,
|
||||
u64x4(
|
||||
((a.0 as u128 * b.0 as u128) >> 64) as u64,
|
||||
((a.1 as u128 * b.1 as u128) >> 64) as u64,
|
||||
((a.2 as u128 * b.2 as u128) >> 64) as u64,
|
||||
((a.3 as u128 * b.3 as u128) >> 64) as u64,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
|
||||
let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
|
||||
let (lo, hi) = simd.widening_mul_u64x8(a, b);
|
||||
assert_eq!(
|
||||
lo,
|
||||
u64x8(
|
||||
u64::wrapping_mul(a.0, b.0),
|
||||
u64::wrapping_mul(a.1, b.1),
|
||||
u64::wrapping_mul(a.2, b.2),
|
||||
u64::wrapping_mul(a.3, b.3),
|
||||
u64::wrapping_mul(a.4, b.4),
|
||||
u64::wrapping_mul(a.5, b.5),
|
||||
u64::wrapping_mul(a.6, b.6),
|
||||
u64::wrapping_mul(a.7, b.7),
|
||||
),
|
||||
);
|
||||
assert_eq!(
|
||||
hi,
|
||||
u64x8(
|
||||
((a.0 as u128 * b.0 as u128) >> 64) as u64,
|
||||
((a.1 as u128 * b.1 as u128) >> 64) as u64,
|
||||
((a.2 as u128 * b.2 as u128) >> 64) as u64,
|
||||
((a.3 as u128 * b.3 as u128) >> 64) as u64,
|
||||
((a.4 as u128 * b.4 as u128) >> 64) as u64,
|
||||
((a.5 as u128 * b.5 as u128) >> 64) as u64,
|
||||
((a.6 as u128 * b.6 as u128) >> 64) as u64,
|
||||
((a.7 as u128 * b.7 as u128) >> 64) as u64,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_low_32_bits() {
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
let a = u64x4(rnd(), rnd(), rnd(), rnd());
|
||||
let b = u64x4(rnd(), rnd(), rnd(), rnd());
|
||||
let res = simd.mul_low_32_bits_u64x4(a, b);
|
||||
assert_eq!(
|
||||
res,
|
||||
u64x4(
|
||||
a.0 as u32 as u64 * b.0 as u32 as u64,
|
||||
a.1 as u32 as u64 * b.1 as u32 as u64,
|
||||
a.2 as u32 as u64 * b.2 as u32 as u64,
|
||||
a.3 as u32 as u64 * b.3 as u32 as u64,
|
||||
),
|
||||
);
|
||||
}
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
|
||||
let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
|
||||
let res = simd.mul_low_32_bits_u64x8(a, b);
|
||||
assert_eq!(
|
||||
res,
|
||||
u64x8(
|
||||
a.0 as u32 as u64 * b.0 as u32 as u64,
|
||||
a.1 as u32 as u64 * b.1 as u32 as u64,
|
||||
a.2 as u32 as u64 * b.2 as u32 as u64,
|
||||
a.3 as u32 as u64 * b.3 as u32 as u64,
|
||||
a.4 as u32 as u64 * b.4 as u32 as u64,
|
||||
a.5 as u32 as u64 * b.5 as u32 as u64,
|
||||
a.6 as u32 as u64 * b.6 as u32 as u64,
|
||||
a.7 as u32 as u64 * b.7 as u32 as u64,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_lhs_with_low_32_bits_of_rhs() {
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
let a = u64x4(rnd(), rnd(), rnd(), rnd());
|
||||
let b = u64x4(rnd(), rnd(), rnd(), rnd());
|
||||
let res = simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(a, b);
|
||||
assert_eq!(
|
||||
res,
|
||||
u64x4(
|
||||
u64::wrapping_mul(a.0, b.0 as u32 as u64),
|
||||
u64::wrapping_mul(a.1, b.1 as u32 as u64),
|
||||
u64::wrapping_mul(a.2, b.2 as u32 as u64),
|
||||
u64::wrapping_mul(a.3, b.3 as u32 as u64),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
448
tfhe-ntt/src/native128.rs
Normal file
448
tfhe-ntt/src/native128.rs
Normal file
@@ -0,0 +1,448 @@
|
||||
pub(crate) use crate::native64::{mul_mod32, mul_mod64};
|
||||
use aligned_vec::avec;
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 128bit polynomials.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan32(
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
);
|
||||
|
||||
#[inline(always)]
|
||||
fn reconstruct_32bit_0123456789_v2(
|
||||
mod_p0: u32,
|
||||
mod_p1: u32,
|
||||
mod_p2: u32,
|
||||
mod_p3: u32,
|
||||
mod_p4: u32,
|
||||
mod_p5: u32,
|
||||
mod_p6: u32,
|
||||
mod_p7: u32,
|
||||
mod_p8: u32,
|
||||
mod_p9: u32,
|
||||
) -> u128 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let mod_p01 = {
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
|
||||
v0 as u64 + (v1 as u64 * P0 as u64)
|
||||
};
|
||||
let mod_p23 = {
|
||||
let v2 = mod_p2;
|
||||
let v3 = mul_mod32(P3, P2_INV_MOD_P3, 2 * P3 + mod_p3 - v2);
|
||||
v2 as u64 + (v3 as u64 * P2 as u64)
|
||||
};
|
||||
let mod_p45 = {
|
||||
let v4 = mod_p4;
|
||||
let v5 = mul_mod32(P5, P4_INV_MOD_P5, 2 * P5 + mod_p5 - v4);
|
||||
v4 as u64 + (v5 as u64 * P4 as u64)
|
||||
};
|
||||
let mod_p67 = {
|
||||
let v6 = mod_p6;
|
||||
let v7 = mul_mod32(P7, P6_INV_MOD_P7, 2 * P7 + mod_p7 - v6);
|
||||
v6 as u64 + (v7 as u64 * P6 as u64)
|
||||
};
|
||||
let mod_p89 = {
|
||||
let v8 = mod_p8;
|
||||
let v9 = mul_mod32(P9, P8_INV_MOD_P9, 2 * P9 + mod_p9 - v8);
|
||||
v8 as u64 + (v9 as u64 * P8 as u64)
|
||||
};
|
||||
|
||||
let v01 = mod_p01;
|
||||
let v23 = mul_mod64(
|
||||
P23.wrapping_neg(),
|
||||
2 * P23 + mod_p23 - v01,
|
||||
P01_INV_MOD_P23,
|
||||
P01_INV_MOD_P23_SHOUP,
|
||||
);
|
||||
let v45 = mul_mod64(
|
||||
P45.wrapping_neg(),
|
||||
2 * P45 + mod_p45 - (v01 + mul_mod64(P45.wrapping_neg(), v23, P01, P01_MOD_P45_SHOUP)),
|
||||
P0123_INV_MOD_P45,
|
||||
P0123_INV_MOD_P45_SHOUP,
|
||||
);
|
||||
let v67 = mul_mod64(
|
||||
P67.wrapping_neg(),
|
||||
2 * P67 + mod_p67
|
||||
- (v01
|
||||
+ mul_mod64(
|
||||
P67.wrapping_neg(),
|
||||
v23 + mul_mod64(P67.wrapping_neg(), v45, P23, P23_MOD_P67_SHOUP),
|
||||
P01,
|
||||
P01_MOD_P67_SHOUP,
|
||||
)),
|
||||
P012345_INV_MOD_P67,
|
||||
P012345_INV_MOD_P67_SHOUP,
|
||||
);
|
||||
let v89 = mul_mod64(
|
||||
P89.wrapping_neg(),
|
||||
2 * P89 + mod_p89
|
||||
- (v01
|
||||
+ mul_mod64(
|
||||
P89.wrapping_neg(),
|
||||
v23 + mul_mod64(
|
||||
P89.wrapping_neg(),
|
||||
v45 + mul_mod64(P89.wrapping_neg(), v67, P45, P45_MOD_P89_SHOUP),
|
||||
P23,
|
||||
P23_MOD_P89_SHOUP,
|
||||
),
|
||||
P01,
|
||||
P01_MOD_P89_SHOUP,
|
||||
)),
|
||||
P01234567_INV_MOD_P89,
|
||||
P01234567_INV_MOD_P89_SHOUP,
|
||||
);
|
||||
|
||||
let sign = v89 > (P89 / 2);
|
||||
let pos = (v01 as u128)
|
||||
.wrapping_add(u128::wrapping_mul(v23 as u128, P01 as u128))
|
||||
.wrapping_add(u128::wrapping_mul(v45 as u128, P0123))
|
||||
.wrapping_add(u128::wrapping_mul(v67 as u128, P012345))
|
||||
.wrapping_add(u128::wrapping_mul(v89 as u128, P01234567));
|
||||
let neg = pos.wrapping_sub(P0123456789);
|
||||
|
||||
if sign {
|
||||
neg
|
||||
} else {
|
||||
pos
|
||||
}
|
||||
}
|
||||
|
||||
impl Plan32 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime32::Plan, primes32::*};
|
||||
Some(Self(
|
||||
Plan::try_new(n, P0)?,
|
||||
Plan::try_new(n, P1)?,
|
||||
Plan::try_new(n, P2)?,
|
||||
Plan::try_new(n, P3)?,
|
||||
Plan::try_new(n, P4)?,
|
||||
Plan::try_new(n, P5)?,
|
||||
Plan::try_new(n, P6)?,
|
||||
Plan::try_new(n, P7)?,
|
||||
Plan::try_new(n, P8)?,
|
||||
Plan::try_new(n, P9)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn ntt_0(&self) -> &crate::prime32::Plan {
|
||||
&self.0
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_1(&self) -> &crate::prime32::Plan {
|
||||
&self.1
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_2(&self) -> &crate::prime32::Plan {
|
||||
&self.2
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_3(&self) -> &crate::prime32::Plan {
|
||||
&self.3
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_4(&self) -> &crate::prime32::Plan {
|
||||
&self.4
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_5(&self) -> &crate::prime32::Plan {
|
||||
&self.5
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_6(&self) -> &crate::prime32::Plan {
|
||||
&self.6
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_7(&self) -> &crate::prime32::Plan {
|
||||
&self.7
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_8(&self) -> &crate::prime32::Plan {
|
||||
&self.8
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_9(&self) -> &crate::prime32::Plan {
|
||||
&self.9
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
value: &[u128],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
mod_p3: &mut [u32],
|
||||
mod_p4: &mut [u32],
|
||||
mod_p5: &mut [u32],
|
||||
mod_p6: &mut [u32],
|
||||
mod_p7: &mut [u32],
|
||||
mod_p8: &mut [u32],
|
||||
mod_p9: &mut [u32],
|
||||
) {
|
||||
for (
|
||||
value,
|
||||
mod_p0,
|
||||
mod_p1,
|
||||
mod_p2,
|
||||
mod_p3,
|
||||
mod_p4,
|
||||
mod_p5,
|
||||
mod_p6,
|
||||
mod_p7,
|
||||
mod_p8,
|
||||
mod_p9,
|
||||
) in crate::izip!(
|
||||
value,
|
||||
&mut *mod_p0,
|
||||
&mut *mod_p1,
|
||||
&mut *mod_p2,
|
||||
&mut *mod_p3,
|
||||
&mut *mod_p4,
|
||||
&mut *mod_p5,
|
||||
&mut *mod_p6,
|
||||
&mut *mod_p7,
|
||||
&mut *mod_p8,
|
||||
&mut *mod_p9,
|
||||
) {
|
||||
*mod_p0 = (value % crate::primes32::P0 as u128) as u32;
|
||||
*mod_p1 = (value % crate::primes32::P1 as u128) as u32;
|
||||
*mod_p2 = (value % crate::primes32::P2 as u128) as u32;
|
||||
*mod_p3 = (value % crate::primes32::P3 as u128) as u32;
|
||||
*mod_p4 = (value % crate::primes32::P4 as u128) as u32;
|
||||
*mod_p5 = (value % crate::primes32::P5 as u128) as u32;
|
||||
*mod_p6 = (value % crate::primes32::P6 as u128) as u32;
|
||||
*mod_p7 = (value % crate::primes32::P7 as u128) as u32;
|
||||
*mod_p8 = (value % crate::primes32::P8 as u128) as u32;
|
||||
*mod_p9 = (value % crate::primes32::P9 as u128) as u32;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
self.2.fwd(mod_p2);
|
||||
self.3.fwd(mod_p3);
|
||||
self.4.fwd(mod_p4);
|
||||
self.5.fwd(mod_p5);
|
||||
self.6.fwd(mod_p6);
|
||||
self.7.fwd(mod_p7);
|
||||
self.8.fwd(mod_p8);
|
||||
self.9.fwd(mod_p9);
|
||||
}
|
||||
|
||||
pub fn inv(
|
||||
&self,
|
||||
value: &mut [u128],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
mod_p3: &mut [u32],
|
||||
mod_p4: &mut [u32],
|
||||
mod_p5: &mut [u32],
|
||||
mod_p6: &mut [u32],
|
||||
mod_p7: &mut [u32],
|
||||
mod_p8: &mut [u32],
|
||||
mod_p9: &mut [u32],
|
||||
) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
self.2.inv(mod_p2);
|
||||
self.3.inv(mod_p3);
|
||||
self.4.inv(mod_p4);
|
||||
self.5.inv(mod_p5);
|
||||
self.6.inv(mod_p6);
|
||||
self.7.inv(mod_p7);
|
||||
self.8.inv(mod_p8);
|
||||
self.9.inv(mod_p9);
|
||||
|
||||
for (
|
||||
value,
|
||||
&mod_p0,
|
||||
&mod_p1,
|
||||
&mod_p2,
|
||||
&mod_p3,
|
||||
&mod_p4,
|
||||
&mod_p5,
|
||||
&mod_p6,
|
||||
&mod_p7,
|
||||
&mod_p8,
|
||||
&mod_p9,
|
||||
) in crate::izip!(
|
||||
value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4, &*mod_p5, &*mod_p6, &*mod_p7,
|
||||
&*mod_p8, &*mod_p9,
|
||||
) {
|
||||
*value = reconstruct_32bit_0123456789_v2(
|
||||
mod_p0, mod_p1, mod_p2, mod_p3, mod_p4, mod_p5, mod_p6, mod_p7, mod_p8, mod_p9,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
let mut lhs2 = avec![0; n];
|
||||
let mut lhs3 = avec![0; n];
|
||||
let mut lhs4 = avec![0; n];
|
||||
let mut lhs5 = avec![0; n];
|
||||
let mut lhs6 = avec![0; n];
|
||||
let mut lhs7 = avec![0; n];
|
||||
let mut lhs8 = avec![0; n];
|
||||
let mut lhs9 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
let mut rhs2 = avec![0; n];
|
||||
let mut rhs3 = avec![0; n];
|
||||
let mut rhs4 = avec![0; n];
|
||||
let mut rhs5 = avec![0; n];
|
||||
let mut rhs6 = avec![0; n];
|
||||
let mut rhs7 = avec![0; n];
|
||||
let mut rhs8 = avec![0; n];
|
||||
let mut rhs9 = avec![0; n];
|
||||
|
||||
self.fwd(
|
||||
lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
|
||||
&mut lhs7, &mut lhs8, &mut lhs9,
|
||||
);
|
||||
self.fwd(
|
||||
rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4, &mut rhs5, &mut rhs6,
|
||||
&mut rhs7, &mut rhs8, &mut rhs9,
|
||||
);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
|
||||
self.3.mul_assign_normalize(&mut lhs3, &rhs3);
|
||||
self.4.mul_assign_normalize(&mut lhs4, &rhs4);
|
||||
self.5.mul_assign_normalize(&mut lhs5, &rhs5);
|
||||
self.6.mul_assign_normalize(&mut lhs6, &rhs6);
|
||||
self.7.mul_assign_normalize(&mut lhs7, &rhs7);
|
||||
self.8.mul_assign_normalize(&mut lhs8, &rhs8);
|
||||
self.9.mul_assign_normalize(&mut lhs9, &rhs9);
|
||||
|
||||
self.inv(
|
||||
prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
|
||||
&mut lhs7, &mut lhs8, &mut lhs9,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use rand::random;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
pub fn negacyclic_convolution(n: usize, lhs: &[u128], rhs: &[u128]) -> Vec<u128> {
|
||||
let mut full_convolution = vec![0u128; 2 * n];
|
||||
let mut negacyclic_convolution = vec![0u128; n];
|
||||
for i in 0..n {
|
||||
for j in 0..n {
|
||||
full_convolution[i + j] =
|
||||
full_convolution[i + j].wrapping_add(lhs[i].wrapping_mul(rhs[j]));
|
||||
}
|
||||
}
|
||||
for i in 0..n {
|
||||
negacyclic_convolution[i] = full_convolution[i].wrapping_sub(full_convolution[i + n]);
|
||||
}
|
||||
negacyclic_convolution
|
||||
}
|
||||
|
||||
pub fn random_lhs_rhs_with_negacyclic_convolution(
|
||||
n: usize,
|
||||
) -> (Vec<u128>, Vec<u128>, Vec<u128>) {
|
||||
let mut lhs = vec![0u128; n];
|
||||
let mut rhs = vec![0u128; n];
|
||||
|
||||
for x in &mut lhs {
|
||||
*x = random();
|
||||
}
|
||||
for x in &mut rhs {
|
||||
*x = random();
|
||||
}
|
||||
|
||||
let lhs = lhs;
|
||||
let rhs = rhs;
|
||||
|
||||
let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
|
||||
(lhs, rhs, negacyclic_convolution)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reconstruct_32bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
let value = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
|
||||
let mut value_roundtrip = vec![0; n];
|
||||
let mut mod_p0 = vec![0; n];
|
||||
let mut mod_p1 = vec![0; n];
|
||||
let mut mod_p2 = vec![0; n];
|
||||
let mut mod_p3 = vec![0; n];
|
||||
let mut mod_p4 = vec![0; n];
|
||||
let mut mod_p5 = vec![0; n];
|
||||
let mut mod_p6 = vec![0; n];
|
||||
let mut mod_p7 = vec![0; n];
|
||||
let mut mod_p8 = vec![0; n];
|
||||
let mut mod_p9 = vec![0; n];
|
||||
|
||||
let plan = Plan32::try_new(n).unwrap();
|
||||
plan.fwd(
|
||||
&value,
|
||||
&mut mod_p0,
|
||||
&mut mod_p1,
|
||||
&mut mod_p2,
|
||||
&mut mod_p3,
|
||||
&mut mod_p4,
|
||||
&mut mod_p5,
|
||||
&mut mod_p6,
|
||||
&mut mod_p7,
|
||||
&mut mod_p8,
|
||||
&mut mod_p9,
|
||||
);
|
||||
plan.inv(
|
||||
&mut value_roundtrip,
|
||||
&mut mod_p0,
|
||||
&mut mod_p1,
|
||||
&mut mod_p2,
|
||||
&mut mod_p3,
|
||||
&mut mod_p4,
|
||||
&mut mod_p5,
|
||||
&mut mod_p6,
|
||||
&mut mod_p7,
|
||||
&mut mod_p8,
|
||||
&mut mod_p9,
|
||||
);
|
||||
for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
|
||||
assert_eq!(value_roundtrip, value.wrapping_mul(n as u128));
|
||||
}
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) = random_lhs_rhs_with_negacyclic_convolution(n);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
}
|
||||
597
tfhe-ntt/src/native32.rs
Normal file
597
tfhe-ntt/src/native32.rs
Normal file
@@ -0,0 +1,597 @@
|
||||
use aligned_vec::avec;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 32bit polynomials.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan32(
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
);
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 32bit polynomials.
|
||||
/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
|
||||
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan52(crate::prime64::Plan, crate::prime64::Plan, crate::V4IFma);
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn mul_mod32(p: u32, a: u32, b: u32) -> u32 {
|
||||
let wide = a as u64 * b as u64;
|
||||
(wide % p as u64) as u32
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn reconstruct_32bit_012(mod_p0: u32, mod_p1: u32, mod_p2: u32) -> u32 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
|
||||
let v2 = mul_mod32(
|
||||
P2,
|
||||
P01_INV_MOD_P2,
|
||||
2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
|
||||
);
|
||||
|
||||
let sign = v2 > (P2 / 2);
|
||||
|
||||
const _0: u32 = P0;
|
||||
const _01: u32 = _0.wrapping_mul(P1);
|
||||
const _012: u32 = _01.wrapping_mul(P2);
|
||||
|
||||
let pos = v0
|
||||
.wrapping_add(v1.wrapping_mul(_0))
|
||||
.wrapping_add(v2.wrapping_mul(_01));
|
||||
|
||||
let neg = pos.wrapping_sub(_012);
|
||||
|
||||
if sign {
|
||||
neg
|
||||
} else {
|
||||
pos
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn mul_mod32_avx2(
|
||||
simd: crate::V3,
|
||||
p: u32x8,
|
||||
a: u32x8,
|
||||
b: u32x8,
|
||||
b_shoup: u32x8,
|
||||
) -> u32x8 {
|
||||
let shoup_q = simd.widening_mul_u32x8(a, b_shoup).1;
|
||||
let t = simd.wrapping_sub_u32x8(
|
||||
simd.wrapping_mul_u32x8(a, b),
|
||||
simd.wrapping_mul_u32x8(shoup_q, p),
|
||||
);
|
||||
simd.small_mod_u32x8(p, t)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn mul_mod32_avx512(
|
||||
simd: crate::V4IFma,
|
||||
p: u32x16,
|
||||
a: u32x16,
|
||||
b: u32x16,
|
||||
b_shoup: u32x16,
|
||||
) -> u32x16 {
|
||||
let shoup_q = simd.widening_mul_u32x16(a, b_shoup).1;
|
||||
let t = simd.wrapping_sub_u32x16(
|
||||
simd.wrapping_mul_u32x16(a, b),
|
||||
simd.wrapping_mul_u32x16(shoup_q, p),
|
||||
);
|
||||
simd.small_mod_u32x16(p, t)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn mul_mod52_avx512(
|
||||
simd: crate::V4IFma,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
a: u64x8,
|
||||
b: u64x8,
|
||||
b_shoup: u64x8,
|
||||
) -> u64x8 {
|
||||
let shoup_q = simd.widening_mul_u52x8(a, b_shoup).1;
|
||||
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(a, b).0);
|
||||
simd.small_mod_u64x8(p, t)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn reconstruct_32bit_012_avx2(
|
||||
simd: crate::V3,
|
||||
mod_p0: u32x8,
|
||||
mod_p1: u32x8,
|
||||
mod_p2: u32x8,
|
||||
) -> u32x8 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let p0 = simd.splat_u32x8(P0);
|
||||
let p1 = simd.splat_u32x8(P1);
|
||||
let p2 = simd.splat_u32x8(P2);
|
||||
let two_p1 = simd.splat_u32x8(2 * P1);
|
||||
let two_p2 = simd.splat_u32x8(2 * P2);
|
||||
let half_p2 = simd.splat_u32x8(P2 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
|
||||
let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
|
||||
let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
|
||||
let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
|
||||
let p012 = simd.splat_u32x8(P0.wrapping_mul(P1).wrapping_mul(P2));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32_avx2(
|
||||
simd,
|
||||
p1,
|
||||
simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
let v2 = mul_mod32_avx2(
|
||||
simd,
|
||||
p2,
|
||||
simd.wrapping_sub_u32x8(
|
||||
simd.wrapping_add_u32x8(two_p2, mod_p2),
|
||||
simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
|
||||
),
|
||||
p01_inv_mod_p2,
|
||||
p01_inv_mod_p2_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u32x8(v2, half_p2);
|
||||
let pos = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0)),
|
||||
simd.wrapping_mul_u32x8(v2, p01),
|
||||
);
|
||||
let neg = simd.wrapping_sub_u32x8(pos, p012);
|
||||
|
||||
simd.select_u32x8(sign, neg, pos)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
fn reconstruct_32bit_012_avx512(
|
||||
simd: crate::V4IFma,
|
||||
mod_p0: u32x16,
|
||||
mod_p1: u32x16,
|
||||
mod_p2: u32x16,
|
||||
) -> u32x16 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let p0 = simd.splat_u32x16(P0);
|
||||
let p1 = simd.splat_u32x16(P1);
|
||||
let p2 = simd.splat_u32x16(P2);
|
||||
let two_p1 = simd.splat_u32x16(2 * P1);
|
||||
let two_p2 = simd.splat_u32x16(2 * P2);
|
||||
let half_p2 = simd.splat_u32x16(P2 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
|
||||
let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
|
||||
let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
|
||||
let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
|
||||
let p012 = simd.splat_u32x16(P0.wrapping_mul(P1).wrapping_mul(P2));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32_avx512(
|
||||
simd,
|
||||
p1,
|
||||
simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
let v2 = mul_mod32_avx512(
|
||||
simd,
|
||||
p2,
|
||||
simd.wrapping_sub_u32x16(
|
||||
simd.wrapping_add_u32x16(two_p2, mod_p2),
|
||||
simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
|
||||
),
|
||||
p01_inv_mod_p2,
|
||||
p01_inv_mod_p2_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u32x16(v2, half_p2);
|
||||
let pos = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0)),
|
||||
simd.wrapping_mul_u32x16(v2, p01),
|
||||
);
|
||||
let neg = simd.wrapping_sub_u32x16(pos, p012);
|
||||
|
||||
simd.select_u32x16(sign, neg, pos)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
fn reconstruct_52bit_01_avx512(simd: crate::V4IFma, mod_p0: u64x8, mod_p1: u64x8) -> u32x8 {
|
||||
use crate::primes52::*;
|
||||
|
||||
let p0 = simd.splat_u64x8(P0);
|
||||
let p1 = simd.splat_u64x8(P1);
|
||||
let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
|
||||
let two_p1 = simd.splat_u64x8(2 * P1);
|
||||
let half_p1 = simd.splat_u64x8(P1 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod52_avx512(
|
||||
simd,
|
||||
p1,
|
||||
neg_p1,
|
||||
simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u64x8(v1, half_p1);
|
||||
|
||||
let pos = simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0));
|
||||
let neg = simd.wrapping_sub_u64x8(pos, p01);
|
||||
|
||||
simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
fn reconstruct_slice_32bit_012_avx2(
|
||||
simd: crate::V3,
|
||||
value: &mut [u32],
|
||||
mod_p0: &[u32],
|
||||
mod_p1: &[u32],
|
||||
mod_p2: &[u32],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<8, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
|
||||
let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
|
||||
*value = cast(reconstruct_32bit_012_avx2(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
cast(mod_p2),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
fn reconstruct_slice_32bit_012_avx512(
|
||||
simd: crate::V4IFma,
|
||||
value: &mut [u32],
|
||||
mod_p0: &[u32],
|
||||
mod_p1: &[u32],
|
||||
mod_p2: &[u32],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<16, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
|
||||
let mod_p2 = pulp::as_arrays::<16, _>(mod_p2).0;
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
|
||||
*value = cast(reconstruct_32bit_012_avx512(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
cast(mod_p2),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
fn reconstruct_slice_52bit_01_avx512(
|
||||
simd: crate::V4IFma,
|
||||
value: &mut [u32],
|
||||
mod_p0: &[u64],
|
||||
mod_p1: &[u64],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<8, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
|
||||
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
|
||||
*value = cast(reconstruct_52bit_01_avx512(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
impl Plan32 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime32::Plan, primes32::*};
|
||||
Some(Self(
|
||||
Plan::try_new(n, P0)?,
|
||||
Plan::try_new(n, P1)?,
|
||||
Plan::try_new(n, P2)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn ntt_0(&self) -> &crate::prime32::Plan {
|
||||
&self.0
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_1(&self) -> &crate::prime32::Plan {
|
||||
&self.1
|
||||
}
|
||||
#[inline]
|
||||
pub fn ntt_2(&self) -> &crate::prime32::Plan {
|
||||
&self.2
|
||||
}
|
||||
|
||||
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32], mod_p2: &mut [u32]) {
|
||||
for (value, mod_p0, mod_p1, mod_p2) in
|
||||
crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
|
||||
{
|
||||
*mod_p0 = value % crate::primes32::P0;
|
||||
*mod_p1 = value % crate::primes32::P1;
|
||||
*mod_p2 = value % crate::primes32::P2;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
self.2.fwd(mod_p2);
|
||||
}
|
||||
|
||||
pub fn inv(
|
||||
&self,
|
||||
value: &mut [u32],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
self.2.inv(mod_p2);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4IFma::try_new() {
|
||||
reconstruct_slice_32bit_012_avx512(simd, value, mod_p0, mod_p1, mod_p2);
|
||||
return;
|
||||
}
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
reconstruct_slice_32bit_012_avx2(simd, value, mod_p0, mod_p1, mod_p2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2)
|
||||
{
|
||||
*value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs: &[u32]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
let mut lhs2 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
let mut rhs2 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
|
||||
self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
|
||||
|
||||
self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl Plan52 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters, or if the AVX512
|
||||
/// instruction set isn't detected.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime64::Plan, primes52::*};
|
||||
let simd = crate::V4IFma::try_new()?;
|
||||
Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?, simd))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
|
||||
self.2.vectorize(
|
||||
#[inline(always)]
|
||||
|| {
|
||||
for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
|
||||
*mod_p0 = *value as u64;
|
||||
*mod_p1 = *value as u64;
|
||||
}
|
||||
},
|
||||
);
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
}
|
||||
|
||||
pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
|
||||
let simd = self.2;
|
||||
reconstruct_slice_52bit_01_avx512(simd, value, mod_p0, mod_p1);
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs: &[u32]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0, &mut lhs1);
|
||||
self.fwd(rhs, &mut rhs0, &mut rhs1);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
|
||||
self.inv(prod, &mut lhs0, &mut lhs1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prime32::tests::random_lhs_rhs_with_negacyclic_convolution;
|
||||
use rand::random;
|
||||
|
||||
extern crate alloc;
|
||||
use alloc::{vec, vec::Vec};
|
||||
|
||||
#[test]
|
||||
fn reconstruct_32bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
let plan = Plan32::try_new(n).unwrap();
|
||||
|
||||
let value = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
|
||||
let mut value_roundtrip = vec![0u32; n];
|
||||
let mut mod_p0 = vec![0u32; n];
|
||||
let mut mod_p1 = vec![0u32; n];
|
||||
let mut mod_p2 = vec![0u32; n];
|
||||
|
||||
plan.fwd(&value, &mut mod_p0, &mut mod_p1, &mut mod_p2);
|
||||
plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1, &mut mod_p2);
|
||||
for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
|
||||
assert_eq!(value_roundtrip, value.wrapping_mul(n as u32));
|
||||
}
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, 0);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn reconstruct_52bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
if let Some(plan) = Plan52::try_new(n) {
|
||||
let value = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
|
||||
let mut value_roundtrip = vec![0u32; n];
|
||||
let mut mod_p0 = vec![0u64; n];
|
||||
let mut mod_p1 = vec![0u64; n];
|
||||
|
||||
plan.fwd(&value, &mut mod_p0, &mut mod_p1);
|
||||
plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1);
|
||||
for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
|
||||
assert_eq!(value_roundtrip, value.wrapping_mul(n as u32));
|
||||
}
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, 0);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[test]
|
||||
fn reconstruct_32bit_avx() {
|
||||
for n in [16, 32, 64, 256, 1024, 2048] {
|
||||
use crate::primes32::*;
|
||||
|
||||
let mut value = vec![0u32; n];
|
||||
let mut value_avx2 = vec![0u32; n];
|
||||
#[cfg(feature = "nightly")]
|
||||
let mut value_avx512 = vec![0u32; n];
|
||||
let mod_p0 = (0..n).map(|_| random::<u32>() % P0).collect::<Vec<_>>();
|
||||
let mod_p1 = (0..n).map(|_| random::<u32>() % P1).collect::<Vec<_>>();
|
||||
let mod_p2 = (0..n).map(|_| random::<u32>() % P2).collect::<Vec<_>>();
|
||||
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in
|
||||
crate::izip!(&mut value, &mod_p0, &mod_p1, &mod_p2)
|
||||
{
|
||||
*value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
|
||||
}
|
||||
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
reconstruct_slice_32bit_012_avx2(simd, &mut value_avx2, &mod_p0, &mod_p1, &mod_p2);
|
||||
assert_eq!(value, value_avx2);
|
||||
}
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4IFma::try_new() {
|
||||
reconstruct_slice_32bit_012_avx512(
|
||||
simd,
|
||||
&mut value_avx512,
|
||||
&mod_p0,
|
||||
&mod_p1,
|
||||
&mod_p2,
|
||||
);
|
||||
assert_eq!(value, value_avx512);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1294
tfhe-ntt/src/native64.rs
Normal file
1294
tfhe-ntt/src/native64.rs
Normal file
File diff suppressed because it is too large
Load Diff
222
tfhe-ntt/src/native_binary128.rs
Normal file
222
tfhe-ntt/src/native_binary128.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
pub(crate) use crate::native64::{mul_mod32, mul_mod64};
|
||||
use aligned_vec::avec;
|
||||
|
||||
pub struct Plan32(
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
);
|
||||
|
||||
#[inline(always)]
|
||||
fn reconstruct_32bit_01234_v2(
|
||||
mod_p0: u32,
|
||||
mod_p1: u32,
|
||||
mod_p2: u32,
|
||||
mod_p3: u32,
|
||||
mod_p4: u32,
|
||||
) -> u128 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let mod_p12 = {
|
||||
let v1 = mod_p1;
|
||||
let v2 = mul_mod32(P2, P1_INV_MOD_P2, 2 * P2 + mod_p2 - v1);
|
||||
v1 as u64 + (v2 as u64 * P1 as u64)
|
||||
};
|
||||
let mod_p34 = {
|
||||
let v3 = mod_p3;
|
||||
let v4 = mul_mod32(P4, P3_INV_MOD_P4, 2 * P4 + mod_p4 - v3);
|
||||
v3 as u64 + (v4 as u64 * P3 as u64)
|
||||
};
|
||||
|
||||
let v0 = mod_p0 as u64;
|
||||
let v12 = mul_mod64(
|
||||
P12.wrapping_neg(),
|
||||
2 * P12 + mod_p12 - v0,
|
||||
P0_INV_MOD_P12,
|
||||
P0_INV_MOD_P12_SHOUP,
|
||||
);
|
||||
let v34 = mul_mod64(
|
||||
P34.wrapping_neg(),
|
||||
2 * P34 + mod_p34 - (v0 + mul_mod64(P34.wrapping_neg(), v12, P0 as u64, P0_MOD_P34_SHOUP)),
|
||||
P012_INV_MOD_P34,
|
||||
P012_INV_MOD_P34_SHOUP,
|
||||
);
|
||||
|
||||
let sign = v34 > (P34 / 2);
|
||||
|
||||
const _0: u128 = P0 as u128;
|
||||
const _012: u128 = _0.wrapping_mul(P12 as u128);
|
||||
const _01234: u128 = _012.wrapping_mul(P34 as u128);
|
||||
|
||||
let pos = (v0 as u128)
|
||||
.wrapping_add((v12 as u128).wrapping_mul(_0))
|
||||
.wrapping_add((v34 as u128).wrapping_mul(_012));
|
||||
let neg = pos.wrapping_sub(_01234);
|
||||
|
||||
if sign {
|
||||
neg
|
||||
} else {
|
||||
pos
|
||||
}
|
||||
}
|
||||
|
||||
impl Plan32 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime32::Plan, primes32::*};
|
||||
Some(Self(
|
||||
Plan::try_new(n, P0)?,
|
||||
Plan::try_new(n, P1)?,
|
||||
Plan::try_new(n, P2)?,
|
||||
Plan::try_new(n, P3)?,
|
||||
Plan::try_new(n, P4)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
value: &[u128],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
mod_p3: &mut [u32],
|
||||
mod_p4: &mut [u32],
|
||||
) {
|
||||
for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
|
||||
value,
|
||||
&mut *mod_p0,
|
||||
&mut *mod_p1,
|
||||
&mut *mod_p2,
|
||||
&mut *mod_p3,
|
||||
&mut *mod_p4,
|
||||
) {
|
||||
*mod_p0 = (value % crate::primes32::P0 as u128) as u32;
|
||||
*mod_p1 = (value % crate::primes32::P1 as u128) as u32;
|
||||
*mod_p2 = (value % crate::primes32::P2 as u128) as u32;
|
||||
*mod_p3 = (value % crate::primes32::P3 as u128) as u32;
|
||||
*mod_p4 = (value % crate::primes32::P4 as u128) as u32;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
self.2.fwd(mod_p2);
|
||||
self.3.fwd(mod_p3);
|
||||
self.4.fwd(mod_p4);
|
||||
}
|
||||
|
||||
pub fn fwd_binary(
|
||||
&self,
|
||||
value: &[u128],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
mod_p3: &mut [u32],
|
||||
mod_p4: &mut [u32],
|
||||
) {
|
||||
for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
|
||||
value,
|
||||
&mut *mod_p0,
|
||||
&mut *mod_p1,
|
||||
&mut *mod_p2,
|
||||
&mut *mod_p3,
|
||||
&mut *mod_p4,
|
||||
) {
|
||||
*mod_p0 = *value as u32;
|
||||
*mod_p1 = *value as u32;
|
||||
*mod_p2 = *value as u32;
|
||||
*mod_p3 = *value as u32;
|
||||
*mod_p4 = *value as u32;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
self.2.fwd(mod_p2);
|
||||
self.3.fwd(mod_p3);
|
||||
self.4.fwd(mod_p4);
|
||||
}
|
||||
|
||||
pub fn inv(
|
||||
&self,
|
||||
value: &mut [u128],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
mod_p3: &mut [u32],
|
||||
mod_p4: &mut [u32],
|
||||
) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
self.2.inv(mod_p2);
|
||||
self.3.inv(mod_p3);
|
||||
self.4.inv(mod_p4);
|
||||
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
|
||||
crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4)
|
||||
{
|
||||
*value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
let mut lhs2 = avec![0; n];
|
||||
let mut lhs3 = avec![0; n];
|
||||
let mut lhs4 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
let mut rhs2 = avec![0; n];
|
||||
let mut rhs3 = avec![0; n];
|
||||
let mut rhs4 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
|
||||
self.fwd_binary(rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
|
||||
self.3.mul_assign_normalize(&mut lhs3, &rhs3);
|
||||
self.4.mul_assign_normalize(&mut lhs4, &rhs4);
|
||||
|
||||
self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::native128::tests::negacyclic_convolution;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use rand::random;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_32bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
let plan = Plan32::try_new(n).unwrap();
|
||||
|
||||
let lhs = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| random::<u128>() % 2).collect::<Vec<_>>();
|
||||
let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
}
|
||||
364
tfhe-ntt/src/native_binary32.rs
Normal file
364
tfhe-ntt/src/native_binary32.rs
Normal file
@@ -0,0 +1,364 @@
|
||||
use aligned_vec::avec;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
use crate::native32::mul_mod32;
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
|
||||
/// coefficients.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan32(crate::prime32::Plan, crate::prime32::Plan);
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
|
||||
/// coefficients.
|
||||
/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
|
||||
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan52(crate::prime64::Plan, crate::V4IFma);
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn reconstruct_32bit_01(mod_p0: u32, mod_p1: u32) -> u32 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
|
||||
|
||||
let sign = v1 > (P1 / 2);
|
||||
|
||||
const _0: u32 = P0;
|
||||
const _01: u32 = _0.wrapping_mul(P1);
|
||||
|
||||
let pos = v0.wrapping_add(v1.wrapping_mul(_0));
|
||||
let neg = pos.wrapping_sub(_01);
|
||||
|
||||
if sign {
|
||||
neg
|
||||
} else {
|
||||
pos
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn reconstruct_32bit_01_avx2(simd: crate::V3, mod_p0: u32x8, mod_p1: u32x8) -> u32x8 {
|
||||
use crate::{native32::mul_mod32_avx2, primes32::*};
|
||||
|
||||
let p0 = simd.splat_u32x8(P0);
|
||||
let p1 = simd.splat_u32x8(P1);
|
||||
let two_p1 = simd.splat_u32x8(2 * P1);
|
||||
let half_p1 = simd.splat_u32x8(P1 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32_avx2(
|
||||
simd,
|
||||
p1,
|
||||
simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u32x8(v1, half_p1);
|
||||
let pos = simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0));
|
||||
|
||||
let neg = simd.wrapping_sub_u32x8(pos, p01);
|
||||
|
||||
simd.select_u32x8(sign, neg, pos)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
fn reconstruct_32bit_01_avx512(simd: crate::V4IFma, mod_p0: u32x16, mod_p1: u32x16) -> u32x16 {
|
||||
use crate::{native32::mul_mod32_avx512, primes32::*};
|
||||
|
||||
let p0 = simd.splat_u32x16(P0);
|
||||
let p1 = simd.splat_u32x16(P1);
|
||||
let two_p1 = simd.splat_u32x16(2 * P1);
|
||||
let half_p1 = simd.splat_u32x16(P1 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32_avx512(
|
||||
simd,
|
||||
p1,
|
||||
simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u32x16(v1, half_p1);
|
||||
let pos = simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0));
|
||||
|
||||
let neg = simd.wrapping_sub_u32x16(pos, p01);
|
||||
|
||||
simd.select_u32x16(sign, neg, pos)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
fn reconstruct_52bit_0_avx512(simd: crate::V4IFma, mod_p0: u64x8) -> u32x8 {
|
||||
use crate::primes52::*;
|
||||
|
||||
let p0 = simd.splat_u64x8(P0);
|
||||
let half_p0 = simd.splat_u64x8(P0 / 2);
|
||||
|
||||
let v0 = mod_p0;
|
||||
|
||||
let sign = simd.cmp_gt_u64x8(v0, half_p0);
|
||||
|
||||
let pos = v0;
|
||||
let neg = simd.wrapping_sub_u64x8(pos, p0);
|
||||
|
||||
simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
fn reconstruct_slice_32bit_01_avx2(
|
||||
simd: crate::V3,
|
||||
value: &mut [u32],
|
||||
mod_p0: &[u32],
|
||||
mod_p1: &[u32],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<8, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
|
||||
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
|
||||
*value = cast(reconstruct_32bit_01_avx2(simd, cast(mod_p0), cast(mod_p1)));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
fn reconstruct_slice_32bit_01_avx512(
|
||||
simd: crate::V4IFma,
|
||||
value: &mut [u32],
|
||||
mod_p0: &[u32],
|
||||
mod_p1: &[u32],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<16, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
|
||||
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
|
||||
*value = cast(reconstruct_32bit_01_avx512(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
fn reconstruct_slice_52bit_0_avx512(simd: crate::V4IFma, value: &mut [u32], mod_p0: &[u64]) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<8, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
|
||||
for (value, &mod_p0) in crate::izip!(value, mod_p0) {
|
||||
*value = cast(reconstruct_52bit_0_avx512(simd, cast(mod_p0)));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
impl Plan32 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime32::Plan, primes32::*};
|
||||
Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
|
||||
for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
|
||||
*mod_p0 = value % crate::primes32::P0;
|
||||
*mod_p1 = value % crate::primes32::P1;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
}
|
||||
|
||||
pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
|
||||
for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
|
||||
*mod_p0 = *value;
|
||||
*mod_p1 = *value;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
}
|
||||
|
||||
pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4IFma::try_new() {
|
||||
reconstruct_slice_32bit_01_avx512(simd, value, mod_p0, mod_p1);
|
||||
return;
|
||||
}
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
reconstruct_slice_32bit_01_avx2(simd, value, mod_p0, mod_p1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (value, &mod_p0, &mod_p1) in crate::izip!(value, &*mod_p0, &*mod_p1) {
|
||||
*value = reconstruct_32bit_01(mod_p0, mod_p1);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs_binary.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0, &mut lhs1);
|
||||
self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
|
||||
self.inv(prod, &mut lhs0, &mut lhs1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl Plan52 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters, or if the AVX512
|
||||
/// instruction set isn't detected.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime64::Plan, primes52::*};
|
||||
let simd = crate::V4IFma::try_new()?;
|
||||
Some(Self(Plan::try_new(n, P0)?, simd))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64]) {
|
||||
self.1.vectorize(
|
||||
#[inline(always)]
|
||||
|| {
|
||||
for (value, mod_p0) in crate::izip!(value, &mut *mod_p0) {
|
||||
*mod_p0 = *value as u64;
|
||||
}
|
||||
},
|
||||
);
|
||||
self.0.fwd(mod_p0);
|
||||
}
|
||||
|
||||
pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u64]) {
|
||||
self.fwd(value, mod_p0);
|
||||
}
|
||||
|
||||
pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64]) {
|
||||
self.0.inv(mod_p0);
|
||||
|
||||
let simd = self.1;
|
||||
reconstruct_slice_52bit_0_avx512(simd, value, mod_p0);
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs_binary.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut rhs0 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0);
|
||||
self.fwd_binary(rhs_binary, &mut rhs0);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
|
||||
self.inv(prod, &mut lhs0);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prime32::tests::negacyclic_convolution;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use rand::random;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_32bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
let plan = Plan32::try_new(n).unwrap();
|
||||
|
||||
let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
|
||||
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn reconstruct_52bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
if let Some(plan) = Plan52::try_new(n) {
|
||||
let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
|
||||
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
563
tfhe-ntt/src/native_binary64.rs
Normal file
563
tfhe-ntt/src/native_binary64.rs
Normal file
@@ -0,0 +1,563 @@
|
||||
use aligned_vec::avec;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
pub(crate) use crate::native32::mul_mod32;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) use crate::native32::mul_mod32_avx2;
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) use crate::native32::{mul_mod32_avx512, mul_mod52_avx512};
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
|
||||
/// coefficients.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan32(
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
crate::prime32::Plan,
|
||||
);
|
||||
|
||||
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
|
||||
/// coefficients.
|
||||
/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
|
||||
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Plan52(crate::prime64::Plan, crate::prime64::Plan, crate::V4IFma);
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(dead_code)]
|
||||
fn reconstruct_32bit_012(mod_p0: u32, mod_p1: u32, mod_p2: u32) -> u64 {
|
||||
use crate::primes32::*;
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
|
||||
let v2 = mul_mod32(
|
||||
P2,
|
||||
P01_INV_MOD_P2,
|
||||
2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
|
||||
);
|
||||
|
||||
let sign = v2 > (P2 / 2);
|
||||
|
||||
const _0: u64 = P0 as u64;
|
||||
const _01: u64 = _0.wrapping_mul(P1 as u64);
|
||||
const _012: u64 = _01.wrapping_mul(P2 as u64);
|
||||
|
||||
let pos = (v0 as u64)
|
||||
.wrapping_add((v1 as u64).wrapping_mul(_0))
|
||||
.wrapping_add((v2 as u64).wrapping_mul(_01));
|
||||
|
||||
let neg = pos.wrapping_sub(_012);
|
||||
|
||||
if sign {
|
||||
neg
|
||||
} else {
|
||||
pos
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[allow(dead_code)]
|
||||
#[inline(always)]
|
||||
fn reconstruct_32bit_012_avx2(
|
||||
simd: crate::V3,
|
||||
mod_p0: u32x8,
|
||||
mod_p1: u32x8,
|
||||
mod_p2: u32x8,
|
||||
) -> [u64x4; 2] {
|
||||
use crate::primes32::*;
|
||||
|
||||
let p0 = simd.splat_u32x8(P0);
|
||||
let p1 = simd.splat_u32x8(P1);
|
||||
let p2 = simd.splat_u32x8(P2);
|
||||
let two_p1 = simd.splat_u32x8(2 * P1);
|
||||
let two_p2 = simd.splat_u32x8(2 * P2);
|
||||
let half_p2 = simd.splat_u32x8(P2 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
|
||||
let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
|
||||
|
||||
let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
|
||||
let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64));
|
||||
let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32_avx2(
|
||||
simd,
|
||||
p1,
|
||||
simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
let v2 = mul_mod32_avx2(
|
||||
simd,
|
||||
p2,
|
||||
simd.wrapping_sub_u32x8(
|
||||
simd.wrapping_add_u32x8(two_p2, mod_p2),
|
||||
simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
|
||||
),
|
||||
p01_inv_mod_p2,
|
||||
p01_inv_mod_p2_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u32x8(v2, half_p2);
|
||||
let sign: [i32x4; 2] = pulp::cast(sign);
|
||||
// sign extend so that -1i32 becomes -1i64
|
||||
let sign0: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[0])) };
|
||||
let sign1: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[1])) };
|
||||
|
||||
let v0: [u32x4; 2] = pulp::cast(v0);
|
||||
let v1: [u32x4; 2] = pulp::cast(v1);
|
||||
let v2: [u32x4; 2] = pulp::cast(v2);
|
||||
let v00 = simd.convert_u32x4_to_u64x4(v0[0]);
|
||||
let v01 = simd.convert_u32x4_to_u64x4(v0[1]);
|
||||
let v10 = simd.convert_u32x4_to_u64x4(v1[0]);
|
||||
let v11 = simd.convert_u32x4_to_u64x4(v1[1]);
|
||||
let v20 = simd.convert_u32x4_to_u64x4(v2[0]);
|
||||
let v21 = simd.convert_u32x4_to_u64x4(v2[1]);
|
||||
|
||||
let pos0 = v00;
|
||||
let pos0 = simd.wrapping_add_u64x4(pos0, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v10));
|
||||
let pos0 = simd.wrapping_add_u64x4(
|
||||
pos0,
|
||||
simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v20),
|
||||
);
|
||||
|
||||
let pos1 = v01;
|
||||
let pos1 = simd.wrapping_add_u64x4(pos1, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v11));
|
||||
let pos1 = simd.wrapping_add_u64x4(
|
||||
pos1,
|
||||
simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v21),
|
||||
);
|
||||
|
||||
let neg0 = simd.wrapping_sub_u64x4(pos0, p012);
|
||||
let neg1 = simd.wrapping_sub_u64x4(pos1, p012);
|
||||
|
||||
[
|
||||
simd.select_u64x4(sign0, neg0, pos0),
|
||||
simd.select_u64x4(sign1, neg1, pos1),
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[allow(dead_code)]
|
||||
#[inline(always)]
|
||||
fn reconstruct_32bit_012_avx512(
|
||||
simd: crate::V4IFma,
|
||||
mod_p0: u32x16,
|
||||
mod_p1: u32x16,
|
||||
mod_p2: u32x16,
|
||||
) -> [u64x8; 2] {
|
||||
use crate::primes32::*;
|
||||
|
||||
let p0 = simd.splat_u32x16(P0);
|
||||
let p1 = simd.splat_u32x16(P1);
|
||||
let p2 = simd.splat_u32x16(P2);
|
||||
let two_p1 = simd.splat_u32x16(2 * P1);
|
||||
let two_p2 = simd.splat_u32x16(2 * P2);
|
||||
let half_p2 = simd.splat_u32x16(P2 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
|
||||
let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
|
||||
|
||||
let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
|
||||
let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64));
|
||||
let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod32_avx512(
|
||||
simd,
|
||||
p1,
|
||||
simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
let v2 = mul_mod32_avx512(
|
||||
simd,
|
||||
p2,
|
||||
simd.wrapping_sub_u32x16(
|
||||
simd.wrapping_add_u32x16(two_p2, mod_p2),
|
||||
simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
|
||||
),
|
||||
p01_inv_mod_p2,
|
||||
p01_inv_mod_p2_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u32x16(v2, half_p2).0;
|
||||
let sign0 = b8(sign as u8);
|
||||
let sign1 = b8((sign >> 8) as u8);
|
||||
let v0: [u32x8; 2] = pulp::cast(v0);
|
||||
let v1: [u32x8; 2] = pulp::cast(v1);
|
||||
let v2: [u32x8; 2] = pulp::cast(v2);
|
||||
let v00 = simd.convert_u32x8_to_u64x8(v0[0]);
|
||||
let v01 = simd.convert_u32x8_to_u64x8(v0[1]);
|
||||
let v10 = simd.convert_u32x8_to_u64x8(v1[0]);
|
||||
let v11 = simd.convert_u32x8_to_u64x8(v1[1]);
|
||||
let v20 = simd.convert_u32x8_to_u64x8(v2[0]);
|
||||
let v21 = simd.convert_u32x8_to_u64x8(v2[1]);
|
||||
|
||||
let pos0 = v00;
|
||||
let pos0 = simd.wrapping_add_u64x8(pos0, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v10));
|
||||
let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p01, v20));
|
||||
|
||||
let pos1 = v01;
|
||||
let pos1 = simd.wrapping_add_u64x8(pos1, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v11));
|
||||
let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p01, v21));
|
||||
|
||||
let neg0 = simd.wrapping_sub_u64x8(pos0, p012);
|
||||
let neg1 = simd.wrapping_sub_u64x8(pos1, p012);
|
||||
|
||||
[
|
||||
simd.select_u64x8(sign0, neg0, pos0),
|
||||
simd.select_u64x8(sign1, neg1, pos1),
|
||||
]
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
fn reconstruct_52bit_01_avx512(simd: crate::V4IFma, mod_p0: u64x8, mod_p1: u64x8) -> u64x8 {
|
||||
use crate::primes52::*;
|
||||
|
||||
let p0 = simd.splat_u64x8(P0);
|
||||
let p1 = simd.splat_u64x8(P1);
|
||||
let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
|
||||
let two_p1 = simd.splat_u64x8(2 * P1);
|
||||
let half_p1 = simd.splat_u64x8(P1 / 2);
|
||||
|
||||
let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
|
||||
let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
|
||||
|
||||
let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
|
||||
|
||||
let v0 = mod_p0;
|
||||
let v1 = mul_mod52_avx512(
|
||||
simd,
|
||||
p1,
|
||||
neg_p1,
|
||||
simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
|
||||
p0_inv_mod_p1,
|
||||
p0_inv_mod_p1_shoup,
|
||||
);
|
||||
|
||||
let sign = simd.cmp_gt_u64x8(v1, half_p1);
|
||||
|
||||
let pos = simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0));
|
||||
let neg = simd.wrapping_sub_u64x8(pos, p01);
|
||||
|
||||
simd.select_u64x8(sign, neg, pos)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
fn reconstruct_slice_32bit_012_avx2(
|
||||
simd: crate::V3,
|
||||
value: &mut [u64],
|
||||
mod_p0: &[u32],
|
||||
mod_p1: &[u32],
|
||||
mod_p2: &[u32],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<8, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
|
||||
let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
|
||||
*value = cast(reconstruct_32bit_012_avx2(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
cast(mod_p2),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
fn reconstruct_slice_32bit_012_avx512(
|
||||
simd: crate::V4IFma,
|
||||
value: &mut [u64],
|
||||
mod_p0: &[u32],
|
||||
mod_p1: &[u32],
|
||||
mod_p2: &[u32],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<16, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
|
||||
let mod_p2 = pulp::as_arrays::<16, _>(mod_p2).0;
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
|
||||
*value = cast(reconstruct_32bit_012_avx512(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
cast(mod_p2),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
fn reconstruct_slice_52bit_01_avx512(
|
||||
simd: crate::V4IFma,
|
||||
value: &mut [u64],
|
||||
mod_p0: &[u64],
|
||||
mod_p1: &[u64],
|
||||
) {
|
||||
simd.vectorize(
|
||||
#[inline(always)]
|
||||
move || {
|
||||
let value = pulp::as_arrays_mut::<8, _>(value).0;
|
||||
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
|
||||
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
|
||||
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
|
||||
*value = cast(reconstruct_52bit_01_avx512(
|
||||
simd,
|
||||
cast(mod_p0),
|
||||
cast(mod_p1),
|
||||
));
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
impl Plan32 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime32::Plan, primes32::*};
|
||||
Some(Self(
|
||||
Plan::try_new(n, P0)?,
|
||||
Plan::try_new(n, P1)?,
|
||||
Plan::try_new(n, P2)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
pub fn fwd(&self, value: &[u64], mod_p0: &mut [u32], mod_p1: &mut [u32], mod_p2: &mut [u32]) {
|
||||
for (value, mod_p0, mod_p1, mod_p2) in
|
||||
crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
|
||||
{
|
||||
*mod_p0 = (value % crate::primes32::P0 as u64) as u32;
|
||||
*mod_p1 = (value % crate::primes32::P1 as u64) as u32;
|
||||
*mod_p2 = (value % crate::primes32::P2 as u64) as u32;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
self.2.fwd(mod_p2);
|
||||
}
|
||||
pub fn fwd_binary(
|
||||
&self,
|
||||
value: &[u64],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
) {
|
||||
for (value, mod_p0, mod_p1, mod_p2) in
|
||||
crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
|
||||
{
|
||||
*mod_p0 = *value as u32;
|
||||
*mod_p1 = *value as u32;
|
||||
*mod_p2 = *value as u32;
|
||||
}
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
self.2.fwd(mod_p2);
|
||||
}
|
||||
|
||||
pub fn inv(
|
||||
&self,
|
||||
value: &mut [u64],
|
||||
mod_p0: &mut [u32],
|
||||
mod_p1: &mut [u32],
|
||||
mod_p2: &mut [u32],
|
||||
) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
self.2.inv(mod_p2);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
{
|
||||
#[cfg(feature = "nightly")]
|
||||
if let Some(simd) = crate::V4IFma::try_new() {
|
||||
reconstruct_slice_32bit_012_avx512(simd, value, mod_p0, mod_p1, mod_p2);
|
||||
return;
|
||||
}
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
reconstruct_slice_32bit_012_avx2(simd, value, mod_p0, mod_p1, mod_p2);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2)
|
||||
{
|
||||
*value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs_binary: &[u64]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs_binary.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
let mut lhs2 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
let mut rhs2 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
|
||||
self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1, &mut rhs2);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
|
||||
|
||||
self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
impl Plan52 {
|
||||
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
|
||||
/// suitable roots of unity can be found for the wanted parameters, or if the AVX512
|
||||
/// instruction set isn't detected.
|
||||
pub fn try_new(n: usize) -> Option<Self> {
|
||||
use crate::{prime64::Plan, primes52::*};
|
||||
let simd = crate::V4IFma::try_new()?;
|
||||
Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?, simd))
|
||||
}
|
||||
|
||||
/// Returns the polynomial size of the negacyclic NTT plan.
|
||||
#[inline]
|
||||
pub fn ntt_size(&self) -> usize {
|
||||
self.0.ntt_size()
|
||||
}
|
||||
|
||||
pub fn fwd(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
|
||||
use crate::primes52::*;
|
||||
self.2.vectorize(
|
||||
#[inline(always)]
|
||||
|| {
|
||||
for (&value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
|
||||
*mod_p0 = value % P0;
|
||||
*mod_p1 = value % P1;
|
||||
}
|
||||
},
|
||||
);
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
}
|
||||
pub fn fwd_binary(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
|
||||
self.2.vectorize(
|
||||
#[inline(always)]
|
||||
|| {
|
||||
for (&value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
|
||||
*mod_p0 = value;
|
||||
*mod_p1 = value;
|
||||
}
|
||||
},
|
||||
);
|
||||
self.0.fwd(mod_p0);
|
||||
self.1.fwd(mod_p1);
|
||||
}
|
||||
|
||||
pub fn inv(&self, value: &mut [u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
|
||||
self.0.inv(mod_p0);
|
||||
self.1.inv(mod_p1);
|
||||
|
||||
reconstruct_slice_52bit_01_avx512(self.2, value, mod_p0, mod_p1);
|
||||
}
|
||||
|
||||
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
|
||||
/// `prod`.
|
||||
pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs_binary: &[u64]) {
|
||||
let n = prod.len();
|
||||
assert_eq!(n, lhs.len());
|
||||
assert_eq!(n, rhs_binary.len());
|
||||
|
||||
let mut lhs0 = avec![0; n];
|
||||
let mut lhs1 = avec![0; n];
|
||||
|
||||
let mut rhs0 = avec![0; n];
|
||||
let mut rhs1 = avec![0; n];
|
||||
|
||||
self.fwd(lhs, &mut lhs0, &mut lhs1);
|
||||
self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
|
||||
|
||||
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
|
||||
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
|
||||
|
||||
self.inv(prod, &mut lhs0, &mut lhs1);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prime64::tests::negacyclic_convolution;
|
||||
use alloc::{vec, vec::Vec};
|
||||
use rand::random;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[test]
|
||||
fn reconstruct_32bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
let plan = Plan32::try_new(n).unwrap();
|
||||
|
||||
let lhs = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| random::<u64>() % 2).collect::<Vec<_>>();
|
||||
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn reconstruct_52bit() {
|
||||
for n in [32, 64, 256, 1024, 2048] {
|
||||
if let Some(plan) = Plan52::try_new(n) {
|
||||
let lhs = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
|
||||
let rhs = (0..n).map(|_| random::<u64>() % 2).collect::<Vec<_>>();
|
||||
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
|
||||
|
||||
let mut prod = vec![0; n];
|
||||
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
|
||||
assert_eq!(prod, negacyclic_convolution);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
223
tfhe-ntt/src/prime.rs
Normal file
223
tfhe-ntt/src/prime.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
use crate::fastdiv::{Div32, Div64};
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn mul_mod32(n: Div32, x: u32, y: u32) -> u32 {
|
||||
Div32::rem_u64(x as u64 * y as u64, n)
|
||||
}
|
||||
#[inline(always)]
|
||||
pub const fn mul_mod64(n: Div64, x: u64, y: u64) -> u64 {
|
||||
Div64::rem_u128(x as u128 * y as u128, n)
|
||||
}
|
||||
|
||||
pub const fn exp_mod32(n: Div32, base: u32, pow: u32) -> u32 {
|
||||
if pow == 0 {
|
||||
1
|
||||
} else {
|
||||
let mut pow = pow;
|
||||
let mut y = 1;
|
||||
let mut x = base;
|
||||
|
||||
while pow > 1 {
|
||||
if pow % 2 == 1 {
|
||||
y = mul_mod32(n, x, y);
|
||||
}
|
||||
x = mul_mod32(n, x, x);
|
||||
pow /= 2;
|
||||
}
|
||||
mul_mod32(n, x, y)
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn exp_mod64(n: Div64, base: u64, pow: u64) -> u64 {
|
||||
if pow == 0 {
|
||||
1
|
||||
} else {
|
||||
let mut pow = pow;
|
||||
let mut y = 1;
|
||||
let mut x = base;
|
||||
|
||||
while pow > 1 {
|
||||
if pow % 2 == 1 {
|
||||
y = mul_mod64(n, x, y);
|
||||
}
|
||||
x = mul_mod64(n, x, x);
|
||||
pow /= 2;
|
||||
}
|
||||
mul_mod64(n, x, y)
|
||||
}
|
||||
}
|
||||
|
||||
const fn is_prime_miller_rabin_iter(n: Div64, s: u64, d: u64, a: u64) -> bool {
|
||||
let mut x = exp_mod64(n, a, d);
|
||||
let n_minus_1 = n.divisor() - 1;
|
||||
if x == 1 || x == n_minus_1 {
|
||||
true
|
||||
} else {
|
||||
let mut count = 0;
|
||||
while count < s - 1 {
|
||||
x = mul_mod64(n, x, x);
|
||||
if x == n_minus_1 {
|
||||
return true;
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
const fn max64(a: u64, b: u64) -> u64 {
|
||||
if a > b {
|
||||
a
|
||||
} else {
|
||||
b
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_prime64(n: u64) -> bool {
|
||||
// 0 and 1 are not prime
|
||||
if n < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// test divisibility by small primes
|
||||
// hand-unrolled for the compiler to optimize divisions
|
||||
#[rustfmt::skip]
|
||||
{
|
||||
if n % 2 == 0 { return n == 2; }
|
||||
if n % 3 == 0 { return n == 3; }
|
||||
if n % 5 == 0 { return n == 5; }
|
||||
if n % 7 == 0 { return n == 7; }
|
||||
if n % 11 == 0 { return n == 11; }
|
||||
if n % 13 == 0 { return n == 13; }
|
||||
if n % 17 == 0 { return n == 17; }
|
||||
if n % 19 == 0 { return n == 19; }
|
||||
if n % 23 == 0 { return n == 23; }
|
||||
if n % 29 == 0 { return n == 29; }
|
||||
if n % 31 == 0 { return n == 31; }
|
||||
if n % 37 == 0 { return n == 37; }
|
||||
};
|
||||
|
||||
// deterministic miller rabin test, works for any n < 2^64
|
||||
// aside from the primes tested just before
|
||||
|
||||
// https://en.wikipedia.org/wiki/Miller-Rabin_primality_test#Testing_against_small_sets_of_bases
|
||||
let mut s = 0;
|
||||
let mut d = n - 1;
|
||||
|
||||
while d % 2 == 0 {
|
||||
s += 1;
|
||||
d /= 2;
|
||||
}
|
||||
|
||||
let (s, d) = (s, d);
|
||||
let n = Div64::new(n);
|
||||
is_prime_miller_rabin_iter(n, s, d, 2)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 3)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 5)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 7)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 11)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 13)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 17)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 19)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 23)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 29)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 31)
|
||||
&& is_prime_miller_rabin_iter(n, s, d, 37)
|
||||
}
|
||||
|
||||
/// Largest prime of the form `factor * x + offset` in the range
|
||||
/// `[lo, hi]`.
|
||||
pub const fn largest_prime_in_arithmetic_progression64(
|
||||
factor: u64,
|
||||
offset: u64,
|
||||
lo: u64,
|
||||
hi: u64,
|
||||
) -> Option<u64> {
|
||||
if lo > hi {
|
||||
return None;
|
||||
}
|
||||
|
||||
let a = factor;
|
||||
let b = offset;
|
||||
// lo <= ax + b <= hi
|
||||
// (lo - b)/a <= x <= (hi - b)/a
|
||||
|
||||
if b > hi {
|
||||
return None;
|
||||
}
|
||||
|
||||
if a == 0 {
|
||||
if lo <= b && b <= hi && is_prime64(b) {
|
||||
return Some(b);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let mut x_lo = (max64(lo, b) - b) / a;
|
||||
let rem = (max64(lo, b) - b) % a;
|
||||
if rem != 0 {
|
||||
x_lo += 1;
|
||||
}
|
||||
|
||||
let x_hi = (hi - b) / a;
|
||||
|
||||
let mut x = x_hi;
|
||||
let mut in_range = true;
|
||||
while in_range {
|
||||
let val = a * x + b;
|
||||
if is_prime64(val) {
|
||||
return Some(val);
|
||||
}
|
||||
|
||||
if x == x_lo {
|
||||
in_range = false;
|
||||
} else {
|
||||
x -= 1;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::prime64::Solinas;
|
||||
|
||||
#[test]
|
||||
fn test_is_prime() {
|
||||
let primes_under_1000 = [
|
||||
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83,
|
||||
89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179,
|
||||
181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271,
|
||||
277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379,
|
||||
383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479,
|
||||
487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599,
|
||||
601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
|
||||
709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823,
|
||||
827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941,
|
||||
947, 953, 967, 971, 977, 983, 991, 997,
|
||||
];
|
||||
|
||||
for n in 0..1000 {
|
||||
assert_eq!(primes_under_1000.contains(&n), is_prime64(n));
|
||||
}
|
||||
assert!(is_prime64(Solinas::P));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[rustfmt::skip]
|
||||
fn test_prime_search() {
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(0, 2, 1, 4), Some(2));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(0, 2, 2, 2), Some(2));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(0, 2, 2, 1), None);
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(1, 0, 14, 16), None);
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(1, 0, 14, 17), Some(17));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(1, 0, 17, 18), Some(17));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(2, 1, 14, 16), None);
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(2, 1, 14, 17), Some(17));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(2, 1, 17, 18), Some(17));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(6, 5, 0, u64::MAX) , Some(18446744073709551557));
|
||||
assert_eq!(largest_prime_in_arithmetic_progression64(6, 1, 0, u64::MAX) , Some(18446744073709551427));
|
||||
}
|
||||
}
|
||||
1648
tfhe-ntt/src/prime32.rs
Normal file
1648
tfhe-ntt/src/prime32.rs
Normal file
File diff suppressed because it is too large
Load Diff
1726
tfhe-ntt/src/prime32/generic.rs
Normal file
1726
tfhe-ntt/src/prime32/generic.rs
Normal file
File diff suppressed because it is too large
Load Diff
616
tfhe-ntt/src/prime32/less_than_30bit.rs
Normal file
616
tfhe-ntt/src/prime32/less_than_30bit.rs
Normal file
@@ -0,0 +1,616 @@
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u32x16(two_p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(z1, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
(
|
||||
simd.wrapping_add_u32x16(z0, t),
|
||||
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), two_p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let z0 = simd.small_mod_u32x16(two_p, z0);
|
||||
let z0 = simd.small_mod_u32x16(p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(z1, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u32x16(p, t);
|
||||
(
|
||||
simd.small_mod_u32x16(p, simd.wrapping_add_u32x16(z0, t)),
|
||||
simd.small_mod_u32x16(
|
||||
p,
|
||||
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u32x8(two_p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(z1, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
(
|
||||
simd.wrapping_add_u32x8(z0, t),
|
||||
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), two_p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let z0 = simd.small_mod_u32x8(two_p, z0);
|
||||
let z0 = simd.small_mod_u32x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(z1, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u32x8(p, t);
|
||||
(
|
||||
simd.small_mod_u32x8(p, simd.wrapping_add_u32x8(z0, t)),
|
||||
simd.small_mod_u32x8(
|
||||
p,
|
||||
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = p;
|
||||
let z0 = z0.min(z0.wrapping_sub(two_p));
|
||||
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(two_p))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = p;
|
||||
let z0 = z0.min(z0.wrapping_sub(two_p));
|
||||
let z0 = z0.min(z0.wrapping_sub(p));
|
||||
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let t = t.min(t.wrapping_sub(p));
|
||||
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
|
||||
(
|
||||
res.0.min(res.0.wrapping_sub(p)),
|
||||
res.1.min(res.1.wrapping_sub(p)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u32x16(z0, z1);
|
||||
let y0 = simd.small_mod_u32x16(two_p, y0);
|
||||
let t = simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u32x16(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(t, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u32x16(z0, z1);
|
||||
let y0 = simd.small_mod_u32x16(two_p, y0);
|
||||
let t = simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u32x16(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(t, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
|
||||
(simd.small_mod_u32x16(p, y0), simd.small_mod_u32x16(p, y1))
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u32x8(z0, z1);
|
||||
let y0 = simd.small_mod_u32x8(two_p, y0);
|
||||
let t = simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u32x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(t, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u32x8(z0, z1);
|
||||
let y0 = simd.small_mod_u32x8(two_p, y0);
|
||||
let t = simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u32x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(t, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
|
||||
(simd.small_mod_u32x8(p, y0), simd.small_mod_u32x8(p, y1))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = z0.wrapping_add(z1);
|
||||
let y0 = y0.min(y0.wrapping_sub(two_p));
|
||||
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
|
||||
let shoup_q = ((t as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let y1 = u32::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = z0.wrapping_add(z1);
|
||||
let y0 = y0.min(y0.wrapping_sub(two_p));
|
||||
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
|
||||
let shoup_q = ((t as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let y1 = u32::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
(y0.min(y0.wrapping_sub(p)), y1.min(y1.wrapping_sub(p)))
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn fwd_avx512(
|
||||
simd: crate::V4,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn inv_avx512(
|
||||
simd: crate::V4,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn fwd_avx2(
|
||||
simd: crate::V3,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn inv_avx2(
|
||||
simd: crate::V3,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fwd_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
|
||||
super::shoup::fwd_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn inv_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
|
||||
super::shoup::inv_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
prime::largest_prime_in_arithmetic_progression64,
|
||||
prime32::{
|
||||
init_negacyclic_twiddles_shoup,
|
||||
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
|
||||
},
|
||||
};
|
||||
|
||||
extern crate alloc;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn test_product() {
|
||||
for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30).unwrap()
|
||||
as u32;
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u32; n];
|
||||
let mut twid_shoup = vec![0u32; n];
|
||||
let mut inv_twid = vec![0u32; n];
|
||||
let mut inv_twid_shoup = vec![0u32; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u32; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[test]
|
||||
fn test_product_avx2() {
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
for n in [32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30)
|
||||
.unwrap() as u32;
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u32; n];
|
||||
let mut twid_shoup = vec![0u32; n];
|
||||
let mut inv_twid = vec![0u32; n];
|
||||
let mut inv_twid_shoup = vec![0u32; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u32; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn test_product_avx512() {
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
for n in [32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30)
|
||||
.unwrap() as u32;
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u32; n];
|
||||
let mut twid_shoup = vec![0u32; n];
|
||||
let mut inv_twid = vec![0u32; n];
|
||||
let mut inv_twid_shoup = vec![0u32; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u32; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
546
tfhe-ntt/src/prime32/less_than_31bit.rs
Normal file
546
tfhe-ntt/src/prime32/less_than_31bit.rs
Normal file
@@ -0,0 +1,546 @@
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u32x16(p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(z1, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u32x16(p, t);
|
||||
(
|
||||
simd.wrapping_add_u32x16(z0, t),
|
||||
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u32x16(p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(z1, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u32x16(p, t);
|
||||
(
|
||||
simd.small_mod_u32x16(p, simd.wrapping_add_u32x16(z0, t)),
|
||||
simd.small_mod_u32x16(
|
||||
p,
|
||||
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u32x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(z1, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u32x8(p, t);
|
||||
(
|
||||
simd.wrapping_add_u32x8(z0, t),
|
||||
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u32x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(z1, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u32x8(p, t);
|
||||
(
|
||||
simd.small_mod_u32x8(p, simd.wrapping_add_u32x8(z0, t)),
|
||||
simd.small_mod_u32x8(
|
||||
p,
|
||||
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = two_p;
|
||||
let z0 = z0.min(z0.wrapping_sub(p));
|
||||
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let t = t.min(t.wrapping_sub(p));
|
||||
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = two_p;
|
||||
let z0 = z0.min(z0.wrapping_sub(p));
|
||||
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let t = t.min(t.wrapping_sub(p));
|
||||
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
|
||||
(
|
||||
res.0.min(res.0.wrapping_sub(p)),
|
||||
res.1.min(res.1.wrapping_sub(p)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u32x16,
|
||||
z1: u32x16,
|
||||
w: u32x16,
|
||||
w_shoup: u32x16,
|
||||
p: u32x16,
|
||||
neg_p: u32x16,
|
||||
two_p: u32x16,
|
||||
) -> (u32x16, u32x16) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = simd.wrapping_add_u32x16(z0, z1);
|
||||
let y0 = simd.small_mod_u32x16(p, y0);
|
||||
let t = simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, z1), p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u32x16(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u32x16(
|
||||
simd.wrapping_mul_u32x16(t, w),
|
||||
simd.wrapping_mul_u32x16(shoup_q, neg_p),
|
||||
);
|
||||
let y1 = simd.small_mod_u32x16(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u32x8,
|
||||
z1: u32x8,
|
||||
w: u32x8,
|
||||
w_shoup: u32x8,
|
||||
p: u32x8,
|
||||
neg_p: u32x8,
|
||||
two_p: u32x8,
|
||||
) -> (u32x8, u32x8) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = simd.wrapping_add_u32x8(z0, z1);
|
||||
let y0 = simd.small_mod_u32x8(p, y0);
|
||||
let t = simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, z1), p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u32x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u32x8(
|
||||
simd.wrapping_mul_u32x8(t, w),
|
||||
simd.wrapping_mul_u32x8(shoup_q, neg_p),
|
||||
);
|
||||
let y1 = simd.small_mod_u32x8(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_scalar(
|
||||
z0: u32,
|
||||
z1: u32,
|
||||
w: u32,
|
||||
w_shoup: u32,
|
||||
p: u32,
|
||||
neg_p: u32,
|
||||
two_p: u32,
|
||||
) -> (u32, u32) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = z0.wrapping_add(z1);
|
||||
let y0 = y0.min(y0.wrapping_sub(p));
|
||||
let t = z0.wrapping_sub(z1).wrapping_add(p);
|
||||
let shoup_q = ((t as u64 * w_shoup as u64) >> 32) as u32;
|
||||
let y1 = u32::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let y1 = y1.min(y1.wrapping_sub(p));
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn fwd_avx512(
|
||||
simd: crate::V4,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn inv_avx512(
|
||||
simd: crate::V4,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn fwd_avx2(
|
||||
simd: crate::V3,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn inv_avx2(
|
||||
simd: crate::V3,
|
||||
p: u32,
|
||||
data: &mut [u32],
|
||||
twid: &[u32],
|
||||
twid_shoup: &[u32],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fwd_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
|
||||
super::shoup::fwd_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn inv_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
|
||||
super::shoup::inv_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
prime::largest_prime_in_arithmetic_progression64,
|
||||
prime32::{
|
||||
init_negacyclic_twiddles_shoup,
|
||||
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
|
||||
},
|
||||
};
|
||||
|
||||
extern crate alloc;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn test_product() {
|
||||
for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
|
||||
as u32;
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u32; n];
|
||||
let mut twid_shoup = vec![0u32; n];
|
||||
let mut inv_twid = vec![0u32; n];
|
||||
let mut inv_twid_shoup = vec![0u32; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u32; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[test]
|
||||
fn test_product_avx2() {
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
for n in [32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31)
|
||||
.unwrap() as u32;
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u32; n];
|
||||
let mut twid_shoup = vec![0u32; n];
|
||||
let mut inv_twid = vec![0u32; n];
|
||||
let mut inv_twid_shoup = vec![0u32; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u32; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn test_product_avx512() {
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
for n in [32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31)
|
||||
.unwrap() as u32;
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u32; n];
|
||||
let mut twid_shoup = vec![0u32; n];
|
||||
let mut inv_twid = vec![0u32; n];
|
||||
let mut inv_twid_shoup = vec![0u32; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u32; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1481
tfhe-ntt/src/prime32/shoup.rs
Normal file
1481
tfhe-ntt/src/prime32/shoup.rs
Normal file
File diff suppressed because it is too large
Load Diff
1883
tfhe-ntt/src/prime64.rs
Normal file
1883
tfhe-ntt/src/prime64.rs
Normal file
File diff suppressed because it is too large
Load Diff
1827
tfhe-ntt/src/prime64/generic_solinas.rs
Normal file
1827
tfhe-ntt/src/prime64/generic_solinas.rs
Normal file
File diff suppressed because it is too large
Load Diff
213
tfhe-ntt/src/prime64/less_than_50bit.rs
Normal file
213
tfhe-ntt/src/prime64/less_than_50bit.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use pulp::u64x8;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u64x8(two_p, z0);
|
||||
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
|
||||
(
|
||||
simd.wrapping_add_u64x8(z0, t),
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), two_p),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let z0 = simd.small_mod_u64x8(two_p, z0);
|
||||
let z0 = simd.small_mod_u64x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
|
||||
let t = simd.small_mod_u64x8(p, t);
|
||||
(
|
||||
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
|
||||
simd.small_mod_u64x8(
|
||||
p,
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x8(z0, z1);
|
||||
let y0 = simd.small_mod_u64x8(two_p, y0);
|
||||
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u52x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(t, w).0);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x8(z0, z1);
|
||||
let y0 = simd.small_mod_u64x8(two_p, y0);
|
||||
let y0 = simd.small_mod_u64x8(p, y0);
|
||||
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u52x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(t, w).0);
|
||||
let y1 = simd.small_mod_u64x8(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
pub(crate) fn fwd_avx512(
|
||||
simd: crate::V4IFma,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn inv_avx512(
|
||||
simd: crate::V4IFma,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
prime::largest_prime_in_arithmetic_progression64,
|
||||
prime64::{
|
||||
init_negacyclic_twiddles_shoup,
|
||||
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
|
||||
},
|
||||
};
|
||||
use alloc::vec;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[test]
|
||||
fn test_product() {
|
||||
if let Some(simd) = crate::V4IFma::try_new() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 49, 1 << 50)
|
||||
.unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
52,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
190
tfhe-ntt/src/prime64/less_than_51bit.rs
Normal file
190
tfhe-ntt/src/prime64/less_than_51bit.rs
Normal file
@@ -0,0 +1,190 @@
|
||||
use pulp::u64x8;
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u64x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
|
||||
let t = simd.small_mod_u64x8(p, t);
|
||||
(
|
||||
simd.wrapping_add_u64x8(z0, t),
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u64x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
|
||||
let t = simd.small_mod_u64x8(p, t);
|
||||
(
|
||||
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
|
||||
simd.small_mod_u64x8(
|
||||
p,
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx512(
|
||||
simd: crate::V4IFma,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x8(z0, z1);
|
||||
let y0 = simd.small_mod_u64x8(p, y0);
|
||||
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u52x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(t, w).0);
|
||||
let y1 = simd.small_mod_u64x8(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
pub(crate) fn fwd_avx512(
|
||||
simd: crate::V4IFma,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn inv_avx512(
|
||||
simd: crate::V4IFma,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::inv_breadth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
prime::largest_prime_in_arithmetic_progression64,
|
||||
prime64::{
|
||||
init_negacyclic_twiddles_shoup,
|
||||
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
|
||||
},
|
||||
};
|
||||
use alloc::vec;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[test]
|
||||
fn test_product() {
|
||||
if let Some(simd) = crate::V4IFma::try_new() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51)
|
||||
.unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
52,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
629
tfhe-ntt/src/prime64/less_than_62bit.rs
Normal file
629
tfhe-ntt/src/prime64/less_than_62bit.rs
Normal file
@@ -0,0 +1,629 @@
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u64x8(two_p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(z1, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
(
|
||||
simd.wrapping_add_u64x8(z0, t),
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), two_p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u64x8(two_p, z0);
|
||||
let z0 = simd.small_mod_u64x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(z1, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u64x8(p, t);
|
||||
(
|
||||
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
|
||||
simd.small_mod_u64x8(
|
||||
p,
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u64x4(two_p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(z1, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
(
|
||||
simd.wrapping_add_u64x4(z0, t),
|
||||
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), two_p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = p;
|
||||
let z0 = simd.small_mod_u64x4(two_p, z0);
|
||||
let z0 = simd.small_mod_u64x4(p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(z1, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
let t = simd.small_mod_u64x4(p, t);
|
||||
(
|
||||
simd.small_mod_u64x4(p, simd.wrapping_add_u64x4(z0, t)),
|
||||
simd.small_mod_u64x4(
|
||||
p,
|
||||
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = p;
|
||||
let z0 = z0.min(z0.wrapping_sub(two_p));
|
||||
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(two_p))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = p;
|
||||
let z0 = z0.min(z0.wrapping_sub(two_p));
|
||||
let z0 = z0.min(z0.wrapping_sub(p));
|
||||
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let t = t.min(t.wrapping_sub(p));
|
||||
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
|
||||
(
|
||||
res.0.min(res.0.wrapping_sub(p)),
|
||||
res.1.min(res.1.wrapping_sub(p)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x8(z0, z1);
|
||||
let y0 = simd.small_mod_u64x8(two_p, y0);
|
||||
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u64x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(t, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x8(z0, z1);
|
||||
let y0 = simd.small_mod_u64x8(two_p, y0);
|
||||
let y0 = simd.small_mod_u64x8(p, y0);
|
||||
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u64x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(t, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
let y1 = simd.small_mod_u64x8(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x4(z0, z1);
|
||||
let y0 = simd.small_mod_u64x4(two_p, y0);
|
||||
let t = simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u64x4(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(t, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x4(z0, z1);
|
||||
let y0 = simd.small_mod_u64x4(two_p, y0);
|
||||
let y0 = simd.small_mod_u64x4(p, y0);
|
||||
let t = simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, z1), two_p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u64x4(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(t, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
let y1 = simd.small_mod_u64x4(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = z0.wrapping_add(z1);
|
||||
let y0 = y0.min(y0.wrapping_sub(two_p));
|
||||
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
|
||||
let shoup_q = ((t as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let y1 = u64::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_last_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = p;
|
||||
|
||||
let y0 = z0.wrapping_add(z1);
|
||||
let y0 = y0.min(y0.wrapping_sub(two_p));
|
||||
let y0 = y0.min(y0.wrapping_sub(p));
|
||||
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
|
||||
let shoup_q = ((t as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let y1 = u64::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let y1 = y1.min(y1.wrapping_sub(p));
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn fwd_avx512(
|
||||
simd: crate::V4,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn inv_avx512(
|
||||
simd: crate::V4,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn fwd_avx2(
|
||||
simd: crate::V3,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn inv_avx2(
|
||||
simd: crate::V3,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fwd_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
|
||||
super::shoup::fwd_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn inv_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
|
||||
super::shoup::inv_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
prime::largest_prime_in_arithmetic_progression64,
|
||||
prime64::{
|
||||
init_negacyclic_twiddles_shoup,
|
||||
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
|
||||
},
|
||||
};
|
||||
use alloc::vec;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn test_product_avx512() {
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62)
|
||||
.unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
64,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[test]
|
||||
fn test_product_avx2() {
|
||||
use crate::prime64::tests::mul;
|
||||
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62)
|
||||
.unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
64,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_scalar() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p =
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62).unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
64,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
549
tfhe-ntt/src/prime64/less_than_63bit.rs
Normal file
549
tfhe-ntt/src/prime64/less_than_63bit.rs
Normal file
@@ -0,0 +1,549 @@
|
||||
#[allow(unused_imports)]
|
||||
use pulp::*;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u64x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(z1, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u64x8(p, t);
|
||||
(
|
||||
simd.wrapping_add_u64x8(z0, t),
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u64x8(p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(z1, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
let t = simd.small_mod_u64x8(p, t);
|
||||
(
|
||||
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
|
||||
simd.small_mod_u64x8(
|
||||
p,
|
||||
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u64x4(p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(z1, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
let t = simd.small_mod_u64x4(p, t);
|
||||
(
|
||||
simd.wrapping_add_u64x4(z0, t),
|
||||
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), p),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = two_p;
|
||||
let z0 = simd.small_mod_u64x4(p, z0);
|
||||
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
|
||||
let t = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(z1, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
let t = simd.small_mod_u64x4(p, t);
|
||||
(
|
||||
simd.small_mod_u64x4(p, simd.wrapping_add_u64x4(z0, t)),
|
||||
simd.small_mod_u64x4(
|
||||
p,
|
||||
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), p),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = two_p;
|
||||
let z0 = z0.min(z0.wrapping_sub(p));
|
||||
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let t = t.min(t.wrapping_sub(p));
|
||||
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn fwd_last_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = two_p;
|
||||
let z0 = z0.min(z0.wrapping_sub(p));
|
||||
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let t = t.min(t.wrapping_sub(p));
|
||||
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
|
||||
(
|
||||
res.0.min(res.0.wrapping_sub(p)),
|
||||
res.1.min(res.1.wrapping_sub(p)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx512(
|
||||
simd: crate::V4,
|
||||
z0: u64x8,
|
||||
z1: u64x8,
|
||||
w: u64x8,
|
||||
w_shoup: u64x8,
|
||||
p: u64x8,
|
||||
neg_p: u64x8,
|
||||
two_p: u64x8,
|
||||
) -> (u64x8, u64x8) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x8(z0, z1);
|
||||
let y0 = simd.small_mod_u64x8(p, y0);
|
||||
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u64x8(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u64x8(
|
||||
simd.wrapping_mul_u64x8(t, w),
|
||||
simd.wrapping_mul_u64x8(shoup_q, neg_p),
|
||||
);
|
||||
let y1 = simd.small_mod_u64x8(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_avx2(
|
||||
simd: crate::V3,
|
||||
z0: u64x4,
|
||||
z1: u64x4,
|
||||
w: u64x4,
|
||||
w_shoup: u64x4,
|
||||
p: u64x4,
|
||||
neg_p: u64x4,
|
||||
two_p: u64x4,
|
||||
) -> (u64x4, u64x4) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = simd.wrapping_add_u64x4(z0, z1);
|
||||
let y0 = simd.small_mod_u64x4(p, y0);
|
||||
let t = simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, z1), p);
|
||||
|
||||
let shoup_q = simd.widening_mul_u64x4(t, w_shoup).1;
|
||||
let y1 = simd.wrapping_add_u64x4(
|
||||
simd.widening_mul_u64x4(t, w).0,
|
||||
simd.widening_mul_u64x4(shoup_q, neg_p).0,
|
||||
);
|
||||
let y1 = simd.small_mod_u64x4(p, y1);
|
||||
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn inv_butterfly_scalar(
|
||||
z0: u64,
|
||||
z1: u64,
|
||||
w: u64,
|
||||
w_shoup: u64,
|
||||
p: u64,
|
||||
neg_p: u64,
|
||||
two_p: u64,
|
||||
) -> (u64, u64) {
|
||||
let _ = two_p;
|
||||
|
||||
let y0 = z0.wrapping_add(z1);
|
||||
let y0 = y0.min(y0.wrapping_sub(p));
|
||||
let t = z0.wrapping_sub(z1).wrapping_add(p);
|
||||
let shoup_q = ((t as u128 * w_shoup as u128) >> 64) as u64;
|
||||
let y1 = u64::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
|
||||
let y1 = y1.min(y1.wrapping_sub(p));
|
||||
(y0, y1)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn fwd_avx512(
|
||||
simd: crate::V4,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
pub(crate) fn inv_avx512(
|
||||
simd: crate::V4,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx512(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn fwd_avx2(
|
||||
simd: crate::V3,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::fwd_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
pub(crate) fn inv_avx2(
|
||||
simd: crate::V3,
|
||||
p: u64,
|
||||
data: &mut [u64],
|
||||
twid: &[u64],
|
||||
twid_shoup: &[u64],
|
||||
) {
|
||||
super::shoup::inv_depth_first_avx2(
|
||||
simd,
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn fwd_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
|
||||
super::shoup::fwd_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn inv_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
|
||||
super::shoup::inv_depth_first_scalar(
|
||||
p,
|
||||
data,
|
||||
twid,
|
||||
twid_shoup,
|
||||
0,
|
||||
0,
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
#[inline(always)]
|
||||
|z0, z1, w, w_shoup, p, neg_p, two_p| {
|
||||
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
prime::largest_prime_in_arithmetic_progression64,
|
||||
prime64::{
|
||||
init_negacyclic_twiddles_shoup,
|
||||
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
|
||||
},
|
||||
};
|
||||
use alloc::vec;
|
||||
|
||||
extern crate alloc;
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[cfg(feature = "nightly")]
|
||||
#[test]
|
||||
fn test_product_avx512() {
|
||||
if let Some(simd) = crate::V4::try_new() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63)
|
||||
.unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
64,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64),);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
#[test]
|
||||
fn test_product_avx2() {
|
||||
if let Some(simd) = crate::V3::try_new() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63)
|
||||
.unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
64,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64),);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_product_scalar() {
|
||||
for n in [16, 32, 64, 128, 256, 512, 1024] {
|
||||
let p =
|
||||
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap();
|
||||
|
||||
let (lhs, rhs, negacyclic_convolution) =
|
||||
random_lhs_rhs_with_negacyclic_convolution(n, p);
|
||||
|
||||
let mut twid = vec![0u64; n];
|
||||
let mut twid_shoup = vec![0u64; n];
|
||||
let mut inv_twid = vec![0u64; n];
|
||||
let mut inv_twid_shoup = vec![0u64; n];
|
||||
init_negacyclic_twiddles_shoup(
|
||||
p,
|
||||
n,
|
||||
64,
|
||||
&mut twid,
|
||||
&mut twid_shoup,
|
||||
&mut inv_twid,
|
||||
&mut inv_twid_shoup,
|
||||
);
|
||||
|
||||
let mut prod = vec![0u64; n];
|
||||
let mut lhs_fourier = lhs.clone();
|
||||
let mut rhs_fourier = rhs.clone();
|
||||
|
||||
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
|
||||
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
|
||||
for x in &lhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
for x in &rhs_fourier {
|
||||
assert!(*x < p);
|
||||
}
|
||||
|
||||
for i in 0..n {
|
||||
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
|
||||
}
|
||||
|
||||
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
|
||||
let result = prod;
|
||||
|
||||
for i in 0..n {
|
||||
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64),);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1468
tfhe-ntt/src/prime64/shoup.rs
Normal file
1468
tfhe-ntt/src/prime64/shoup.rs
Normal file
File diff suppressed because it is too large
Load Diff
1168
tfhe-ntt/src/product.rs
Normal file
1168
tfhe-ntt/src/product.rs
Normal file
File diff suppressed because it is too large
Load Diff
132
tfhe-ntt/src/roots.rs
Normal file
132
tfhe-ntt/src/roots.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
use crate::{
|
||||
fastdiv::Div64,
|
||||
prime::{exp_mod64, mul_mod64},
|
||||
};
|
||||
|
||||
pub const fn get_q_s64(p: Div64) -> (u64, u64) {
|
||||
let p = p.divisor();
|
||||
let mut q = p - 1;
|
||||
let mut s = 0;
|
||||
while q % 2 == 0 {
|
||||
q /= 2;
|
||||
s += 1;
|
||||
}
|
||||
(q, s)
|
||||
}
|
||||
|
||||
pub const fn get_z64(p: Div64) -> Option<u64> {
|
||||
let p_val = p.divisor();
|
||||
|
||||
let mut n = 2;
|
||||
while n < p_val {
|
||||
if exp_mod64(p, n, (p_val - 1) / 2) == p_val - 1 {
|
||||
return Some(n);
|
||||
}
|
||||
n += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// <https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm#The_algorithm>
|
||||
pub const fn sqrt_mod_ex64(p: Div64, q: u64, s: u64, z: u64, n: u64) -> Option<u64> {
|
||||
let mut m = s;
|
||||
let mut c = exp_mod64(p, z, q);
|
||||
let mut t = exp_mod64(p, n, q);
|
||||
let mut r = exp_mod64(p, n, (q + 1) / 2);
|
||||
|
||||
loop {
|
||||
if t == 0 {
|
||||
return Some(0);
|
||||
}
|
||||
if t == 1 {
|
||||
return Some(r);
|
||||
}
|
||||
|
||||
let mut i = 0;
|
||||
let mut t_pow = t;
|
||||
while i < m {
|
||||
t_pow = mul_mod64(p, t_pow, t_pow);
|
||||
i += 1;
|
||||
if t_pow == 1 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let i = i;
|
||||
if i == m {
|
||||
assert!(t_pow == 1);
|
||||
return None;
|
||||
}
|
||||
|
||||
let b = exp_mod64(p, c, 1 << (m - i - 1));
|
||||
m = i;
|
||||
c = mul_mod64(p, b, b);
|
||||
t = mul_mod64(p, t, c);
|
||||
r = mul_mod64(p, r, b);
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn find_primitive_root64(p: Div64, degree: u64) -> Option<u64> {
|
||||
assert!(degree.is_power_of_two());
|
||||
assert!(degree > 1);
|
||||
let n = degree.trailing_zeros();
|
||||
|
||||
let p_val = p.divisor();
|
||||
let mut root = p_val - 1;
|
||||
let (q, s) = get_q_s64(p);
|
||||
let z = match get_z64(p) {
|
||||
Some(z) => z,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let mut i = 0;
|
||||
while i < n - 1 {
|
||||
root = match sqrt_mod_ex64(p, q, s, z, root) {
|
||||
Some(r) => r,
|
||||
None => return None,
|
||||
};
|
||||
i += 1;
|
||||
}
|
||||
|
||||
Some(root)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{fastdiv::Div64, prime::largest_prime_in_arithmetic_progression64};
|
||||
|
||||
const fn sqrt_mod64(p: Div64, n: u64) -> Option<u64> {
|
||||
if p.divisor() == 2 {
|
||||
Some(n)
|
||||
} else {
|
||||
let z = match get_z64(p) {
|
||||
Some(z) => z,
|
||||
None => panic!(),
|
||||
};
|
||||
let (q, s) = get_q_s64(p);
|
||||
sqrt_mod_ex64(p, q, s, z, n)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
let p_val = largest_prime_in_arithmetic_progression64(1 << 10, 1, 0, u64::MAX).unwrap();
|
||||
let p = Div64::new(p_val);
|
||||
let i = sqrt_mod64(p, p_val - 1).unwrap();
|
||||
let j = sqrt_mod64(p, i).unwrap();
|
||||
assert_eq!(mul_mod64(p, i, i), p_val - 1);
|
||||
assert_eq!(mul_mod64(p, j, j), i);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_primitive_root() {
|
||||
let deg = 1 << 10;
|
||||
let p_val = largest_prime_in_arithmetic_progression64(deg, 1, 0, u64::MAX).unwrap();
|
||||
let p = Div64::new(p_val);
|
||||
let root = find_primitive_root64(p, deg).unwrap();
|
||||
for i in 1..deg {
|
||||
assert_ne!(exp_mod64(p, root, i), 1);
|
||||
}
|
||||
assert_eq!(exp_mod64(p, root, deg), 1);
|
||||
}
|
||||
}
|
||||
137
tfhe-ntt/src/u256_impl.rs
Normal file
137
tfhe-ntt/src/u256_impl.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
#[allow(non_camel_case_types)]
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct u256 {
|
||||
pub x0: u64,
|
||||
pub x1: u64,
|
||||
pub x2: u64,
|
||||
pub x3: u64,
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn to_double_digit(lo: u64, hi: u64) -> u128 {
|
||||
(lo as u128) | ((hi as u128) << u64::BITS)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn adc(l: u64, r: u64, c: bool) -> (u64, bool) {
|
||||
let (lr, o0) = l.overflowing_add(r);
|
||||
let (lrc, o1) = lr.overflowing_add(c as u64);
|
||||
(lrc, o0 | o1)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn mul_with_carry(l: u64, r: u64, c: u64) -> (u64, u64) {
|
||||
let res = (l as u128 * r as u128) + c as u128;
|
||||
(res as u64, (res >> 64) as u64)
|
||||
}
|
||||
|
||||
impl u256 {
|
||||
pub const MAX: Self = Self {
|
||||
x0: u64::MAX,
|
||||
x1: u64::MAX,
|
||||
x2: u64::MAX,
|
||||
x3: u64::MAX,
|
||||
};
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn overflowing_add(self, rhs: Self) -> (Self, bool) {
|
||||
let lhs = self;
|
||||
|
||||
let mut carry = false;
|
||||
let x0;
|
||||
let x1;
|
||||
let x2;
|
||||
let x3;
|
||||
|
||||
(x0, carry) = adc(lhs.x0, rhs.x0, carry);
|
||||
(x1, carry) = adc(lhs.x1, rhs.x1, carry);
|
||||
(x2, carry) = adc(lhs.x2, rhs.x2, carry);
|
||||
(x3, carry) = adc(lhs.x3, rhs.x3, carry);
|
||||
|
||||
(Self { x0, x1, x2, x3 }, carry)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn wrapping_add(self, rhs: Self) -> Self {
|
||||
self.overflowing_add(rhs).0
|
||||
}
|
||||
|
||||
pub const fn div_rem_u256_u64(self, rhs: u64) -> (Self, u64) {
|
||||
let lhs = self;
|
||||
|
||||
let mut rem = 0;
|
||||
let rhs = rhs as u128;
|
||||
|
||||
let double = to_double_digit(lhs.x3, rem);
|
||||
let q = double / rhs;
|
||||
let r = double % rhs;
|
||||
rem = r as u64;
|
||||
let x3 = q as u64;
|
||||
|
||||
let double = to_double_digit(lhs.x2, rem);
|
||||
let q = double / rhs;
|
||||
let r = double % rhs;
|
||||
rem = r as u64;
|
||||
let x2 = q as u64;
|
||||
|
||||
let double = to_double_digit(lhs.x1, rem);
|
||||
let q = double / rhs;
|
||||
let r = double % rhs;
|
||||
rem = r as u64;
|
||||
let x1 = q as u64;
|
||||
|
||||
let double = to_double_digit(lhs.x0, rem);
|
||||
let q = double / rhs;
|
||||
let r = double % rhs;
|
||||
rem = r as u64;
|
||||
let x0 = q as u64;
|
||||
|
||||
(Self { x0, x1, x2, x3 }, rem)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn mul_u256_u64(self, rhs: u64) -> (Self, u64) {
|
||||
let mut carry = 0;
|
||||
let (x0, x1, x2, x3);
|
||||
|
||||
(x0, carry) = mul_with_carry(self.x0, rhs, carry);
|
||||
(x1, carry) = mul_with_carry(self.x1, rhs, carry);
|
||||
(x2, carry) = mul_with_carry(self.x2, rhs, carry);
|
||||
(x3, carry) = mul_with_carry(self.x3, rhs, carry);
|
||||
|
||||
(Self { x0, x1, x2, x3 }, carry)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn mul_u256_u128(self, rhs: u128) -> (Self, u128) {
|
||||
let (x, x4) = Self::mul_u256_u64(self, rhs as u64);
|
||||
let (y, y5) = Self::mul_u256_u64(self, (rhs >> 64) as u64);
|
||||
let y4 = y.x3;
|
||||
let y = u256 {
|
||||
x0: 0,
|
||||
x1: y.x0,
|
||||
x2: y.x1,
|
||||
x3: y.x2,
|
||||
};
|
||||
|
||||
let (r, carry) = x.overflowing_add(y);
|
||||
let (r4, carry) = adc(x4, y4, carry);
|
||||
let r5 = y5 + carry as u64;
|
||||
|
||||
(r, to_double_digit(r4, r5))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub const fn wrapping_mul_u256_u128(self, rhs: u128) -> Self {
|
||||
let (x, _) = Self::mul_u256_u64(self, rhs as u64);
|
||||
let (y, _) = Self::mul_u256_u64(self, (rhs >> 64) as u64);
|
||||
let y = u256 {
|
||||
x0: 0,
|
||||
x1: y.x0,
|
||||
x2: y.x1,
|
||||
x3: y.x2,
|
||||
};
|
||||
x.wrapping_add(y)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user