cpu_ntt pre-parallel

This commit is contained in:
Shanie Winitz
2024-07-28 15:26:24 +03:00
parent 2a3dcd776a
commit 96debb96a8

View File

@@ -6,6 +6,7 @@
#include "icicle/fields/field_config.h"
#include "icicle/vec_ops.h"
#include <thread>
#include <vector>
#include <algorithm>
#include <iostream>
@@ -19,6 +20,14 @@ using namespace icicle;
namespace ntt_cpu {
// TODO SHANIE - after implementing real parallelism, try different sizes to choose the optimal one. Or consider using
// a function to calculate subset sizes
constexpr uint32_t layers_subntt_log_size[31][3] = {
{0, 0, 0}, {1, 0, 0}, {2, 0, 0}, {3, 0, 0}, {4, 0, 0}, {5, 0, 0}, {3, 3, 0}, {4, 3, 0},
{4, 4, 0}, {5, 4, 0}, {5, 5, 0}, {4, 4, 3}, {4, 4, 4}, {5, 4, 4}, {5, 5, 4}, {5, 5, 5},
{8, 8, 0}, {9, 8, 0}, {9, 9, 0}, {10, 9, 0}, {10, 10, 0}, {11, 10, 0}, {11, 11, 0}, {12, 11, 0},
{12, 12, 0}, {13, 12, 0}, {13, 13, 0}, {14, 13, 0}, {14, 14, 0}, {15, 14, 0}, {15, 15, 0}};
template <typename S>
class CpuNttDomain
{
@@ -28,6 +37,8 @@ namespace ntt_cpu {
std::mutex domain_mutex;
public:
std::unordered_map<S, int> coset_index = {};
static eIcicleError
cpu_ntt_init_domain(const Device& device, const S& primitive_root, const NTTInitDomainConfig& config);
static eIcicleError cpu_ntt_release_domain(const Device& device);
@@ -92,6 +103,7 @@ namespace ntt_cpu {
temp_twiddles[0] = S::one();
for (int i = 1; i <= s_ntt_domain.max_size; i++) {
temp_twiddles[i] = temp_twiddles[i - 1] * tw_omega;
s_ntt_domain.coset_index[temp_twiddles[i]] = i;
}
s_ntt_domain.twiddles = std::move(temp_twiddles); // Assign twiddles using unique_ptr
}
@@ -130,36 +142,101 @@ namespace ntt_cpu {
return rev;
}
template <typename E = scalar_t>
eIcicleError reorder_by_bit_reverse(int logn, E* output, int batch_size)
inline uint64_t idx_in_mem(
int element, int block_idx, int subntt_idx, const std::vector<int> layers_sntt_log_size = {}, int layer = 0)
{
uint64_t size = 1 << logn;
int s0 = layers_sntt_log_size[0];
int s1 = layers_sntt_log_size[1];
int s2 = layers_sntt_log_size[2];
switch (layer) {
case 0:
return block_idx + ((subntt_idx + (element << s1)) << s2);
case 1:
return block_idx + ((element + (subntt_idx << s1)) << s2);
case 2:
return ((block_idx << (s1 + s2)) & ((1 << (s0 + s1 + s2)) - 1)) +
(((block_idx << (s1 + s2)) >> (s0 + s1 + s2)) << s2) + element;
default:
ICICLE_ASSERT(false) << "Unsupported layer";
}
return -1;
}
template <typename E = scalar_t>
eIcicleError reorder_by_bit_reverse(
int log_original_size,
E* elements,
int batch_size,
bool columns_batch,
int block_idx = 0,
int subntt_idx = 0,
std::vector<int> layers_sntt_log_size = {},
int layer = 0)
{
uint64_t subntt_size = (layers_sntt_log_size.empty()) ? 1 << log_original_size : 1 << layers_sntt_log_size[layer];
int subntt_log_size = (layers_sntt_log_size.empty()) ? log_original_size : layers_sntt_log_size[layer];
uint64_t original_size = (1 << log_original_size);
int stride = columns_batch ? batch_size : 1;
for (int batch = 0; batch < batch_size; ++batch) {
E* current_output = output + batch * size;
int rev;
for (int i = 0; i < size; ++i) {
rev = bit_reverse(i, logn);
if (i < rev) { std::swap(current_output[i], current_output[rev]); }
E* current_elements = columns_batch ? elements + batch : elements + batch * original_size;
uint64_t rev;
uint64_t i_mem_idx;
uint64_t rev_mem_idx;
for (uint64_t i = 0; i < subntt_size; ++i) {
rev = bit_reverse(i, subntt_log_size);
if (!layers_sntt_log_size.empty()) {
i_mem_idx = idx_in_mem(i, block_idx, subntt_idx, layers_sntt_log_size, layer);
rev_mem_idx = idx_in_mem(rev, block_idx, subntt_idx, layers_sntt_log_size, layer);
} else {
i_mem_idx = i;
rev_mem_idx = rev;
}
if (i < rev) {
if (i_mem_idx < original_size && rev_mem_idx < original_size) { // Ensure indices are within bounds
std::swap(current_elements[stride * i_mem_idx], current_elements[stride * rev_mem_idx]);
} else {
// Handle out-of-bounds error
ICICLE_LOG_ERROR << "i=" << i << ", rev=" << rev << ", original_size=" << original_size;
ICICLE_LOG_ERROR << "Index out of bounds: i_mem_idx=" << i_mem_idx << ", rev_mem_idx=" << rev_mem_idx;
return eIcicleError::INVALID_ARGUMENT;
}
}
}
}
return eIcicleError::SUCCESS;
}
template <typename S = scalar_t, typename E = scalar_t>
void dit_ntt(E* elements, uint64_t size, int batch_size, const S* twiddles, NTTDir dir, int domain_max_size)
void dit_ntt(
E* elements,
uint64_t total_ntt_size,
int batch_size,
bool columns_batch,
const S* twiddles,
NTTDir dir,
int domain_max_size,
int block_idx = 0,
int subntt_idx = 0,
std::vector<int> layers_sntt_log_size = {},
int layer = 0) // R --> N
{
uint64_t subntt_size = 1 << layers_sntt_log_size[layer];
int stride = columns_batch ? batch_size : 1;
for (int batch = 0; batch < batch_size; ++batch) {
E* current_elements = elements + batch * size;
for (int len = 2; len <= size; len <<= 1) {
E* current_elements = columns_batch ? elements + batch : elements + batch * total_ntt_size;
for (int len = 2; len <= subntt_size; len <<= 1) {
int half_len = len / 2;
int step = (size / len) * (domain_max_size / size);
for (int i = 0; i < size; i += len) {
int step = (subntt_size / len) * (domain_max_size / subntt_size);
for (int i = 0; i < subntt_size; i += len) {
for (int j = 0; j < half_len; ++j) {
int tw_idx = (dir == NTTDir::kForward) ? j * step : domain_max_size - j * step;
E u = current_elements[i + j];
E v = current_elements[i + j + half_len] * twiddles[tw_idx];
current_elements[i + j] = u + v;
current_elements[i + j + half_len] = u - v;
uint64_t u_mem_idx = stride * idx_in_mem(i + j, block_idx, subntt_idx, layers_sntt_log_size, layer);
uint64_t v_mem_idx =
stride * idx_in_mem(i + j + half_len, block_idx, subntt_idx, layers_sntt_log_size, layer);
E u = current_elements[u_mem_idx];
E v = current_elements[v_mem_idx] * twiddles[tw_idx];
current_elements[u_mem_idx] = u + v;
current_elements[v_mem_idx] = u - v;
}
}
}
@@ -167,53 +244,61 @@ namespace ntt_cpu {
}
template <typename S = scalar_t, typename E = scalar_t>
void dif_ntt(E* elements, uint64_t size, int batch_size, const S* twiddles, NTTDir dir, int domain_max_size)
void dif_ntt(
E* elements,
uint64_t total_ntt_size,
int batch_size,
bool columns_batch,
const S* twiddles,
NTTDir dir,
int domain_max_size,
int block_idx = 0,
int subntt_idx = 0,
std::vector<int> layers_sntt_log_size = {},
int layer = 0)
{
uint64_t subntt_size = 1 << layers_sntt_log_size[layer];
int stride = columns_batch ? batch_size : 1;
for (int batch = 0; batch < batch_size; ++batch) {
E* current_elements = elements + batch * size;
for (int len = size; len >= 2; len >>= 1) {
E* current_elements = columns_batch ? elements + batch : elements + batch * total_ntt_size;
for (int len = subntt_size; len >= 2; len >>= 1) {
int half_len = len / 2;
int step = (size / len) * (domain_max_size / size);
for (int i = 0; i < size; i += len) {
int step = (subntt_size / len) * (domain_max_size / subntt_size);
for (int i = 0; i < subntt_size; i += len) {
for (int j = 0; j < half_len; ++j) {
int tw_idx = (dir == NTTDir::kForward) ? j * step : domain_max_size - j * step;
E u = current_elements[i + j];
E v = current_elements[i + j + half_len];
current_elements[i + j] = u + v;
current_elements[i + j + half_len] = (u - v) * twiddles[tw_idx];
uint64_t u_mem_idx = stride * idx_in_mem(i + j, block_idx, subntt_idx, layers_sntt_log_size, layer);
uint64_t v_mem_idx =
stride * idx_in_mem(i + j + half_len, block_idx, subntt_idx, layers_sntt_log_size, layer);
E u = current_elements[u_mem_idx];
E v = current_elements[v_mem_idx];
current_elements[u_mem_idx] = u + v;
current_elements[v_mem_idx] = (u - v) * twiddles[tw_idx];
}
}
}
}
}
template <typename E = scalar_t>
void transpose(const E* input, E* output, int rows, int cols)
{
for (int col = 0; col < cols; ++col) {
for (int row = 0; row < rows; ++row) {
output[col * rows + row] = input[row * cols + col];
}
}
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError coset_mul(
int logn,
int domain_max_size,
E* elements,
int batch_size,
bool columns_batch,
const S* twiddles = nullptr,
int stride = 0,
const std::unique_ptr<S[]>& arbitrary_coset = nullptr,
bool bit_rev = false,
NTTDir dir = NTTDir::kForward,
bool columns_batch = false)
NTTDir dir = NTTDir::kForward)
{
uint64_t size = 1 << logn;
uint64_t i_mem_idx;
int idx;
int batch_stride = columns_batch ? batch_size : 1;
for (int batch = 0; batch < batch_size; ++batch) {
E* current_elements = elements + batch * size;
E* current_elements = columns_batch ? elements + batch : elements + batch * size;
if (arbitrary_coset) {
for (int i = 1; i < size; ++i) {
idx = columns_batch ? batch : i;
@@ -224,7 +309,7 @@ namespace ntt_cpu {
for (int i = 1; i < size; ++i) {
idx = bit_rev ? stride * (bit_reverse(i, logn)) : stride * i;
idx = dir == NTTDir::kForward ? idx : domain_max_size - idx;
current_elements[i] = current_elements[i] * twiddles[idx];
current_elements[batch_stride * i] = current_elements[batch_stride * i] * twiddles[idx];
}
}
}
@@ -232,49 +317,318 @@ namespace ntt_cpu {
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError
cpu_ntt_ref(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output)
void refactor_and_reorder(
E* layer_output,
E* next_layer_input,
const S* twiddles,
int batch_size,
bool columns_batch,
int domain_max_size,
std::vector<int> layers_sntt_log_size = {},
int layer = 0,
icicle::NTTDir dir = icicle::NTTDir::kForward)
{
if (size & (size - 1)) {
ICICLE_LOG_ERROR << "Size must be a power of 2. Size = " << size;
int sntt_size = 1 << layers_sntt_log_size[1];
int nof_sntts = 1 << layers_sntt_log_size[0];
int ntt_size = 1 << (layers_sntt_log_size[0] + layers_sntt_log_size[1]);
auto temp_elements =
std::make_unique<E[]>(ntt_size * batch_size); // TODO shanie - consider using an algorithm for sorting in-place
int stride = columns_batch ? batch_size : 1;
for (int batch = 0; batch < batch_size; ++batch) {
E* cur_layer_output = columns_batch ? layer_output + batch : layer_output + batch * ntt_size;
E* cur_temp_elements = columns_batch ? temp_elements.get() + batch : temp_elements.get() + batch * ntt_size;
for (int sntt_idx = 0; sntt_idx < nof_sntts; sntt_idx++) {
for (int elem = 0; elem < sntt_size; elem++) {
uint64_t tw_idx = (dir == NTTDir::kForward)
? ((domain_max_size / ntt_size) * sntt_idx * elem)
: domain_max_size - ((domain_max_size / ntt_size) * sntt_idx * elem);
cur_temp_elements[stride * (sntt_idx * sntt_size + elem)] =
cur_layer_output[stride * (elem * nof_sntts + sntt_idx)] * twiddles[tw_idx];
}
}
}
std::copy(temp_elements.get(), temp_elements.get() + ntt_size * batch_size, next_layer_input);
}
template <typename S = scalar_t, typename E = scalar_t>
void refactor_output(
E* layer_output,
E* next_layer_input,
uint64_t tot_ntt_size,
int batch_size,
bool columns_batch,
const S* twiddles,
int domain_max_size,
std::vector<int> layers_sntt_log_size = {},
int layer = 0,
icicle::NTTDir dir = icicle::NTTDir::kForward)
{
int subntt_size = 1 << layers_sntt_log_size[0];
int nof_subntts = 1 << layers_sntt_log_size[1];
int nof_blocks = 1 << layers_sntt_log_size[2];
int i, j;
int ntt_size = layer == 0 ? 1 << (layers_sntt_log_size[0] + layers_sntt_log_size[1])
: 1 << (layers_sntt_log_size[0] + layers_sntt_log_size[1] + layers_sntt_log_size[2]);
int stride = columns_batch ? batch_size : 1;
for (int batch = 0; batch < batch_size; ++batch) {
E* current_layer_output = columns_batch ? layer_output + batch : layer_output + batch * tot_ntt_size;
E* current_next_layer_input = columns_batch ? next_layer_input + batch : next_layer_input + batch * tot_ntt_size;
for (int block_idx = 0; block_idx < nof_blocks; block_idx++) {
for (int sntt_idx = 0; sntt_idx < nof_subntts; sntt_idx++) {
for (int elem = 0; elem < subntt_size; elem++) {
uint64_t elem_mem_idx = stride * idx_in_mem(elem, block_idx, sntt_idx, layers_sntt_log_size, 0);
i = (layer == 0) ? elem : elem + sntt_idx * subntt_size;
j = (layer == 0) ? sntt_idx : block_idx;
uint64_t tw_idx = (dir == NTTDir::kForward) ? ((domain_max_size / ntt_size) * j * i)
: domain_max_size - ((domain_max_size / ntt_size) * j * i);
current_next_layer_input[elem_mem_idx] = current_layer_output[elem_mem_idx] * twiddles[tw_idx];
}
}
}
}
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError reorder_input(
E* input, uint64_t size, int batch_size, bool columns_batch, const std::vector<int> layers_sntt_log_size = {})
{ // TODO shanie future - consider using an algorithm for efficient reordering
if (layers_sntt_log_size.empty()) {
ICICLE_LOG_ERROR << "layers_sntt_log_size is null";
return eIcicleError::INVALID_ARGUMENT;
}
int stride = columns_batch ? batch_size : 1;
auto temp_input = std::make_unique<E[]>(batch_size * size);
for (int batch = 0; batch < batch_size; ++batch) {
E* current_elements = columns_batch ? input + batch : input + batch * size;
E* current_temp_input = columns_batch ? temp_input.get() + batch : temp_input.get() + batch * size;
uint64_t idx = 0;
uint64_t new_idx = 0;
int cur_ntt_log_size = layers_sntt_log_size[0];
int next_ntt_log_size = layers_sntt_log_size[1];
for (int i = 0; i < size; i++) {
int subntt_idx = i >> cur_ntt_log_size;
int element = i & ((1 << cur_ntt_log_size) - 1);
new_idx = subntt_idx + (element << next_ntt_log_size);
current_temp_input[stride * i] = current_elements[stride * new_idx];
}
}
std::copy(temp_input.get(), temp_input.get() + batch_size * size, input);
return eIcicleError::SUCCESS;
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError reorder_output(
E* output,
uint64_t size,
const std::vector<int> layers_sntt_log_size = {},
int batch_size = 1,
bool columns_batch = 0)
{ // TODO shanie future - consider using an algorithm for efficient reordering
if (layers_sntt_log_size.empty()) {
ICICLE_LOG_ERROR << "layers_sntt_log_size is null";
return eIcicleError::INVALID_ARGUMENT;
}
int temp_output_size = columns_batch ? size * batch_size : size;
auto temp_output = std::make_unique<E[]>(temp_output_size);
uint64_t idx = 0;
uint64_t mem_idx = 0;
uint64_t new_idx = 0;
int subntt_idx;
int element;
int s0 = layers_sntt_log_size[0];
int s1 = layers_sntt_log_size[1];
int s2 = layers_sntt_log_size[2];
int p0, p1, p2;
int stride = columns_batch ? batch_size : 1;
int rep = columns_batch ? batch_size : 1;
for (int batch = 0; batch < rep; ++batch) {
E* current_elements =
columns_batch
? output + batch
: output; // if columns_batch=false, then output is already shifted by batch*size when calling the function
E* current_temp_output = columns_batch ? temp_output.get() + batch : temp_output.get();
for (int i = 0; i < size; i++) {
if (layers_sntt_log_size[2]) {
p0 = (i >> (s1 + s2));
p1 = (((i >> s2) & ((1 << (s1)) - 1)) << s0);
p2 = ((i & ((1 << s2) - 1)) << (s0 + s1));
new_idx = p0 + p1 + p2;
current_temp_output[stride * new_idx] = current_elements[stride * i];
} else {
subntt_idx = i >> s1;
element = i & ((1 << s1) - 1);
new_idx = subntt_idx + (element << s0);
current_temp_output[stride * new_idx] = current_elements[stride * i];
}
}
}
std::copy(temp_output.get(), temp_output.get() + temp_output_size, output);
return eIcicleError::SUCCESS;
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError cpu_ntt_basic(
const icicle::Device& device,
E* input,
uint64_t original_size,
icicle::NTTDir dir,
const icicle::NTTConfig<S>& config,
E* output,
int block_idx = 0,
int subntt_idx = 0,
const std::vector<int> layers_sntt_log_size = {},
int layer = 0)
{
const uint64_t subntt_size = (1 << layers_sntt_log_size[layer]);
const uint64_t total_memory_size = original_size * config.batch_size;
const int log_original_size = int(log2(original_size));
const S* twiddles = CpuNttDomain<S>::s_ntt_domain.get_twiddles();
const int domain_max_size = CpuNttDomain<S>::s_ntt_domain.get_max_size();
if (domain_max_size < subntt_size) {
ICICLE_LOG_ERROR << "NTT domain size is less than input size. Domain size = " << domain_max_size
<< ", Input size = " << subntt_size;
return eIcicleError::INVALID_ARGUMENT;
}
// Copy input to "temp_elements" instead of pointing temp_elements to input to ensure freeing temp_elements does not
// free the input, preventing a potential double-free error.
// TODO [SHANIE]: Later, remove temp_elements and perform all calculations in-place
// (implement NTT for the case where columns_batch=true, in-place).
bool dit = true;
bool input_rev = false;
bool output_rev = false;
// bool need_to_reorder = false;
bool need_to_reorder = true;
// switch (config.ordering) { // kNN, kNR, kRN, kRR, kNM, kMN
// case Ordering::kNN: //dit R --> N
// need_to_reorder = true;
// break;
// case Ordering::kNR: // dif N --> R
// case Ordering::kNM: // dif N --> R
// dit = false;
// output_rev = true;
// break;
// case Ordering::kRR: // dif N --> R
// input_rev = true;
// output_rev = true;
// need_to_reorder = true;
// dit = false; // dif
// break;
// case Ordering::kRN: //dit R --> N
// case Ordering::kMN: //dit R --> N
// input_rev = true;
// break;
// default:
// return eIcicleError::INVALID_ARGUMENT;
// }
const uint64_t total_size = size * config.batch_size;
auto temp_elements = std::make_unique<E[]>(total_size);
auto vec_ops_config = default_vec_ops_config();
if (config.columns_batch) {
transpose(input, temp_elements.get(), size, config.batch_size);
if (need_to_reorder) {
reorder_by_bit_reverse(
log_original_size, input, config.batch_size, config.columns_batch, block_idx, subntt_idx, layers_sntt_log_size,
layer);
} // TODO - check if access the fixed indexes instead of reordering may be more efficient?
// NTT/INTT
if (dit) {
dit_ntt<S, E>(
input, original_size, config.batch_size, config.columns_batch, twiddles, dir, domain_max_size, block_idx,
subntt_idx, layers_sntt_log_size, layer); // R --> N
} else {
std::copy(input, input + total_size, temp_elements.get());
dif_ntt<S, E>(
input, original_size, config.batch_size, config.columns_batch, twiddles, dir, domain_max_size, block_idx,
subntt_idx, layers_sntt_log_size, layer); // N --> R
}
return eIcicleError::SUCCESS;
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError cpu_ntt_parallel(
const Device& device,
uint64_t size,
uint64_t original_size,
NTTDir dir,
const NTTConfig<S>& config,
E* output,
const S* twiddles,
const int domain_max_size = 0)
{
const int logn = int(log2(size));
std::vector<int> layers_sntt_log_size(
std::begin(layers_subntt_log_size[logn]), std::end(layers_subntt_log_size[logn]));
// Assuming that NTT fits in the cache, so we split the NTT to layers and calculate them one after the other.
// Subntts inside the same laye are calculate in parallel.
// Sorting is not needed, since the elements needed for each subntt are close to each other in memory.
// Instead of sorting, we are using the function idx_in_mem to calculate the memory index of each element.
for (int layer = 0; layer < layers_sntt_log_size.size(); layer++) {
if (layer == 0) {
int log_nof_subntts = layers_sntt_log_size[1];
int log_nof_blocks = layers_sntt_log_size[2];
for (int block_idx = 0; block_idx < (1 << log_nof_blocks); block_idx++) {
for (int subntt_idx = 0; subntt_idx < (1 << log_nof_subntts); subntt_idx++) {
cpu_ntt_basic(
device, output, original_size, dir, config, output, block_idx, subntt_idx, layers_sntt_log_size, layer);
}
}
}
if (layer == 1 && layers_sntt_log_size[1]) {
int log_nof_subntts = layers_sntt_log_size[0];
int log_nof_blocks = layers_sntt_log_size[2];
for (int block_idx = 0; block_idx < (1 << log_nof_blocks); block_idx++) {
for (int subntt_idx = 0; subntt_idx < (1 << log_nof_subntts); subntt_idx++) {
cpu_ntt_basic(
device, output /*input*/, original_size, dir, config, output, block_idx, subntt_idx, layers_sntt_log_size,
layer); // input=output (in-place)
}
}
}
if (layer == 2 && layers_sntt_log_size[2]) {
int log_nof_blocks = layers_sntt_log_size[0] + layers_sntt_log_size[1];
for (int block_idx = 0; block_idx < (1 << log_nof_blocks); block_idx++) {
cpu_ntt_basic(
device, output /*input*/, original_size, dir, config, output, block_idx, 0 /*subntt_idx - not used*/,
layers_sntt_log_size, layer); // input=output (in-place)
}
}
if (layer != 2 && layers_sntt_log_size[layer + 1] != 0) {
refactor_output<S, E>(
output, output /*input for next layer*/, original_size, config.batch_size, config.columns_batch, twiddles,
domain_max_size, layers_sntt_log_size, layer, dir);
}
}
// Sort the output at the end so that elements will be in right order.
// TODO SHANIE - After implementing for different ordering, maybe this should be done in a different place
// - When implementing real parallelism, consider sorting in parallel and in-place
if (layers_sntt_log_size[1]) { // at least 2 layers
if (config.columns_batch) {
reorder_output(output, size, layers_sntt_log_size, config.batch_size, config.columns_batch);
} else {
for (int b = 0; b < config.batch_size; b++) {
reorder_output(
output + b * original_size, size, layers_sntt_log_size, config.batch_size, config.columns_batch);
}
}
}
return eIcicleError::SUCCESS;
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output)
{
if (size & (size - 1)) {
ICICLE_LOG_ERROR << "Size must be a power of 2. size = " << size;
return eIcicleError::INVALID_ARGUMENT;
}
const int logn = int(log2(size));
const S* twiddles = CpuNttDomain<S>::s_ntt_domain.get_twiddles();
const int domain_max_size = CpuNttDomain<S>::s_ntt_domain.get_max_size();
std::unique_ptr<S[]> arbitrary_coset = nullptr;
// TODO SHANIE - move to init domain
int coset_stride = 0;
if (domain_max_size < size) {
ICICLE_LOG_ERROR << "NTT domain size is less than input size. Domain size = " << domain_max_size
<< ", Input size = " << size;
return eIcicleError::INVALID_ARGUMENT;
}
std::unique_ptr<S[]> arbitrary_coset = nullptr;
if (config.coset_gen != S::one()) { // TODO SHANIE - implement more efficient way to find coset_stride
for (int i = 1; i <= domain_max_size; i++) {
if (twiddles[i] == config.coset_gen) {
coset_stride = i;
break;
}
}
if (coset_stride == 0) { // if the coset_gen is not found in the twiddles, calculate arbitrary coset
try {
coset_stride = CpuNttDomain<S>::s_ntt_domain.coset_index.at(config.coset_gen);
ICICLE_LOG_DEBUG << "Coset generator found in twiddles. coset_stride=" << coset_stride;
} catch (const std::out_of_range& oor) {
ICICLE_LOG_DEBUG << "Coset generator not found in twiddles. Calculating arbitrary coset.";
auto temp_cosets = std::make_unique<S[]>(domain_max_size + 1);
arbitrary_coset = std::make_unique<S[]>(domain_max_size + 1);
arbitrary_coset[0] = S::one();
S coset_gen = dir == NTTDir::kForward ? config.coset_gen : S::inverse(config.coset_gen); // inverse for INTT
@@ -284,74 +638,81 @@ namespace ntt_cpu {
}
}
bool dit = true;
bool input_rev = false;
bool output_rev = false;
bool need_to_reorder = false;
bool coset = (config.coset_gen != S::one() && dir == NTTDir::kForward);
switch (config.ordering) { // kNN, kNR, kRN, kRR, kNM, kMN
case Ordering::kNN:
need_to_reorder = true;
break;
case Ordering::kNR:
case Ordering::kNM:
dit = false; // dif
output_rev = true;
break;
case Ordering::kRR:
input_rev = true;
output_rev = true;
need_to_reorder = true;
dit = false; // dif
break;
case Ordering::kRN:
case Ordering::kMN:
input_rev = true;
break;
default:
return eIcicleError::INVALID_ARGUMENT;
std::copy(input, input + size * config.batch_size, output);
if (config.ordering == Ordering::kRN || config.ordering == Ordering::kRR) {
reorder_by_bit_reverse(
logn, output, config.batch_size,
config.columns_batch); // TODO - check if access the fixed indexes instead of reordering may be more efficient?
}
if (coset) {
if (config.coset_gen != S::one() && dir == NTTDir::kForward) {
// bool input_rev = config.ordering == Ordering::kRR || config.ordering == Ordering::kMN || config.ordering ==
// Ordering::kRN;
bool input_rev = false;
coset_mul(
logn, domain_max_size, temp_elements.get(), config.batch_size, twiddles, coset_stride, arbitrary_coset,
input_rev);
logn, domain_max_size, output, config.batch_size, config.columns_batch, twiddles, coset_stride, arbitrary_coset,
input_rev, dir);
}
std::vector<int> layers_sntt_log_size(
std::begin(layers_subntt_log_size[logn]), std::end(layers_subntt_log_size[logn]));
if (need_to_reorder) { reorder_by_bit_reverse(logn, temp_elements.get(), config.batch_size); }
if (logn > 15) {
// TODO future - maybe can start 4'rth layer in parallel to 3'rd layer?
// Assuming that NTT doesn't fit in the cache, so we split the NTT to 2 layers and calculate them one after the
// other. Inside each layer each sub-NTT calculation is split to layers as well, and those are calculated in
// parallel. Sorting is done between the layers, so that the elements needed for each sunbtt are close to each
// other in memory.
int stride = config.columns_batch ? config.batch_size : 1;
reorder_input(output, size, config.batch_size, config.columns_batch, layers_sntt_log_size);
for (int subntt_idx = 0; subntt_idx < (1 << layers_sntt_log_size[1]); subntt_idx++) {
E* current_elements =
output + stride * (subntt_idx << layers_sntt_log_size[0]); // output + subntt_idx * subntt_size
cpu_ntt_parallel(
device, (1 << layers_sntt_log_size[0]), size, dir, config, current_elements, twiddles, domain_max_size);
}
refactor_and_reorder<S, E>(
output, output /*input for next layer*/, twiddles, config.batch_size, config.columns_batch, domain_max_size,
layers_sntt_log_size, 0 /*layer*/, dir);
for (int subntt_idx = 0; subntt_idx < (1 << layers_sntt_log_size[0]); subntt_idx++) {
E* current_elements =
output + stride * (subntt_idx << layers_sntt_log_size[1]); // output + subntt_idx * subntt_size
cpu_ntt_parallel(
device, (1 << layers_sntt_log_size[1]), size, dir, config, current_elements, twiddles, domain_max_size);
}
if (config.columns_batch) {
reorder_output(output, size, layers_sntt_log_size, config.batch_size, config.columns_batch);
} else {
for (int b = 0; b < config.batch_size; b++) {
reorder_output(output + b * size, size, layers_sntt_log_size, config.batch_size, config.columns_batch);
}
}
// NTT/INTT
if (dit) {
dit_ntt<S, E>(temp_elements.get(), size, config.batch_size, twiddles, dir, domain_max_size);
} else {
dif_ntt<S, E>(temp_elements.get(), size, config.batch_size, twiddles, dir, domain_max_size);
cpu_ntt_parallel(device, size, size, dir, config, output, twiddles, domain_max_size);
}
if (dir == NTTDir::kInverse) {
// Normalize results
if (dir == NTTDir::kInverse) { // TODO SHANIE - do that in parallel
S inv_size = S::inv_log_size(logn);
for (int i = 0; i < total_size; ++i) {
temp_elements[i] = temp_elements[i] * inv_size;
for (uint64_t i = 0; i < size * config.batch_size; ++i) {
output[i] = output[i] * inv_size;
}
if (config.coset_gen != S::one()) {
// bool output_rev = config.ordering == Ordering::kNR || config.ordering == Ordering::kNM || config.ordering ==
// Ordering::kRR;
bool output_rev = false;
coset_mul(
logn, domain_max_size, temp_elements.get(), config.batch_size, twiddles, coset_stride, arbitrary_coset,
output_rev, dir);
logn, domain_max_size, output, config.batch_size, config.columns_batch, twiddles, coset_stride,
arbitrary_coset, output_rev, dir);
}
}
if (config.columns_batch) {
transpose(temp_elements.get(), output, config.batch_size, size);
} else {
std::copy(temp_elements.get(), temp_elements.get() + total_size, output);
if (config.ordering == Ordering::kNR || config.ordering == Ordering::kRR) {
reorder_by_bit_reverse(
logn, output, config.batch_size,
config.columns_batch); // TODO - check if access the fixed indexes instead of reordering may be more efficient?
}
return eIcicleError::SUCCESS;
}
template <typename S = scalar_t, typename E = scalar_t>
eIcicleError cpu_ntt(const Device& device, const E* input, uint64_t size, NTTDir dir, NTTConfig<S>& config, E* output)
{
return cpu_ntt_ref(device, input, size, dir, config, output);
}
} // namespace ntt_cpu