mirror of
https://github.com/Sunscreen-tech/Sunscreen.git
synced 2026-01-10 06:08:00 -05:00
17
Cargo.lock
generated
17
Cargo.lock
generated
@@ -38,6 +38,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aligned-vec"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "allocator-api2"
|
||||
version = "0.2.16"
|
||||
@@ -1629,12 +1638,6 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "linked-list"
|
||||
version = "0.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4dacf969043dc69f1f731b5042eb05e030d264bcf34f2242889fcbdc7a65f06"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.13"
|
||||
@@ -3086,9 +3089,9 @@ dependencies = [
|
||||
name = "sunscreen_tfhe"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"aligned-vec",
|
||||
"bytemuck",
|
||||
"criterion 0.5.1",
|
||||
"linked-list",
|
||||
"logproof",
|
||||
"merlin",
|
||||
"num",
|
||||
|
||||
@@ -38,6 +38,7 @@ lto = false
|
||||
codegen-units = 16
|
||||
|
||||
[workspace.dependencies]
|
||||
aligned-vec = { version = "0.5.0", features = ["serde"] }
|
||||
bytemuck = "1.13.0"
|
||||
lazy_static = "1.4.0"
|
||||
metal = "0.26.0"
|
||||
|
||||
@@ -16,9 +16,8 @@ readme = "crates-io.md"
|
||||
|
||||
|
||||
[dependencies]
|
||||
aligned-vec = { workspace = true }
|
||||
bytemuck = { workspace = true }
|
||||
# TODO: Remove when Rust stabilizes Cursor API
|
||||
linked-list = "0.0.3"
|
||||
logproof = { workspace = true, optional = true }
|
||||
num = { workspace = true }
|
||||
paste = { workspace = true }
|
||||
|
||||
@@ -1,5 +1,23 @@
|
||||
use crate::scratch::Pod;
|
||||
|
||||
macro_rules! avec {
|
||||
($elem:expr; $count:expr) => {
|
||||
aligned_vec::AVec::__from_elem(crate::scratch::SIMD_ALIGN, $elem, $count)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! avec_from_iter {
|
||||
($iter:expr) => {
|
||||
aligned_vec::AVec::from_iter(crate::scratch::SIMD_ALIGN, $iter)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! avec_from_slice {
|
||||
($slice:expr) => {
|
||||
aligned_vec::AVec::from_slice(crate::scratch::SIMD_ALIGN, $slice)
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! dst {
|
||||
($(#[$meta:meta])* $t:ty, $ref_t:ty, $wrapper:ty, ($($derive:ident),* $(,)? ), ($($t_bounds:ty),* $(,)? )) => {
|
||||
paste::paste! {
|
||||
@@ -7,7 +25,7 @@ macro_rules! dst {
|
||||
$(#[$meta])*
|
||||
#[derive($($derive,)*)]
|
||||
pub struct $t<T> where T: Clone $(+ $t_bounds)* {
|
||||
data: Vec<$wrapper<T>>
|
||||
data: aligned_vec::AVec<$wrapper<T>, aligned_vec::ConstAlign<{ crate::scratch::SIMD_ALIGN }>>
|
||||
}
|
||||
|
||||
/// A reference to the data structure.
|
||||
@@ -36,14 +54,6 @@ macro_rules! dst {
|
||||
pub fn as_mut_slice(&mut self) -> &mut [$wrapper<T>] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
/// Move the contents of rhs into self.
|
||||
pub fn move_from(&mut self, rhs: $t<T>) {
|
||||
for (l, r) in self.data.iter_mut().zip(rhs.data.into_iter()) {
|
||||
*l = r;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> crate::dst::FromSlice<$wrapper<T>> for $ref_t<T> where T: Clone $(+ $t_bounds)* {
|
||||
@@ -98,7 +108,7 @@ macro_rules! dst {
|
||||
type Owned = $t<T>;
|
||||
|
||||
fn to_owned(&self) -> Self::Owned {
|
||||
$t { data: self.data.to_owned() }
|
||||
$t { data: aligned_vec::AVec::from_slice(crate::scratch::SIMD_ALIGN, &self.data) }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ impl<S: TorusOps> BivariateLookupTable<S> {
|
||||
F: Fn(u64, u64) -> u64,
|
||||
{
|
||||
let mut lut = BivariateLookupTable {
|
||||
data: vec![Torus::zero(); BivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
data: avec!(Torus::zero(); BivariateLookupTableRef::<S>::size(glwe.dim)),
|
||||
};
|
||||
|
||||
lut.fill_trivial_from_fn(map, glwe, plaintext_bits, carry_bits);
|
||||
|
||||
@@ -41,7 +41,7 @@ impl<S: TorusOps> BlindRotationShift<S> {
|
||||
let len = BlindRotationShiftRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
data: avec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,7 @@ impl BlindRotationShiftFft<Complex<f64>> {
|
||||
let len = BlindRotationShiftFftRef::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Complex::zero(); len],
|
||||
data: avec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ impl<S: TorusOps> BootstrapKey<S> {
|
||||
let len = BootstrapKeyRef::<S>::size((lwe_params.dim, glwe_params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
data: avec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,7 +145,7 @@ impl BootstrapKeyFft<Complex<f64>> {
|
||||
let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Complex::zero(); len],
|
||||
data: avec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ impl<S: TorusOps> CircuitBootstrappingKeyswitchKeys<S> {
|
||||
));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
data: avec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ where
|
||||
let elems = GgswCiphertextRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); elems],
|
||||
data: avec![Torus::zero(); elems],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ where
|
||||
assert_eq!(data.len(), elems);
|
||||
|
||||
Self {
|
||||
data: data.to_vec(),
|
||||
data: avec_from_slice!(data),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ impl GgswCiphertextFft<Complex<f64>> {
|
||||
let len = GgswCiphertextFftRef::size((params.dim, radix.count));
|
||||
|
||||
GgswCiphertextFft {
|
||||
data: vec![Complex::zero(); len],
|
||||
data: avec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,9 +51,9 @@ where
|
||||
|
||||
let len = GlweCiphertextRef::<S>::size(params.dim);
|
||||
|
||||
let data = (0..len).map(|_| Torus::<S>::zero()).collect::<Vec<_>>();
|
||||
|
||||
GlweCiphertext { data }
|
||||
GlweCiphertext {
|
||||
data: avec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes the external product of a GLWE ciphertext and a GGSW ciphertext.
|
||||
@@ -72,10 +72,7 @@ where
|
||||
assert_eq!(data.len(), GlweCiphertextRef::<S>::size(params.dim));
|
||||
|
||||
GlweCiphertext {
|
||||
data: data
|
||||
.iter()
|
||||
.map(|x| Torus::from(*x))
|
||||
.collect::<Vec<Torus<S>>>(),
|
||||
data: avec_from_iter!(data.iter().map(|x| Torus::from(*x))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ impl GlweCiphertextFft<Complex<f64>> {
|
||||
let len = GlweCiphertextFftRef::size(params.dim);
|
||||
|
||||
Self {
|
||||
data: vec![Complex::zero(); len],
|
||||
data: avec![Complex::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ where
|
||||
let elems = GlweKeyswitchKeyRef::<S>::size((params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); elems],
|
||||
data: avec![Torus::zero(); elems],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,9 +55,7 @@ where
|
||||
let len = GlweSecretKeyRef::<S>::size(params.dim);
|
||||
|
||||
GlweSecretKey {
|
||||
data: (0..len)
|
||||
.map(|_| torus_element_generator())
|
||||
.collect::<Vec<_>>(),
|
||||
data: avec_from_iter!((0..len).map(|_| torus_element_generator())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,12 +210,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
let sk3_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_add(b)));
|
||||
|
||||
let sk3 = sk + sk2;
|
||||
|
||||
@@ -231,12 +228,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let mut sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk
|
||||
let sk2_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_add(b)));
|
||||
|
||||
sk2 += sk;
|
||||
|
||||
@@ -250,12 +246,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
let sk3_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_add(b)));
|
||||
|
||||
let sk3 = sk.as_ref() + sk2.as_ref();
|
||||
|
||||
@@ -269,12 +264,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
let sk3_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_add(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_add(b)));
|
||||
|
||||
let sk3 = sk.wrapping_add(&sk2);
|
||||
|
||||
@@ -290,12 +284,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
let sk3_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_sub(b)));
|
||||
|
||||
let sk3 = sk - sk2;
|
||||
|
||||
@@ -309,12 +302,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let mut sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk2
|
||||
let sk2_expected = avec_from_iter!(sk2
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_sub(b)));
|
||||
|
||||
sk2 -= sk;
|
||||
|
||||
@@ -328,12 +320,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
let sk3_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_sub(b)));
|
||||
|
||||
let sk3 = sk.as_ref() - sk2.as_ref();
|
||||
|
||||
@@ -347,12 +338,11 @@ mod tests {
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
let sk2 = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk3_expected = sk
|
||||
let sk3_expected = avec_from_iter!(sk
|
||||
.data
|
||||
.iter()
|
||||
.zip(sk2.data.iter())
|
||||
.map(|(a, b)| a.wrapping_sub(b))
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| a.wrapping_sub(b)));
|
||||
|
||||
let sk3 = sk.wrapping_sub(&sk2);
|
||||
|
||||
@@ -367,7 +357,7 @@ mod tests {
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2_expected = avec_from_iter!(sk.data.iter().map(|a| a.wrapping_neg()));
|
||||
let sk2 = -sk;
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
@@ -379,7 +369,7 @@ mod tests {
|
||||
|
||||
let sk = keygen::generate_uniform_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2_expected = avec_from_iter!(sk.data.iter().map(|a| a.wrapping_neg()));
|
||||
let sk2 = -sk.as_ref();
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
@@ -391,7 +381,7 @@ mod tests {
|
||||
|
||||
let sk = keygen::generate_binary_glwe_sk(¶ms);
|
||||
|
||||
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
|
||||
let sk2_expected = avec_from_iter!(sk.data.iter().map(|a| a.wrapping_neg()));
|
||||
let sk2 = sk.wrapping_neg();
|
||||
|
||||
assert_eq!(sk2_expected, sk2.data)
|
||||
|
||||
@@ -40,7 +40,7 @@ impl<S: TorusOps> LweCiphertext<S> {
|
||||
|
||||
/// Create a new LWE ciphertext with all coefficients set to zero.
|
||||
pub fn zero(params: &LweDef) -> Self {
|
||||
let data = vec![Torus::zero(); LweCiphertextRef::<S>::size(params.dim)];
|
||||
let data = avec![Torus::zero(); LweCiphertextRef::<S>::size(params.dim)];
|
||||
|
||||
Self { data }
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ impl<S: TorusOps> LweCiphertextList<S> {
|
||||
/// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
|
||||
pub fn new(lwe: &LweDef, count: usize) -> Self {
|
||||
Self {
|
||||
data: vec![Torus::zero(); LweCiphertextListRef::<S>::size((lwe.dim, count))],
|
||||
data: avec![Torus::zero(); LweCiphertextListRef::<S>::size((lwe.dim, count))],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ where
|
||||
LweKeyswitchKeyRef::<S>::size((original_params.dim, new_params.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); elems],
|
||||
data: avec![Torus::zero(); elems],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ where
|
||||
sk.assert_valid(params);
|
||||
|
||||
let mut pk = LwePublicKey {
|
||||
data: vec![Torus::zero(); LwePublicKeyRef::<S>::size(params.dim)],
|
||||
data: avec![Torus::zero(); LwePublicKeyRef::<S>::size(params.dim)],
|
||||
};
|
||||
let enc_zeros = pk.enc_zeros_mut(params);
|
||||
|
||||
|
||||
@@ -43,9 +43,7 @@ where
|
||||
let len = LweSecretKeyRef::<S>::size(params.dim);
|
||||
|
||||
LweSecretKey {
|
||||
data: (0..len)
|
||||
.map(|_| torus_element_generator())
|
||||
.collect::<Vec<_>>(),
|
||||
data: avec_from_iter!((0..len).map(|_| torus_element_generator())),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ where
|
||||
/// Create a new polynomial from a slice of coefficients.
|
||||
pub fn new(data: &[T]) -> Polynomial<T> {
|
||||
Polynomial {
|
||||
data: data.to_owned(),
|
||||
data: avec_from_slice!(data),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ where
|
||||
T: Zero,
|
||||
{
|
||||
Polynomial {
|
||||
data: vec![T::zero(); len],
|
||||
data: avec![T::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,7 +64,7 @@ where
|
||||
{
|
||||
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
|
||||
Self {
|
||||
data: iter.into_iter().collect::<Vec<_>>(),
|
||||
data: avec_from_iter!(iter),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -91,7 +91,7 @@ where
|
||||
U: Clone,
|
||||
{
|
||||
Polynomial {
|
||||
data: self.data.iter().map(f).collect::<Vec<_>>(),
|
||||
data: avec_from_iter!(self.data.iter().map(f)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -284,13 +284,12 @@ where
|
||||
fn add(self, rhs: &PolynomialRef<S>) -> Self::Output {
|
||||
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
|
||||
|
||||
let coeffs = self
|
||||
let coeffs = avec_from_iter!(self
|
||||
.coeffs()
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(rhs.coeffs().as_ref().iter())
|
||||
.map(|(a, b)| *a + *b)
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| *a + *b));
|
||||
|
||||
Polynomial { data: coeffs }
|
||||
}
|
||||
@@ -325,13 +324,12 @@ where
|
||||
fn sub(self, rhs: &PolynomialRef<S>) -> Self::Output {
|
||||
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
|
||||
|
||||
let coeffs = self
|
||||
let coeffs = avec_from_iter!(self
|
||||
.coeffs()
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(rhs.coeffs().as_ref().iter())
|
||||
.map(|(a, b)| *a - *b)
|
||||
.collect::<Vec<_>>();
|
||||
.map(|(a, b)| *a - *b));
|
||||
|
||||
Polynomial { data: coeffs }
|
||||
}
|
||||
@@ -358,7 +356,7 @@ where
|
||||
assert_eq!(rhs.len(), self.len());
|
||||
|
||||
let mut c = Polynomial {
|
||||
data: vec![Torus::zero(); rhs.len()],
|
||||
data: avec![Torus::zero(); rhs.len()],
|
||||
};
|
||||
|
||||
polynomial_external_mad(&mut c, self, rhs);
|
||||
|
||||
@@ -46,7 +46,7 @@ where
|
||||
/// Create a new polynomial with the given length in the fourier domain.
|
||||
pub fn new(data: &[T]) -> Self {
|
||||
Self {
|
||||
data: data.to_owned(),
|
||||
data: avec_from_slice!(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ where
|
||||
/// Create a new polynomial list, where each polynomial has the same degree.
|
||||
pub fn new(degree: PolynomialDegree, count: usize) -> Self {
|
||||
Self {
|
||||
data: vec![S::zero(); degree.0 * count],
|
||||
data: avec![S::zero(); degree.0 * count],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ impl<S: TorusOps> PrivateFunctionalKeyswitchKey<S> {
|
||||
lwe_count: &PrivateFunctionalKeyswitchLweCount,
|
||||
) -> Self {
|
||||
Self {
|
||||
data: vec![
|
||||
data: avec![
|
||||
Torus::zero();
|
||||
PrivateFunctionalKeyswitchKeyRef::<S>::size((
|
||||
from_lwe.dim,
|
||||
|
||||
@@ -36,7 +36,7 @@ impl<S: TorusOps> PublicFunctionalKeyswitchKey<S> {
|
||||
PublicFunctionalKeyswitchKeyRef::<S>::size((from_lwe.dim, to_glwe.dim, radix.count));
|
||||
|
||||
Self {
|
||||
data: vec![Torus::zero(); len],
|
||||
data: avec![Torus::zero(); len],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ impl<S: TorusOps> UnivariateLookupTable<S> {
|
||||
F: Fn(u64) -> u64,
|
||||
{
|
||||
let mut lut = UnivariateLookupTable {
|
||||
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
data: avec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
};
|
||||
|
||||
lut.fill_trivial_from_fns(&[map], glwe, plaintext_bits);
|
||||
@@ -66,7 +66,7 @@ impl<S: TorusOps> UnivariateLookupTable<S> {
|
||||
assert!(maps.len() > 1);
|
||||
|
||||
let mut lut = UnivariateLookupTable {
|
||||
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
data: avec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
|
||||
};
|
||||
|
||||
lut.fill_trivial_from_fns(maps, glwe, plaintext_bits);
|
||||
|
||||
@@ -84,7 +84,7 @@ macro_rules! impl_unary_op {
|
||||
// We call the wrapping trait instead of using the dot
|
||||
// syntax because the dot syntax can dereference the value
|
||||
// and can cause problems with Deref.
|
||||
let data = self.data.iter().map(|a| num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a)).collect();
|
||||
let data = avec_from_iter!(self.data.iter().map(|a| num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a)));
|
||||
|
||||
$type { data }
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ impl TorusOps for u32 {}
|
||||
|
||||
/// A wrapper around a type that supports Torus operations.
|
||||
#[repr(transparent)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)]
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct Torus<S: TorusOps = u64>(S);
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use linked_list::{Cursor, LinkedList};
|
||||
use aligned_vec::{avec_rt, AVec, RuntimeAlign};
|
||||
use num::{Complex, Float};
|
||||
use rustfft::FftNum;
|
||||
use std::{
|
||||
alloc::Layout,
|
||||
cell::RefCell,
|
||||
collections::LinkedList,
|
||||
marker::PhantomData,
|
||||
mem::{size_of, transmute},
|
||||
mem::{align_of, size_of},
|
||||
rc::Rc,
|
||||
};
|
||||
|
||||
use crate::{Torus, TorusOps};
|
||||
@@ -27,6 +28,15 @@ macro_rules! allocate_scratch_ref {
|
||||
|
||||
pub(crate) use allocate_scratch_ref;
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub const SIMD_ALIGN: usize = align_of::<std::arch::aarch64::float64x2_t>();
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
pub const SIMD_ALIGN: usize = align_of::<std::arch::x86_64::__m512d>();
|
||||
|
||||
#[cfg(not(any(target_feature = "neon", target_arch = "x86_64")))]
|
||||
pub const SIMD_ALIGN: usize = align_of::<u128>();
|
||||
|
||||
/// Indicates this is a "Plain Old Data" type. For `T` qualify as such,
|
||||
/// all bit patterns must be considered a properly initialized instance of
|
||||
/// `T`.
|
||||
@@ -91,19 +101,14 @@ where
|
||||
/// The references doled out by allocate have `'static``
|
||||
/// lifetimes. This is needed so you can have mutable references
|
||||
/// to different allocations at the same time.
|
||||
///
|
||||
/// # Safety
|
||||
/// [`Scratch`] must not drop before all its outstanding allocations have done so.
|
||||
///
|
||||
/// Please note the lack of [`Sync`] and [`Send`] on this object. It is not sound to
|
||||
/// share these between threads.
|
||||
struct Scratch {
|
||||
// Only accessed through the cursor, so compiler thinks it's unused.
|
||||
#[allow(unused)]
|
||||
stack: Box<LinkedList<Allocation>>,
|
||||
top: *mut Cursor<'static, Allocation>,
|
||||
}
|
||||
|
||||
impl Drop for Scratch {
|
||||
fn drop(&mut self) {
|
||||
let top = unsafe { Box::from_raw(self.top) };
|
||||
|
||||
std::mem::drop(top);
|
||||
}
|
||||
stack: Rc<RefCell<LinkedList<*mut Allocation>>>,
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
@@ -111,12 +116,9 @@ impl Scratch {
|
||||
/// the only way to use scratch memory, which will allocate memory
|
||||
/// using a thread_local allocator.
|
||||
fn new() -> Self {
|
||||
let mut list = Box::new(LinkedList::new());
|
||||
|
||||
let cursor = Box::new(list.cursor());
|
||||
let top = unsafe { transmute(Box::into_raw(cursor)) };
|
||||
|
||||
Self { stack: list, top }
|
||||
Self {
|
||||
stack: Rc::new(RefCell::new(LinkedList::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a buffer matching the given specification.
|
||||
@@ -126,102 +128,93 @@ impl Scratch {
|
||||
{
|
||||
assert_ne!(size_of::<T>(), 0);
|
||||
|
||||
let top = unsafe { &mut *self.top };
|
||||
let alignment = usize::max(SIMD_ALIGN, align_of::<T>());
|
||||
let u8_len = count * size_of::<T>();
|
||||
|
||||
// Push the top as far down until we hit the bottom or an allocation
|
||||
// currently in use.
|
||||
loop {
|
||||
let prev = top.peek_prev();
|
||||
|
||||
if let Some(x) = prev {
|
||||
if x.is_free {
|
||||
top.prev().unwrap();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
let layout = Layout::array::<T>(count).unwrap();
|
||||
let req_len = layout.size() + layout.align();
|
||||
|
||||
let allocation = match top.peek_next() {
|
||||
Some(d) => {
|
||||
assert!(d.is_free);
|
||||
|
||||
// Resize the allocation if needed.
|
||||
if d.data.len() < req_len {
|
||||
d.data.resize(req_len, 0u8);
|
||||
}
|
||||
|
||||
d.requested_len = count;
|
||||
top.next().unwrap()
|
||||
}
|
||||
None => {
|
||||
let data = vec![0u8; req_len];
|
||||
let allocation = unsafe {
|
||||
let allocation = self.stack.borrow_mut().pop_back();
|
||||
|
||||
if allocation.is_none() {
|
||||
// If we don't have an existing allocation, make one
|
||||
let allocation = Allocation {
|
||||
requested_len: count,
|
||||
is_free: false,
|
||||
data,
|
||||
data: avec_rt!([alignment]| u8::default(); u8_len),
|
||||
};
|
||||
|
||||
top.insert(allocation);
|
||||
top.next().unwrap()
|
||||
let allocation = Box::new(allocation);
|
||||
Box::into_raw(allocation)
|
||||
} else if (*allocation.unwrap()).data.alignment() < alignment
|
||||
|| (*allocation.unwrap()).data.len() < u8_len
|
||||
{
|
||||
// If we found an allocation, but its size and len requirements
|
||||
// are insufficient.
|
||||
let allocation = allocation.unwrap();
|
||||
let drop_box = Box::from_raw(allocation);
|
||||
std::mem::drop(drop_box);
|
||||
|
||||
let allocation = Allocation {
|
||||
data: avec_rt!([alignment]| u8::default(); u8_len),
|
||||
};
|
||||
|
||||
let allocation = Box::new(allocation);
|
||||
Box::into_raw(allocation)
|
||||
} else {
|
||||
// Otherwise, reuse the allocation.
|
||||
allocation.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
allocation.is_free = false;
|
||||
|
||||
ScratchBuffer {
|
||||
allocation: allocation as *mut Allocation,
|
||||
allocation,
|
||||
pool: self.stack.clone(),
|
||||
requested_len: count,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Allocation {
|
||||
requested_len: usize,
|
||||
data: Vec<u8>,
|
||||
is_free: bool,
|
||||
data: AVec<u8, RuntimeAlign>,
|
||||
}
|
||||
|
||||
/// An allocation returned by [`Scratch::allocate`].
|
||||
///
|
||||
/// # Safety
|
||||
/// Please note the lack of [`Sync`] and [`Send`] on this object. It is not sound to
|
||||
/// share these between threads.
|
||||
///
|
||||
/// You *may* however share the slice returned by [`Self::as_slice`] or
|
||||
/// [`Self::as_mut_slice`], as these obey standard lifetime rules.
|
||||
pub struct ScratchBuffer<'a, T> {
|
||||
allocation: *mut Allocation,
|
||||
pool: Rc<RefCell<LinkedList<*mut Allocation>>>,
|
||||
requested_len: usize,
|
||||
_phantom: PhantomData<&'a T>,
|
||||
}
|
||||
|
||||
impl<'a, T> ScratchBuffer<'a, T> {
|
||||
#[allow(unused)]
|
||||
/// Get a slice to the underlying data.
|
||||
///
|
||||
/// # Remarks
|
||||
/// While not extremely expensive, this operation does require capturing
|
||||
/// an aligned slice of data in an underlying allocation. As such,
|
||||
/// you should avoid repeated calls.
|
||||
|
||||
pub fn as_slice(&self) -> &[T] {
|
||||
let count = unsafe { (*self.allocation).requested_len };
|
||||
let (_, slice, _) = unsafe { (*self.allocation).data.align_to::<T>() };
|
||||
unsafe { transmute(&slice[0..count]) }
|
||||
let slice =
|
||||
unsafe { std::mem::transmute::<&[u8], &[T]>((*self.allocation).data.as_slice()) };
|
||||
|
||||
&slice[0..self.requested_len]
|
||||
}
|
||||
|
||||
/// Get a mutable slice to the underlying data.
|
||||
///
|
||||
/// # Remarks
|
||||
/// While not extremely expensive, this operation does require capturing
|
||||
/// an aligned slice of data in an underlying allocation. As such,
|
||||
/// you should avoid repeated calls.
|
||||
pub fn as_mut_slice(&mut self) -> &mut [T] {
|
||||
let count = unsafe { (*self.allocation).requested_len };
|
||||
let (_pre, slice, _post) = unsafe { (*self.allocation).data.align_to_mut::<T>() };
|
||||
unsafe { transmute(&mut slice[0..count]) }
|
||||
let slice = unsafe {
|
||||
std::mem::transmute::<&mut [u8], &mut [T]>((*self.allocation).data.as_mut_slice())
|
||||
};
|
||||
|
||||
&mut slice[0..self.requested_len]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for ScratchBuffer<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { (*self.allocation).is_free = true };
|
||||
self.pool.borrow_mut().push_back(self.allocation);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,8 +236,6 @@ mod tests {
|
||||
for (i, d_i) in d.iter_mut().enumerate() {
|
||||
*d_i = i as u64;
|
||||
}
|
||||
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -253,7 +244,6 @@ mod tests {
|
||||
|
||||
let b = scratch.allocate::<u64>(64);
|
||||
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
let b_slice = b.as_slice();
|
||||
assert_eq!(b_slice.len(), 64);
|
||||
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
|
||||
@@ -264,7 +254,6 @@ mod tests {
|
||||
let mut b = scratch.allocate::<u64>(64);
|
||||
let b_slice = b.as_mut_slice();
|
||||
assert_eq!(first_ptr, b_slice.as_ptr());
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
assert_eq!(b_slice.len(), 64);
|
||||
|
||||
for (i, b_i) in b_slice.iter_mut().enumerate() {
|
||||
@@ -272,36 +261,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn reallocate_on_bigger_request() {
|
||||
let mut scratch = Scratch::new();
|
||||
|
||||
let mut b = scratch.allocate::<u64>(64);
|
||||
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
let b_slice = b.as_mut_slice();
|
||||
assert_eq!(b_slice.len(), 64);
|
||||
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
|
||||
let first_ptr = b_slice.as_ptr();
|
||||
|
||||
for (i, b_i) in b_slice.iter_mut().enumerate() {
|
||||
*b_i = i as u64;
|
||||
}
|
||||
|
||||
std::mem::drop(b);
|
||||
|
||||
let mut b = scratch.allocate::<u64>(16384);
|
||||
let b = b.as_mut_slice();
|
||||
assert_ne!(first_ptr, b.as_ptr());
|
||||
assert_eq!(scratch.stack.len(), 1);
|
||||
assert_eq!(b.len(), 16384);
|
||||
|
||||
for (i, b_i) in b.iter_mut().enumerate() {
|
||||
*b_i = i as u64;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allocate_two_buffers() {
|
||||
let mut scratch = Scratch::new();
|
||||
@@ -314,7 +273,6 @@ mod tests {
|
||||
|
||||
assert_eq!(a.len(), 12);
|
||||
assert_eq!(b.len(), 12);
|
||||
assert_eq!(scratch.stack.len(), 2);
|
||||
assert_ne!(a.as_mut_ptr(), b.as_mut_ptr());
|
||||
|
||||
for i in 0..a.len() {
|
||||
@@ -330,8 +288,6 @@ mod tests {
|
||||
.map(|_| scratch.allocate::<u128>(10))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(scratch.stack.len(), 10);
|
||||
|
||||
for b in buffers.iter_mut() {
|
||||
let b = b.as_mut_slice();
|
||||
assert_eq!(b.len(), 10);
|
||||
@@ -347,7 +303,7 @@ mod tests {
|
||||
// Chose an alignment larger than any reasonable OS's page size
|
||||
// to try to force the alignment algorithm into play.
|
||||
#[repr(C, align(65536))]
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Default)]
|
||||
struct Foo {
|
||||
x: u32,
|
||||
}
|
||||
@@ -370,37 +326,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stack_coalesces_correctly() {
|
||||
let mut scratch = Scratch::new();
|
||||
let a = scratch.allocate::<u64>(16);
|
||||
let mut b: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
let b_ptr = b.as_mut_slice().as_mut_ptr();
|
||||
|
||||
let c: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
let d: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
|
||||
std::mem::drop(b);
|
||||
assert_eq!(scratch.stack.len(), 4);
|
||||
|
||||
// We can't reuse b's buffer until c, d, e get dropped.
|
||||
let mut e: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
|
||||
assert_ne!(b_ptr, e.as_mut_slice().as_mut_ptr());
|
||||
|
||||
assert_eq!(scratch.stack.len(), 5);
|
||||
|
||||
std::mem::drop(c);
|
||||
std::mem::drop(d);
|
||||
std::mem::drop(e);
|
||||
|
||||
// Now we can reuse b's buffer.
|
||||
let mut f = scratch.allocate::<u64>(16);
|
||||
assert_eq!(f.as_mut_slice().as_mut_ptr(), b_ptr);
|
||||
assert_eq!(scratch.stack.len(), 5);
|
||||
|
||||
std::mem::drop(a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zero_size_allocations() {
|
||||
let mut scratch = Scratch::new();
|
||||
@@ -410,7 +335,6 @@ mod tests {
|
||||
let a_slice = a.as_slice();
|
||||
let b_slice = b.as_slice();
|
||||
|
||||
assert_eq!(scratch.stack.len(), 2);
|
||||
assert_eq!(a_slice.len(), 2);
|
||||
assert_eq!(b_slice.len(), 0);
|
||||
}
|
||||
@@ -418,10 +342,24 @@ mod tests {
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zst_allocations_should_panic() {
|
||||
#[derive(Default)]
|
||||
struct Foo {}
|
||||
unsafe impl Pod for Foo {}
|
||||
|
||||
let mut scratch = Scratch::new();
|
||||
let _ = scratch.allocate::<Foo>(0x1 << 48);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simd_alignment() {
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
assert_eq!(SIMD_ALIGN, 16);
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
assert_eq!(SIMD_ALIGN, 64);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user