chore(ntt): bring concrete-ntt in the repo as tfhe-ntt

This commit is contained in:
Arthur Meyre
2024-11-15 16:17:24 +01:00
parent fcc0378c98
commit 36deaec607
47 changed files with 20093 additions and 5 deletions

1
tfhe-ntt/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
benchmarks_parameters/

36
tfhe-ntt/Cargo.toml Normal file
View File

@@ -0,0 +1,36 @@
[package]
name = "tfhe-ntt"
version = "0.3.0"
edition = "2021"
description = "tfhe-ntt is a pure Rust high performance number theoretic transform library."
readme = "README.md"
repository = "https://github.com/zama-ai/tfhe-rs"
license = "BSD-3-Clause-Clear"
homepage = "https://zama.ai/"
keywords = ["ntt"]
rust-version = "1.67"
[dependencies]
aligned-vec = { workspace = true }
bytemuck = { workspace = true }
pulp = { workspace = true }
[features]
default = ["std"]
std = ["pulp/std", "aligned-vec/std"]
nightly = ["pulp/nightly"]
[dev-dependencies]
criterion = "0.4"
rand = "0.8"
serde = "1.0.163"
serde_json = "1.0.96"
[[bench]]
name = "ntt"
harness = false
[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--html-in-header", "katex-header.html", "--cfg", "docsrs"]

33
tfhe-ntt/LICENSE Normal file
View File

@@ -0,0 +1,33 @@
BSD 3-Clause Clear License
Copyright © 2023 ZAMA.
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice, this
list of conditions and the following disclaimer in the documentation and/or other
materials provided with the distribution.
3. Neither the name of ZAMA nor the names of its contributors may be used to endorse
or promote products derived from this software without specific prior written permission.
NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*.
THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
ZAMA OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive,
free and non-commercial license on all patents filed in its name relating to the open-source
code (the "Patents") for the sole purpose of evaluation, development, research, prototyping
and experimentation.

64
tfhe-ntt/README.md Normal file
View File

@@ -0,0 +1,64 @@
tfhe-ntt is a pure Rust high performance Number Theoretic Transform library that processes
vectors of sizes that are powers of two.
This library provides three kinds of NTT:
- The prime NTT computes the transform in a field $\mathbb{Z}/p \mathbb{Z}$
with $p$ prime, allowing for arithmetic operations on the polynomial modulo $p$.
- The native NTT internally computes the transform of the first kind with
several primes, allowing the simulation of arithmetic modulo the product of
those primes, and truncates the
result when the inverse transform is desired. The truncated result is guaranteed to be as if
the computations were performed with wrapping arithmetic, as long as the full integer result
would have be smaller than half the product of the primes, in absolute value. It is guaranteed
to be suitable for multiplying two polynomials with arbitrary coefficients, and returns the
result in wrapping arithmetic.
- The native binary NTT is similar to the native NTT, but is optimized for the case where one
of the operands of the multiplication has coefficients in $\lbrace 0, 1 \rbrace$.
# Rust requirements
tfhe-ntt requires a Rust version >= 1.67.0.
# Features
- `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
- `nightly`: This enables unstable Rust features to further speed up the NTT, by enabling
AVX512 instructions on CPUs that support them. This feature requires a nightly Rust
toolchain.
# Example
```rust
use tfhe_ntt::prime32::Plan;
const N: usize = 32;
let p = 1062862849;
let plan = Plan::try_new(N, p).unwrap();
let data = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31,
];
let mut transformed_fwd = data;
plan.fwd(&mut transformed_fwd);
let mut transformed_inv = transformed_fwd;
plan.inv(&mut transformed_inv);
for (&actual, expected) in transformed_inv.iter().zip(data.iter().map(|x| x * N as u32)) {
assert_eq!(expected, actual);
}
```
More examples can be found in the `examples` directory.
- `mul_poly_prime.rs`: Negacyclic polynomial multiplication with a prime modulus.
Run the example with `cargo run --example mul_poly_prime`.
- `mul_poly_native.rs`: Negacyclic polynomial multiplication with a native modulus (`2^32`, `2^64`, or `2^128`).
Run the example with `cargo run --example mul_poly_native`.
# Benchmarks
Benchmarks can be executed with `cargo bench`. If a nightly toolchain is
available, then AVX512 acceleration can be enabled by passing the
`--features=nightly` flag.

3
tfhe-ntt/benches/lib.rs Normal file
View File

@@ -0,0 +1,3 @@
#![allow(dead_code)]
mod ntt;

238
tfhe-ntt/benches/ntt.rs Normal file
View File

@@ -0,0 +1,238 @@
use serde::Serialize;
use std::{fs, path::PathBuf};
use criterion::*;
use tfhe_ntt::{prime::largest_prime_in_arithmetic_progression64, *};
#[derive(Serialize)]
enum PrimeModulus {
// 32 bits section
FitsIn30Bits,
FitsIn31Bits,
FitsIn32Bits,
Native32,
// 64 bits section
FitsIn50Bits,
FitsIn51Bits,
FitsIn52Bits,
FitsIn62Bits,
FitsIn63Bits,
FitsIn64Bits,
Native64,
// 128 bits section
Native128,
}
impl PrimeModulus {
fn from_u64(p: u64) -> Self {
if p < 1 << 30 {
Self::FitsIn30Bits
} else if p < 1 << 31 {
Self::FitsIn31Bits
} else if p < 1 << 32 {
Self::FitsIn32Bits
} else if p < 1 << 50 {
Self::FitsIn50Bits
} else if p < 1 << 51 {
Self::FitsIn51Bits
} else if p < 1 << 52 {
Self::FitsIn52Bits
} else if p < 1 << 62 {
Self::FitsIn62Bits
} else if p < 1 << 63 {
Self::FitsIn63Bits
} else {
Self::FitsIn64Bits
}
}
}
#[derive(Serialize)]
struct BenchmarkParametersRecord {
display_name: String,
polynomial_size: usize,
prime_modulus: PrimeModulus,
// If this field value is set to 0 means that the number is not a prime.
prime_number: u64,
}
/// Writes benchmarks parameters to disk in JSON format.
fn write_to_json(
bench_id: &str,
display_name: impl Into<String>,
polynomial_size: usize,
prime_modulus: PrimeModulus,
prime_number: u64,
) {
let record = BenchmarkParametersRecord {
display_name: display_name.into(),
polynomial_size,
prime_modulus,
prime_number,
};
let mut params_directory = ["benchmarks_parameters", bench_id]
.iter()
.collect::<PathBuf>();
fs::create_dir_all(&params_directory).unwrap();
params_directory.push("parameters.json");
fs::write(params_directory, serde_json::to_string(&record).unwrap()).unwrap();
}
fn criterion_bench(c: &mut Criterion) {
let ns = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768];
for n in ns {
let mut data = vec![0; n];
for p in [
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30).unwrap(),
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap(),
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 31, 1 << 32).unwrap(),
] {
let p_u64 = p;
let p = p as u32;
let plan = prime32::Plan::try_new(n, p).unwrap();
let bench_id = format!("fwd-32-{p}-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.fwd(&mut data));
});
write_to_json(&bench_id, "fwd-32", n, PrimeModulus::from_u64(p_u64), p_u64);
let bench_id = format!("inv-32-{p}-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.inv(&mut data));
});
write_to_json(&bench_id, "inv-32", n, PrimeModulus::from_u64(p_u64), p_u64);
}
}
for n in ns {
let mut data = vec![0; n];
for p in [
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 49, 1 << 50).unwrap(),
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap(),
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62).unwrap(),
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap(),
prime64::Solinas::P,
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 63, u64::MAX).unwrap(),
] {
let plan = prime64::Plan::try_new(n, p).unwrap();
let bench_id = format!("fwd-64-{p}-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.fwd(&mut data));
});
write_to_json(&bench_id, "fwd-64", n, PrimeModulus::from_u64(p), p);
let bench_id = format!("inv-64-{p}-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.inv(&mut data));
});
write_to_json(&bench_id, "inv-64", n, PrimeModulus::from_u64(p), p);
}
}
for n in ns {
let mut prod = vec![0; n];
let lhs = vec![0; n];
let rhs = vec![0; n];
let plan = native32::Plan32::try_new(n).unwrap();
let bench_id = format!("native32-32-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "native32-32", n, PrimeModulus::Native32, 0);
let plan = native_binary32::Plan32::try_new(n).unwrap();
let bench_id = format!("nativebinary32-32-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "nativebinary32-32", n, PrimeModulus::Native32, 0);
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
{
if let Some(plan) = native32::Plan52::try_new(n) {
let bench_id = format!("native32-52-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "native32-52", n, PrimeModulus::Native32, 0);
}
if let Some(plan) = native_binary32::Plan52::try_new(n) {
let bench_id = format!("nativebinary32-52-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "nativebinary32-52", n, PrimeModulus::Native32, 0);
}
}
}
for n in ns {
let mut prod = vec![0; n];
let lhs = vec![0; n];
let rhs = vec![0; n];
let plan = native64::Plan32::try_new(n).unwrap();
let bench_id = format!("native64-32-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "native64-32", n, PrimeModulus::Native64, 0);
let plan = native_binary64::Plan32::try_new(n).unwrap();
let bench_id = format!("nativebinary64-32-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "nativebinary64-32", n, PrimeModulus::Native64, 0);
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
{
if let Some(plan) = native64::Plan52::try_new(n) {
let bench_id = format!("native64-52-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "native64-52", n, PrimeModulus::Native64, 0);
}
if let Some(plan) = native_binary64::Plan52::try_new(n) {
let bench_id = format!("nativebinary64-52-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "nativebinary64-52", n, PrimeModulus::Native64, 0);
}
}
}
for n in ns {
let mut prod = vec![0; n];
let lhs = vec![0; n];
let rhs = vec![0; n];
let plan = native128::Plan32::try_new(n).unwrap();
let bench_id = format!("native128-32-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(&bench_id, "native128-32", n, PrimeModulus::Native128, 0);
let plan = native_binary128::Plan32::try_new(n).unwrap();
let bench_id = format!("nativebinary128-32-{n}");
c.bench_function(&bench_id, |b| {
b.iter(|| plan.negacyclic_polymul(&mut prod, &lhs, &rhs));
});
write_to_json(
&bench_id,
"nativebinary128-32",
n,
PrimeModulus::Native128,
0,
);
}
}
criterion_group!(benches, criterion_bench);
criterion_main!(benches);

View File

@@ -0,0 +1,38 @@
use rand::random;
use tfhe_ntt::native32::Plan32;
fn main() {
// define suitable polynomial size. Power of two polynomial sizes up to `2^16` are supported.
let polynomial_size = 1024;
let lhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>()).collect();
let rhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>()).collect();
// method 1: schoolbook algorithm
let add = |x: u32, y: u32| x.wrapping_add(y);
let sub = |x: u32, y: u32| x.wrapping_sub(y);
let mul = |x: u32, y: u32| x.wrapping_mul(y);
let mut full_convolution = vec![0; 2 * polynomial_size];
for i in 0..polynomial_size {
for j in 0..polynomial_size {
full_convolution[i + j] = add(full_convolution[i + j], mul(lhs_poly[i], rhs_poly[j]));
}
}
let mut negacyclic_convolution = vec![0; polynomial_size];
for i in 0..polynomial_size {
negacyclic_convolution[i] = sub(full_convolution[i], full_convolution[polynomial_size + i]);
}
// method 2: NTT
let plan = Plan32::try_new(polynomial_size).unwrap();
let mut product_poly = vec![0; polynomial_size];
// convert to NTT domain
plan.negacyclic_polymul(&mut product_poly, &lhs_poly, &rhs_poly);
// check that method 1 and method 2 give the same result
assert_eq!(product_poly, negacyclic_convolution);
println!("Success!");
}

View File

@@ -0,0 +1,49 @@
use rand::random;
use tfhe_ntt::prime32::Plan;
fn main() {
// define suitable NTT prime and polynomial size
let p: u32 = 1073479681;
let polynomial_size = 1024;
// unwrapping is fine here because we know roots of unity exist for the combination
// `(polynomial_size, p)`
let lhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>() % p).collect();
let rhs_poly: Vec<u32> = (0..polynomial_size).map(|_| random::<u32>() % p).collect();
// method 1: schoolbook algorithm
let add = |x: u32, y: u32| ((x as u64 + y as u64) % p as u64) as u32;
let sub = |x: u32, y: u32| add(x, p - y);
let mul = |x: u32, y: u32| ((x as u64 * y as u64) % p as u64) as u32;
let mut full_convolution = vec![0; 2 * polynomial_size];
for i in 0..polynomial_size {
for j in 0..polynomial_size {
full_convolution[i + j] = add(full_convolution[i + j], mul(lhs_poly[i], rhs_poly[j]));
}
}
let mut negacyclic_convolution = vec![0; polynomial_size];
for i in 0..polynomial_size {
negacyclic_convolution[i] = sub(full_convolution[i], full_convolution[polynomial_size + i]);
}
// method 2: NTT
let plan = Plan::try_new(polynomial_size, p).unwrap();
let mut lhs_ntt = lhs_poly;
let mut rhs_ntt = rhs_poly;
// convert to NTT domain
plan.fwd(&mut lhs_ntt);
plan.fwd(&mut rhs_ntt);
// perform elementwise multiplication and normalize (result is stored in `lhs_ntt`)
plan.mul_assign_normalize(&mut lhs_ntt, &rhs_ntt);
// convert back to standard domain
plan.inv(&mut lhs_ntt);
// check that method 1 and method 2 give the same result
assert_eq!(lhs_ntt, negacyclic_convolution);
println!("Success!");
}

View File

@@ -0,0 +1,15 @@
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.15.3/dist/katex.min.css" integrity="sha384-KiWOvVjnN8qwAZbuQyWDIbfCLFhLXNETzBQjA/92pIowpC0d2O3nppDGQVgwd2nB" crossorigin="anonymous">
<script src="https://cdn.jsdelivr.net/npm/katex@0.15.3/dist/katex.min.js" integrity="sha384-0fdwu/T/EQMsQlrHCCHoH10pkPLlKA1jL5dFyUOvB3lfeT2540/2g6YgSi2BL14p" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/katex@0.15.3/dist/contrib/auto-render.min.js" integrity="sha384-+XBljXPPiv+OzfbB3cVmLHf4hdUFHlWNZN5spNQ7rmHTXpd7WvJum6fIACpNNfIR" crossorigin="anonymous"></script>
<script>
document.addEventListener("DOMContentLoaded", function() {
renderMathInElement(document.body, {
delimiters: [
{left: "$$", right: "$$", display: true},
{left: "\\(", right: "\\)", display: false},
{left: "$", right: "$", display: false},
{left: "\\[", right: "\\]", display: true}
]
});
});
</script>

5
tfhe-ntt/rustfmt.toml Normal file
View File

@@ -0,0 +1,5 @@
unstable_features = true
imports_granularity="Crate"
format_code_in_doc_comments = true
wrap_comments = true
comment_width = 100

196
tfhe-ntt/src/fastdiv.rs Normal file
View File

@@ -0,0 +1,196 @@
use crate::u256;
#[inline(always)]
pub(crate) const fn mul128_u32(lowbits: u64, d: u32) -> u32 {
((lowbits as u128 * d as u128) >> 64) as u32
}
#[inline(always)]
pub(crate) const fn mul128_u64(lowbits: u128, d: u64) -> u64 {
let mut bottom_half = (lowbits & 0xFFFF_FFFF_FFFF_FFFF) * d as u128;
bottom_half >>= 64;
let top_half = (lowbits >> 64) * d as u128;
let both_halves = bottom_half + top_half;
(both_halves >> 64) as u64
}
#[inline(always)]
pub(crate) const fn mul256_u128(lowbits: u256, d: u128) -> u128 {
lowbits.mul_u256_u128(d).1
}
#[inline(always)]
pub(crate) const fn mul256_u64(lowbits: u256, d: u64) -> u64 {
lowbits.mul_u256_u64(d).1
}
/// Divisor representing a 32bit denominator.
#[derive(Copy, Clone, Debug)]
pub struct Div32 {
pub double_reciprocal: u128,
pub single_reciprocal: u64,
pub divisor: u32,
}
/// Divisor representing a 64bit denominator.
#[derive(Copy, Clone, Debug)]
pub struct Div64 {
pub double_reciprocal: u256,
pub single_reciprocal: u128,
pub divisor: u64,
}
impl Div32 {
/// Returns the division structure holding the given divisor.
///
/// # Panics
/// Panics if the divisor is zero or one.
pub const fn new(divisor: u32) -> Self {
assert!(divisor > 1);
let single_reciprocal = (u64::MAX / divisor as u64) + 1;
let double_reciprocal = (u128::MAX / divisor as u128) + 1;
Self {
double_reciprocal,
single_reciprocal,
divisor,
}
}
/// Returns the quotient of the division of `n` by `d`.
#[inline(always)]
pub const fn div(n: u32, d: Self) -> u32 {
mul128_u32(d.single_reciprocal, n)
}
/// Returns the remainder of the division of `n` by `d`.
#[inline(always)]
pub const fn rem(n: u32, d: Self) -> u32 {
let low_bits = d.single_reciprocal.wrapping_mul(n as u64);
mul128_u32(low_bits, d.divisor)
}
/// Returns the quotient of the division of `n` by `d`.
#[inline(always)]
pub const fn div_u64(n: u64, d: Self) -> u64 {
mul128_u64(d.double_reciprocal, n)
}
/// Returns the remainder of the division of `n` by `d`.
#[inline(always)]
pub const fn rem_u64(n: u64, d: Self) -> u32 {
let low_bits = d.double_reciprocal.wrapping_mul(n as u128);
mul128_u64(low_bits, d.divisor as u64) as u32
}
/// Returns the internal divisor as an integer.
#[inline(always)]
pub const fn divisor(&self) -> u32 {
self.divisor
}
}
impl Div64 {
/// Returns the division structure holding the given divisor.
///
/// # Panics
/// Panics if the divisor is zero or one.
pub const fn new(divisor: u64) -> Self {
assert!(divisor > 1);
let single_reciprocal = ((u128::MAX) / divisor as u128) + 1;
let double_reciprocal = u256::MAX
.div_rem_u256_u64(divisor)
.0
.overflowing_add(u256 {
x0: 1,
x1: 0,
x2: 0,
x3: 0,
})
.0;
Self {
double_reciprocal,
single_reciprocal,
divisor,
}
}
/// Returns the quotient of the division of `n` by `d`.
#[inline(always)]
pub const fn div(n: u64, d: Self) -> u64 {
mul128_u64(d.single_reciprocal, n)
}
/// Returns the remainder of the division of `n` by `d`.
#[inline(always)]
pub const fn rem(n: u64, d: Self) -> u64 {
let low_bits = d.single_reciprocal.wrapping_mul(n as u128);
mul128_u64(low_bits, d.divisor)
}
/// Returns the quotient of the division of `n` by `d`.
#[inline(always)]
pub const fn div_u128(n: u128, d: Self) -> u128 {
mul256_u128(d.double_reciprocal, n)
}
/// Returns the remainder of the division of `n` by `d`.
#[inline(always)]
pub const fn rem_u128(n: u128, d: Self) -> u64 {
let low_bits = d.double_reciprocal.wrapping_mul_u256_u128(n);
mul256_u64(low_bits, d.divisor)
}
/// Returns the internal divisor as an integer.
#[inline(always)]
pub const fn divisor(&self) -> u64 {
self.divisor
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::random;
#[test]
fn test_div64() {
for _ in 0..1000 {
let divisor = loop {
let d = random();
if d > 1 {
break d;
}
};
let div = Div64::new(divisor);
let n = random();
let m = random();
assert_eq!(Div64::div(m, div), m / divisor);
assert_eq!(Div64::rem(m, div), m % divisor);
assert_eq!(Div64::div_u128(n, div), n / divisor as u128);
assert_eq!(Div64::rem_u128(n, div) as u128, n % divisor as u128);
}
}
#[test]
fn test_div32() {
for _ in 0..1000 {
let divisor = loop {
let d = random();
if d > 1 {
break d;
}
};
let div = Div32::new(divisor);
let n = random();
let m = random();
assert_eq!(Div32::div(m, div), m / divisor);
assert_eq!(Div32::rem(m, div), m % divisor);
assert_eq!(Div32::div_u64(n, div), n / divisor as u64);
assert_eq!(Div32::rem_u64(n, div) as u64, n % divisor as u64);
}
}
}

908
tfhe-ntt/src/lib.rs Normal file
View File

@@ -0,0 +1,908 @@
//! tfhe-ntt is a pure Rust high performance number theoretic transform library that processes
//! vectors of sizes that are powers of two.
//!
//! This library provides three kinds of NTT:
//! - The prime NTT computes the transform in a field $\mathbb{Z}/p\mathbb{Z}$ with $p$ prime,
//! allowing for arithmetic operations on the polynomial modulo $p$.
//! - The native NTT internally computes the transform of the first kind with several primes,
//! allowing the simulation of arithmetic modulo the product of those primes, and truncates the
//! result when the inverse transform is desired. The truncated result is guaranteed to be as if
//! the computations were performed with wrapping arithmetic, as long as the full integer result
//! would have been smaller than half the product of the primes, in absolute value. It is
//! guaranteed to be suitable for multiplying two polynomials with arbitrary coefficients, and
//! returns the result in wrapping arithmetic.
//! - The native binary NTT is similar to the native NTT, but is optimized for the case where one of
//! the operands of the multiplication has coefficients in $\lbrace 0, 1 \rbrace$.
//!
//! # Features
//!
//! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
//! - `nightly`: This enables unstable Rust features to further speed up the NTT, by enabling AVX512
//! instructions on CPUs that support them. This feature requires a nightly Rust toolchain.
//!
//! # Example
//!
//! ```
//! use tfhe_ntt::prime32::Plan;
//!
//! const N: usize = 32;
//! let p = 1062862849;
//! let plan = Plan::try_new(N, p).unwrap();
//!
//! let data = [
//! 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
//! 25, 26, 27, 28, 29, 30, 31,
//! ];
//!
//! let mut transformed_fwd = data;
//! plan.fwd(&mut transformed_fwd);
//!
//! let mut transformed_inv = transformed_fwd;
//! plan.inv(&mut transformed_inv);
//!
//! for (&actual, expected) in transformed_inv
//! .iter()
//! .zip(data.iter().map(|x| x * N as u32))
//! {
//! assert_eq!(expected, actual);
//! }
//! ```
#![cfg_attr(
all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")),
feature(avx512_target_feature, stdarch_x86_avx512)
)]
#![cfg_attr(not(feature = "std"), no_std)]
#![allow(clippy::too_many_arguments, clippy::let_unit_value)]
#![cfg_attr(docsrs, feature(doc_cfg))]
/// Implementation notes:
///
/// we use `NullaryFnOnce` instead of a closure because we need the `#[inline(always)]`
/// annotation, which doesn't always work with closures for some reason.
///
/// Shoup modular multiplication
/// <https://pdfs.semanticscholar.org/e000/fa109f1b2a6a3e52e04462bac4b7d58140c9.pdf>
///
/// Lemire modular reduction
/// <https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/>
///
/// Barrett reduction
/// <https://arxiv.org/pdf/2103.16400.pdf> Algorithm 8
///
/// Chinese remainder theorem solution:
/// The art of computer programming (Donald E. Knuth), section 4.3.2
#[allow(dead_code)]
fn implementation_notes() {}
use u256_impl::u256;
#[allow(unused_imports)]
use pulp::*;
#[doc(hidden)]
pub mod prime;
mod roots;
mod u256_impl;
/// Fast division by a constant divisor.
pub mod fastdiv;
/// 32bit negacyclic NTT for a prime modulus.
pub mod prime32;
/// 64bit negacyclic NTT for a prime modulus.
pub mod prime64;
/// Negacyclic NTT for multiplying two polynomials with values less than `2^128`.
pub mod native128;
/// Negacyclic NTT for multiplying two polynomials with values less than `2^32`.
pub mod native32;
/// Negacyclic NTT for multiplying two polynomials with values less than `2^64`.
pub mod native64;
/// Negacyclic NTT for multiplying a polynomial with values less than `2^128` with a binary
/// polynomial.
pub mod native_binary128;
/// Negacyclic NTT for multiplying a polynomial with values less than `2^32` with a binary
/// polynomial.
pub mod native_binary32;
/// Negacyclic NTT for multiplying a polynomial with values less than `2^64` with a binary
/// polynomial.
pub mod native_binary64;
pub mod product;
// Fn arguments are (simd, z0, z1, w, w_shoup, p, neg_p, two_p)
trait Butterfly<S: Copy, V: Copy>: Copy + Fn(S, V, V, V, V, V, V, V) -> (V, V) {}
impl<F: Copy + Fn(S, V, V, V, V, V, V, V) -> (V, V), S: Copy, V: Copy> Butterfly<S, V> for F {}
#[inline]
fn bit_rev(nbits: u32, i: usize) -> usize {
i.reverse_bits() >> (usize::BITS - nbits)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
struct V3(pulp::x86::V3);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
struct V4(pulp::x86::V4);
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
pulp::simd_type! {
struct V4IFma {
pub sse: "sse",
pub sse2: "sse2",
pub fxsr: "fxsr",
pub sse3: "sse3",
pub ssse3: "ssse3",
pub sse4_1: "sse4.1",
pub sse4_2: "sse4.2",
pub popcnt: "popcnt",
pub avx: "avx",
pub avx2: "avx2",
pub bmi1: "bmi1",
pub bmi2: "bmi2",
pub fma: "fma",
pub lzcnt: "lzcnt",
pub avx512f: "avx512f",
pub avx512bw: "avx512bw",
pub avx512cd: "avx512cd",
pub avx512dq: "avx512dq",
pub avx512vl: "avx512vl",
pub avx512ifma: "avx512ifma",
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl V4 {
#[inline]
pub fn try_new() -> Option<Self> {
pulp::x86::V4::try_new().map(Self)
}
/// Returns separately two vectors containing the low 64 bits of the result,
/// and the high 64 bits of the result.
#[inline(always)]
pub fn widening_mul_u64x8(self, a: u64x8, b: u64x8) -> (u64x8, u64x8) {
// https://stackoverflow.com/a/28827013
let avx = self.avx512f;
let x = cast(a);
let y = cast(b);
let lo_mask = avx._mm512_set1_epi64(0x0000_0000_FFFF_FFFFu64 as _);
let x_hi = avx._mm512_shuffle_epi32::<0b1011_0001>(x);
let y_hi = avx._mm512_shuffle_epi32::<0b1011_0001>(y);
let z_lo_lo = avx._mm512_mul_epu32(x, y);
let z_lo_hi = avx._mm512_mul_epu32(x, y_hi);
let z_hi_lo = avx._mm512_mul_epu32(x_hi, y);
let z_hi_hi = avx._mm512_mul_epu32(x_hi, y_hi);
let z_lo_lo_shift = avx._mm512_srli_epi64::<32>(z_lo_lo);
let sum_tmp = avx._mm512_add_epi64(z_lo_hi, z_lo_lo_shift);
let sum_lo = avx._mm512_and_si512(sum_tmp, lo_mask);
let sum_mid = avx._mm512_srli_epi64::<32>(sum_tmp);
let sum_mid2 = avx._mm512_add_epi64(z_hi_lo, sum_lo);
let sum_mid2_hi = avx._mm512_srli_epi64::<32>(sum_mid2);
let sum_hi = avx._mm512_add_epi64(z_hi_hi, sum_mid);
let prod_hi = avx._mm512_add_epi64(sum_hi, sum_mid2_hi);
let prod_lo = avx._mm512_add_epi64(
avx._mm512_slli_epi64::<32>(avx._mm512_add_epi64(z_lo_hi, z_hi_lo)),
z_lo_lo,
);
(cast(prod_lo), cast(prod_hi))
}
/// Multiplies the low 32 bits of each 64 bit integer and returns the 64 bit result.
#[inline(always)]
pub fn mul_low_32_bits_u64x8(self, a: u64x8, b: u64x8) -> u64x8 {
pulp::cast(self.avx512f._mm512_mul_epu32(pulp::cast(a), pulp::cast(b)))
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl V4IFma {
/// Returns separately two vectors containing the low 52 bits of the result,
/// and the high 52 bits of the result.
#[inline(always)]
pub fn widening_mul_u52x8(self, a: u64x8, b: u64x8) -> (u64x8, u64x8) {
let a = cast(a);
let b = cast(b);
let zero = cast(self.splat_u64x8(0));
(
cast(self.avx512ifma._mm512_madd52lo_epu64(zero, a, b)),
cast(self.avx512ifma._mm512_madd52hi_epu64(zero, a, b)),
)
}
/// (a * b + c) mod 2^52 for each 52 bit integer in a, b, and c.
#[inline(always)]
pub fn wrapping_mul_add_u52x8(self, a: u64x8, b: u64x8, c: u64x8) -> u64x8 {
self.and_u64x8(
cast(
self.avx512ifma
._mm512_madd52lo_epu64(cast(c), cast(a), cast(b)),
),
self.splat_u64x8((1u64 << 52) - 1),
)
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
trait SupersetOfV4: Copy {
fn get_v4(self) -> V4;
fn vectorize(self, f: impl pulp::NullaryFnOnce);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl SupersetOfV4 for V4 {
#[inline(always)]
fn get_v4(self) -> V4 {
self
}
#[inline(always)]
fn vectorize(self, f: impl pulp::NullaryFnOnce) {
self.0.vectorize(f);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl SupersetOfV4 for V4IFma {
#[inline(always)]
fn get_v4(self) -> V4 {
*self
}
#[inline(always)]
fn vectorize(self, f: impl pulp::NullaryFnOnce) {
self.vectorize(f);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl V3 {
#[inline]
pub fn try_new() -> Option<Self> {
pulp::x86::V3::try_new().map(Self)
}
/// Returns separately two vectors containing the low 64 bits of the result,
/// and the high 64 bits of the result.
#[inline(always)]
pub fn widening_mul_u64x4(self, a: u64x4, b: u64x4) -> (u64x4, u64x4) {
// https://stackoverflow.com/a/28827013
let avx = self.avx;
let avx2 = self.avx2;
let x = cast(a);
let y = cast(b);
let lo_mask = avx._mm256_set1_epi64x(0x0000_0000_FFFF_FFFFu64 as _);
let x_hi = avx2._mm256_shuffle_epi32::<0b10110001>(x);
let y_hi = avx2._mm256_shuffle_epi32::<0b10110001>(y);
let z_lo_lo = avx2._mm256_mul_epu32(x, y);
let z_lo_hi = avx2._mm256_mul_epu32(x, y_hi);
let z_hi_lo = avx2._mm256_mul_epu32(x_hi, y);
let z_hi_hi = avx2._mm256_mul_epu32(x_hi, y_hi);
let z_lo_lo_shift = avx2._mm256_srli_epi64::<32>(z_lo_lo);
let sum_tmp = avx2._mm256_add_epi64(z_lo_hi, z_lo_lo_shift);
let sum_lo = avx2._mm256_and_si256(sum_tmp, lo_mask);
let sum_mid = avx2._mm256_srli_epi64::<32>(sum_tmp);
let sum_mid2 = avx2._mm256_add_epi64(z_hi_lo, sum_lo);
let sum_mid2_hi = avx2._mm256_srli_epi64::<32>(sum_mid2);
let sum_hi = avx2._mm256_add_epi64(z_hi_hi, sum_mid);
let prod_hi = avx2._mm256_add_epi64(sum_hi, sum_mid2_hi);
let prod_lo = avx2._mm256_add_epi64(
avx2._mm256_slli_epi64::<32>(avx2._mm256_add_epi64(z_lo_hi, z_hi_lo)),
z_lo_lo,
);
(cast(prod_lo), cast(prod_hi))
}
/// Multiplies the low 32 bits of each 64 bit integer and returns the 64 bit result.
#[inline(always)]
pub fn mul_low_32_bits_u64x4(self, a: u64x4, b: u64x4) -> u64x4 {
pulp::cast(self.avx2._mm256_mul_epu32(pulp::cast(a), pulp::cast(b)))
}
// (a * b mod 2^32) mod 2^64 for each element in a and b.
#[inline(always)]
pub fn wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(self, a: u64x4, b: u64x4) -> u64x4 {
let a = cast(a);
let b = cast(b);
let avx2 = self.avx2;
let x_hi = avx2._mm256_shuffle_epi32::<0b10110001>(a);
let z_lo_lo = avx2._mm256_mul_epu32(a, b);
let z_hi_lo = avx2._mm256_mul_epu32(x_hi, b);
cast(avx2._mm256_add_epi64(avx2._mm256_slli_epi64::<32>(z_hi_lo), z_lo_lo))
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl core::ops::Deref for V4 {
type Target = pulp::x86::V4;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl core::ops::Deref for V4IFma {
type Target = V4;
#[inline]
fn deref(&self) -> &Self::Target {
let Self {
sse,
sse2,
fxsr,
sse3,
ssse3,
sse4_1,
sse4_2,
popcnt,
avx,
avx2,
bmi1,
bmi2,
fma,
lzcnt,
avx512f,
avx512bw,
avx512cd,
avx512dq,
avx512vl,
avx512ifma: _,
} = *self;
let simd_ref = (pulp::x86::V4 {
sse,
sse2,
fxsr,
sse3,
ssse3,
sse4_1,
sse4_2,
popcnt,
avx,
avx2,
bmi1,
bmi2,
fma,
lzcnt,
avx512f,
avx512bw,
avx512cd,
avx512dq,
avx512vl,
})
.to_ref();
// SAFETY
// `pulp::x86::V4` and `crate::V4` have the same layout, since the latter is
// #[repr(transparent)].
unsafe { &*(simd_ref as *const pulp::x86::V4 as *const V4) }
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl core::ops::Deref for V3 {
type Target = pulp::x86::V3;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0
}
}
// the magic constants are such that
// for all x < 2^64
// x / P_i == ((x * P_i_MAGIC) >> 64) >> P_i_MAGIC_SHIFT
//
// this can be used to implement the modulo operation in constant time to avoid side channel
// attacks, can also speed up the operation x % P_i, since the compiler doesn't manage to vectorize
// it on its own.
//
// how to:
// run `cargo test generate_primes -- --nocapture`
//
// copy paste the generated primes in this function
// ```
// pub fn codegen(x: u64) -> u64 {
// x / $PRIME
// }
// ```
//
// look at the generated assembly for codegen
// extract primes that satisfy the desired property
//
// asm should look like this on x86_64
// ```
// mov rax, rdi
// movabs rcx, P_MAGIC (as i64 signed value)
// mul rcx
// mov rax, rdx
// shr rax, P_MAGIC_SHIFT
// ret
// ```
#[allow(dead_code)]
pub(crate) mod primes32 {
use crate::{
fastdiv::{Div32, Div64},
prime::exp_mod32,
};
pub const P0: u32 = 0b0011_1111_0101_1010_0000_0000_0000_0001;
pub const P1: u32 = 0b0011_1111_0101_1101_0000_0000_0000_0001;
pub const P2: u32 = 0b0011_1111_0111_0110_0000_0000_0000_0001;
pub const P3: u32 = 0b0011_1111_1000_0010_0000_0000_0000_0001;
pub const P4: u32 = 0b0011_1111_1010_1100_0000_0000_0000_0001;
pub const P5: u32 = 0b0011_1111_1010_1111_0000_0000_0000_0001;
pub const P6: u32 = 0b0011_1111_1011_0001_0000_0000_0000_0001;
pub const P7: u32 = 0b0011_1111_1011_1011_0000_0000_0000_0001;
pub const P8: u32 = 0b0011_1111_1101_1110_0000_0000_0000_0001;
pub const P9: u32 = 0b0011_1111_1111_1100_0000_0000_0000_0001;
pub const P0_MAGIC: u64 = 9317778228489988551;
pub const P1_MAGIC: u64 = 4658027473943558643;
pub const P2_MAGIC: u64 = 1162714878353869247;
pub const P3_MAGIC: u64 = 4647426722536610861;
pub const P4_MAGIC: u64 = 9270903515973367219;
pub const P5_MAGIC: u64 = 2317299382174935855;
pub const P6_MAGIC: u64 = 9268060552616330319;
pub const P7_MAGIC: u64 = 2315594963384859737;
pub const P8_MAGIC: u64 = 9242552129100825291;
pub const P9_MAGIC: u64 = 576601523622774689;
pub const P0_MAGIC_SHIFT: u32 = 29;
pub const P1_MAGIC_SHIFT: u32 = 28;
pub const P2_MAGIC_SHIFT: u32 = 26;
pub const P3_MAGIC_SHIFT: u32 = 28;
pub const P4_MAGIC_SHIFT: u32 = 29;
pub const P5_MAGIC_SHIFT: u32 = 27;
pub const P6_MAGIC_SHIFT: u32 = 29;
pub const P7_MAGIC_SHIFT: u32 = 27;
pub const P8_MAGIC_SHIFT: u32 = 29;
pub const P9_MAGIC_SHIFT: u32 = 25;
const fn mul_mod(modulus: u32, a: u32, b: u32) -> u32 {
let wide = a as u64 * b as u64;
(wide % modulus as u64) as u32
}
const fn inv_mod(modulus: u32, x: u32) -> u32 {
exp_mod32(Div32::new(modulus), x, modulus - 2)
}
const fn shoup(modulus: u32, w: u32) -> u32 {
(((w as u64) << 32) / modulus as u64) as u32
}
const fn mul_mod64(modulus: u64, a: u64, b: u64) -> u64 {
let wide = a as u128 * b as u128;
(wide % modulus as u128) as u64
}
const fn exp_mod64(modulus: u64, base: u64, pow: u64) -> u64 {
crate::prime::exp_mod64(Div64::new(modulus), base, pow)
}
const fn shoup64(modulus: u64, w: u64) -> u64 {
(((w as u128) << 64) / modulus as u128) as u64
}
pub const P0_INV_MOD_P1: u32 = inv_mod(P1, P0);
pub const P0_INV_MOD_P1_SHOUP: u32 = shoup(P1, P0_INV_MOD_P1);
pub const P01_INV_MOD_P2: u32 = inv_mod(P2, mul_mod(P2, P0, P1));
pub const P01_INV_MOD_P2_SHOUP: u32 = shoup(P2, P01_INV_MOD_P2);
pub const P012_INV_MOD_P3: u32 = inv_mod(P3, mul_mod(P3, mul_mod(P3, P0, P1), P2));
pub const P012_INV_MOD_P3_SHOUP: u32 = shoup(P3, P012_INV_MOD_P3);
pub const P0123_INV_MOD_P4: u32 =
inv_mod(P4, mul_mod(P4, mul_mod(P4, mul_mod(P4, P0, P1), P2), P3));
pub const P0123_INV_MOD_P4_SHOUP: u32 = shoup(P4, P0123_INV_MOD_P4);
pub const P0_MOD_P2_SHOUP: u32 = shoup(P2, P0);
pub const P0_MOD_P3_SHOUP: u32 = shoup(P3, P0);
pub const P1_MOD_P3_SHOUP: u32 = shoup(P3, P1);
pub const P0_MOD_P4_SHOUP: u32 = shoup(P4, P0);
pub const P1_MOD_P4_SHOUP: u32 = shoup(P4, P1);
pub const P2_MOD_P4_SHOUP: u32 = shoup(P4, P2);
pub const P1_INV_MOD_P2: u32 = inv_mod(P2, P1);
pub const P1_INV_MOD_P2_SHOUP: u32 = shoup(P2, P1_INV_MOD_P2);
pub const P3_INV_MOD_P4: u32 = inv_mod(P4, P3);
pub const P3_INV_MOD_P4_SHOUP: u32 = shoup(P4, P3_INV_MOD_P4);
pub const P12: u64 = P1 as u64 * P2 as u64;
pub const P34: u64 = P3 as u64 * P4 as u64;
pub const P0_INV_MOD_P12: u64 =
exp_mod64(P12, P0 as u64, (P1 as u64 - 1) * (P2 as u64 - 1) - 1);
pub const P0_INV_MOD_P12_SHOUP: u64 = shoup64(P12, P0_INV_MOD_P12);
pub const P0_MOD_P34_SHOUP: u64 = shoup64(P34, P0 as u64);
pub const P012_INV_MOD_P34: u64 = exp_mod64(
P34,
mul_mod64(P34, P0 as u64, P12),
(P3 as u64 - 1) * (P4 as u64 - 1) - 1,
);
pub const P012_INV_MOD_P34_SHOUP: u64 = shoup64(P34, P012_INV_MOD_P34);
pub const P2_INV_MOD_P3: u32 = inv_mod(P3, P2);
pub const P2_INV_MOD_P3_SHOUP: u32 = shoup(P3, P2_INV_MOD_P3);
pub const P4_INV_MOD_P5: u32 = inv_mod(P5, P4);
pub const P4_INV_MOD_P5_SHOUP: u32 = shoup(P5, P4_INV_MOD_P5);
pub const P6_INV_MOD_P7: u32 = inv_mod(P7, P6);
pub const P6_INV_MOD_P7_SHOUP: u32 = shoup(P7, P6_INV_MOD_P7);
pub const P8_INV_MOD_P9: u32 = inv_mod(P9, P8);
pub const P8_INV_MOD_P9_SHOUP: u32 = shoup(P9, P8_INV_MOD_P9);
pub const P01: u64 = P0 as u64 * P1 as u64;
pub const P23: u64 = P2 as u64 * P3 as u64;
pub const P45: u64 = P4 as u64 * P5 as u64;
pub const P67: u64 = P6 as u64 * P7 as u64;
pub const P89: u64 = P8 as u64 * P9 as u64;
pub const P01_MOD_P45_SHOUP: u64 = shoup64(P45, P01);
pub const P01_MOD_P67_SHOUP: u64 = shoup64(P67, P01);
pub const P01_MOD_P89_SHOUP: u64 = shoup64(P89, P01);
pub const P23_MOD_P67_SHOUP: u64 = shoup64(P67, P23);
pub const P23_MOD_P89_SHOUP: u64 = shoup64(P89, P23);
pub const P45_MOD_P89_SHOUP: u64 = shoup64(P89, P45);
pub const P01_INV_MOD_P23: u64 = exp_mod64(P23, P01, (P2 as u64 - 1) * (P3 as u64 - 1) - 1);
pub const P01_INV_MOD_P23_SHOUP: u64 = shoup64(P23, P01_INV_MOD_P23);
pub const P0123_INV_MOD_P45: u64 = exp_mod64(
P45,
mul_mod64(P45, P01, P23),
(P4 as u64 - 1) * (P5 as u64 - 1) - 1,
);
pub const P0123_INV_MOD_P45_SHOUP: u64 = shoup64(P45, P0123_INV_MOD_P45);
pub const P012345_INV_MOD_P67: u64 = exp_mod64(
P67,
mul_mod64(P67, mul_mod64(P67, P01, P23), P45),
(P6 as u64 - 1) * (P7 as u64 - 1) - 1,
);
pub const P012345_INV_MOD_P67_SHOUP: u64 = shoup64(P67, P012345_INV_MOD_P67);
pub const P01234567_INV_MOD_P89: u64 = exp_mod64(
P89,
mul_mod64(P89, mul_mod64(P89, mul_mod64(P89, P01, P23), P45), P67),
(P8 as u64 - 1) * (P9 as u64 - 1) - 1,
);
pub const P01234567_INV_MOD_P89_SHOUP: u64 = shoup64(P89, P01234567_INV_MOD_P89);
pub const P0123: u128 = u128::wrapping_mul(P01 as u128, P23 as u128);
pub const P012345: u128 = u128::wrapping_mul(P0123, P45 as u128);
pub const P01234567: u128 = u128::wrapping_mul(P012345, P67 as u128);
pub const P0123456789: u128 = u128::wrapping_mul(P01234567, P89 as u128);
}
#[allow(dead_code)]
pub(crate) mod primes52 {
use crate::fastdiv::Div64;
pub const P0: u64 = 0b0011_1111_1111_1111_1111_1111_1110_0111_0111_0000_0000_0000_0001;
pub const P1: u64 = 0b0011_1111_1111_1111_1111_1111_1110_1011_1001_0000_0000_0000_0001;
pub const P2: u64 = 0b0011_1111_1111_1111_1111_1111_1110_1100_1000_0000_0000_0000_0001;
pub const P3: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1000_1011_0000_0000_0000_0001;
pub const P4: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1011_1000_0000_0000_0000_0001;
pub const P5: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1100_0111_0000_0000_0000_0001;
pub const P0_MAGIC: u64 = 9223372247845040859;
pub const P1_MAGIC: u64 = 4611686106205779591;
pub const P2_MAGIC: u64 = 4611686102179247601;
pub const P3_MAGIC: u64 = 2305843024917166187;
pub const P4_MAGIC: u64 = 4611686037754736721;
pub const P5_MAGIC: u64 = 4611686033728204851;
pub const P0_MAGIC_SHIFT: u32 = 49;
pub const P1_MAGIC_SHIFT: u32 = 48;
pub const P2_MAGIC_SHIFT: u32 = 48;
pub const P3_MAGIC_SHIFT: u32 = 47;
pub const P4_MAGIC_SHIFT: u32 = 48;
pub const P5_MAGIC_SHIFT: u32 = 48;
const fn mul_mod(modulus: u64, a: u64, b: u64) -> u64 {
let wide = a as u128 * b as u128;
(wide % modulus as u128) as u64
}
const fn inv_mod(modulus: u64, x: u64) -> u64 {
crate::prime::exp_mod64(Div64::new(modulus), x, modulus - 2)
}
const fn shoup(modulus: u64, w: u64) -> u64 {
(((w as u128) << 52) / modulus as u128) as u64
}
pub const P0_INV_MOD_P1: u64 = inv_mod(P1, P0);
pub const P0_INV_MOD_P1_SHOUP: u64 = shoup(P1, P0_INV_MOD_P1);
pub const P01_INV_MOD_P2: u64 = inv_mod(P2, mul_mod(P2, P0, P1));
pub const P01_INV_MOD_P2_SHOUP: u64 = shoup(P2, P01_INV_MOD_P2);
pub const P012_INV_MOD_P3: u64 = inv_mod(P3, mul_mod(P3, mul_mod(P3, P0, P1), P2));
pub const P012_INV_MOD_P3_SHOUP: u64 = shoup(P3, P012_INV_MOD_P3);
pub const P0123_INV_MOD_P4: u64 =
inv_mod(P4, mul_mod(P4, mul_mod(P4, mul_mod(P4, P0, P1), P2), P3));
pub const P0123_INV_MOD_P4_SHOUP: u64 = shoup(P4, P0123_INV_MOD_P4);
pub const P0_MOD_P2_SHOUP: u64 = shoup(P2, P0);
pub const P0_MOD_P3_SHOUP: u64 = shoup(P3, P0);
pub const P1_MOD_P3_SHOUP: u64 = shoup(P3, P1);
pub const P0_MOD_P4_SHOUP: u64 = shoup(P4, P0);
pub const P1_MOD_P4_SHOUP: u64 = shoup(P4, P1);
pub const P2_MOD_P4_SHOUP: u64 = shoup(P4, P2);
}
macro_rules! izip {
(@ __closure @ ($a:expr)) => { |a| (a,) };
(@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) };
(@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, g) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) };
( $first:expr $(,)?) => {
{
::core::iter::IntoIterator::into_iter($first)
}
};
( $first:expr, $($rest:expr),+ $(,)?) => {
{
::core::iter::IntoIterator::into_iter($first)
$(.zip($rest))*
.map(crate::izip!(@ __closure @ ($first, $($rest),*)))
}
};
}
pub(crate) use izip;
#[cfg(test)]
mod tests {
use crate::prime::largest_prime_in_arithmetic_progression64;
use rand::random;
#[test]
fn test_barrett32() {
let p =
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
let big_q: u32 = p.ilog2() + 1;
let big_l: u32 = big_q + 31;
let k: u32 = ((1u128 << big_l) / p as u128).try_into().unwrap();
for _ in 0..10000 {
let a = random::<u32>() % p;
let b = random::<u32>() % p;
let d = a as u64 * b as u64;
// Q < 31
// d < 2^(2Q)
// (d >> (Q-1)) < 2^(Q+1) -> c1 fits in u32
let c1 = (d >> (big_q - 1)) as u32;
// c2 < 2^(Q+33)
let c3 = ((c1 as u64 * k as u64) >> 32) as u32;
let c = (d as u32).wrapping_sub(p.wrapping_mul(c3));
let c = if c >= p { c - p } else { c };
assert_eq!(c as u64, d % p as u64);
}
}
#[test]
fn test_barrett52() {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap();
let big_q: u32 = p.ilog2() + 1;
let big_l: u32 = big_q + 51;
let k: u64 = ((1u128 << big_l) / p as u128).try_into().unwrap();
for _ in 0..10000 {
let a = random::<u64>() % p;
let b = random::<u64>() % p;
let d = a as u128 * b as u128;
// Q < 51
// d < 2^(2Q)
// (d >> (Q-1)) < 2^(Q+1) -> c1 fits in u64
let c1 = (d >> (big_q - 1)) as u64;
// c2 < 2^(Q+53)
let c3 = ((c1 as u128 * k as u128) >> 52) as u64;
let c = (d as u64).wrapping_sub(p.wrapping_mul(c3));
let c = if c >= p { c - p } else { c };
assert_eq!(c as u128, d % p as u128);
}
}
#[test]
fn test_barrett64() {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap();
let big_q: u32 = p.ilog2() + 1;
let big_l: u32 = big_q + 63;
let k: u64 = ((1u128 << big_l) / p as u128).try_into().unwrap();
for _ in 0..10000 {
let a = random::<u64>() % p;
let b = random::<u64>() % p;
let d = a as u128 * b as u128;
// Q < 63
// d < 2^(2Q)
// (d >> (Q-1)) < 2^(Q+1) -> c1 fits in u64
let c1 = (d >> (big_q - 1)) as u64;
// c2 < 2^(Q+65)
let c3 = ((c1 as u128 * k as u128) >> 64) as u64;
let c = (d as u64).wrapping_sub(p.wrapping_mul(c3));
let c = if c >= p { c - p } else { c };
assert_eq!(c as u128, d % p as u128);
}
}
// primes should be of the form x * LARGEST_POLYNOMIAL_SIZE(2^16) + 1
// primes should be < 2^30 or < 2^50, for NTT efficiency
// primes should satisfy the magic property documented above the primes32 module
// primes should be as large as possible
#[cfg(feature = "std")]
#[test]
fn generate_primes() {
let mut p = 1u64 << 30;
for _ in 0..100 {
p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, p - 1).unwrap();
println!("{p:#034b}");
}
let mut p = 1u64 << 50;
for _ in 0..100 {
p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, p - 1).unwrap();
println!("{p:#054b}");
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(test)]
mod x86_tests {
use super::*;
use rand::random as rnd;
#[test]
fn test_widening_mul() {
if let Some(simd) = crate::V3::try_new() {
let a = u64x4(rnd(), rnd(), rnd(), rnd());
let b = u64x4(rnd(), rnd(), rnd(), rnd());
let (lo, hi) = simd.widening_mul_u64x4(a, b);
assert_eq!(
lo,
u64x4(
u64::wrapping_mul(a.0, b.0),
u64::wrapping_mul(a.1, b.1),
u64::wrapping_mul(a.2, b.2),
u64::wrapping_mul(a.3, b.3),
),
);
assert_eq!(
hi,
u64x4(
((a.0 as u128 * b.0 as u128) >> 64) as u64,
((a.1 as u128 * b.1 as u128) >> 64) as u64,
((a.2 as u128 * b.2 as u128) >> 64) as u64,
((a.3 as u128 * b.3 as u128) >> 64) as u64,
),
);
}
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4::try_new() {
let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
let (lo, hi) = simd.widening_mul_u64x8(a, b);
assert_eq!(
lo,
u64x8(
u64::wrapping_mul(a.0, b.0),
u64::wrapping_mul(a.1, b.1),
u64::wrapping_mul(a.2, b.2),
u64::wrapping_mul(a.3, b.3),
u64::wrapping_mul(a.4, b.4),
u64::wrapping_mul(a.5, b.5),
u64::wrapping_mul(a.6, b.6),
u64::wrapping_mul(a.7, b.7),
),
);
assert_eq!(
hi,
u64x8(
((a.0 as u128 * b.0 as u128) >> 64) as u64,
((a.1 as u128 * b.1 as u128) >> 64) as u64,
((a.2 as u128 * b.2 as u128) >> 64) as u64,
((a.3 as u128 * b.3 as u128) >> 64) as u64,
((a.4 as u128 * b.4 as u128) >> 64) as u64,
((a.5 as u128 * b.5 as u128) >> 64) as u64,
((a.6 as u128 * b.6 as u128) >> 64) as u64,
((a.7 as u128 * b.7 as u128) >> 64) as u64,
),
);
}
}
#[test]
fn test_mul_low_32_bits() {
if let Some(simd) = crate::V3::try_new() {
let a = u64x4(rnd(), rnd(), rnd(), rnd());
let b = u64x4(rnd(), rnd(), rnd(), rnd());
let res = simd.mul_low_32_bits_u64x4(a, b);
assert_eq!(
res,
u64x4(
a.0 as u32 as u64 * b.0 as u32 as u64,
a.1 as u32 as u64 * b.1 as u32 as u64,
a.2 as u32 as u64 * b.2 as u32 as u64,
a.3 as u32 as u64 * b.3 as u32 as u64,
),
);
}
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4::try_new() {
let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
let res = simd.mul_low_32_bits_u64x8(a, b);
assert_eq!(
res,
u64x8(
a.0 as u32 as u64 * b.0 as u32 as u64,
a.1 as u32 as u64 * b.1 as u32 as u64,
a.2 as u32 as u64 * b.2 as u32 as u64,
a.3 as u32 as u64 * b.3 as u32 as u64,
a.4 as u32 as u64 * b.4 as u32 as u64,
a.5 as u32 as u64 * b.5 as u32 as u64,
a.6 as u32 as u64 * b.6 as u32 as u64,
a.7 as u32 as u64 * b.7 as u32 as u64,
),
);
}
}
#[test]
fn test_mul_lhs_with_low_32_bits_of_rhs() {
if let Some(simd) = crate::V3::try_new() {
let a = u64x4(rnd(), rnd(), rnd(), rnd());
let b = u64x4(rnd(), rnd(), rnd(), rnd());
let res = simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(a, b);
assert_eq!(
res,
u64x4(
u64::wrapping_mul(a.0, b.0 as u32 as u64),
u64::wrapping_mul(a.1, b.1 as u32 as u64),
u64::wrapping_mul(a.2, b.2 as u32 as u64),
u64::wrapping_mul(a.3, b.3 as u32 as u64),
),
);
}
}
}

448
tfhe-ntt/src/native128.rs Normal file
View File

@@ -0,0 +1,448 @@
pub(crate) use crate::native64::{mul_mod32, mul_mod64};
use aligned_vec::avec;
/// Negacyclic NTT plan for multiplying two 128bit polynomials.
#[derive(Clone, Debug)]
pub struct Plan32(
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
);
#[inline(always)]
fn reconstruct_32bit_0123456789_v2(
mod_p0: u32,
mod_p1: u32,
mod_p2: u32,
mod_p3: u32,
mod_p4: u32,
mod_p5: u32,
mod_p6: u32,
mod_p7: u32,
mod_p8: u32,
mod_p9: u32,
) -> u128 {
use crate::primes32::*;
let mod_p01 = {
let v0 = mod_p0;
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
v0 as u64 + (v1 as u64 * P0 as u64)
};
let mod_p23 = {
let v2 = mod_p2;
let v3 = mul_mod32(P3, P2_INV_MOD_P3, 2 * P3 + mod_p3 - v2);
v2 as u64 + (v3 as u64 * P2 as u64)
};
let mod_p45 = {
let v4 = mod_p4;
let v5 = mul_mod32(P5, P4_INV_MOD_P5, 2 * P5 + mod_p5 - v4);
v4 as u64 + (v5 as u64 * P4 as u64)
};
let mod_p67 = {
let v6 = mod_p6;
let v7 = mul_mod32(P7, P6_INV_MOD_P7, 2 * P7 + mod_p7 - v6);
v6 as u64 + (v7 as u64 * P6 as u64)
};
let mod_p89 = {
let v8 = mod_p8;
let v9 = mul_mod32(P9, P8_INV_MOD_P9, 2 * P9 + mod_p9 - v8);
v8 as u64 + (v9 as u64 * P8 as u64)
};
let v01 = mod_p01;
let v23 = mul_mod64(
P23.wrapping_neg(),
2 * P23 + mod_p23 - v01,
P01_INV_MOD_P23,
P01_INV_MOD_P23_SHOUP,
);
let v45 = mul_mod64(
P45.wrapping_neg(),
2 * P45 + mod_p45 - (v01 + mul_mod64(P45.wrapping_neg(), v23, P01, P01_MOD_P45_SHOUP)),
P0123_INV_MOD_P45,
P0123_INV_MOD_P45_SHOUP,
);
let v67 = mul_mod64(
P67.wrapping_neg(),
2 * P67 + mod_p67
- (v01
+ mul_mod64(
P67.wrapping_neg(),
v23 + mul_mod64(P67.wrapping_neg(), v45, P23, P23_MOD_P67_SHOUP),
P01,
P01_MOD_P67_SHOUP,
)),
P012345_INV_MOD_P67,
P012345_INV_MOD_P67_SHOUP,
);
let v89 = mul_mod64(
P89.wrapping_neg(),
2 * P89 + mod_p89
- (v01
+ mul_mod64(
P89.wrapping_neg(),
v23 + mul_mod64(
P89.wrapping_neg(),
v45 + mul_mod64(P89.wrapping_neg(), v67, P45, P45_MOD_P89_SHOUP),
P23,
P23_MOD_P89_SHOUP,
),
P01,
P01_MOD_P89_SHOUP,
)),
P01234567_INV_MOD_P89,
P01234567_INV_MOD_P89_SHOUP,
);
let sign = v89 > (P89 / 2);
let pos = (v01 as u128)
.wrapping_add(u128::wrapping_mul(v23 as u128, P01 as u128))
.wrapping_add(u128::wrapping_mul(v45 as u128, P0123))
.wrapping_add(u128::wrapping_mul(v67 as u128, P012345))
.wrapping_add(u128::wrapping_mul(v89 as u128, P01234567));
let neg = pos.wrapping_sub(P0123456789);
if sign {
neg
} else {
pos
}
}
impl Plan32 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime32::Plan, primes32::*};
Some(Self(
Plan::try_new(n, P0)?,
Plan::try_new(n, P1)?,
Plan::try_new(n, P2)?,
Plan::try_new(n, P3)?,
Plan::try_new(n, P4)?,
Plan::try_new(n, P5)?,
Plan::try_new(n, P6)?,
Plan::try_new(n, P7)?,
Plan::try_new(n, P8)?,
Plan::try_new(n, P9)?,
))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
#[inline]
pub fn ntt_0(&self) -> &crate::prime32::Plan {
&self.0
}
#[inline]
pub fn ntt_1(&self) -> &crate::prime32::Plan {
&self.1
}
#[inline]
pub fn ntt_2(&self) -> &crate::prime32::Plan {
&self.2
}
#[inline]
pub fn ntt_3(&self) -> &crate::prime32::Plan {
&self.3
}
#[inline]
pub fn ntt_4(&self) -> &crate::prime32::Plan {
&self.4
}
#[inline]
pub fn ntt_5(&self) -> &crate::prime32::Plan {
&self.5
}
#[inline]
pub fn ntt_6(&self) -> &crate::prime32::Plan {
&self.6
}
#[inline]
pub fn ntt_7(&self) -> &crate::prime32::Plan {
&self.7
}
#[inline]
pub fn ntt_8(&self) -> &crate::prime32::Plan {
&self.8
}
#[inline]
pub fn ntt_9(&self) -> &crate::prime32::Plan {
&self.9
}
pub fn fwd(
&self,
value: &[u128],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
mod_p3: &mut [u32],
mod_p4: &mut [u32],
mod_p5: &mut [u32],
mod_p6: &mut [u32],
mod_p7: &mut [u32],
mod_p8: &mut [u32],
mod_p9: &mut [u32],
) {
for (
value,
mod_p0,
mod_p1,
mod_p2,
mod_p3,
mod_p4,
mod_p5,
mod_p6,
mod_p7,
mod_p8,
mod_p9,
) in crate::izip!(
value,
&mut *mod_p0,
&mut *mod_p1,
&mut *mod_p2,
&mut *mod_p3,
&mut *mod_p4,
&mut *mod_p5,
&mut *mod_p6,
&mut *mod_p7,
&mut *mod_p8,
&mut *mod_p9,
) {
*mod_p0 = (value % crate::primes32::P0 as u128) as u32;
*mod_p1 = (value % crate::primes32::P1 as u128) as u32;
*mod_p2 = (value % crate::primes32::P2 as u128) as u32;
*mod_p3 = (value % crate::primes32::P3 as u128) as u32;
*mod_p4 = (value % crate::primes32::P4 as u128) as u32;
*mod_p5 = (value % crate::primes32::P5 as u128) as u32;
*mod_p6 = (value % crate::primes32::P6 as u128) as u32;
*mod_p7 = (value % crate::primes32::P7 as u128) as u32;
*mod_p8 = (value % crate::primes32::P8 as u128) as u32;
*mod_p9 = (value % crate::primes32::P9 as u128) as u32;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
self.2.fwd(mod_p2);
self.3.fwd(mod_p3);
self.4.fwd(mod_p4);
self.5.fwd(mod_p5);
self.6.fwd(mod_p6);
self.7.fwd(mod_p7);
self.8.fwd(mod_p8);
self.9.fwd(mod_p9);
}
pub fn inv(
&self,
value: &mut [u128],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
mod_p3: &mut [u32],
mod_p4: &mut [u32],
mod_p5: &mut [u32],
mod_p6: &mut [u32],
mod_p7: &mut [u32],
mod_p8: &mut [u32],
mod_p9: &mut [u32],
) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
self.2.inv(mod_p2);
self.3.inv(mod_p3);
self.4.inv(mod_p4);
self.5.inv(mod_p5);
self.6.inv(mod_p6);
self.7.inv(mod_p7);
self.8.inv(mod_p8);
self.9.inv(mod_p9);
for (
value,
&mod_p0,
&mod_p1,
&mod_p2,
&mod_p3,
&mod_p4,
&mod_p5,
&mod_p6,
&mod_p7,
&mod_p8,
&mod_p9,
) in crate::izip!(
value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4, &*mod_p5, &*mod_p6, &*mod_p7,
&*mod_p8, &*mod_p9,
) {
*value = reconstruct_32bit_0123456789_v2(
mod_p0, mod_p1, mod_p2, mod_p3, mod_p4, mod_p5, mod_p6, mod_p7, mod_p8, mod_p9,
);
}
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut lhs2 = avec![0; n];
let mut lhs3 = avec![0; n];
let mut lhs4 = avec![0; n];
let mut lhs5 = avec![0; n];
let mut lhs6 = avec![0; n];
let mut lhs7 = avec![0; n];
let mut lhs8 = avec![0; n];
let mut lhs9 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
let mut rhs2 = avec![0; n];
let mut rhs3 = avec![0; n];
let mut rhs4 = avec![0; n];
let mut rhs5 = avec![0; n];
let mut rhs6 = avec![0; n];
let mut rhs7 = avec![0; n];
let mut rhs8 = avec![0; n];
let mut rhs9 = avec![0; n];
self.fwd(
lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
&mut lhs7, &mut lhs8, &mut lhs9,
);
self.fwd(
rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4, &mut rhs5, &mut rhs6,
&mut rhs7, &mut rhs8, &mut rhs9,
);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
self.3.mul_assign_normalize(&mut lhs3, &rhs3);
self.4.mul_assign_normalize(&mut lhs4, &rhs4);
self.5.mul_assign_normalize(&mut lhs5, &rhs5);
self.6.mul_assign_normalize(&mut lhs6, &rhs6);
self.7.mul_assign_normalize(&mut lhs7, &rhs7);
self.8.mul_assign_normalize(&mut lhs8, &rhs8);
self.9.mul_assign_normalize(&mut lhs9, &rhs9);
self.inv(
prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4, &mut lhs5, &mut lhs6,
&mut lhs7, &mut lhs8, &mut lhs9,
);
}
}
#[cfg(test)]
pub mod tests {
use super::*;
use alloc::{vec, vec::Vec};
use rand::random;
extern crate alloc;
pub fn negacyclic_convolution(n: usize, lhs: &[u128], rhs: &[u128]) -> Vec<u128> {
let mut full_convolution = vec![0u128; 2 * n];
let mut negacyclic_convolution = vec![0u128; n];
for i in 0..n {
for j in 0..n {
full_convolution[i + j] =
full_convolution[i + j].wrapping_add(lhs[i].wrapping_mul(rhs[j]));
}
}
for i in 0..n {
negacyclic_convolution[i] = full_convolution[i].wrapping_sub(full_convolution[i + n]);
}
negacyclic_convolution
}
pub fn random_lhs_rhs_with_negacyclic_convolution(
n: usize,
) -> (Vec<u128>, Vec<u128>, Vec<u128>) {
let mut lhs = vec![0u128; n];
let mut rhs = vec![0u128; n];
for x in &mut lhs {
*x = random();
}
for x in &mut rhs {
*x = random();
}
let lhs = lhs;
let rhs = rhs;
let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
(lhs, rhs, negacyclic_convolution)
}
#[test]
fn reconstruct_32bit() {
for n in [32, 64, 256, 1024, 2048] {
let value = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
let mut value_roundtrip = vec![0; n];
let mut mod_p0 = vec![0; n];
let mut mod_p1 = vec![0; n];
let mut mod_p2 = vec![0; n];
let mut mod_p3 = vec![0; n];
let mut mod_p4 = vec![0; n];
let mut mod_p5 = vec![0; n];
let mut mod_p6 = vec![0; n];
let mut mod_p7 = vec![0; n];
let mut mod_p8 = vec![0; n];
let mut mod_p9 = vec![0; n];
let plan = Plan32::try_new(n).unwrap();
plan.fwd(
&value,
&mut mod_p0,
&mut mod_p1,
&mut mod_p2,
&mut mod_p3,
&mut mod_p4,
&mut mod_p5,
&mut mod_p6,
&mut mod_p7,
&mut mod_p8,
&mut mod_p9,
);
plan.inv(
&mut value_roundtrip,
&mut mod_p0,
&mut mod_p1,
&mut mod_p2,
&mut mod_p3,
&mut mod_p4,
&mut mod_p5,
&mut mod_p6,
&mut mod_p7,
&mut mod_p8,
&mut mod_p9,
);
for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
assert_eq!(value_roundtrip, value.wrapping_mul(n as u128));
}
let (lhs, rhs, negacyclic_convolution) = random_lhs_rhs_with_negacyclic_convolution(n);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
}

597
tfhe-ntt/src/native32.rs Normal file
View File

@@ -0,0 +1,597 @@
use aligned_vec::avec;
#[allow(unused_imports)]
use pulp::*;
/// Negacyclic NTT plan for multiplying two 32bit polynomials.
#[derive(Clone, Debug)]
pub struct Plan32(
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
);
/// Negacyclic NTT plan for multiplying two 32bit polynomials.
/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
#[derive(Clone, Debug)]
pub struct Plan52(crate::prime64::Plan, crate::prime64::Plan, crate::V4IFma);
#[inline(always)]
pub(crate) fn mul_mod32(p: u32, a: u32, b: u32) -> u32 {
let wide = a as u64 * b as u64;
(wide % p as u64) as u32
}
#[inline(always)]
pub(crate) fn reconstruct_32bit_012(mod_p0: u32, mod_p1: u32, mod_p2: u32) -> u32 {
use crate::primes32::*;
let v0 = mod_p0;
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
let v2 = mul_mod32(
P2,
P01_INV_MOD_P2,
2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
);
let sign = v2 > (P2 / 2);
const _0: u32 = P0;
const _01: u32 = _0.wrapping_mul(P1);
const _012: u32 = _01.wrapping_mul(P2);
let pos = v0
.wrapping_add(v1.wrapping_mul(_0))
.wrapping_add(v2.wrapping_mul(_01));
let neg = pos.wrapping_sub(_012);
if sign {
neg
} else {
pos
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn mul_mod32_avx2(
simd: crate::V3,
p: u32x8,
a: u32x8,
b: u32x8,
b_shoup: u32x8,
) -> u32x8 {
let shoup_q = simd.widening_mul_u32x8(a, b_shoup).1;
let t = simd.wrapping_sub_u32x8(
simd.wrapping_mul_u32x8(a, b),
simd.wrapping_mul_u32x8(shoup_q, p),
);
simd.small_mod_u32x8(p, t)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn mul_mod32_avx512(
simd: crate::V4IFma,
p: u32x16,
a: u32x16,
b: u32x16,
b_shoup: u32x16,
) -> u32x16 {
let shoup_q = simd.widening_mul_u32x16(a, b_shoup).1;
let t = simd.wrapping_sub_u32x16(
simd.wrapping_mul_u32x16(a, b),
simd.wrapping_mul_u32x16(shoup_q, p),
);
simd.small_mod_u32x16(p, t)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn mul_mod52_avx512(
simd: crate::V4IFma,
p: u64x8,
neg_p: u64x8,
a: u64x8,
b: u64x8,
b_shoup: u64x8,
) -> u64x8 {
let shoup_q = simd.widening_mul_u52x8(a, b_shoup).1;
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(a, b).0);
simd.small_mod_u64x8(p, t)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn reconstruct_32bit_012_avx2(
simd: crate::V3,
mod_p0: u32x8,
mod_p1: u32x8,
mod_p2: u32x8,
) -> u32x8 {
use crate::primes32::*;
let p0 = simd.splat_u32x8(P0);
let p1 = simd.splat_u32x8(P1);
let p2 = simd.splat_u32x8(P2);
let two_p1 = simd.splat_u32x8(2 * P1);
let two_p2 = simd.splat_u32x8(2 * P2);
let half_p2 = simd.splat_u32x8(P2 / 2);
let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
let p012 = simd.splat_u32x8(P0.wrapping_mul(P1).wrapping_mul(P2));
let v0 = mod_p0;
let v1 = mul_mod32_avx2(
simd,
p1,
simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let v2 = mul_mod32_avx2(
simd,
p2,
simd.wrapping_sub_u32x8(
simd.wrapping_add_u32x8(two_p2, mod_p2),
simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
),
p01_inv_mod_p2,
p01_inv_mod_p2_shoup,
);
let sign = simd.cmp_gt_u32x8(v2, half_p2);
let pos = simd.wrapping_add_u32x8(
simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0)),
simd.wrapping_mul_u32x8(v2, p01),
);
let neg = simd.wrapping_sub_u32x8(pos, p012);
simd.select_u32x8(sign, neg, pos)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
fn reconstruct_32bit_012_avx512(
simd: crate::V4IFma,
mod_p0: u32x16,
mod_p1: u32x16,
mod_p2: u32x16,
) -> u32x16 {
use crate::primes32::*;
let p0 = simd.splat_u32x16(P0);
let p1 = simd.splat_u32x16(P1);
let p2 = simd.splat_u32x16(P2);
let two_p1 = simd.splat_u32x16(2 * P1);
let two_p2 = simd.splat_u32x16(2 * P2);
let half_p2 = simd.splat_u32x16(P2 / 2);
let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
let p012 = simd.splat_u32x16(P0.wrapping_mul(P1).wrapping_mul(P2));
let v0 = mod_p0;
let v1 = mul_mod32_avx512(
simd,
p1,
simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let v2 = mul_mod32_avx512(
simd,
p2,
simd.wrapping_sub_u32x16(
simd.wrapping_add_u32x16(two_p2, mod_p2),
simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
),
p01_inv_mod_p2,
p01_inv_mod_p2_shoup,
);
let sign = simd.cmp_gt_u32x16(v2, half_p2);
let pos = simd.wrapping_add_u32x16(
simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0)),
simd.wrapping_mul_u32x16(v2, p01),
);
let neg = simd.wrapping_sub_u32x16(pos, p012);
simd.select_u32x16(sign, neg, pos)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
fn reconstruct_52bit_01_avx512(simd: crate::V4IFma, mod_p0: u64x8, mod_p1: u64x8) -> u32x8 {
use crate::primes52::*;
let p0 = simd.splat_u64x8(P0);
let p1 = simd.splat_u64x8(P1);
let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
let two_p1 = simd.splat_u64x8(2 * P1);
let half_p1 = simd.splat_u64x8(P1 / 2);
let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
let v0 = mod_p0;
let v1 = mul_mod52_avx512(
simd,
p1,
neg_p1,
simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let sign = simd.cmp_gt_u64x8(v1, half_p1);
let pos = simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0));
let neg = simd.wrapping_sub_u64x8(pos, p01);
simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn reconstruct_slice_32bit_012_avx2(
simd: crate::V3,
value: &mut [u32],
mod_p0: &[u32],
mod_p1: &[u32],
mod_p2: &[u32],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<8, _>(value).0;
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
*value = cast(reconstruct_32bit_012_avx2(
simd,
cast(mod_p0),
cast(mod_p1),
cast(mod_p2),
));
}
},
);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
fn reconstruct_slice_32bit_012_avx512(
simd: crate::V4IFma,
value: &mut [u32],
mod_p0: &[u32],
mod_p1: &[u32],
mod_p2: &[u32],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<16, _>(value).0;
let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
let mod_p2 = pulp::as_arrays::<16, _>(mod_p2).0;
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
*value = cast(reconstruct_32bit_012_avx512(
simd,
cast(mod_p0),
cast(mod_p1),
cast(mod_p2),
));
}
},
);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
fn reconstruct_slice_52bit_01_avx512(
simd: crate::V4IFma,
value: &mut [u32],
mod_p0: &[u64],
mod_p1: &[u64],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<8, _>(value).0;
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
*value = cast(reconstruct_52bit_01_avx512(
simd,
cast(mod_p0),
cast(mod_p1),
));
}
},
);
}
impl Plan32 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime32::Plan, primes32::*};
Some(Self(
Plan::try_new(n, P0)?,
Plan::try_new(n, P1)?,
Plan::try_new(n, P2)?,
))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
#[inline]
pub fn ntt_0(&self) -> &crate::prime32::Plan {
&self.0
}
#[inline]
pub fn ntt_1(&self) -> &crate::prime32::Plan {
&self.1
}
#[inline]
pub fn ntt_2(&self) -> &crate::prime32::Plan {
&self.2
}
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32], mod_p2: &mut [u32]) {
for (value, mod_p0, mod_p1, mod_p2) in
crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
{
*mod_p0 = value % crate::primes32::P0;
*mod_p1 = value % crate::primes32::P1;
*mod_p2 = value % crate::primes32::P2;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
self.2.fwd(mod_p2);
}
pub fn inv(
&self,
value: &mut [u32],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
self.2.inv(mod_p2);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4IFma::try_new() {
reconstruct_slice_32bit_012_avx512(simd, value, mod_p0, mod_p1, mod_p2);
return;
}
if let Some(simd) = crate::V3::try_new() {
reconstruct_slice_32bit_012_avx2(simd, value, mod_p0, mod_p1, mod_p2);
return;
}
}
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2)
{
*value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
}
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs: &[u32]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut lhs2 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
let mut rhs2 = avec![0; n];
self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl Plan52 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters, or if the AVX512
/// instruction set isn't detected.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime64::Plan, primes52::*};
let simd = crate::V4IFma::try_new()?;
Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?, simd))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
self.2.vectorize(
#[inline(always)]
|| {
for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
*mod_p0 = *value as u64;
*mod_p1 = *value as u64;
}
},
);
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
}
pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
let simd = self.2;
reconstruct_slice_52bit_01_avx512(simd, value, mod_p0, mod_p1);
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs: &[u32]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
self.fwd(lhs, &mut lhs0, &mut lhs1);
self.fwd(rhs, &mut rhs0, &mut rhs1);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.inv(prod, &mut lhs0, &mut lhs1);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prime32::tests::random_lhs_rhs_with_negacyclic_convolution;
use rand::random;
extern crate alloc;
use alloc::{vec, vec::Vec};
#[test]
fn reconstruct_32bit() {
for n in [32, 64, 256, 1024, 2048] {
let plan = Plan32::try_new(n).unwrap();
let value = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
let mut value_roundtrip = vec![0u32; n];
let mut mod_p0 = vec![0u32; n];
let mut mod_p1 = vec![0u32; n];
let mut mod_p2 = vec![0u32; n];
plan.fwd(&value, &mut mod_p0, &mut mod_p1, &mut mod_p2);
plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1, &mut mod_p2);
for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
assert_eq!(value_roundtrip, value.wrapping_mul(n as u32));
}
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, 0);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn reconstruct_52bit() {
for n in [32, 64, 256, 1024, 2048] {
if let Some(plan) = Plan52::try_new(n) {
let value = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
let mut value_roundtrip = vec![0u32; n];
let mut mod_p0 = vec![0u64; n];
let mut mod_p1 = vec![0u64; n];
plan.fwd(&value, &mut mod_p0, &mut mod_p1);
plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1);
for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
assert_eq!(value_roundtrip, value.wrapping_mul(n as u32));
}
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, 0);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn reconstruct_32bit_avx() {
for n in [16, 32, 64, 256, 1024, 2048] {
use crate::primes32::*;
let mut value = vec![0u32; n];
let mut value_avx2 = vec![0u32; n];
#[cfg(feature = "nightly")]
let mut value_avx512 = vec![0u32; n];
let mod_p0 = (0..n).map(|_| random::<u32>() % P0).collect::<Vec<_>>();
let mod_p1 = (0..n).map(|_| random::<u32>() % P1).collect::<Vec<_>>();
let mod_p2 = (0..n).map(|_| random::<u32>() % P2).collect::<Vec<_>>();
for (value, &mod_p0, &mod_p1, &mod_p2) in
crate::izip!(&mut value, &mod_p0, &mod_p1, &mod_p2)
{
*value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
}
if let Some(simd) = crate::V3::try_new() {
reconstruct_slice_32bit_012_avx2(simd, &mut value_avx2, &mod_p0, &mod_p1, &mod_p2);
assert_eq!(value, value_avx2);
}
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4IFma::try_new() {
reconstruct_slice_32bit_012_avx512(
simd,
&mut value_avx512,
&mod_p0,
&mod_p1,
&mod_p2,
);
assert_eq!(value, value_avx512);
}
}
}
}

1294
tfhe-ntt/src/native64.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,222 @@
pub(crate) use crate::native64::{mul_mod32, mul_mod64};
use aligned_vec::avec;
pub struct Plan32(
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
);
#[inline(always)]
fn reconstruct_32bit_01234_v2(
mod_p0: u32,
mod_p1: u32,
mod_p2: u32,
mod_p3: u32,
mod_p4: u32,
) -> u128 {
use crate::primes32::*;
let mod_p12 = {
let v1 = mod_p1;
let v2 = mul_mod32(P2, P1_INV_MOD_P2, 2 * P2 + mod_p2 - v1);
v1 as u64 + (v2 as u64 * P1 as u64)
};
let mod_p34 = {
let v3 = mod_p3;
let v4 = mul_mod32(P4, P3_INV_MOD_P4, 2 * P4 + mod_p4 - v3);
v3 as u64 + (v4 as u64 * P3 as u64)
};
let v0 = mod_p0 as u64;
let v12 = mul_mod64(
P12.wrapping_neg(),
2 * P12 + mod_p12 - v0,
P0_INV_MOD_P12,
P0_INV_MOD_P12_SHOUP,
);
let v34 = mul_mod64(
P34.wrapping_neg(),
2 * P34 + mod_p34 - (v0 + mul_mod64(P34.wrapping_neg(), v12, P0 as u64, P0_MOD_P34_SHOUP)),
P012_INV_MOD_P34,
P012_INV_MOD_P34_SHOUP,
);
let sign = v34 > (P34 / 2);
const _0: u128 = P0 as u128;
const _012: u128 = _0.wrapping_mul(P12 as u128);
const _01234: u128 = _012.wrapping_mul(P34 as u128);
let pos = (v0 as u128)
.wrapping_add((v12 as u128).wrapping_mul(_0))
.wrapping_add((v34 as u128).wrapping_mul(_012));
let neg = pos.wrapping_sub(_01234);
if sign {
neg
} else {
pos
}
}
impl Plan32 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime32::Plan, primes32::*};
Some(Self(
Plan::try_new(n, P0)?,
Plan::try_new(n, P1)?,
Plan::try_new(n, P2)?,
Plan::try_new(n, P3)?,
Plan::try_new(n, P4)?,
))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
pub fn fwd(
&self,
value: &[u128],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
mod_p3: &mut [u32],
mod_p4: &mut [u32],
) {
for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
value,
&mut *mod_p0,
&mut *mod_p1,
&mut *mod_p2,
&mut *mod_p3,
&mut *mod_p4,
) {
*mod_p0 = (value % crate::primes32::P0 as u128) as u32;
*mod_p1 = (value % crate::primes32::P1 as u128) as u32;
*mod_p2 = (value % crate::primes32::P2 as u128) as u32;
*mod_p3 = (value % crate::primes32::P3 as u128) as u32;
*mod_p4 = (value % crate::primes32::P4 as u128) as u32;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
self.2.fwd(mod_p2);
self.3.fwd(mod_p3);
self.4.fwd(mod_p4);
}
pub fn fwd_binary(
&self,
value: &[u128],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
mod_p3: &mut [u32],
mod_p4: &mut [u32],
) {
for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
value,
&mut *mod_p0,
&mut *mod_p1,
&mut *mod_p2,
&mut *mod_p3,
&mut *mod_p4,
) {
*mod_p0 = *value as u32;
*mod_p1 = *value as u32;
*mod_p2 = *value as u32;
*mod_p3 = *value as u32;
*mod_p4 = *value as u32;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
self.2.fwd(mod_p2);
self.3.fwd(mod_p3);
self.4.fwd(mod_p4);
}
pub fn inv(
&self,
value: &mut [u128],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
mod_p3: &mut [u32],
mod_p4: &mut [u32],
) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
self.2.inv(mod_p2);
self.3.inv(mod_p3);
self.4.inv(mod_p4);
for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4)
{
*value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
}
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut lhs2 = avec![0; n];
let mut lhs3 = avec![0; n];
let mut lhs4 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
let mut rhs2 = avec![0; n];
let mut rhs3 = avec![0; n];
let mut rhs4 = avec![0; n];
self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
self.fwd_binary(rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
self.3.mul_assign_normalize(&mut lhs3, &rhs3);
self.4.mul_assign_normalize(&mut lhs4, &rhs4);
self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::native128::tests::negacyclic_convolution;
use alloc::{vec, vec::Vec};
use rand::random;
extern crate alloc;
#[test]
fn reconstruct_32bit() {
for n in [32, 64, 256, 1024, 2048] {
let plan = Plan32::try_new(n).unwrap();
let lhs = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
let rhs = (0..n).map(|_| random::<u128>() % 2).collect::<Vec<_>>();
let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
}

View File

@@ -0,0 +1,364 @@
use aligned_vec::avec;
#[allow(unused_imports)]
use pulp::*;
use crate::native32::mul_mod32;
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
/// coefficients.
#[derive(Clone, Debug)]
pub struct Plan32(crate::prime32::Plan, crate::prime32::Plan);
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
/// coefficients.
/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
#[derive(Clone, Debug)]
pub struct Plan52(crate::prime64::Plan, crate::V4IFma);
#[inline(always)]
pub(crate) fn reconstruct_32bit_01(mod_p0: u32, mod_p1: u32) -> u32 {
use crate::primes32::*;
let v0 = mod_p0;
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
let sign = v1 > (P1 / 2);
const _0: u32 = P0;
const _01: u32 = _0.wrapping_mul(P1);
let pos = v0.wrapping_add(v1.wrapping_mul(_0));
let neg = pos.wrapping_sub(_01);
if sign {
neg
} else {
pos
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn reconstruct_32bit_01_avx2(simd: crate::V3, mod_p0: u32x8, mod_p1: u32x8) -> u32x8 {
use crate::{native32::mul_mod32_avx2, primes32::*};
let p0 = simd.splat_u32x8(P0);
let p1 = simd.splat_u32x8(P1);
let two_p1 = simd.splat_u32x8(2 * P1);
let half_p1 = simd.splat_u32x8(P1 / 2);
let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
let v0 = mod_p0;
let v1 = mul_mod32_avx2(
simd,
p1,
simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let sign = simd.cmp_gt_u32x8(v1, half_p1);
let pos = simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0));
let neg = simd.wrapping_sub_u32x8(pos, p01);
simd.select_u32x8(sign, neg, pos)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
fn reconstruct_32bit_01_avx512(simd: crate::V4IFma, mod_p0: u32x16, mod_p1: u32x16) -> u32x16 {
use crate::{native32::mul_mod32_avx512, primes32::*};
let p0 = simd.splat_u32x16(P0);
let p1 = simd.splat_u32x16(P1);
let two_p1 = simd.splat_u32x16(2 * P1);
let half_p1 = simd.splat_u32x16(P1 / 2);
let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
let v0 = mod_p0;
let v1 = mul_mod32_avx512(
simd,
p1,
simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let sign = simd.cmp_gt_u32x16(v1, half_p1);
let pos = simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0));
let neg = simd.wrapping_sub_u32x16(pos, p01);
simd.select_u32x16(sign, neg, pos)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
fn reconstruct_52bit_0_avx512(simd: crate::V4IFma, mod_p0: u64x8) -> u32x8 {
use crate::primes52::*;
let p0 = simd.splat_u64x8(P0);
let half_p0 = simd.splat_u64x8(P0 / 2);
let v0 = mod_p0;
let sign = simd.cmp_gt_u64x8(v0, half_p0);
let pos = v0;
let neg = simd.wrapping_sub_u64x8(pos, p0);
simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn reconstruct_slice_32bit_01_avx2(
simd: crate::V3,
value: &mut [u32],
mod_p0: &[u32],
mod_p1: &[u32],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<8, _>(value).0;
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
*value = cast(reconstruct_32bit_01_avx2(simd, cast(mod_p0), cast(mod_p1)));
}
},
);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
fn reconstruct_slice_32bit_01_avx512(
simd: crate::V4IFma,
value: &mut [u32],
mod_p0: &[u32],
mod_p1: &[u32],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<16, _>(value).0;
let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
*value = cast(reconstruct_32bit_01_avx512(
simd,
cast(mod_p0),
cast(mod_p1),
));
}
},
);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
fn reconstruct_slice_52bit_0_avx512(simd: crate::V4IFma, value: &mut [u32], mod_p0: &[u64]) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<8, _>(value).0;
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
for (value, &mod_p0) in crate::izip!(value, mod_p0) {
*value = cast(reconstruct_52bit_0_avx512(simd, cast(mod_p0)));
}
},
);
}
impl Plan32 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime32::Plan, primes32::*};
Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
*mod_p0 = value % crate::primes32::P0;
*mod_p1 = value % crate::primes32::P1;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
}
pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
*mod_p0 = *value;
*mod_p1 = *value;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
}
pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4IFma::try_new() {
reconstruct_slice_32bit_01_avx512(simd, value, mod_p0, mod_p1);
return;
}
if let Some(simd) = crate::V3::try_new() {
reconstruct_slice_32bit_01_avx2(simd, value, mod_p0, mod_p1);
return;
}
}
for (value, &mod_p0, &mod_p1) in crate::izip!(value, &*mod_p0, &*mod_p1) {
*value = reconstruct_32bit_01(mod_p0, mod_p1);
}
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs_binary.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
self.fwd(lhs, &mut lhs0, &mut lhs1);
self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.inv(prod, &mut lhs0, &mut lhs1);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl Plan52 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters, or if the AVX512
/// instruction set isn't detected.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime64::Plan, primes52::*};
let simd = crate::V4IFma::try_new()?;
Some(Self(Plan::try_new(n, P0)?, simd))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64]) {
self.1.vectorize(
#[inline(always)]
|| {
for (value, mod_p0) in crate::izip!(value, &mut *mod_p0) {
*mod_p0 = *value as u64;
}
},
);
self.0.fwd(mod_p0);
}
pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u64]) {
self.fwd(value, mod_p0);
}
pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64]) {
self.0.inv(mod_p0);
let simd = self.1;
reconstruct_slice_52bit_0_avx512(simd, value, mod_p0);
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs_binary.len());
let mut lhs0 = avec![0; n];
let mut rhs0 = avec![0; n];
self.fwd(lhs, &mut lhs0);
self.fwd_binary(rhs_binary, &mut rhs0);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.inv(prod, &mut lhs0);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prime32::tests::negacyclic_convolution;
use alloc::{vec, vec::Vec};
use rand::random;
extern crate alloc;
#[test]
fn reconstruct_32bit() {
for n in [32, 64, 256, 1024, 2048] {
let plan = Plan32::try_new(n).unwrap();
let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn reconstruct_52bit() {
for n in [32, 64, 256, 1024, 2048] {
if let Some(plan) = Plan52::try_new(n) {
let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
}
}

View File

@@ -0,0 +1,563 @@
use aligned_vec::avec;
#[allow(unused_imports)]
use pulp::*;
pub(crate) use crate::native32::mul_mod32;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) use crate::native32::mul_mod32_avx2;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) use crate::native32::{mul_mod32_avx512, mul_mod52_avx512};
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
/// coefficients.
#[derive(Clone, Debug)]
pub struct Plan32(
crate::prime32::Plan,
crate::prime32::Plan,
crate::prime32::Plan,
);
/// Negacyclic NTT plan for multiplying two 32bit polynomials, where the RHS contains binary
/// coefficients.
/// This can be more efficient than [`Plan32`], but requires the AVX512 instruction set.
#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
#[derive(Clone, Debug)]
pub struct Plan52(crate::prime64::Plan, crate::prime64::Plan, crate::V4IFma);
#[inline(always)]
#[allow(dead_code)]
fn reconstruct_32bit_012(mod_p0: u32, mod_p1: u32, mod_p2: u32) -> u64 {
use crate::primes32::*;
let v0 = mod_p0;
let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
let v2 = mul_mod32(
P2,
P01_INV_MOD_P2,
2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
);
let sign = v2 > (P2 / 2);
const _0: u64 = P0 as u64;
const _01: u64 = _0.wrapping_mul(P1 as u64);
const _012: u64 = _01.wrapping_mul(P2 as u64);
let pos = (v0 as u64)
.wrapping_add((v1 as u64).wrapping_mul(_0))
.wrapping_add((v2 as u64).wrapping_mul(_01));
let neg = pos.wrapping_sub(_012);
if sign {
neg
} else {
pos
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[allow(dead_code)]
#[inline(always)]
fn reconstruct_32bit_012_avx2(
simd: crate::V3,
mod_p0: u32x8,
mod_p1: u32x8,
mod_p2: u32x8,
) -> [u64x4; 2] {
use crate::primes32::*;
let p0 = simd.splat_u32x8(P0);
let p1 = simd.splat_u32x8(P1);
let p2 = simd.splat_u32x8(P2);
let two_p1 = simd.splat_u32x8(2 * P1);
let two_p2 = simd.splat_u32x8(2 * P2);
let half_p2 = simd.splat_u32x8(P2 / 2);
let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
let p01 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64));
let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
let v0 = mod_p0;
let v1 = mul_mod32_avx2(
simd,
p1,
simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let v2 = mul_mod32_avx2(
simd,
p2,
simd.wrapping_sub_u32x8(
simd.wrapping_add_u32x8(two_p2, mod_p2),
simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
),
p01_inv_mod_p2,
p01_inv_mod_p2_shoup,
);
let sign = simd.cmp_gt_u32x8(v2, half_p2);
let sign: [i32x4; 2] = pulp::cast(sign);
// sign extend so that -1i32 becomes -1i64
let sign0: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[0])) };
let sign1: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[1])) };
let v0: [u32x4; 2] = pulp::cast(v0);
let v1: [u32x4; 2] = pulp::cast(v1);
let v2: [u32x4; 2] = pulp::cast(v2);
let v00 = simd.convert_u32x4_to_u64x4(v0[0]);
let v01 = simd.convert_u32x4_to_u64x4(v0[1]);
let v10 = simd.convert_u32x4_to_u64x4(v1[0]);
let v11 = simd.convert_u32x4_to_u64x4(v1[1]);
let v20 = simd.convert_u32x4_to_u64x4(v2[0]);
let v21 = simd.convert_u32x4_to_u64x4(v2[1]);
let pos0 = v00;
let pos0 = simd.wrapping_add_u64x4(pos0, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v10));
let pos0 = simd.wrapping_add_u64x4(
pos0,
simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v20),
);
let pos1 = v01;
let pos1 = simd.wrapping_add_u64x4(pos1, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v11));
let pos1 = simd.wrapping_add_u64x4(
pos1,
simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v21),
);
let neg0 = simd.wrapping_sub_u64x4(pos0, p012);
let neg1 = simd.wrapping_sub_u64x4(pos1, p012);
[
simd.select_u64x4(sign0, neg0, pos0),
simd.select_u64x4(sign1, neg1, pos1),
]
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[allow(dead_code)]
#[inline(always)]
fn reconstruct_32bit_012_avx512(
simd: crate::V4IFma,
mod_p0: u32x16,
mod_p1: u32x16,
mod_p2: u32x16,
) -> [u64x8; 2] {
use crate::primes32::*;
let p0 = simd.splat_u32x16(P0);
let p1 = simd.splat_u32x16(P1);
let p2 = simd.splat_u32x16(P2);
let two_p1 = simd.splat_u32x16(2 * P1);
let two_p2 = simd.splat_u32x16(2 * P2);
let half_p2 = simd.splat_u32x16(P2 / 2);
let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
let p01 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64));
let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
let v0 = mod_p0;
let v1 = mul_mod32_avx512(
simd,
p1,
simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let v2 = mul_mod32_avx512(
simd,
p2,
simd.wrapping_sub_u32x16(
simd.wrapping_add_u32x16(two_p2, mod_p2),
simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
),
p01_inv_mod_p2,
p01_inv_mod_p2_shoup,
);
let sign = simd.cmp_gt_u32x16(v2, half_p2).0;
let sign0 = b8(sign as u8);
let sign1 = b8((sign >> 8) as u8);
let v0: [u32x8; 2] = pulp::cast(v0);
let v1: [u32x8; 2] = pulp::cast(v1);
let v2: [u32x8; 2] = pulp::cast(v2);
let v00 = simd.convert_u32x8_to_u64x8(v0[0]);
let v01 = simd.convert_u32x8_to_u64x8(v0[1]);
let v10 = simd.convert_u32x8_to_u64x8(v1[0]);
let v11 = simd.convert_u32x8_to_u64x8(v1[1]);
let v20 = simd.convert_u32x8_to_u64x8(v2[0]);
let v21 = simd.convert_u32x8_to_u64x8(v2[1]);
let pos0 = v00;
let pos0 = simd.wrapping_add_u64x8(pos0, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v10));
let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p01, v20));
let pos1 = v01;
let pos1 = simd.wrapping_add_u64x8(pos1, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v11));
let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p01, v21));
let neg0 = simd.wrapping_sub_u64x8(pos0, p012);
let neg1 = simd.wrapping_sub_u64x8(pos1, p012);
[
simd.select_u64x8(sign0, neg0, pos0),
simd.select_u64x8(sign1, neg1, pos1),
]
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
fn reconstruct_52bit_01_avx512(simd: crate::V4IFma, mod_p0: u64x8, mod_p1: u64x8) -> u64x8 {
use crate::primes52::*;
let p0 = simd.splat_u64x8(P0);
let p1 = simd.splat_u64x8(P1);
let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
let two_p1 = simd.splat_u64x8(2 * P1);
let half_p1 = simd.splat_u64x8(P1 / 2);
let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
let v0 = mod_p0;
let v1 = mul_mod52_avx512(
simd,
p1,
neg_p1,
simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
p0_inv_mod_p1,
p0_inv_mod_p1_shoup,
);
let sign = simd.cmp_gt_u64x8(v1, half_p1);
let pos = simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0));
let neg = simd.wrapping_sub_u64x8(pos, p01);
simd.select_u64x8(sign, neg, pos)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn reconstruct_slice_32bit_012_avx2(
simd: crate::V3,
value: &mut [u64],
mod_p0: &[u32],
mod_p1: &[u32],
mod_p2: &[u32],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<8, _>(value).0;
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
*value = cast(reconstruct_32bit_012_avx2(
simd,
cast(mod_p0),
cast(mod_p1),
cast(mod_p2),
));
}
},
);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
fn reconstruct_slice_32bit_012_avx512(
simd: crate::V4IFma,
value: &mut [u64],
mod_p0: &[u32],
mod_p1: &[u32],
mod_p2: &[u32],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<16, _>(value).0;
let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
let mod_p2 = pulp::as_arrays::<16, _>(mod_p2).0;
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
*value = cast(reconstruct_32bit_012_avx512(
simd,
cast(mod_p0),
cast(mod_p1),
cast(mod_p2),
));
}
},
);
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
fn reconstruct_slice_52bit_01_avx512(
simd: crate::V4IFma,
value: &mut [u64],
mod_p0: &[u64],
mod_p1: &[u64],
) {
simd.vectorize(
#[inline(always)]
move || {
let value = pulp::as_arrays_mut::<8, _>(value).0;
let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
*value = cast(reconstruct_52bit_01_avx512(
simd,
cast(mod_p0),
cast(mod_p1),
));
}
},
);
}
impl Plan32 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime32::Plan, primes32::*};
Some(Self(
Plan::try_new(n, P0)?,
Plan::try_new(n, P1)?,
Plan::try_new(n, P2)?,
))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
pub fn fwd(&self, value: &[u64], mod_p0: &mut [u32], mod_p1: &mut [u32], mod_p2: &mut [u32]) {
for (value, mod_p0, mod_p1, mod_p2) in
crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
{
*mod_p0 = (value % crate::primes32::P0 as u64) as u32;
*mod_p1 = (value % crate::primes32::P1 as u64) as u32;
*mod_p2 = (value % crate::primes32::P2 as u64) as u32;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
self.2.fwd(mod_p2);
}
pub fn fwd_binary(
&self,
value: &[u64],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
) {
for (value, mod_p0, mod_p1, mod_p2) in
crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
{
*mod_p0 = *value as u32;
*mod_p1 = *value as u32;
*mod_p2 = *value as u32;
}
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
self.2.fwd(mod_p2);
}
pub fn inv(
&self,
value: &mut [u64],
mod_p0: &mut [u32],
mod_p1: &mut [u32],
mod_p2: &mut [u32],
) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
self.2.inv(mod_p2);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
#[cfg(feature = "nightly")]
if let Some(simd) = crate::V4IFma::try_new() {
reconstruct_slice_32bit_012_avx512(simd, value, mod_p0, mod_p1, mod_p2);
return;
}
if let Some(simd) = crate::V3::try_new() {
reconstruct_slice_32bit_012_avx2(simd, value, mod_p0, mod_p1, mod_p2);
return;
}
}
for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2)
{
*value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
}
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs_binary: &[u64]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs_binary.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut lhs2 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
let mut rhs2 = avec![0; n];
self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1, &mut rhs2);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.2.mul_assign_normalize(&mut lhs2, &rhs2);
self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
impl Plan52 {
/// Returns a negacyclic NTT plan for the given polynomial size, or `None` if no
/// suitable roots of unity can be found for the wanted parameters, or if the AVX512
/// instruction set isn't detected.
pub fn try_new(n: usize) -> Option<Self> {
use crate::{prime64::Plan, primes52::*};
let simd = crate::V4IFma::try_new()?;
Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?, simd))
}
/// Returns the polynomial size of the negacyclic NTT plan.
#[inline]
pub fn ntt_size(&self) -> usize {
self.0.ntt_size()
}
pub fn fwd(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
use crate::primes52::*;
self.2.vectorize(
#[inline(always)]
|| {
for (&value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
*mod_p0 = value % P0;
*mod_p1 = value % P1;
}
},
);
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
}
pub fn fwd_binary(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
self.2.vectorize(
#[inline(always)]
|| {
for (&value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
*mod_p0 = value;
*mod_p1 = value;
}
},
);
self.0.fwd(mod_p0);
self.1.fwd(mod_p1);
}
pub fn inv(&self, value: &mut [u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
self.0.inv(mod_p0);
self.1.inv(mod_p1);
reconstruct_slice_52bit_01_avx512(self.2, value, mod_p0, mod_p1);
}
/// Computes the negacyclic polynomial product of `lhs` and `rhs`, and stores the result in
/// `prod`.
pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs_binary: &[u64]) {
let n = prod.len();
assert_eq!(n, lhs.len());
assert_eq!(n, rhs_binary.len());
let mut lhs0 = avec![0; n];
let mut lhs1 = avec![0; n];
let mut rhs0 = avec![0; n];
let mut rhs1 = avec![0; n];
self.fwd(lhs, &mut lhs0, &mut lhs1);
self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
self.0.mul_assign_normalize(&mut lhs0, &rhs0);
self.1.mul_assign_normalize(&mut lhs1, &rhs1);
self.inv(prod, &mut lhs0, &mut lhs1);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prime64::tests::negacyclic_convolution;
use alloc::{vec, vec::Vec};
use rand::random;
extern crate alloc;
#[test]
fn reconstruct_32bit() {
for n in [32, 64, 256, 1024, 2048] {
let plan = Plan32::try_new(n).unwrap();
let lhs = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
let rhs = (0..n).map(|_| random::<u64>() % 2).collect::<Vec<_>>();
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn reconstruct_52bit() {
for n in [32, 64, 256, 1024, 2048] {
if let Some(plan) = Plan52::try_new(n) {
let lhs = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
let rhs = (0..n).map(|_| random::<u64>() % 2).collect::<Vec<_>>();
let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
let mut prod = vec![0; n];
plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
assert_eq!(prod, negacyclic_convolution);
}
}
}
}

223
tfhe-ntt/src/prime.rs Normal file
View File

@@ -0,0 +1,223 @@
use crate::fastdiv::{Div32, Div64};
#[inline(always)]
pub const fn mul_mod32(n: Div32, x: u32, y: u32) -> u32 {
Div32::rem_u64(x as u64 * y as u64, n)
}
#[inline(always)]
pub const fn mul_mod64(n: Div64, x: u64, y: u64) -> u64 {
Div64::rem_u128(x as u128 * y as u128, n)
}
pub const fn exp_mod32(n: Div32, base: u32, pow: u32) -> u32 {
if pow == 0 {
1
} else {
let mut pow = pow;
let mut y = 1;
let mut x = base;
while pow > 1 {
if pow % 2 == 1 {
y = mul_mod32(n, x, y);
}
x = mul_mod32(n, x, x);
pow /= 2;
}
mul_mod32(n, x, y)
}
}
pub const fn exp_mod64(n: Div64, base: u64, pow: u64) -> u64 {
if pow == 0 {
1
} else {
let mut pow = pow;
let mut y = 1;
let mut x = base;
while pow > 1 {
if pow % 2 == 1 {
y = mul_mod64(n, x, y);
}
x = mul_mod64(n, x, x);
pow /= 2;
}
mul_mod64(n, x, y)
}
}
const fn is_prime_miller_rabin_iter(n: Div64, s: u64, d: u64, a: u64) -> bool {
let mut x = exp_mod64(n, a, d);
let n_minus_1 = n.divisor() - 1;
if x == 1 || x == n_minus_1 {
true
} else {
let mut count = 0;
while count < s - 1 {
x = mul_mod64(n, x, x);
if x == n_minus_1 {
return true;
}
count += 1;
}
false
}
}
const fn max64(a: u64, b: u64) -> u64 {
if a > b {
a
} else {
b
}
}
pub const fn is_prime64(n: u64) -> bool {
// 0 and 1 are not prime
if n < 2 {
return false;
}
// test divisibility by small primes
// hand-unrolled for the compiler to optimize divisions
#[rustfmt::skip]
{
if n % 2 == 0 { return n == 2; }
if n % 3 == 0 { return n == 3; }
if n % 5 == 0 { return n == 5; }
if n % 7 == 0 { return n == 7; }
if n % 11 == 0 { return n == 11; }
if n % 13 == 0 { return n == 13; }
if n % 17 == 0 { return n == 17; }
if n % 19 == 0 { return n == 19; }
if n % 23 == 0 { return n == 23; }
if n % 29 == 0 { return n == 29; }
if n % 31 == 0 { return n == 31; }
if n % 37 == 0 { return n == 37; }
};
// deterministic miller rabin test, works for any n < 2^64
// aside from the primes tested just before
// https://en.wikipedia.org/wiki/Miller-Rabin_primality_test#Testing_against_small_sets_of_bases
let mut s = 0;
let mut d = n - 1;
while d % 2 == 0 {
s += 1;
d /= 2;
}
let (s, d) = (s, d);
let n = Div64::new(n);
is_prime_miller_rabin_iter(n, s, d, 2)
&& is_prime_miller_rabin_iter(n, s, d, 3)
&& is_prime_miller_rabin_iter(n, s, d, 5)
&& is_prime_miller_rabin_iter(n, s, d, 7)
&& is_prime_miller_rabin_iter(n, s, d, 11)
&& is_prime_miller_rabin_iter(n, s, d, 13)
&& is_prime_miller_rabin_iter(n, s, d, 17)
&& is_prime_miller_rabin_iter(n, s, d, 19)
&& is_prime_miller_rabin_iter(n, s, d, 23)
&& is_prime_miller_rabin_iter(n, s, d, 29)
&& is_prime_miller_rabin_iter(n, s, d, 31)
&& is_prime_miller_rabin_iter(n, s, d, 37)
}
/// Largest prime of the form `factor * x + offset` in the range
/// `[lo, hi]`.
pub const fn largest_prime_in_arithmetic_progression64(
factor: u64,
offset: u64,
lo: u64,
hi: u64,
) -> Option<u64> {
if lo > hi {
return None;
}
let a = factor;
let b = offset;
// lo <= ax + b <= hi
// (lo - b)/a <= x <= (hi - b)/a
if b > hi {
return None;
}
if a == 0 {
if lo <= b && b <= hi && is_prime64(b) {
return Some(b);
} else {
return None;
}
}
let mut x_lo = (max64(lo, b) - b) / a;
let rem = (max64(lo, b) - b) % a;
if rem != 0 {
x_lo += 1;
}
let x_hi = (hi - b) / a;
let mut x = x_hi;
let mut in_range = true;
while in_range {
let val = a * x + b;
if is_prime64(val) {
return Some(val);
}
if x == x_lo {
in_range = false;
} else {
x -= 1;
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prime64::Solinas;
#[test]
fn test_is_prime() {
let primes_under_1000 = [
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83,
89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179,
181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271,
277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379,
383, 389, 397, 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479,
487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599,
601, 607, 613, 617, 619, 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701,
709, 719, 727, 733, 739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823,
827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911, 919, 929, 937, 941,
947, 953, 967, 971, 977, 983, 991, 997,
];
for n in 0..1000 {
assert_eq!(primes_under_1000.contains(&n), is_prime64(n));
}
assert!(is_prime64(Solinas::P));
}
#[test]
#[rustfmt::skip]
fn test_prime_search() {
assert_eq!(largest_prime_in_arithmetic_progression64(0, 2, 1, 4), Some(2));
assert_eq!(largest_prime_in_arithmetic_progression64(0, 2, 2, 2), Some(2));
assert_eq!(largest_prime_in_arithmetic_progression64(0, 2, 2, 1), None);
assert_eq!(largest_prime_in_arithmetic_progression64(1, 0, 14, 16), None);
assert_eq!(largest_prime_in_arithmetic_progression64(1, 0, 14, 17), Some(17));
assert_eq!(largest_prime_in_arithmetic_progression64(1, 0, 17, 18), Some(17));
assert_eq!(largest_prime_in_arithmetic_progression64(2, 1, 14, 16), None);
assert_eq!(largest_prime_in_arithmetic_progression64(2, 1, 14, 17), Some(17));
assert_eq!(largest_prime_in_arithmetic_progression64(2, 1, 17, 18), Some(17));
assert_eq!(largest_prime_in_arithmetic_progression64(6, 5, 0, u64::MAX) , Some(18446744073709551557));
assert_eq!(largest_prime_in_arithmetic_progression64(6, 1, 0, u64::MAX) , Some(18446744073709551427));
}
}

1648
tfhe-ntt/src/prime32.rs Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,616 @@
#[allow(unused_imports)]
use pulp::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let _ = p;
let z0 = simd.small_mod_u32x16(two_p, z0);
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
let t = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(z1, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
(
simd.wrapping_add_u32x16(z0, t),
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), two_p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let z0 = simd.small_mod_u32x16(two_p, z0);
let z0 = simd.small_mod_u32x16(p, z0);
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
let t = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(z1, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
let t = simd.small_mod_u32x16(p, t);
(
simd.small_mod_u32x16(p, simd.wrapping_add_u32x16(z0, t)),
simd.small_mod_u32x16(
p,
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), p),
),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let _ = p;
let z0 = simd.small_mod_u32x8(two_p, z0);
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
let t = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(z1, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
(
simd.wrapping_add_u32x8(z0, t),
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), two_p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let z0 = simd.small_mod_u32x8(two_p, z0);
let z0 = simd.small_mod_u32x8(p, z0);
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
let t = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(z1, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
let t = simd.small_mod_u32x8(p, t);
(
simd.small_mod_u32x8(p, simd.wrapping_add_u32x8(z0, t)),
simd.small_mod_u32x8(
p,
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), p),
),
)
}
#[inline(always)]
pub(crate) fn fwd_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = p;
let z0 = z0.min(z0.wrapping_sub(two_p));
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(two_p))
}
#[inline(always)]
pub(crate) fn fwd_last_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = p;
let z0 = z0.min(z0.wrapping_sub(two_p));
let z0 = z0.min(z0.wrapping_sub(p));
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let t = t.min(t.wrapping_sub(p));
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
(
res.0.min(res.0.wrapping_sub(p)),
res.1.min(res.1.wrapping_sub(p)),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn inv_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let _ = p;
let y0 = simd.wrapping_add_u32x16(z0, z1);
let y0 = simd.small_mod_u32x16(two_p, y0);
let t = simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, z1), two_p);
let shoup_q = simd.widening_mul_u32x16(t, w_shoup).1;
let y1 = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(t, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn inv_last_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let _ = p;
let y0 = simd.wrapping_add_u32x16(z0, z1);
let y0 = simd.small_mod_u32x16(two_p, y0);
let t = simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, z1), two_p);
let shoup_q = simd.widening_mul_u32x16(t, w_shoup).1;
let y1 = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(t, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
(simd.small_mod_u32x16(p, y0), simd.small_mod_u32x16(p, y1))
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn inv_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let _ = p;
let y0 = simd.wrapping_add_u32x8(z0, z1);
let y0 = simd.small_mod_u32x8(two_p, y0);
let t = simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, z1), two_p);
let shoup_q = simd.widening_mul_u32x8(t, w_shoup).1;
let y1 = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(t, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn inv_last_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let _ = p;
let y0 = simd.wrapping_add_u32x8(z0, z1);
let y0 = simd.small_mod_u32x8(two_p, y0);
let t = simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, z1), two_p);
let shoup_q = simd.widening_mul_u32x8(t, w_shoup).1;
let y1 = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(t, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
(simd.small_mod_u32x8(p, y0), simd.small_mod_u32x8(p, y1))
}
#[inline(always)]
pub(crate) fn inv_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = p;
let y0 = z0.wrapping_add(z1);
let y0 = y0.min(y0.wrapping_sub(two_p));
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
let shoup_q = ((t as u64 * w_shoup as u64) >> 32) as u32;
let y1 = u32::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
(y0, y1)
}
#[inline(always)]
pub(crate) fn inv_last_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = p;
let y0 = z0.wrapping_add(z1);
let y0 = y0.min(y0.wrapping_sub(two_p));
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
let shoup_q = ((t as u64 * w_shoup as u64) >> 32) as u32;
let y1 = u32::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
(y0.min(y0.wrapping_sub(p)), y1.min(y1.wrapping_sub(p)))
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn fwd_avx512(
simd: crate::V4,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::fwd_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn inv_avx512(
simd: crate::V4,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::inv_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn fwd_avx2(
simd: crate::V3,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::fwd_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn inv_avx2(
simd: crate::V3,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::inv_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn fwd_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
super::shoup::fwd_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn inv_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
super::shoup::inv_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prime::largest_prime_in_arithmetic_progression64,
prime32::{
init_negacyclic_twiddles_shoup,
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
},
};
extern crate alloc;
use alloc::vec;
#[test]
fn test_product() {
for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30).unwrap()
as u32;
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u32; n];
let mut twid_shoup = vec![0u32; n];
let mut inv_twid = vec![0u32; n];
let mut inv_twid_shoup = vec![0u32; n];
init_negacyclic_twiddles_shoup(
p,
n,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u32; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn test_product_avx2() {
if let Some(simd) = crate::V3::try_new() {
for n in [32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30)
.unwrap() as u32;
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u32; n];
let mut twid_shoup = vec![0u32; n];
let mut inv_twid = vec![0u32; n];
let mut inv_twid_shoup = vec![0u32; n];
init_negacyclic_twiddles_shoup(
p,
n,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u32; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
}
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn test_product_avx512() {
if let Some(simd) = crate::V4::try_new() {
for n in [32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 29, 1 << 30)
.unwrap() as u32;
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u32; n];
let mut twid_shoup = vec![0u32; n];
let mut inv_twid = vec![0u32; n];
let mut inv_twid_shoup = vec![0u32; n];
init_negacyclic_twiddles_shoup(
p,
n,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u32; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
}
}
}
}
}

View File

@@ -0,0 +1,546 @@
#[allow(unused_imports)]
use pulp::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let _ = two_p;
let z0 = simd.small_mod_u32x16(p, z0);
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
let t = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(z1, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
let t = simd.small_mod_u32x16(p, t);
(
simd.wrapping_add_u32x16(z0, t),
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let _ = two_p;
let z0 = simd.small_mod_u32x16(p, z0);
let shoup_q = simd.widening_mul_u32x16(z1, w_shoup).1;
let t = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(z1, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
let t = simd.small_mod_u32x16(p, t);
(
simd.small_mod_u32x16(p, simd.wrapping_add_u32x16(z0, t)),
simd.small_mod_u32x16(
p,
simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, t), p),
),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let _ = two_p;
let z0 = simd.small_mod_u32x8(p, z0);
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
let t = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(z1, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
let t = simd.small_mod_u32x8(p, t);
(
simd.wrapping_add_u32x8(z0, t),
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let _ = two_p;
let z0 = simd.small_mod_u32x8(p, z0);
let shoup_q = simd.widening_mul_u32x8(z1, w_shoup).1;
let t = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(z1, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
let t = simd.small_mod_u32x8(p, t);
(
simd.small_mod_u32x8(p, simd.wrapping_add_u32x8(z0, t)),
simd.small_mod_u32x8(
p,
simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, t), p),
),
)
}
#[inline(always)]
pub(crate) fn fwd_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = two_p;
let z0 = z0.min(z0.wrapping_sub(p));
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let t = t.min(t.wrapping_sub(p));
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p))
}
#[inline(always)]
pub(crate) fn fwd_last_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = two_p;
let z0 = z0.min(z0.wrapping_sub(p));
let shoup_q = ((z1 as u64 * w_shoup as u64) >> 32) as u32;
let t = u32::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let t = t.min(t.wrapping_sub(p));
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
(
res.0.min(res.0.wrapping_sub(p)),
res.1.min(res.1.wrapping_sub(p)),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn inv_butterfly_avx512(
simd: crate::V4,
z0: u32x16,
z1: u32x16,
w: u32x16,
w_shoup: u32x16,
p: u32x16,
neg_p: u32x16,
two_p: u32x16,
) -> (u32x16, u32x16) {
let _ = two_p;
let y0 = simd.wrapping_add_u32x16(z0, z1);
let y0 = simd.small_mod_u32x16(p, y0);
let t = simd.wrapping_add_u32x16(simd.wrapping_sub_u32x16(z0, z1), p);
let shoup_q = simd.widening_mul_u32x16(t, w_shoup).1;
let y1 = simd.wrapping_add_u32x16(
simd.wrapping_mul_u32x16(t, w),
simd.wrapping_mul_u32x16(shoup_q, neg_p),
);
let y1 = simd.small_mod_u32x16(p, y1);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn inv_butterfly_avx2(
simd: crate::V3,
z0: u32x8,
z1: u32x8,
w: u32x8,
w_shoup: u32x8,
p: u32x8,
neg_p: u32x8,
two_p: u32x8,
) -> (u32x8, u32x8) {
let _ = two_p;
let y0 = simd.wrapping_add_u32x8(z0, z1);
let y0 = simd.small_mod_u32x8(p, y0);
let t = simd.wrapping_add_u32x8(simd.wrapping_sub_u32x8(z0, z1), p);
let shoup_q = simd.widening_mul_u32x8(t, w_shoup).1;
let y1 = simd.wrapping_add_u32x8(
simd.wrapping_mul_u32x8(t, w),
simd.wrapping_mul_u32x8(shoup_q, neg_p),
);
let y1 = simd.small_mod_u32x8(p, y1);
(y0, y1)
}
#[inline(always)]
pub(crate) fn inv_butterfly_scalar(
z0: u32,
z1: u32,
w: u32,
w_shoup: u32,
p: u32,
neg_p: u32,
two_p: u32,
) -> (u32, u32) {
let _ = two_p;
let y0 = z0.wrapping_add(z1);
let y0 = y0.min(y0.wrapping_sub(p));
let t = z0.wrapping_sub(z1).wrapping_add(p);
let shoup_q = ((t as u64 * w_shoup as u64) >> 32) as u32;
let y1 = u32::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let y1 = y1.min(y1.wrapping_sub(p));
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn fwd_avx512(
simd: crate::V4,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::fwd_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn inv_avx512(
simd: crate::V4,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::inv_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn fwd_avx2(
simd: crate::V3,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::fwd_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn inv_avx2(
simd: crate::V3,
p: u32,
data: &mut [u32],
twid: &[u32],
twid_shoup: &[u32],
) {
super::shoup::inv_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn fwd_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
super::shoup::fwd_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn inv_scalar(p: u32, data: &mut [u32], twid: &[u32], twid_shoup: &[u32]) {
super::shoup::inv_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|(), z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prime::largest_prime_in_arithmetic_progression64,
prime32::{
init_negacyclic_twiddles_shoup,
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
},
};
extern crate alloc;
use alloc::vec;
#[test]
fn test_product() {
for n in [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap()
as u32;
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u32; n];
let mut twid_shoup = vec![0u32; n];
let mut inv_twid = vec![0u32; n];
let mut inv_twid_shoup = vec![0u32; n];
init_negacyclic_twiddles_shoup(
p,
n,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u32; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn test_product_avx2() {
if let Some(simd) = crate::V3::try_new() {
for n in [32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31)
.unwrap() as u32;
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u32; n];
let mut twid_shoup = vec![0u32; n];
let mut inv_twid = vec![0u32; n];
let mut inv_twid_shoup = vec![0u32; n];
init_negacyclic_twiddles_shoup(
p,
n,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u32; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
}
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn test_product_avx512() {
if let Some(simd) = crate::V4::try_new() {
for n in [32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31)
.unwrap() as u32;
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u32; n];
let mut twid_shoup = vec![0u32; n];
let mut inv_twid = vec![0u32; n];
let mut inv_twid_shoup = vec![0u32; n];
init_negacyclic_twiddles_shoup(
p,
n,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u32; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u32));
}
}
}
}
}

File diff suppressed because it is too large Load Diff

1883
tfhe-ntt/src/prime64.rs Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,213 @@
use pulp::u64x8;
#[inline(always)]
pub(crate) fn fwd_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let z0 = simd.small_mod_u64x8(two_p, z0);
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
(
simd.wrapping_add_u64x8(z0, t),
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), two_p),
)
}
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let z0 = simd.small_mod_u64x8(two_p, z0);
let z0 = simd.small_mod_u64x8(p, z0);
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
let t = simd.small_mod_u64x8(p, t);
(
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
simd.small_mod_u64x8(
p,
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
),
)
}
#[inline(always)]
pub(crate) fn inv_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let y0 = simd.wrapping_add_u64x8(z0, z1);
let y0 = simd.small_mod_u64x8(two_p, y0);
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
let shoup_q = simd.widening_mul_u52x8(t, w_shoup).1;
let y1 = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(t, w).0);
(y0, y1)
}
#[inline(always)]
pub(crate) fn inv_last_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let y0 = simd.wrapping_add_u64x8(z0, z1);
let y0 = simd.small_mod_u64x8(two_p, y0);
let y0 = simd.small_mod_u64x8(p, y0);
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
let shoup_q = simd.widening_mul_u52x8(t, w_shoup).1;
let y1 = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(t, w).0);
let y1 = simd.small_mod_u64x8(p, y1);
(y0, y1)
}
pub(crate) fn fwd_avx512(
simd: crate::V4IFma,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::fwd_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn inv_avx512(
simd: crate::V4IFma,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::inv_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prime::largest_prime_in_arithmetic_progression64,
prime64::{
init_negacyclic_twiddles_shoup,
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
},
};
use alloc::vec;
extern crate alloc;
#[test]
fn test_product() {
if let Some(simd) = crate::V4IFma::try_new() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 49, 1 << 50)
.unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
52,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
}
}
}
}
}

View File

@@ -0,0 +1,190 @@
use pulp::u64x8;
#[inline(always)]
pub(crate) fn fwd_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = two_p;
let z0 = simd.small_mod_u64x8(p, z0);
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
let t = simd.small_mod_u64x8(p, t);
(
simd.wrapping_add_u64x8(z0, t),
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
)
}
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = two_p;
let z0 = simd.small_mod_u64x8(p, z0);
let shoup_q = simd.widening_mul_u52x8(z1, w_shoup).1;
let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(z1, w).0);
let t = simd.small_mod_u64x8(p, t);
(
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
simd.small_mod_u64x8(
p,
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
),
)
}
#[inline(always)]
pub(crate) fn inv_butterfly_avx512(
simd: crate::V4IFma,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = two_p;
let y0 = simd.wrapping_add_u64x8(z0, z1);
let y0 = simd.small_mod_u64x8(p, y0);
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), p);
let shoup_q = simd.widening_mul_u52x8(t, w_shoup).1;
let y1 = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(t, w).0);
let y1 = simd.small_mod_u64x8(p, y1);
(y0, y1)
}
pub(crate) fn fwd_avx512(
simd: crate::V4IFma,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::fwd_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn inv_avx512(
simd: crate::V4IFma,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::inv_breadth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prime::largest_prime_in_arithmetic_progression64,
prime64::{
init_negacyclic_twiddles_shoup,
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
},
};
use alloc::vec;
extern crate alloc;
#[test]
fn test_product() {
if let Some(simd) = crate::V4IFma::try_new() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51)
.unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
52,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
}
}
}
}
}

View File

@@ -0,0 +1,629 @@
#[allow(unused_imports)]
use pulp::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let z0 = simd.small_mod_u64x8(two_p, z0);
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
let t = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(z1, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
(
simd.wrapping_add_u64x8(z0, t),
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), two_p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let z0 = simd.small_mod_u64x8(two_p, z0);
let z0 = simd.small_mod_u64x8(p, z0);
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
let t = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(z1, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
let t = simd.small_mod_u64x8(p, t);
(
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
simd.small_mod_u64x8(
p,
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = p;
let z0 = simd.small_mod_u64x4(two_p, z0);
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
let t = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(z1, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
(
simd.wrapping_add_u64x4(z0, t),
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), two_p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = p;
let z0 = simd.small_mod_u64x4(two_p, z0);
let z0 = simd.small_mod_u64x4(p, z0);
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
let t = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(z1, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
let t = simd.small_mod_u64x4(p, t);
(
simd.small_mod_u64x4(p, simd.wrapping_add_u64x4(z0, t)),
simd.small_mod_u64x4(
p,
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), p),
),
)
}
#[inline(always)]
pub(crate) fn fwd_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = p;
let z0 = z0.min(z0.wrapping_sub(two_p));
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(two_p))
}
#[inline(always)]
pub(crate) fn fwd_last_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = p;
let z0 = z0.min(z0.wrapping_sub(two_p));
let z0 = z0.min(z0.wrapping_sub(p));
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let t = t.min(t.wrapping_sub(p));
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
(
res.0.min(res.0.wrapping_sub(p)),
res.1.min(res.1.wrapping_sub(p)),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn inv_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let y0 = simd.wrapping_add_u64x8(z0, z1);
let y0 = simd.small_mod_u64x8(two_p, y0);
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
let shoup_q = simd.widening_mul_u64x8(t, w_shoup).1;
let y1 = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(t, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn inv_last_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = p;
let y0 = simd.wrapping_add_u64x8(z0, z1);
let y0 = simd.small_mod_u64x8(two_p, y0);
let y0 = simd.small_mod_u64x8(p, y0);
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), two_p);
let shoup_q = simd.widening_mul_u64x8(t, w_shoup).1;
let y1 = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(t, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
let y1 = simd.small_mod_u64x8(p, y1);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn inv_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = p;
let y0 = simd.wrapping_add_u64x4(z0, z1);
let y0 = simd.small_mod_u64x4(two_p, y0);
let t = simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, z1), two_p);
let shoup_q = simd.widening_mul_u64x4(t, w_shoup).1;
let y1 = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(t, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn inv_last_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = p;
let y0 = simd.wrapping_add_u64x4(z0, z1);
let y0 = simd.small_mod_u64x4(two_p, y0);
let y0 = simd.small_mod_u64x4(p, y0);
let t = simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, z1), two_p);
let shoup_q = simd.widening_mul_u64x4(t, w_shoup).1;
let y1 = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(t, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
let y1 = simd.small_mod_u64x4(p, y1);
(y0, y1)
}
#[inline(always)]
pub(crate) fn inv_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = p;
let y0 = z0.wrapping_add(z1);
let y0 = y0.min(y0.wrapping_sub(two_p));
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
let shoup_q = ((t as u128 * w_shoup as u128) >> 64) as u64;
let y1 = u64::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
(y0, y1)
}
#[inline(always)]
pub(crate) fn inv_last_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = p;
let y0 = z0.wrapping_add(z1);
let y0 = y0.min(y0.wrapping_sub(two_p));
let y0 = y0.min(y0.wrapping_sub(p));
let t = z0.wrapping_sub(z1).wrapping_add(two_p);
let shoup_q = ((t as u128 * w_shoup as u128) >> 64) as u64;
let y1 = u64::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let y1 = y1.min(y1.wrapping_sub(p));
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn fwd_avx512(
simd: crate::V4,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::fwd_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn inv_avx512(
simd: crate::V4,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::inv_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn fwd_avx2(
simd: crate::V3,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::fwd_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn inv_avx2(
simd: crate::V3,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::inv_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn fwd_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
super::shoup::fwd_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn inv_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
super::shoup::inv_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prime::largest_prime_in_arithmetic_progression64,
prime64::{
init_negacyclic_twiddles_shoup,
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
},
};
use alloc::vec;
extern crate alloc;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn test_product_avx512() {
if let Some(simd) = crate::V4::try_new() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62)
.unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
64,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
}
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn test_product_avx2() {
use crate::prime64::tests::mul;
if let Some(simd) = crate::V3::try_new() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62)
.unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
64,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
}
}
}
}
#[test]
fn test_product_scalar() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p =
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 61, 1 << 62).unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
64,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64));
}
}
}
}

View File

@@ -0,0 +1,549 @@
#[allow(unused_imports)]
use pulp::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = two_p;
let z0 = simd.small_mod_u64x8(p, z0);
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
let t = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(z1, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
let t = simd.small_mod_u64x8(p, t);
(
simd.wrapping_add_u64x8(z0, t),
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = two_p;
let z0 = simd.small_mod_u64x8(p, z0);
let shoup_q = simd.widening_mul_u64x8(z1, w_shoup).1;
let t = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(z1, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
let t = simd.small_mod_u64x8(p, t);
(
simd.small_mod_u64x8(p, simd.wrapping_add_u64x8(z0, t)),
simd.small_mod_u64x8(
p,
simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, t), p),
),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = two_p;
let z0 = simd.small_mod_u64x4(p, z0);
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
let t = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(z1, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
let t = simd.small_mod_u64x4(p, t);
(
simd.wrapping_add_u64x4(z0, t),
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), p),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn fwd_last_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = two_p;
let z0 = simd.small_mod_u64x4(p, z0);
let shoup_q = simd.widening_mul_u64x4(z1, w_shoup).1;
let t = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(z1, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
let t = simd.small_mod_u64x4(p, t);
(
simd.small_mod_u64x4(p, simd.wrapping_add_u64x4(z0, t)),
simd.small_mod_u64x4(
p,
simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, t), p),
),
)
}
#[inline(always)]
pub(crate) fn fwd_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = two_p;
let z0 = z0.min(z0.wrapping_sub(p));
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let t = t.min(t.wrapping_sub(p));
(z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p))
}
#[inline(always)]
pub(crate) fn fwd_last_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = two_p;
let z0 = z0.min(z0.wrapping_sub(p));
let shoup_q = ((z1 as u128 * w_shoup as u128) >> 64) as u64;
let t = u64::wrapping_add(z1.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let t = t.min(t.wrapping_sub(p));
let res = (z0.wrapping_add(t), z0.wrapping_sub(t).wrapping_add(p));
(
res.0.min(res.0.wrapping_sub(p)),
res.1.min(res.1.wrapping_sub(p)),
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[inline(always)]
pub(crate) fn inv_butterfly_avx512(
simd: crate::V4,
z0: u64x8,
z1: u64x8,
w: u64x8,
w_shoup: u64x8,
p: u64x8,
neg_p: u64x8,
two_p: u64x8,
) -> (u64x8, u64x8) {
let _ = two_p;
let y0 = simd.wrapping_add_u64x8(z0, z1);
let y0 = simd.small_mod_u64x8(p, y0);
let t = simd.wrapping_add_u64x8(simd.wrapping_sub_u64x8(z0, z1), p);
let shoup_q = simd.widening_mul_u64x8(t, w_shoup).1;
let y1 = simd.wrapping_add_u64x8(
simd.wrapping_mul_u64x8(t, w),
simd.wrapping_mul_u64x8(shoup_q, neg_p),
);
let y1 = simd.small_mod_u64x8(p, y1);
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[inline(always)]
pub(crate) fn inv_butterfly_avx2(
simd: crate::V3,
z0: u64x4,
z1: u64x4,
w: u64x4,
w_shoup: u64x4,
p: u64x4,
neg_p: u64x4,
two_p: u64x4,
) -> (u64x4, u64x4) {
let _ = two_p;
let y0 = simd.wrapping_add_u64x4(z0, z1);
let y0 = simd.small_mod_u64x4(p, y0);
let t = simd.wrapping_add_u64x4(simd.wrapping_sub_u64x4(z0, z1), p);
let shoup_q = simd.widening_mul_u64x4(t, w_shoup).1;
let y1 = simd.wrapping_add_u64x4(
simd.widening_mul_u64x4(t, w).0,
simd.widening_mul_u64x4(shoup_q, neg_p).0,
);
let y1 = simd.small_mod_u64x4(p, y1);
(y0, y1)
}
#[inline(always)]
pub(crate) fn inv_butterfly_scalar(
z0: u64,
z1: u64,
w: u64,
w_shoup: u64,
p: u64,
neg_p: u64,
two_p: u64,
) -> (u64, u64) {
let _ = two_p;
let y0 = z0.wrapping_add(z1);
let y0 = y0.min(y0.wrapping_sub(p));
let t = z0.wrapping_sub(z1).wrapping_add(p);
let shoup_q = ((t as u128 * w_shoup as u128) >> 64) as u64;
let y1 = u64::wrapping_add(t.wrapping_mul(w), shoup_q.wrapping_mul(neg_p));
let y1 = y1.min(y1.wrapping_sub(p));
(y0, y1)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn fwd_avx512(
simd: crate::V4,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::fwd_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
pub(crate) fn inv_avx512(
simd: crate::V4,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::inv_depth_first_avx512(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx512(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn fwd_avx2(
simd: crate::V3,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::fwd_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub(crate) fn inv_avx2(
simd: crate::V3,
p: u64,
data: &mut [u64],
twid: &[u64],
twid_shoup: &[u64],
) {
super::shoup::inv_depth_first_avx2(
simd,
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|simd, z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_avx2(simd, z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn fwd_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
super::shoup::fwd_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
fwd_last_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
pub(crate) fn inv_scalar(p: u64, data: &mut [u64], twid: &[u64], twid_shoup: &[u64]) {
super::shoup::inv_depth_first_scalar(
p,
data,
twid,
twid_shoup,
0,
0,
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
#[inline(always)]
|z0, z1, w, w_shoup, p, neg_p, two_p| {
inv_butterfly_scalar(z0, z1, w, w_shoup, p, neg_p, two_p)
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
prime::largest_prime_in_arithmetic_progression64,
prime64::{
init_negacyclic_twiddles_shoup,
tests::{mul, random_lhs_rhs_with_negacyclic_convolution},
},
};
use alloc::vec;
extern crate alloc;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[cfg(feature = "nightly")]
#[test]
fn test_product_avx512() {
if let Some(simd) = crate::V4::try_new() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63)
.unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
64,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx512(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx512(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx512(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64),);
}
}
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[test]
fn test_product_avx2() {
if let Some(simd) = crate::V3::try_new() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63)
.unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
64,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_avx2(simd, p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_avx2(simd, p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_avx2(simd, p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64),);
}
}
}
}
#[test]
fn test_product_scalar() {
for n in [16, 32, 64, 128, 256, 512, 1024] {
let p =
largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap();
let (lhs, rhs, negacyclic_convolution) =
random_lhs_rhs_with_negacyclic_convolution(n, p);
let mut twid = vec![0u64; n];
let mut twid_shoup = vec![0u64; n];
let mut inv_twid = vec![0u64; n];
let mut inv_twid_shoup = vec![0u64; n];
init_negacyclic_twiddles_shoup(
p,
n,
64,
&mut twid,
&mut twid_shoup,
&mut inv_twid,
&mut inv_twid_shoup,
);
let mut prod = vec![0u64; n];
let mut lhs_fourier = lhs.clone();
let mut rhs_fourier = rhs.clone();
fwd_scalar(p, &mut lhs_fourier, &twid, &twid_shoup);
fwd_scalar(p, &mut rhs_fourier, &twid, &twid_shoup);
for x in &lhs_fourier {
assert!(*x < p);
}
for x in &rhs_fourier {
assert!(*x < p);
}
for i in 0..n {
prod[i] = mul(p, lhs_fourier[i], rhs_fourier[i]);
}
inv_scalar(p, &mut prod, &inv_twid, &inv_twid_shoup);
let result = prod;
for i in 0..n {
assert_eq!(result[i], mul(p, negacyclic_convolution[i], n as u64),);
}
}
}
}

File diff suppressed because it is too large Load Diff

1168
tfhe-ntt/src/product.rs Normal file

File diff suppressed because it is too large Load Diff

132
tfhe-ntt/src/roots.rs Normal file
View File

@@ -0,0 +1,132 @@
use crate::{
fastdiv::Div64,
prime::{exp_mod64, mul_mod64},
};
pub const fn get_q_s64(p: Div64) -> (u64, u64) {
let p = p.divisor();
let mut q = p - 1;
let mut s = 0;
while q % 2 == 0 {
q /= 2;
s += 1;
}
(q, s)
}
pub const fn get_z64(p: Div64) -> Option<u64> {
let p_val = p.divisor();
let mut n = 2;
while n < p_val {
if exp_mod64(p, n, (p_val - 1) / 2) == p_val - 1 {
return Some(n);
}
n += 1;
}
None
}
/// <https://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm#The_algorithm>
pub const fn sqrt_mod_ex64(p: Div64, q: u64, s: u64, z: u64, n: u64) -> Option<u64> {
let mut m = s;
let mut c = exp_mod64(p, z, q);
let mut t = exp_mod64(p, n, q);
let mut r = exp_mod64(p, n, (q + 1) / 2);
loop {
if t == 0 {
return Some(0);
}
if t == 1 {
return Some(r);
}
let mut i = 0;
let mut t_pow = t;
while i < m {
t_pow = mul_mod64(p, t_pow, t_pow);
i += 1;
if t_pow == 1 {
break;
}
}
let i = i;
if i == m {
assert!(t_pow == 1);
return None;
}
let b = exp_mod64(p, c, 1 << (m - i - 1));
m = i;
c = mul_mod64(p, b, b);
t = mul_mod64(p, t, c);
r = mul_mod64(p, r, b);
}
}
pub const fn find_primitive_root64(p: Div64, degree: u64) -> Option<u64> {
assert!(degree.is_power_of_two());
assert!(degree > 1);
let n = degree.trailing_zeros();
let p_val = p.divisor();
let mut root = p_val - 1;
let (q, s) = get_q_s64(p);
let z = match get_z64(p) {
Some(z) => z,
None => return None,
};
let mut i = 0;
while i < n - 1 {
root = match sqrt_mod_ex64(p, q, s, z, root) {
Some(r) => r,
None => return None,
};
i += 1;
}
Some(root)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{fastdiv::Div64, prime::largest_prime_in_arithmetic_progression64};
const fn sqrt_mod64(p: Div64, n: u64) -> Option<u64> {
if p.divisor() == 2 {
Some(n)
} else {
let z = match get_z64(p) {
Some(z) => z,
None => panic!(),
};
let (q, s) = get_q_s64(p);
sqrt_mod_ex64(p, q, s, z, n)
}
}
#[test]
fn test_sqrt() {
let p_val = largest_prime_in_arithmetic_progression64(1 << 10, 1, 0, u64::MAX).unwrap();
let p = Div64::new(p_val);
let i = sqrt_mod64(p, p_val - 1).unwrap();
let j = sqrt_mod64(p, i).unwrap();
assert_eq!(mul_mod64(p, i, i), p_val - 1);
assert_eq!(mul_mod64(p, j, j), i);
}
#[test]
fn test_primitive_root() {
let deg = 1 << 10;
let p_val = largest_prime_in_arithmetic_progression64(deg, 1, 0, u64::MAX).unwrap();
let p = Div64::new(p_val);
let root = find_primitive_root64(p, deg).unwrap();
for i in 1..deg {
assert_ne!(exp_mod64(p, root, i), 1);
}
assert_eq!(exp_mod64(p, root, deg), 1);
}
}

137
tfhe-ntt/src/u256_impl.rs Normal file
View File

@@ -0,0 +1,137 @@
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct u256 {
pub x0: u64,
pub x1: u64,
pub x2: u64,
pub x3: u64,
}
#[inline(always)]
pub const fn to_double_digit(lo: u64, hi: u64) -> u128 {
(lo as u128) | ((hi as u128) << u64::BITS)
}
#[inline(always)]
pub const fn adc(l: u64, r: u64, c: bool) -> (u64, bool) {
let (lr, o0) = l.overflowing_add(r);
let (lrc, o1) = lr.overflowing_add(c as u64);
(lrc, o0 | o1)
}
#[inline(always)]
pub const fn mul_with_carry(l: u64, r: u64, c: u64) -> (u64, u64) {
let res = (l as u128 * r as u128) + c as u128;
(res as u64, (res >> 64) as u64)
}
impl u256 {
pub const MAX: Self = Self {
x0: u64::MAX,
x1: u64::MAX,
x2: u64::MAX,
x3: u64::MAX,
};
#[inline(always)]
pub const fn overflowing_add(self, rhs: Self) -> (Self, bool) {
let lhs = self;
let mut carry = false;
let x0;
let x1;
let x2;
let x3;
(x0, carry) = adc(lhs.x0, rhs.x0, carry);
(x1, carry) = adc(lhs.x1, rhs.x1, carry);
(x2, carry) = adc(lhs.x2, rhs.x2, carry);
(x3, carry) = adc(lhs.x3, rhs.x3, carry);
(Self { x0, x1, x2, x3 }, carry)
}
#[inline(always)]
pub const fn wrapping_add(self, rhs: Self) -> Self {
self.overflowing_add(rhs).0
}
pub const fn div_rem_u256_u64(self, rhs: u64) -> (Self, u64) {
let lhs = self;
let mut rem = 0;
let rhs = rhs as u128;
let double = to_double_digit(lhs.x3, rem);
let q = double / rhs;
let r = double % rhs;
rem = r as u64;
let x3 = q as u64;
let double = to_double_digit(lhs.x2, rem);
let q = double / rhs;
let r = double % rhs;
rem = r as u64;
let x2 = q as u64;
let double = to_double_digit(lhs.x1, rem);
let q = double / rhs;
let r = double % rhs;
rem = r as u64;
let x1 = q as u64;
let double = to_double_digit(lhs.x0, rem);
let q = double / rhs;
let r = double % rhs;
rem = r as u64;
let x0 = q as u64;
(Self { x0, x1, x2, x3 }, rem)
}
#[inline(always)]
pub const fn mul_u256_u64(self, rhs: u64) -> (Self, u64) {
let mut carry = 0;
let (x0, x1, x2, x3);
(x0, carry) = mul_with_carry(self.x0, rhs, carry);
(x1, carry) = mul_with_carry(self.x1, rhs, carry);
(x2, carry) = mul_with_carry(self.x2, rhs, carry);
(x3, carry) = mul_with_carry(self.x3, rhs, carry);
(Self { x0, x1, x2, x3 }, carry)
}
#[inline(always)]
pub const fn mul_u256_u128(self, rhs: u128) -> (Self, u128) {
let (x, x4) = Self::mul_u256_u64(self, rhs as u64);
let (y, y5) = Self::mul_u256_u64(self, (rhs >> 64) as u64);
let y4 = y.x3;
let y = u256 {
x0: 0,
x1: y.x0,
x2: y.x1,
x3: y.x2,
};
let (r, carry) = x.overflowing_add(y);
let (r4, carry) = adc(x4, y4, carry);
let r5 = y5 + carry as u64;
(r, to_double_digit(r4, r5))
}
#[inline(always)]
pub const fn wrapping_mul_u256_u128(self, rhs: u128) -> Self {
let (x, _) = Self::mul_u256_u64(self, rhs as u64);
let (y, _) = Self::mul_u256_u64(self, (rhs >> 64) as u64);
let y = u256 {
x0: 0,
x1: y.x0,
x2: y.x1,
x3: y.x2,
};
x.wrapping_add(y)
}
}