mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
fix: bug regarding MixedRadix coset (I)NTT for NM/MN ordering (#497)
The bug is in how twiddles array is indexed when multiplied by a mixed (M) vector to implement (I)NTT on cosets. The fix is to use the DIF-digit-reverse to compute the index of the element in the natural (N) vector that moved to index 'i' in the M vector. This is emulating a DIT-digit-reverse (which is mixing like a DIF-compute) reorder of the twiddles array and element-wise multiplication without reordering the twiddles memory.
This commit is contained in:
@@ -32,6 +32,7 @@ namespace mxntt {
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
S* linear_twiddle, // twiddles organized as [1,w,w^2,...] for coset-eval in fast-tw mode
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
|
||||
@@ -134,14 +134,23 @@ namespace mxntt {
|
||||
int n_scalars,
|
||||
uint32_t log_size,
|
||||
eRevType rev_type,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
E* out_vec)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= size * batch_size) return;
|
||||
int64_t scalar_id = (tid / columns_batch_size) % size;
|
||||
if (rev_type != eRevType::None)
|
||||
scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type);
|
||||
if (rev_type != eRevType::None) {
|
||||
// Note: when we multiply an in_vec that is mixed (by DIF (I)NTT), we want to shuffle the
|
||||
// scalars the same way (then multiply element-wise). This would be a DIT-digit-reverse shuffle. (this is
|
||||
// confusing but) BUT to avoid shuffling the scalars, we instead want to ask which element in the non-shuffled
|
||||
// vec is now placed at index tid, which is the opposite of a DIT-digit-reverse --> this is the DIF-digit-reverse.
|
||||
// Therefore we use the DIF-digit-reverse to know which element moved to index tid and use it to access the
|
||||
// corresponding element in scalars vec.
|
||||
const bool dif = rev_type == eRevType::NaturalToMixedRev;
|
||||
scalar_id =
|
||||
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, !dif, fast_tw, rev_type);
|
||||
}
|
||||
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
|
||||
}
|
||||
|
||||
@@ -903,6 +912,7 @@ namespace mxntt {
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
S* linear_twiddle, // twiddles organized as [1,w,w^2,...] for coset-eval in fast-tw mode
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
@@ -958,8 +968,8 @@ namespace mxntt {
|
||||
if (is_on_coset && !is_inverse) {
|
||||
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
|
||||
arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
|
||||
reverse_coset, dit, d_output);
|
||||
arbitrary_coset ? arbitrary_coset : linear_twiddle, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
|
||||
reverse_coset, fast_tw, d_output);
|
||||
|
||||
d_input = d_output;
|
||||
}
|
||||
@@ -991,8 +1001,8 @@ namespace mxntt {
|
||||
if (is_on_coset && is_inverse) {
|
||||
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
|
||||
arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
|
||||
n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
arbitrary_coset ? arbitrary_coset : linear_twiddle + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
|
||||
n_twiddles, logn, reverse_coset, fast_tw, d_output);
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
@@ -1021,6 +1031,8 @@ namespace mxntt {
|
||||
scalar_t* external_twiddles,
|
||||
scalar_t* internal_twiddles,
|
||||
scalar_t* basic_twiddles,
|
||||
scalar_t* linear_twiddles,
|
||||
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
@@ -1039,6 +1051,8 @@ namespace mxntt {
|
||||
scalar_t* external_twiddles,
|
||||
scalar_t* internal_twiddles,
|
||||
scalar_t* basic_twiddles,
|
||||
scalar_t* linear_twiddles,
|
||||
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
|
||||
@@ -717,8 +717,7 @@ namespace ntt {
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_on_coset = (coset_index != 0) || coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr);
|
||||
S* twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
|
||||
: domain.twiddles;
|
||||
@@ -728,9 +727,11 @@ namespace ntt {
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
S* linear_twiddles = domain.twiddles; // twiddles organized as [1,w,w^2,...]
|
||||
CHK_IF_RETURN(mxntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, linear_twiddles, size, domain.max_log_size,
|
||||
batch_size, config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index,
|
||||
stream));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -378,6 +378,13 @@ macro_rules! impl_ntt_tests {
|
||||
check_ntt_coset_from_subgroup::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[parallel]
|
||||
fn test_ntt_coset_interpolation_nm() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
|
||||
check_ntt_coset_interpolation_nm::<$field>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[parallel]
|
||||
fn test_ntt_arbitrary_coset() {
|
||||
|
||||
@@ -190,6 +190,62 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_ntt_coset_interpolation_nm<F: FieldImpl + ArkConvertible>()
|
||||
where
|
||||
F::ArkEquivalent: FftField,
|
||||
<F as FieldImpl>::Config: NTT<F, F> + GenerateRandom<F>,
|
||||
{
|
||||
let test_sizes = [1 << 9, 1 << 10, 1 << 11, 1 << 13, 1 << 14, 1 << 16];
|
||||
for test_size in test_sizes {
|
||||
let test_size_rou = F::ArkEquivalent::get_root_of_unity((test_size << 1) as u64).unwrap();
|
||||
let coset_generators = [F::from_ark(test_size_rou), F::Config::generate_random(1)[0]];
|
||||
|
||||
let scalars: Vec<F> = F::Config::generate_random(test_size);
|
||||
|
||||
let ark_domain = GeneralEvaluationDomain::<F::ArkEquivalent>::new(test_size).unwrap();
|
||||
|
||||
for coset_gen in coset_generators {
|
||||
// (1) intt from evals to coeffs
|
||||
let mut config = NTTConfig::default();
|
||||
config.ordering = Ordering::kNM;
|
||||
config.ntt_algorithm = NttAlgorithm::MixedRadix;
|
||||
|
||||
let mut intt_result = vec![F::zero(); test_size];
|
||||
let intt_result = HostSlice::from_mut_slice(&mut intt_result);
|
||||
ntt(HostSlice::from_slice(&scalars), NTTDir::kInverse, &config, intt_result).unwrap();
|
||||
|
||||
let mut ark_scalars = scalars
|
||||
.iter()
|
||||
.map(|v| v.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
ark_domain.ifft_in_place(&mut ark_scalars);
|
||||
|
||||
// (2) coset-ntt (compute coset evals)
|
||||
config.coset_gen = coset_gen;
|
||||
config.ordering = Ordering::kMN;
|
||||
let mut coset_evals = vec![F::zero(); test_size];
|
||||
ntt(
|
||||
intt_result,
|
||||
NTTDir::kForward,
|
||||
&config,
|
||||
HostSlice::from_mut_slice(&mut coset_evals),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ark_coset_domain = ark_domain
|
||||
.get_coset(coset_gen.to_ark())
|
||||
.unwrap();
|
||||
ark_coset_domain.fft_in_place(&mut ark_scalars); // to reuse in next iteration
|
||||
|
||||
let coest_evals_as_ark = coset_evals
|
||||
.iter()
|
||||
.map(|v| v.to_ark())
|
||||
.collect::<Vec<F::ArkEquivalent>>();
|
||||
assert_eq!(coest_evals_as_ark, ark_scalars);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_ntt_arbitrary_coset<F: FieldImpl + ArkConvertible>()
|
||||
where
|
||||
F::ArkEquivalent: FftField + ArkField,
|
||||
|
||||
Reference in New Issue
Block a user