NTT columns batch (#424)

This PR adds the columns batch feature - enabling batch NTT computation
to be performed directly on the columns of a matrix without having to
transpose it beforehand, as requested in issue #264.

Also some small fixes to the reordering kernels were added and some
unnecessary parameters were removes from functions interfaces.

---------

Co-authored-by: DmytroTym <dmytrotym1@gmail.com>
This commit is contained in:
HadarIngonyama
2024-03-13 18:46:47 +02:00
committed by GitHub
parent 89082fb561
commit 287f53ff16
10 changed files with 385 additions and 127 deletions

View File

@@ -77,6 +77,8 @@ pub struct NTTConfig<'a, S> {
pub coset_gen: S,
/// The number of NTTs to compute. Default value: 1.
pub batch_size: i32,
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
pub columns_batch: bool,
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
pub ordering: Ordering,
are_inputs_on_device: bool,
@@ -101,6 +103,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
ctx: DeviceContext::default_for_device(device_id),
coset_gen: S::one(),
batch_size: 1,
columns_batch: false,
ordering: Ordering::kNN,
are_inputs_on_device: false,
are_outputs_on_device: false,

View File

@@ -44,6 +44,14 @@ pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
u32::from_str_radix(&reversed, 2).unwrap()
}
pub fn transpose_flattened_matrix<T: Copy>(m: &[T], nrows: usize) -> Vec<T> {
let ncols = m.len() / nrows;
assert!(nrows * ncols == m.len());
(0..m.len())
.map(|i| m[(i % nrows) * ncols + i / nrows])
.collect()
}
pub fn list_to_reverse_bit_order<T: Copy>(l: &[T]) -> Vec<T> {
l.iter()
.enumerate()
@@ -253,11 +261,10 @@ where
] {
config.coset_gen = coset_gen;
config.ordering = ordering;
let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
config.batch_size = batch_size as i32;
config.ntt_algorithm = alg;
let mut batch_ntt_result =
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
config.batch_size = 1;
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
@@ -275,6 +282,20 @@ where
);
}
}
// for now, columns batching only works with MixedRadix NTT
config.batch_size = batch_size as i32;
config.columns_batch = true;
let transposed_input =
HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size));
let mut col_batch_ntt_result =
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap();
assert_eq!(
batch_ntt_result[..],
transpose_flattened_matrix(&col_batch_ntt_result[..], test_size)
);
config.columns_batch = false;
}
}
}