mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
NTT columns batch (#424)
This PR adds the columns batch feature - enabling batch NTT computation to be performed directly on the columns of a matrix without having to transpose it beforehand, as requested in issue #264. Also some small fixes to the reordering kernels were added and some unnecessary parameters were removes from functions interfaces. --------- Co-authored-by: DmytroTym <dmytrotym1@gmail.com>
This commit is contained in:
@@ -77,6 +77,8 @@ pub struct NTTConfig<'a, S> {
|
||||
pub coset_gen: S,
|
||||
/// The number of NTTs to compute. Default value: 1.
|
||||
pub batch_size: i32,
|
||||
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
|
||||
pub columns_batch: bool,
|
||||
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
|
||||
pub ordering: Ordering,
|
||||
are_inputs_on_device: bool,
|
||||
@@ -101,6 +103,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
coset_gen: S::one(),
|
||||
batch_size: 1,
|
||||
columns_batch: false,
|
||||
ordering: Ordering::kNN,
|
||||
are_inputs_on_device: false,
|
||||
are_outputs_on_device: false,
|
||||
|
||||
@@ -44,6 +44,14 @@ pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
|
||||
u32::from_str_radix(&reversed, 2).unwrap()
|
||||
}
|
||||
|
||||
pub fn transpose_flattened_matrix<T: Copy>(m: &[T], nrows: usize) -> Vec<T> {
|
||||
let ncols = m.len() / nrows;
|
||||
assert!(nrows * ncols == m.len());
|
||||
(0..m.len())
|
||||
.map(|i| m[(i % nrows) * ncols + i / nrows])
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn list_to_reverse_bit_order<T: Copy>(l: &[T]) -> Vec<T> {
|
||||
l.iter()
|
||||
.enumerate()
|
||||
@@ -253,11 +261,10 @@ where
|
||||
] {
|
||||
config.coset_gen = coset_gen;
|
||||
config.ordering = ordering;
|
||||
let mut batch_ntt_result = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
for alg in [NttAlgorithm::Radix2, NttAlgorithm::MixedRadix] {
|
||||
config.batch_size = batch_size as i32;
|
||||
config.ntt_algorithm = alg;
|
||||
let mut batch_ntt_result =
|
||||
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&scalars, is_inverse, &config, &mut batch_ntt_result).unwrap();
|
||||
config.batch_size = 1;
|
||||
let mut one_ntt_result = HostOrDeviceSlice::on_host(vec![F::one(); test_size]);
|
||||
@@ -275,6 +282,20 @@ where
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// for now, columns batching only works with MixedRadix NTT
|
||||
config.batch_size = batch_size as i32;
|
||||
config.columns_batch = true;
|
||||
let transposed_input =
|
||||
HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size));
|
||||
let mut col_batch_ntt_result =
|
||||
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap();
|
||||
assert_eq!(
|
||||
batch_ntt_result[..],
|
||||
transpose_flattened_matrix(&col_batch_ntt_result[..], test_size)
|
||||
);
|
||||
config.columns_batch = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user