Rweber/tfhe mem (#368)

Simplify `Scratch` type.
This commit is contained in:
rickwebiii
2024-03-04 16:25:46 -07:00
committed by GitHub
parent 8d6b64b4b8
commit f773923b4b
28 changed files with 179 additions and 245 deletions

17
Cargo.lock generated
View File

@@ -38,6 +38,15 @@ dependencies = [
"memchr",
]
[[package]]
name = "aligned-vec"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
dependencies = [
"serde",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
@@ -1629,12 +1638,6 @@ dependencies = [
"cc",
]
[[package]]
name = "linked-list"
version = "0.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a4dacf969043dc69f1f731b5042eb05e030d264bcf34f2242889fcbdc7a65f06"
[[package]]
name = "linux-raw-sys"
version = "0.4.13"
@@ -3086,9 +3089,9 @@ dependencies = [
name = "sunscreen_tfhe"
version = "0.1.0"
dependencies = [
"aligned-vec",
"bytemuck",
"criterion 0.5.1",
"linked-list",
"logproof",
"merlin",
"num",

View File

@@ -38,6 +38,7 @@ lto = false
codegen-units = 16
[workspace.dependencies]
aligned-vec = { version = "0.5.0", features = ["serde"] }
bytemuck = "1.13.0"
lazy_static = "1.4.0"
metal = "0.26.0"

View File

@@ -16,9 +16,8 @@ readme = "crates-io.md"
[dependencies]
aligned-vec = { workspace = true }
bytemuck = { workspace = true }
# TODO: Remove when Rust stabilizes Cursor API
linked-list = "0.0.3"
logproof = { workspace = true, optional = true }
num = { workspace = true }
paste = { workspace = true }

View File

@@ -1,5 +1,23 @@
use crate::scratch::Pod;
macro_rules! avec {
($elem:expr; $count:expr) => {
aligned_vec::AVec::__from_elem(crate::scratch::SIMD_ALIGN, $elem, $count)
};
}
macro_rules! avec_from_iter {
($iter:expr) => {
aligned_vec::AVec::from_iter(crate::scratch::SIMD_ALIGN, $iter)
};
}
macro_rules! avec_from_slice {
($slice:expr) => {
aligned_vec::AVec::from_slice(crate::scratch::SIMD_ALIGN, $slice)
};
}
macro_rules! dst {
($(#[$meta:meta])* $t:ty, $ref_t:ty, $wrapper:ty, ($($derive:ident),* $(,)? ), ($($t_bounds:ty),* $(,)? )) => {
paste::paste! {
@@ -7,7 +25,7 @@ macro_rules! dst {
$(#[$meta])*
#[derive($($derive,)*)]
pub struct $t<T> where T: Clone $(+ $t_bounds)* {
data: Vec<$wrapper<T>>
data: aligned_vec::AVec<$wrapper<T>, aligned_vec::ConstAlign<{ crate::scratch::SIMD_ALIGN }>>
}
/// A reference to the data structure.
@@ -36,14 +54,6 @@ macro_rules! dst {
pub fn as_mut_slice(&mut self) -> &mut [$wrapper<T>] {
&mut self.data
}
#[allow(unused)]
/// Move the contents of rhs into self.
pub fn move_from(&mut self, rhs: $t<T>) {
for (l, r) in self.data.iter_mut().zip(rhs.data.into_iter()) {
*l = r;
}
}
}
impl<T> crate::dst::FromSlice<$wrapper<T>> for $ref_t<T> where T: Clone $(+ $t_bounds)* {
@@ -98,7 +108,7 @@ macro_rules! dst {
type Owned = $t<T>;
fn to_owned(&self) -> Self::Owned {
$t { data: self.data.to_owned() }
$t { data: aligned_vec::AVec::from_slice(crate::scratch::SIMD_ALIGN, &self.data) }
}
}

View File

@@ -43,7 +43,7 @@ impl<S: TorusOps> BivariateLookupTable<S> {
F: Fn(u64, u64) -> u64,
{
let mut lut = BivariateLookupTable {
data: vec![Torus::zero(); BivariateLookupTableRef::<S>::size(glwe.dim)],
data: avec!(Torus::zero(); BivariateLookupTableRef::<S>::size(glwe.dim)),
};
lut.fill_trivial_from_fn(map, glwe, plaintext_bits, carry_bits);

View File

@@ -41,7 +41,7 @@ impl<S: TorusOps> BlindRotationShift<S> {
let len = BlindRotationShiftRef::<S>::size((params.dim, radix.count));
Self {
data: vec![Torus::zero(); len],
data: avec![Torus::zero(); len],
}
}
}
@@ -94,7 +94,7 @@ impl BlindRotationShiftFft<Complex<f64>> {
let len = BlindRotationShiftFftRef::size((params.dim, radix.count));
Self {
data: vec![Complex::zero(); len],
data: avec![Complex::zero(); len],
}
}
}

View File

@@ -42,7 +42,7 @@ impl<S: TorusOps> BootstrapKey<S> {
let len = BootstrapKeyRef::<S>::size((lwe_params.dim, glwe_params.dim, radix.count));
Self {
data: vec![Torus::zero(); len],
data: avec![Torus::zero(); len],
}
}
}
@@ -145,7 +145,7 @@ impl BootstrapKeyFft<Complex<f64>> {
let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count));
Self {
data: vec![Complex::zero(); len],
data: avec![Complex::zero(); len],
}
}
}

View File

@@ -44,7 +44,7 @@ impl<S: TorusOps> CircuitBootstrappingKeyswitchKeys<S> {
));
Self {
data: vec![Torus::zero(); len],
data: avec![Torus::zero(); len],
}
}
}

View File

@@ -42,7 +42,7 @@ where
let elems = GgswCiphertextRef::<S>::size((params.dim, radix.count));
Self {
data: vec![Torus::zero(); elems],
data: avec![Torus::zero(); elems],
}
}
@@ -53,7 +53,7 @@ where
assert_eq!(data.len(), elems);
Self {
data: data.to_vec(),
data: avec_from_slice!(data),
}
}

View File

@@ -34,7 +34,7 @@ impl GgswCiphertextFft<Complex<f64>> {
let len = GgswCiphertextFftRef::size((params.dim, radix.count));
GgswCiphertextFft {
data: vec![Complex::zero(); len],
data: avec![Complex::zero(); len],
}
}
}

View File

@@ -51,9 +51,9 @@ where
let len = GlweCiphertextRef::<S>::size(params.dim);
let data = (0..len).map(|_| Torus::<S>::zero()).collect::<Vec<_>>();
GlweCiphertext { data }
GlweCiphertext {
data: avec![Torus::zero(); len],
}
}
/// Computes the external product of a GLWE ciphertext and a GGSW ciphertext.
@@ -72,10 +72,7 @@ where
assert_eq!(data.len(), GlweCiphertextRef::<S>::size(params.dim));
GlweCiphertext {
data: data
.iter()
.map(|x| Torus::from(*x))
.collect::<Vec<Torus<S>>>(),
data: avec_from_iter!(data.iter().map(|x| Torus::from(*x))),
}
}
}

View File

@@ -34,7 +34,7 @@ impl GlweCiphertextFft<Complex<f64>> {
let len = GlweCiphertextFftRef::size(params.dim);
Self {
data: vec![Complex::zero(); len],
data: avec![Complex::zero(); len],
}
}
}

View File

@@ -44,7 +44,7 @@ where
let elems = GlweKeyswitchKeyRef::<S>::size((params.dim, radix.count));
Self {
data: vec![Torus::zero(); elems],
data: avec![Torus::zero(); elems],
}
}
}

View File

@@ -55,9 +55,7 @@ where
let len = GlweSecretKeyRef::<S>::size(params.dim);
GlweSecretKey {
data: (0..len)
.map(|_| torus_element_generator())
.collect::<Vec<_>>(),
data: avec_from_iter!((0..len).map(|_| torus_element_generator())),
}
}
@@ -212,12 +210,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
let sk3_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_add(b)));
let sk3 = sk + sk2;
@@ -231,12 +228,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let mut sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk
let sk2_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_add(b)));
sk2 += sk;
@@ -250,12 +246,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
let sk3_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_add(b)));
let sk3 = sk.as_ref() + sk2.as_ref();
@@ -269,12 +264,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
let sk3_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_add(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_add(b)));
let sk3 = sk.wrapping_add(&sk2);
@@ -290,12 +284,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
let sk3_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_sub(b)));
let sk3 = sk - sk2;
@@ -309,12 +302,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let mut sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk2
let sk2_expected = avec_from_iter!(sk2
.data
.iter()
.zip(sk.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_sub(b)));
sk2 -= sk;
@@ -328,12 +320,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
let sk3_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_sub(b)));
let sk3 = sk.as_ref() - sk2.as_ref();
@@ -347,12 +338,11 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2 = keygen::generate_uniform_glwe_sk(&params);
let sk3_expected = sk
let sk3_expected = avec_from_iter!(sk
.data
.iter()
.zip(sk2.data.iter())
.map(|(a, b)| a.wrapping_sub(b))
.collect::<Vec<_>>();
.map(|(a, b)| a.wrapping_sub(b)));
let sk3 = sk.wrapping_sub(&sk2);
@@ -367,7 +357,7 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2_expected = avec_from_iter!(sk.data.iter().map(|a| a.wrapping_neg()));
let sk2 = -sk;
assert_eq!(sk2_expected, sk2.data)
@@ -379,7 +369,7 @@ mod tests {
let sk = keygen::generate_uniform_glwe_sk(&params);
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2_expected = avec_from_iter!(sk.data.iter().map(|a| a.wrapping_neg()));
let sk2 = -sk.as_ref();
assert_eq!(sk2_expected, sk2.data)
@@ -391,7 +381,7 @@ mod tests {
let sk = keygen::generate_binary_glwe_sk(&params);
let sk2_expected = sk.data.iter().map(|a| a.wrapping_neg()).collect::<Vec<_>>();
let sk2_expected = avec_from_iter!(sk.data.iter().map(|a| a.wrapping_neg()));
let sk2 = sk.wrapping_neg();
assert_eq!(sk2_expected, sk2.data)

View File

@@ -40,7 +40,7 @@ impl<S: TorusOps> LweCiphertext<S> {
/// Create a new LWE ciphertext with all coefficients set to zero.
pub fn zero(params: &LweDef) -> Self {
let data = vec![Torus::zero(); LweCiphertextRef::<S>::size(params.dim)];
let data = avec![Torus::zero(); LweCiphertextRef::<S>::size(params.dim)];
Self { data }
}

View File

@@ -31,7 +31,7 @@ impl<S: TorusOps> LweCiphertextList<S> {
/// [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
pub fn new(lwe: &LweDef, count: usize) -> Self {
Self {
data: vec![Torus::zero(); LweCiphertextListRef::<S>::size((lwe.dim, count))],
data: avec![Torus::zero(); LweCiphertextListRef::<S>::size((lwe.dim, count))],
}
}
}

View File

@@ -48,7 +48,7 @@ where
LweKeyswitchKeyRef::<S>::size((original_params.dim, new_params.dim, radix.count));
Self {
data: vec![Torus::zero(); elems],
data: avec![Torus::zero(); elems],
}
}
}

View File

@@ -55,7 +55,7 @@ where
sk.assert_valid(params);
let mut pk = LwePublicKey {
data: vec![Torus::zero(); LwePublicKeyRef::<S>::size(params.dim)],
data: avec![Torus::zero(); LwePublicKeyRef::<S>::size(params.dim)],
};
let enc_zeros = pk.enc_zeros_mut(params);

View File

@@ -43,9 +43,7 @@ where
let len = LweSecretKeyRef::<S>::size(params.dim);
LweSecretKey {
data: (0..len)
.map(|_| torus_element_generator())
.collect::<Vec<_>>(),
data: avec_from_iter!((0..len).map(|_| torus_element_generator())),
}
}

View File

@@ -43,7 +43,7 @@ where
/// Create a new polynomial from a slice of coefficients.
pub fn new(data: &[T]) -> Polynomial<T> {
Polynomial {
data: data.to_owned(),
data: avec_from_slice!(data),
}
}
@@ -53,7 +53,7 @@ where
T: Zero,
{
Polynomial {
data: vec![T::zero(); len],
data: avec![T::zero(); len],
}
}
}
@@ -64,7 +64,7 @@ where
{
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Self {
data: iter.into_iter().collect::<Vec<_>>(),
data: avec_from_iter!(iter),
}
}
}
@@ -91,7 +91,7 @@ where
U: Clone,
{
Polynomial {
data: self.data.iter().map(f).collect::<Vec<_>>(),
data: avec_from_iter!(self.data.iter().map(f)),
}
}
@@ -284,13 +284,12 @@ where
fn add(self, rhs: &PolynomialRef<S>) -> Self::Output {
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
let coeffs = self
let coeffs = avec_from_iter!(self
.coeffs()
.as_ref()
.iter()
.zip(rhs.coeffs().as_ref().iter())
.map(|(a, b)| *a + *b)
.collect::<Vec<_>>();
.map(|(a, b)| *a + *b));
Polynomial { data: coeffs }
}
@@ -325,13 +324,12 @@ where
fn sub(self, rhs: &PolynomialRef<S>) -> Self::Output {
assert_eq!(self.data.as_ref().len(), rhs.data.as_ref().len());
let coeffs = self
let coeffs = avec_from_iter!(self
.coeffs()
.as_ref()
.iter()
.zip(rhs.coeffs().as_ref().iter())
.map(|(a, b)| *a - *b)
.collect::<Vec<_>>();
.map(|(a, b)| *a - *b));
Polynomial { data: coeffs }
}
@@ -358,7 +356,7 @@ where
assert_eq!(rhs.len(), self.len());
let mut c = Polynomial {
data: vec![Torus::zero(); rhs.len()],
data: avec![Torus::zero(); rhs.len()],
};
polynomial_external_mad(&mut c, self, rhs);

View File

@@ -46,7 +46,7 @@ where
/// Create a new polynomial with the given length in the fourier domain.
pub fn new(data: &[T]) -> Self {
Self {
data: data.to_owned(),
data: avec_from_slice!(data),
}
}
}

View File

@@ -31,7 +31,7 @@ where
/// Create a new polynomial list, where each polynomial has the same degree.
pub fn new(degree: PolynomialDegree, count: usize) -> Self {
Self {
data: vec![S::zero(); degree.0 * count],
data: avec![S::zero(); degree.0 * count],
}
}
}

View File

@@ -66,7 +66,7 @@ impl<S: TorusOps> PrivateFunctionalKeyswitchKey<S> {
lwe_count: &PrivateFunctionalKeyswitchLweCount,
) -> Self {
Self {
data: vec![
data: avec![
Torus::zero();
PrivateFunctionalKeyswitchKeyRef::<S>::size((
from_lwe.dim,

View File

@@ -36,7 +36,7 @@ impl<S: TorusOps> PublicFunctionalKeyswitchKey<S> {
PublicFunctionalKeyswitchKeyRef::<S>::size((from_lwe.dim, to_glwe.dim, radix.count));
Self {
data: vec![Torus::zero(); len],
data: avec![Torus::zero(); len],
}
}
}

View File

@@ -41,7 +41,7 @@ impl<S: TorusOps> UnivariateLookupTable<S> {
F: Fn(u64) -> u64,
{
let mut lut = UnivariateLookupTable {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
data: avec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};
lut.fill_trivial_from_fns(&[map], glwe, plaintext_bits);
@@ -66,7 +66,7 @@ impl<S: TorusOps> UnivariateLookupTable<S> {
assert!(maps.len() > 1);
let mut lut = UnivariateLookupTable {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
data: avec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};
lut.fill_trivial_from_fns(maps, glwe, plaintext_bits);

View File

@@ -84,7 +84,7 @@ macro_rules! impl_unary_op {
// We call the wrapping trait instead of using the dot
// syntax because the dot syntax can dereference the value
// and can cause problems with Deref.
let data = self.data.iter().map(|a| num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a)).collect();
let data = avec_from_iter!(self.data.iter().map(|a| num::traits::[<Wrapping $op>]::[<wrapping_ $op:lower>](a)));
$type { data }
}

View File

@@ -212,7 +212,7 @@ impl TorusOps for u32 {}
/// A wrapper around a type that supports Torus operations.
#[repr(transparent)]
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Copy, PartialEq, PartialOrd, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Torus<S: TorusOps = u64>(S);

View File

@@ -1,11 +1,12 @@
use linked_list::{Cursor, LinkedList};
use aligned_vec::{avec_rt, AVec, RuntimeAlign};
use num::{Complex, Float};
use rustfft::FftNum;
use std::{
alloc::Layout,
cell::RefCell,
collections::LinkedList,
marker::PhantomData,
mem::{size_of, transmute},
mem::{align_of, size_of},
rc::Rc,
};
use crate::{Torus, TorusOps};
@@ -27,6 +28,15 @@ macro_rules! allocate_scratch_ref {
pub(crate) use allocate_scratch_ref;
#[cfg(target_feature = "neon")]
pub const SIMD_ALIGN: usize = align_of::<std::arch::aarch64::float64x2_t>();
#[cfg(target_arch = "x86_64")]
pub const SIMD_ALIGN: usize = align_of::<std::arch::x86_64::__m512d>();
#[cfg(not(any(target_feature = "neon", target_arch = "x86_64")))]
pub const SIMD_ALIGN: usize = align_of::<u128>();
/// Indicates this is a "Plain Old Data" type. For `T` qualify as such,
/// all bit patterns must be considered a properly initialized instance of
/// `T`.
@@ -91,19 +101,14 @@ where
/// The references doled out by allocate have `'static``
/// lifetimes. This is needed so you can have mutable references
/// to different allocations at the same time.
///
/// # Safety
/// [`Scratch`] must not drop before all its outstanding allocations have done so.
///
/// Please note the lack of [`Sync`] and [`Send`] on this object. It is not sound to
/// share these between threads.
struct Scratch {
// Only accessed through the cursor, so compiler thinks it's unused.
#[allow(unused)]
stack: Box<LinkedList<Allocation>>,
top: *mut Cursor<'static, Allocation>,
}
impl Drop for Scratch {
fn drop(&mut self) {
let top = unsafe { Box::from_raw(self.top) };
std::mem::drop(top);
}
stack: Rc<RefCell<LinkedList<*mut Allocation>>>,
}
impl Scratch {
@@ -111,12 +116,9 @@ impl Scratch {
/// the only way to use scratch memory, which will allocate memory
/// using a thread_local allocator.
fn new() -> Self {
let mut list = Box::new(LinkedList::new());
let cursor = Box::new(list.cursor());
let top = unsafe { transmute(Box::into_raw(cursor)) };
Self { stack: list, top }
Self {
stack: Rc::new(RefCell::new(LinkedList::new())),
}
}
/// Allocate a buffer matching the given specification.
@@ -126,102 +128,93 @@ impl Scratch {
{
assert_ne!(size_of::<T>(), 0);
let top = unsafe { &mut *self.top };
let alignment = usize::max(SIMD_ALIGN, align_of::<T>());
let u8_len = count * size_of::<T>();
// Push the top as far down until we hit the bottom or an allocation
// currently in use.
loop {
let prev = top.peek_prev();
if let Some(x) = prev {
if x.is_free {
top.prev().unwrap();
continue;
}
}
break;
}
let layout = Layout::array::<T>(count).unwrap();
let req_len = layout.size() + layout.align();
let allocation = match top.peek_next() {
Some(d) => {
assert!(d.is_free);
// Resize the allocation if needed.
if d.data.len() < req_len {
d.data.resize(req_len, 0u8);
}
d.requested_len = count;
top.next().unwrap()
}
None => {
let data = vec![0u8; req_len];
let allocation = unsafe {
let allocation = self.stack.borrow_mut().pop_back();
if allocation.is_none() {
// If we don't have an existing allocation, make one
let allocation = Allocation {
requested_len: count,
is_free: false,
data,
data: avec_rt!([alignment]| u8::default(); u8_len),
};
top.insert(allocation);
top.next().unwrap()
let allocation = Box::new(allocation);
Box::into_raw(allocation)
} else if (*allocation.unwrap()).data.alignment() < alignment
|| (*allocation.unwrap()).data.len() < u8_len
{
// If we found an allocation, but its size and len requirements
// are insufficient.
let allocation = allocation.unwrap();
let drop_box = Box::from_raw(allocation);
std::mem::drop(drop_box);
let allocation = Allocation {
data: avec_rt!([alignment]| u8::default(); u8_len),
};
let allocation = Box::new(allocation);
Box::into_raw(allocation)
} else {
// Otherwise, reuse the allocation.
allocation.unwrap()
}
};
allocation.is_free = false;
ScratchBuffer {
allocation: allocation as *mut Allocation,
allocation,
pool: self.stack.clone(),
requested_len: count,
_phantom: PhantomData,
}
}
}
struct Allocation {
requested_len: usize,
data: Vec<u8>,
is_free: bool,
data: AVec<u8, RuntimeAlign>,
}
/// An allocation returned by [`Scratch::allocate`].
///
/// # Safety
/// Please note the lack of [`Sync`] and [`Send`] on this object. It is not sound to
/// share these between threads.
///
/// You *may* however share the slice returned by [`Self::as_slice`] or
/// [`Self::as_mut_slice`], as these obey standard lifetime rules.
pub struct ScratchBuffer<'a, T> {
allocation: *mut Allocation,
pool: Rc<RefCell<LinkedList<*mut Allocation>>>,
requested_len: usize,
_phantom: PhantomData<&'a T>,
}
impl<'a, T> ScratchBuffer<'a, T> {
#[allow(unused)]
/// Get a slice to the underlying data.
///
/// # Remarks
/// While not extremely expensive, this operation does require capturing
/// an aligned slice of data in an underlying allocation. As such,
/// you should avoid repeated calls.
pub fn as_slice(&self) -> &[T] {
let count = unsafe { (*self.allocation).requested_len };
let (_, slice, _) = unsafe { (*self.allocation).data.align_to::<T>() };
unsafe { transmute(&slice[0..count]) }
let slice =
unsafe { std::mem::transmute::<&[u8], &[T]>((*self.allocation).data.as_slice()) };
&slice[0..self.requested_len]
}
/// Get a mutable slice to the underlying data.
///
/// # Remarks
/// While not extremely expensive, this operation does require capturing
/// an aligned slice of data in an underlying allocation. As such,
/// you should avoid repeated calls.
pub fn as_mut_slice(&mut self) -> &mut [T] {
let count = unsafe { (*self.allocation).requested_len };
let (_pre, slice, _post) = unsafe { (*self.allocation).data.align_to_mut::<T>() };
unsafe { transmute(&mut slice[0..count]) }
let slice = unsafe {
std::mem::transmute::<&mut [u8], &mut [T]>((*self.allocation).data.as_mut_slice())
};
&mut slice[0..self.requested_len]
}
}
impl<'a, T> Drop for ScratchBuffer<'a, T> {
fn drop(&mut self) {
unsafe { (*self.allocation).is_free = true };
self.pool.borrow_mut().push_back(self.allocation);
}
}
@@ -243,8 +236,6 @@ mod tests {
for (i, d_i) in d.iter_mut().enumerate() {
*d_i = i as u64;
}
assert_eq!(scratch.stack.len(), 1);
}
#[test]
@@ -253,7 +244,6 @@ mod tests {
let b = scratch.allocate::<u64>(64);
assert_eq!(scratch.stack.len(), 1);
let b_slice = b.as_slice();
assert_eq!(b_slice.len(), 64);
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
@@ -264,7 +254,6 @@ mod tests {
let mut b = scratch.allocate::<u64>(64);
let b_slice = b.as_mut_slice();
assert_eq!(first_ptr, b_slice.as_ptr());
assert_eq!(scratch.stack.len(), 1);
assert_eq!(b_slice.len(), 64);
for (i, b_i) in b_slice.iter_mut().enumerate() {
@@ -272,36 +261,6 @@ mod tests {
}
}
#[test]
#[ignore]
fn reallocate_on_bigger_request() {
let mut scratch = Scratch::new();
let mut b = scratch.allocate::<u64>(64);
assert_eq!(scratch.stack.len(), 1);
let b_slice = b.as_mut_slice();
assert_eq!(b_slice.len(), 64);
assert_eq!(b_slice.as_ptr().align_offset(align_of::<u64>()), 0);
let first_ptr = b_slice.as_ptr();
for (i, b_i) in b_slice.iter_mut().enumerate() {
*b_i = i as u64;
}
std::mem::drop(b);
let mut b = scratch.allocate::<u64>(16384);
let b = b.as_mut_slice();
assert_ne!(first_ptr, b.as_ptr());
assert_eq!(scratch.stack.len(), 1);
assert_eq!(b.len(), 16384);
for (i, b_i) in b.iter_mut().enumerate() {
*b_i = i as u64;
}
}
#[test]
fn allocate_two_buffers() {
let mut scratch = Scratch::new();
@@ -314,7 +273,6 @@ mod tests {
assert_eq!(a.len(), 12);
assert_eq!(b.len(), 12);
assert_eq!(scratch.stack.len(), 2);
assert_ne!(a.as_mut_ptr(), b.as_mut_ptr());
for i in 0..a.len() {
@@ -330,8 +288,6 @@ mod tests {
.map(|_| scratch.allocate::<u128>(10))
.collect::<Vec<_>>();
assert_eq!(scratch.stack.len(), 10);
for b in buffers.iter_mut() {
let b = b.as_mut_slice();
assert_eq!(b.len(), 10);
@@ -347,7 +303,7 @@ mod tests {
// Chose an alignment larger than any reasonable OS's page size
// to try to force the alignment algorithm into play.
#[repr(C, align(65536))]
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Default)]
struct Foo {
x: u32,
}
@@ -370,37 +326,6 @@ mod tests {
}
}
#[test]
fn stack_coalesces_correctly() {
let mut scratch = Scratch::new();
let a = scratch.allocate::<u64>(16);
let mut b: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
let b_ptr = b.as_mut_slice().as_mut_ptr();
let c: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
let d: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
std::mem::drop(b);
assert_eq!(scratch.stack.len(), 4);
// We can't reuse b's buffer until c, d, e get dropped.
let mut e: ScratchBuffer<'_, u64> = scratch.allocate::<u64>(16);
assert_ne!(b_ptr, e.as_mut_slice().as_mut_ptr());
assert_eq!(scratch.stack.len(), 5);
std::mem::drop(c);
std::mem::drop(d);
std::mem::drop(e);
// Now we can reuse b's buffer.
let mut f = scratch.allocate::<u64>(16);
assert_eq!(f.as_mut_slice().as_mut_ptr(), b_ptr);
assert_eq!(scratch.stack.len(), 5);
std::mem::drop(a);
}
#[test]
fn zero_size_allocations() {
let mut scratch = Scratch::new();
@@ -410,7 +335,6 @@ mod tests {
let a_slice = a.as_slice();
let b_slice = b.as_slice();
assert_eq!(scratch.stack.len(), 2);
assert_eq!(a_slice.len(), 2);
assert_eq!(b_slice.len(), 0);
}
@@ -418,10 +342,24 @@ mod tests {
#[test]
#[should_panic]
fn zst_allocations_should_panic() {
#[derive(Default)]
struct Foo {}
unsafe impl Pod for Foo {}
let mut scratch = Scratch::new();
let _ = scratch.allocate::<Foo>(0x1 << 48);
}
#[test]
fn simd_alignment() {
#[cfg(target_arch = "aarch64")]
{
assert_eq!(SIMD_ALIGN, 16);
}
#[cfg(target_arch = "x86_64")]
{
assert_eq!(SIMD_ALIGN, 64);
}
}
}