Mixed-radix NTT algorithm

Co-authored-by: hadaringonyama <hadar@ingonyama.com>
This commit is contained in:
yshekel
2024-02-08 13:52:00 +02:00
committed by ImmanuelSegol
parent d367a8c1e0
commit 382bec4ad3
20 changed files with 1734 additions and 81 deletions

2
.gitignore vendored
View File

@@ -17,3 +17,5 @@
**/icicle/build/
**/wrappers/rust/icicle-cuda-runtime/src/bindings.rs
**/build
**/icicle/appUtils/large_ntt/work
icicle/appUtils/large_ntt/work/test_ntt

View File

@@ -5,28 +5,31 @@
#define CURVE_ID 1
// include NTT template
#include "appUtils/ntt/ntt.cu"
#include "appUtils/ntt/kernel_ntt.cu"
using namespace curve_config;
// Operate on scalars
typedef scalar_t S;
typedef scalar_t E;
void print_elements(const unsigned n, E * elements ) {
void print_elements(const unsigned n, E* elements)
{
for (unsigned i = 0; i < n; i++) {
std::cout << i << ": " << elements[i] << std::endl;
std::cout << i << ": " << elements[i] << std::endl;
}
}
void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E * elements ) {
void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E* elements)
{
// Lowest Harmonics
for (unsigned i = 0; i < ntt_size; i=i+1) {
for (unsigned i = 0; i < ntt_size; i = i + 1) {
elements[i] = E::one();
}
// print_elements(ntt_size, elements );
// Highest Harmonics
for (unsigned i = 1*ntt_size; i < 2*ntt_size; i=i+2) {
elements[i] = E::one();
elements[i+1] = E::neg(scalar_t::one());
for (unsigned i = 1 * ntt_size; i < 2 * ntt_size; i = i + 2) {
elements[i] = E::one();
elements[i + 1] = E::neg(scalar_t::one());
}
// print_elements(ntt_size, &elements[1*ntt_size] );
}
@@ -34,7 +37,7 @@ void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E * elem
int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* elements)
{
int nof_errors = 0;
E amplitude = E::from((uint32_t) ntt_size);
E amplitude = E::from((uint32_t)ntt_size);
// std::cout << "Amplitude: " << amplitude << std::endl;
// Lowest Harmonics
if (elements[0] != amplitude) {
@@ -44,8 +47,8 @@ int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* element
} else {
std::cout << "Validated lowest harmonics" << std::endl;
}
// Highest Harmonics
if (elements[1*ntt_size+ntt_size/2] != amplitude) {
// Highest Harmonics
if (elements[1 * ntt_size + ntt_size / 2] != amplitude) {
++nof_errors;
std::cout << "Error in highest harmonics! " << std::endl;
// print_elements(ntt_size, &elements[1*ntt_size] );
@@ -66,24 +69,24 @@ int main(int argc, char* argv[])
const unsigned nof_ntts = 2;
std::cout << "Number of NTTs: " << nof_ntts << std::endl;
const unsigned batch_size = nof_ntts * ntt_size;
std::cout << "Generating input data for lowest and highest harmonics" << std::endl;
E* input;
input = (E*) malloc(sizeof(E) * batch_size);
initialize_input(ntt_size, nof_ntts, input );
input = (E*)malloc(sizeof(E) * batch_size);
initialize_input(ntt_size, nof_ntts, input);
E* output;
output = (E*) malloc(sizeof(E) * batch_size);
output = (E*)malloc(sizeof(E) * batch_size);
std::cout << "Running NTT with on-host data" << std::endl;
cudaStream_t stream;
cudaStreamCreate(&stream);
// Create a device context
auto ctx = device_context::get_default_device_context();
// the next line is valid only for CURVE_ID 1 (will add support for other curves soon)
S rou = S{ {0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1} };
S rou = S{{0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1}};
ntt::InitDomain(rou, ctx);
// Create an NTTConfig instance
ntt::NTTConfig<S> config=ntt::DefaultNTTConfig<S>();
ntt::NTTConfig<S> config = ntt::DefaultNTTConfig<S>();
config.batch_size = nof_ntts;
config.ctx.stream = stream;
auto begin0 = std::chrono::high_resolution_clock::now();
@@ -91,7 +94,7 @@ int main(int argc, char* argv[])
auto end0 = std::chrono::high_resolution_clock::now();
auto elapsed0 = std::chrono::duration_cast<std::chrono::nanoseconds>(end0 - begin0);
printf("On-device runtime: %.3f seconds\n", elapsed0.count() * 1e-9);
validate_output(ntt_size, nof_ntts, output );
validate_output(ntt_size, nof_ntts, output);
cudaStreamDestroy(stream);
free(input);
free(output);

View File

@@ -0,0 +1,25 @@
# Make sure NVIDIA Container Toolkit is installed on your host
# Use the specified base image
FROM nvidia/cuda:12.0.0-devel-ubuntu22.04
# Update and install dependencies
RUN apt-get update && apt-get install -y \
cmake \
curl \
build-essential \
git \
libboost-all-dev \
&& rm -rf /var/lib/apt/lists/*
# Clone Icicle from a GitHub repository
RUN git clone https://github.com/ingonyama-zk/icicle.git /icicle
# Set the working directory in the container
WORKDIR /icicle-example
# Specify the default command for the container
CMD ["/bin/bash"]

View File

@@ -0,0 +1,22 @@
{
"name": "Icicle Examples: polynomial multiplication",
"build": {
"dockerfile": "Dockerfile"
},
"runArgs": [
"--gpus",
"all"
],
"postCreateCommand": [
"nvidia-smi"
],
"customizations": {
"vscode": {
"extensions": [
"ms-vscode.cmake-tools",
"ms-python.python",
"ms-vscode.cpptools"
]
}
}
}

View File

@@ -0,0 +1,26 @@
cmake_minimum_required(VERSION 3.18)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED TRUE)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
if (${CMAKE_VERSION} VERSION_LESS "3.24.0")
set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH})
else()
set(CMAKE_CUDA_ARCHITECTURES native) # on 3.24+, on earlier it is ignored, and the target is not passed
endif ()
project(icicle LANGUAGES CUDA CXX)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS_RELEASE "")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
# change the path to your Icicle location
include_directories("../../../icicle")
add_executable(
example
example.cu
)
find_library(NVML_LIBRARY nvidia-ml PATHS /usr/local/cuda-12.0/targets/x86_64-linux/lib/stubs/ )
target_link_libraries(example ${NVML_LIBRARY})
set_target_properties(example PROPERTIES CUDA_SEPARABLE_COMPILATION ON)

View File

@@ -0,0 +1,11 @@
#!/bin/bash
# Exit immediately on error
set -e
rm -rf build
mkdir -p build
cmake -S . -B build
cmake --build build

View File

@@ -0,0 +1,114 @@
#define CURVE_ID BLS12_381
#include <chrono>
#include <iostream>
#include <vector>
#include "curves/curve_config.cuh"
#include "appUtils/ntt/ntt.cu"
#include "appUtils/ntt/kernel_ntt.cu"
#include "utils/vec_ops.cu"
#include "utils/error_handler.cuh"
#include <memory>
typedef curve_config::scalar_t test_scalar;
typedef curve_config::scalar_t test_data;
void random_samples(test_data* res, uint32_t count)
{
for (int i = 0; i < count; i++)
res[i] = i < 1000 ? test_data::rand_host() : res[i - 1000];
}
void incremental_values(test_scalar* res, uint32_t count)
{
for (int i = 0; i < count; i++) {
res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero();
}
}
// calcaulting polynomial multiplication A*B via NTT,pointwise-multiplication and INTT
// (1) allocate A,B on CPU. Randomize first half, zero second half
// (2) allocate NttAGpu, NttBGpu on GPU
// (3) calc NTT for A and for B from cpu to GPU
// (4) multiply MulGpu = NttAGpu * NttBGpu (pointwise)
// (5) INTT MulGpu inplace
int main(int argc, char** argv)
{
cudaEvent_t start, stop;
float measured_time;
int NTT_LOG_SIZE = 23;
int NTT_SIZE = 1 << NTT_LOG_SIZE;
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context
// init domain
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward
ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false;
const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix";
std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: ";
CHK_IF_RETURN(cudaEventCreate(&start));
CHK_IF_RETURN(cudaEventCreate(&stop));
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
ntt::InitDomain(basic_root, ntt_config.ctx);
// (1) cpu allocation
auto CpuA = std::make_unique<test_data[]>(NTT_SIZE);
auto CpuB = std::make_unique<test_data[]>(NTT_SIZE);
random_samples(CpuA.get(), NTT_SIZE >> 1); // second half zeros
random_samples(CpuB.get(), NTT_SIZE >> 1); // second half zeros
test_data *GpuA, *GpuB, *MulGpu;
auto benchmark = [&](bool print, int iterations = 1) {
// start recording
CHK_IF_RETURN(cudaEventRecord(start, ntt_config.ctx.stream));
for (int iter = 0; iter < iterations; ++iter) {
// (2) gpu input allocation
CHK_IF_RETURN(cudaMallocAsync(&GpuA, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaMallocAsync(&GpuB, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
// (3) NTT for A,B from cpu to gpu
ntt_config.are_inputs_on_device = false;
ntt_config.are_outputs_on_device = true;
CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA));
CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB));
// (4) multiply A,B
CHK_IF_RETURN(cudaMallocAsync(&MulGpu, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
CHK_IF_RETURN(
vec_ops::Mul(GpuA, GpuB, NTT_SIZE, true /*=is_on_device*/, false /*=is_montgomery*/, ntt_config.ctx, MulGpu));
// (5) INTT (in place)
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu));
CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaFreeAsync(GpuB, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaFreeAsync(MulGpu, ntt_config.ctx.stream));
}
CHK_IF_RETURN(cudaEventRecord(stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&measured_time, start, stop));
if (print) { std::cout << measured_time / iterations << " MS" << std::endl; }
return CHK_LAST();
};
benchmark(false); // warmup
benchmark(true, 20);
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
return 0;
}

View File

@@ -0,0 +1,3 @@
#!/bin/bash
./build/example 1 # radix2
./build/example 0 # mixed-radix

View File

@@ -53,7 +53,7 @@ struct Args {
lower_bound_log_size: u8,
/// Upper bound of MSM sizes to run for
#[arg(short, long, default_value_t = 23)]
#[arg(short, long, default_value_t = 22)]
upper_bound_log_size: u8,
}

View File

@@ -108,6 +108,7 @@ if (NOT BUILD_TESTS)
primitives/projective.cu
appUtils/msm/msm.cu
appUtils/ntt/ntt.cu
appUtils/ntt/kernel_ntt.cu
${ICICLE_SOURCES}
)
set_target_properties(icicle PROPERTIES OUTPUT_NAME "ingo_${CURVE}")

View File

@@ -0,0 +1,6 @@
build_verification:
mkdir -p work
nvcc -o work/test_verification -I. -I.. -I../.. -I../ntt tests/verification.cu -std=c++17
test_verification: build_verification
work/test_verification

View File

@@ -0,0 +1,640 @@
#include "appUtils/ntt/thread_ntt.cu"
#include "curves/curve_config.cuh"
#include "utils/sharedmem.cuh"
#include "appUtils/ntt/ntt.cuh" // for Ordering
namespace ntt {
static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
{
uint32_t rev_num = 0, temp, dig_len;
if (dit) {
for (int i = 4; i >= 0; i--) {
dig_len = STAGE_SIZES_DEVICE[log_size][i];
temp = num & ((1 << dig_len) - 1);
num = num >> dig_len;
rev_num = rev_num << dig_len;
rev_num = rev_num | temp;
}
} else {
for (int i = 0; i < 5; i++) {
dig_len = STAGE_SIZES_DEVICE[log_size][i];
temp = num & ((1 << dig_len) - 1);
num = num >> dig_len;
rev_num = rev_num << dig_len;
rev_num = rev_num | temp;
}
}
return rev_num;
}
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void
reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
// launch N threads
// each thread starts from one index and calculates the corresponding group
// if its index is the smallest number in the group -> do the memory transformation
// else --> do nothing
const uint32_t idx = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t next_element = idx;
uint32_t group[MAX_GROUP_SIZE];
group[0] = idx;
uint32_t i = 1;
for (; i < MAX_GROUP_SIZE;) {
next_element = dig_rev(next_element, log_size, dit);
if (next_element < idx) return; // not handling this group
if (next_element == idx) break; // calculated whole group
group[i++] = next_element;
}
if (i == 1) { // single element in group --> nothing to do (except maybe normalize for INTT)
if (is_normalize) { arr[idx] = arr[idx] * inverse_N; }
return;
}
--i;
// reaching here means I am handling this group
const E last_element_in_group = arr[group[i]];
for (; i > 0; --i) {
arr[group[i]] = is_normalize ? (arr[group[i - 1]] * inverse_N) : arr[group[i - 1]];
}
arr[idx] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group;
}
template <typename E, typename S>
__launch_bounds__(64) __global__
void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
uint32_t rd = tid;
uint32_t wr = dig_rev(tid, log_size, dit);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt64(
E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
SharedMemory<E> smem;
E* shmem = smem.getPointer();
s_meta.th_stride = 8;
s_meta.ntt_block_size = 64;
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride && dit) {
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.loadInternalTwiddles64(internal_twiddles, strided, inv);
#pragma unroll 1
for (uint32_t phase = 0; phase < 2; phase++) {
engine.ntt8win();
if (phase == 0) {
engine.SharedData64Columns8(shmem, true, false, strided); // store
__syncthreads();
engine.SharedData64Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
}
}
if (twiddle_stride && !dit) {
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt32(
E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
SharedMemory<E> smem;
E* shmem = smem.getPointer();
s_meta.th_stride = 4;
s_meta.ntt_block_size = 32;
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
engine.ntt8win();
engine.twiddlesInternal();
engine.SharedData32Columns8(shmem, true, false, strided); // store
__syncthreads();
engine.SharedData32Rows4_2(shmem, false, false, strided); // load
engine.ntt4_2();
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt32dit(
E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
int32_t tw_log_size,
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
SharedMemory<E> smem;
E* shmem = smem.getPointer();
s_meta.th_stride = 4;
s_meta.ntt_block_size = 32;
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
engine.ntt4_2();
engine.SharedData32Columns4_2(shmem, true, false, strided); // store
__syncthreads();
engine.SharedData32Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
engine.ntt8win();
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt16(
E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
SharedMemory<E> smem;
E* shmem = smem.getPointer();
s_meta.th_stride = 2;
s_meta.ntt_block_size = 16;
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
engine.ntt8win();
engine.twiddlesInternal();
engine.SharedData16Columns8(shmem, true, false, strided); // store
__syncthreads();
engine.SharedData16Rows2_4(shmem, false, false, strided); // load
engine.ntt2_4();
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt16dit(
E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t data_stride,
uint32_t log_data_stride,
uint32_t twiddle_stride,
bool strided,
uint32_t stage_num,
bool inv,
bool dit)
{
NTTEngine<E, S> engine;
stage_metadata s_meta;
SharedMemory<E> smem;
E* shmem = smem.getPointer();
s_meta.th_stride = 2;
s_meta.ntt_block_size = 16;
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
engine.loadBasicTwiddles(basic_twiddles, inv);
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (twiddle_stride) {
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
engine.ntt2_4();
engine.SharedData16Columns2_4(shmem, true, false, strided); // store
__syncthreads();
engine.SharedData16Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
engine.ntt8win();
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
}
template <typename E, typename S>
__global__ void normalize_kernel(E* data, S norm_factor)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
data[tid] = data[tid] * norm_factor;
}
template <typename S>
__global__ void generate_base_table(S basic_root, S* base_table, uint32_t skip)
{
S w = basic_root;
S t = S::one();
for (int i = 0; i < 64; i += skip) {
base_table[i] = t;
t = t * w;
}
}
template <typename S>
__global__ void generate_basic_twiddles(S basic_root, S* w6_table, S* basic_twiddles)
{
S w0 = basic_root * basic_root;
S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1);
S w2 = (basic_root - w0 * basic_root) * S::inv_log_size(1);
basic_twiddles[0] = w0;
basic_twiddles[1] = w1;
basic_twiddles[2] = w2;
S basic_inv = w6_table[64 - 8];
w0 = basic_inv * basic_inv;
w1 = (basic_inv + w0 * basic_inv) * S::inv_log_size(1);
w2 = (basic_inv - w0 * basic_inv) * S::inv_log_size(1);
basic_twiddles[3] = w0;
basic_twiddles[4] = w1;
basic_twiddles[5] = w2;
}
template <typename S>
__global__ void generate_twiddle_combinations_generic(
S* w6_table, S* w12_table, S* w18_table, S* w24_table, S* w30_table, S* external_twiddles, uint32_t log_size)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
uint32_t exp = tid << (30 - log_size);
S w6, w12, w18, w24, w30;
w6 = w6_table[exp >> 24];
w12 = w12_table[((exp >> 18) & 0x3f)];
w18 = w18_table[((exp >> 12) & 0x3f)];
w24 = w24_table[((exp >> 6) & 0x3f)];
w30 = w30_table[(exp & 0x3f)];
S t = w6 * w12 * w18 * w24 * w30;
external_twiddles[tid] = t;
}
template <typename S>
__global__ void set_value(S* arr, int idx, S val)
{
arr[idx] = val;
}
template <typename S>
cudaError_t generate_external_twiddles_generic(
const S& basic_root,
S* external_twiddles,
S*& internal_twiddles,
S*& basic_twiddles,
uint32_t log_size,
cudaStream_t& stream)
{
CHK_INIT_IF_RETURN();
const int n = pow(2, log_size);
CHK_IF_RETURN(cudaMallocAsync(&basic_twiddles, 6 * sizeof(S), stream));
S* w6_table;
S* w12_table;
S* w18_table;
S* w24_table;
S* w30_table;
CHK_IF_RETURN(cudaMallocAsync(&w6_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w12_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w18_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w24_table, sizeof(S) * 64, stream));
CHK_IF_RETURN(cudaMallocAsync(&w30_table, sizeof(S) * 64, stream));
// Note: for compatibility with radix-2 INTT, need ONE in last element (in addition to first element)
set_value<<<1, 1, 0, stream>>>(external_twiddles, n /*last element idx*/, S::one());
cudaStreamSynchronize(stream);
S temp_root = basic_root;
generate_base_table<<<1, 1, 0, stream>>>(basic_root, w30_table, 1 << (30 - log_size));
if (log_size > 24)
for (int i = 0; i < 6 - (30 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w24_table, 1 << (log_size > 24 ? 0 : 24 - log_size));
if (log_size > 18)
for (int i = 0; i < 6 - (log_size > 24 ? 0 : 24 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w18_table, 1 << (log_size > 18 ? 0 : 18 - log_size));
if (log_size > 12)
for (int i = 0; i < 6 - (log_size > 18 ? 0 : 18 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w12_table, 1 << (log_size > 12 ? 0 : 12 - log_size));
if (log_size > 6)
for (int i = 0; i < 6 - (log_size > 12 ? 0 : 12 - log_size); i++)
temp_root = temp_root * temp_root;
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w6_table, 1 << (log_size > 6 ? 0 : 6 - log_size));
if (log_size > 2)
for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++)
temp_root = temp_root * temp_root;
generate_basic_twiddles<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles);
const int NOF_BLOCKS = (log_size >= 8) ? (1 << (log_size - 8)) : 1;
const int NOF_THREADS = (log_size >= 8) ? 256 : (1 << log_size);
generate_twiddle_combinations_generic<<<NOF_BLOCKS, NOF_THREADS, 0, stream>>>(
w6_table, w12_table, w18_table, w24_table, w30_table, external_twiddles, log_size);
internal_twiddles = w6_table;
CHK_IF_RETURN(cudaFreeAsync(w12_table, stream));
CHK_IF_RETURN(cudaFreeAsync(w18_table, stream));
CHK_IF_RETURN(cudaFreeAsync(w24_table, stream));
CHK_IF_RETURN(cudaFreeAsync(w30_table, stream));
return CHK_LAST();
}
template <typename E, typename S>
cudaError_t large_ntt(
E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
bool inv,
bool normalize,
bool dit,
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
if (log_size == 1 || log_size == 2 || log_size == 3 || log_size == 7) {
throw IcicleError(IcicleError_t::InvalidArgument, "size not implemented for mixed-radix-NTT");
}
if (log_size == 4) {
if (dit) {
ntt16dit<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
} else { // dif
ntt16<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
}
if (normalize) normalize_kernel<<<1, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
return CHK_LAST();
}
if (log_size == 5) {
if (dit) {
ntt32dit<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
} else { // dif
ntt32<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
}
if (normalize) normalize_kernel<<<1, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
return CHK_LAST();
}
if (log_size == 6) {
ntt64<<<1, 8, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
if (normalize) normalize_kernel<<<1, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
return CHK_LAST();
}
if (log_size == 8) {
if (dit) {
ntt16dit<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
ntt16dit<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1,
inv,
dit); // we need threads 32+ although 16-31 are idle
} else { // dif
ntt16<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1, inv,
dit); // we need threads 32+ although 16-31 are idle
ntt16<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
dit);
}
if (normalize) normalize_kernel<<<1, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
return CHK_LAST();
}
// general case:
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
uint32_t stride_log = 0;
for (int j = 0; j < i; j++)
stride_log += STAGE_SIZES_HOST[log_size][j];
if (stage_size == 6)
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 5)
ntt32dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 4)
ntt16dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
}
} else { // dif
bool first_run = false, prev_stage = false;
for (int i = 4; i >= 0; i--) {
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
uint32_t stride_log = 0;
for (int j = 0; j < i; j++)
stride_log += STAGE_SIZES_HOST[log_size][j];
first_run = stage_size && !prev_stage;
if (stage_size == 6)
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 5)
ntt32<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
else if (stage_size == 4)
ntt16<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
prev_stage = stage_size;
}
}
if (normalize) normalize_kernel<<<1 << (log_size - 8), 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
return CHK_LAST();
}
template <typename E, typename S>
cudaError_t mixed_radix_ntt(
E* d_input,
E* d_output,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
int ntt_size,
int max_logn,
bool is_inverse,
Ordering ordering,
cudaStream_t cuda_stream)
{
CHK_INIT_IF_RETURN();
// TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse)
if (ordering != Ordering::kNN) {
throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only");
}
const int logn = int(log2(ntt_size));
const int NOF_BLOCKS = (1 << (max(logn, 6) - 6));
const int NOF_THREADS = min(64, 1 << logn);
const bool reverse_input = ordering == Ordering::kNN;
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
bool is_normalize = is_inverse;
if (reverse_input) {
// Note: fusing reorder with normalize for INTT
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
}
// inplace ntt
CHK_IF_RETURN(large_ntt(
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, is_inverse,
is_normalize, is_dit, cuda_stream));
return CHK_LAST();
}
// Explicit instantiation for scalar type
template cudaError_t generate_external_twiddles_generic(
const curve_config::scalar_t& basic_root,
curve_config::scalar_t* external_twiddles,
curve_config::scalar_t*& internal_twiddles,
curve_config::scalar_t*& basic_twiddles,
uint32_t log_size,
cudaStream_t& stream);
template cudaError_t mixed_radix_ntt<curve_config::scalar_t, curve_config::scalar_t>(
curve_config::scalar_t* d_input,
curve_config::scalar_t* d_output,
curve_config::scalar_t* external_twiddles,
curve_config::scalar_t* internal_twiddles,
curve_config::scalar_t* basic_twiddles,
int ntt_size,
int max_logn,
bool is_inverse,
Ordering ordering,
cudaStream_t cuda_stream);
} // namespace ntt

View File

@@ -7,6 +7,7 @@
#include "utils/sharedmem.cuh"
#include "utils/utils_kernels.cuh"
#include "utils/utils.h"
#include "appUtils/ntt/ntt_impl.cuh"
namespace ntt {
@@ -25,7 +26,10 @@ namespace ntt {
int idx = threadId % n;
int batch_idx = threadId / n;
int idx_reversed = __brev(idx) >> (32 - logn);
arr_reversed[batch_idx * n + idx_reversed] = arr[batch_idx * n + idx];
E val = arr[batch_idx * n + idx];
if (arr == arr_reversed) { __syncthreads(); } // for in-place (when pointers arr==arr_reversed)
arr_reversed[batch_idx * n + idx_reversed] = val;
}
}
@@ -357,25 +361,24 @@ namespace ntt {
template <typename S>
class Domain
{
static int max_size;
static S* twiddles;
static std::unordered_map<S, int> coset_index;
static inline int max_size = 0;
static inline int max_log_size = 0;
static inline S* twiddles = nullptr;
static inline std::unordered_map<S, int> coset_index = {};
static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT
static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT
public:
template <typename U>
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
template <typename U, typename E>
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
};
template <typename S>
int Domain<S>::max_size = 0;
template <typename S>
S* Domain<S>::twiddles = nullptr;
template <typename S>
std::unordered_map<S, int> Domain<S>::coset_index = {};
template <typename S>
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
{
@@ -385,34 +388,70 @@ namespace ntt {
// please note that this is not thread-safe at all,
// but it's a singleton that is supposed to be initialized once per program lifetime
if (!Domain<S>::twiddles) {
bool found_logn = false;
S omega = primitive_root;
unsigned omegas_count = S::get_omegas_count();
for (int i = 0; i < omegas_count; i++)
for (int i = 0; i < omegas_count; i++) {
omega = S::sqr(omega);
if (!found_logn) {
++Domain<S>::max_log_size;
found_logn = omega == S::one();
if (found_logn) break;
}
}
Domain<S>::max_size = (int)pow(2, Domain<S>::max_log_size);
if (omega != S::one()) {
std::cerr << "Primitive root provided to the InitDomain function is not in the subgroup" << '\n';
throw -1;
throw IcicleError(
IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup");
}
std::vector<S> h_twiddles;
h_twiddles.push_back(S::one());
int n = 1;
do {
Domain<S>::coset_index[h_twiddles.at(n - 1)] = n - 1;
h_twiddles.push_back(h_twiddles.at(n - 1) * primitive_root);
} while (h_twiddles.at(n++) != S::one());
CHK_IF_RETURN(cudaMallocAsync(&Domain<S>::twiddles, n * sizeof(S), ctx.stream));
CHK_IF_RETURN(
cudaMemcpyAsync(Domain<S>::twiddles, &h_twiddles.front(), n * sizeof(S), cudaMemcpyHostToDevice, ctx.stream));
Domain<S>::max_size = n - 1;
// allocate and calculate twiddles on GPU
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
CHK_IF_RETURN(cudaMallocManaged(&Domain<S>::twiddles, (Domain<S>::max_size + 1) * sizeof(S)));
CHK_IF_RETURN(generate_external_twiddles_generic(
primitive_root, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles,
Domain<S>::max_log_size, ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
const bool is_map_only_powers_of_primitive_root = true;
if (is_map_only_powers_of_primitive_root) {
// populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8
// etc.)
Domain<S>::coset_index[S::one()] = 0;
for (int i = 0; i < Domain<S>::max_log_size; ++i) {
const int index = (int)pow(2, i);
Domain<S>::coset_index[Domain<S>::twiddles[index]] = index;
}
} else {
// populate all values
for (int i = 0; i < Domain<S>::max_size; ++i) {
Domain<S>::coset_index[Domain<S>::twiddles[i]] = i;
}
}
}
return CHK_LAST();
}
template <typename S>
cudaError_t Domain<S>::ReleaseDomain(device_context::DeviceContext& ctx)
{
CHK_INIT_IF_RETURN();
max_size = 0;
max_log_size = 0;
cudaFreeAsync(twiddles, ctx.stream);
twiddles = nullptr;
cudaFreeAsync(internal_twiddles, ctx.stream);
internal_twiddles = nullptr;
cudaFreeAsync(basic_twiddles, ctx.stream);
basic_twiddles = nullptr;
coset_index.clear();
return CHK_LAST();
}
template <typename S, typename E>
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
{
@@ -431,6 +470,20 @@ namespace ntt {
bool are_inputs_on_device = config.are_inputs_on_device;
bool are_outputs_on_device = config.are_outputs_on_device;
E* d_input;
if (are_inputs_on_device) {
d_input = input;
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream));
CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
}
E* d_output;
if (are_outputs_on_device) {
d_output = output;
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
}
S* coset = nullptr;
int coset_index = 0;
try {
@@ -448,45 +501,44 @@ namespace ntt {
h_coset.clear();
}
E* d_input;
if (are_inputs_on_device) {
d_input = input;
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream));
CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
}
E* d_output;
if (are_outputs_on_device) {
d_output = output;
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
}
const bool is_small_ntt = logn < 16; // cutoff point where mixed-radix is faster than radix-2
const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet
const bool is_batch_ntt = batch_size > 1; // batch not supported by mixed-radidx algorithm yet
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
const bool is_radix2_algorithm = config.is_force_radix2 || is_batch_ntt || is_small_ntt || is_on_coset || !is_NN;
bool ct_butterfly = true;
bool reverse_input = false;
switch (config.ordering) {
case Ordering::kNN:
reverse_input = true;
break;
case Ordering::kNR:
ct_butterfly = false;
break;
case Ordering::kRR:
reverse_input = true;
ct_butterfly = false;
break;
if (is_radix2_algorithm) {
bool ct_butterfly = true;
bool reverse_input = false;
switch (config.ordering) {
case Ordering::kNN:
reverse_input = true;
break;
case Ordering::kNR:
ct_butterfly = false;
break;
case Ordering::kRR:
reverse_input = true;
ct_butterfly = false;
break;
}
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
CHK_IF_RETURN(ntt_inplace_batch_template(
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
} else { // mixed-radix algorithm
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
Domain<S>::max_log_size, dir == NTTDir::kInverse, config.ordering, stream));
}
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
CHK_IF_RETURN(ntt_inplace_batch_template(
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
if (!are_outputs_on_device)
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
@@ -506,6 +558,7 @@ namespace ntt {
false, // are_inputs_on_device
false, // are_outputs_on_device
false, // is_async
false, // is_force_radix2
};
return config;
}

View File

@@ -80,6 +80,8 @@ namespace ntt {
* non-blocking and you'd need to synchronize it explicitly by running
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT
* function will block the current CPU thread. */
bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects
radix-2 or mixed-radix algorithm based on heuristics). */
};
/**

View File

@@ -0,0 +1,33 @@
#pragma once
#ifndef _NTT_IMPL_H
#define _NTT_IMPL_H
#include <stdint.h>
#include "appUtils/ntt/ntt.cuh" // for enum Ordering
namespace ntt {
template <typename S>
cudaError_t generate_external_twiddles_generic(
const S& basic_root,
S* external_twiddles,
S*& internal_twiddles,
S*& basic_twiddles,
uint32_t log_size,
cudaStream_t& stream);
template <typename E, typename S>
cudaError_t mixed_radix_ntt(
E* d_input,
E* d_output,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
int ntt_size,
int max_logn,
bool is_inverse,
Ordering ordering,
cudaStream_t cuda_stream);
} // namespace ntt
#endif //_NTT_IMPL_H

View File

@@ -0,0 +1,155 @@
#define CURVE_ID BLS12_381
#include "primitives/field.cuh"
#include "primitives/projective.cuh"
#include "utils/cuda_utils.cuh"
#include <chrono>
#include <iostream>
#include <vector>
#include "curves/curve_config.cuh"
#include "ntt/ntt.cu"
#include "ntt/ntt_impl.cuh"
#include <memory>
typedef curve_config::scalar_t test_scalar;
typedef curve_config::scalar_t test_data;
#include "kernel_ntt.cu"
void random_samples(test_data* res, uint32_t count)
{
for (int i = 0; i < count; i++)
res[i] = i < 1000 ? test_data::rand_host() : res[i - 1000];
}
void incremental_values(test_scalar* res, uint32_t count)
{
for (int i = 0; i < count; i++) {
res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero();
}
}
int main(int argc, char** argv)
{
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
float icicle_time, new_time;
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : true;
int INV = (argc > 3) ? atoi(argv[3]) : true;
const ntt::Ordering ordering = ntt::Ordering::kNN;
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
: ordering == ntt::Ordering::kNR ? "NR"
: ordering == ntt::Ordering::kRN ? "RN"
: "RR";
printf("running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV);
cudaFree(nullptr); // init GPU context (warmup)
// init domain
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
ntt_config.ordering = ordering;
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
CHK_IF_RETURN(cudaEventCreate(&new_start));
CHK_IF_RETURN(cudaEventCreate(&new_stop));
auto start = std::chrono::high_resolution_clock::now();
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
ntt::InitDomain(basic_root, ntt_config.ctx);
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count();
std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl;
// cpu allocation
auto CpuScalars = std::make_unique<test_data[]>(NTT_SIZE);
auto CpuOutputOld = std::make_unique<test_data[]>(NTT_SIZE);
auto CpuOutputNew = std::make_unique<test_data[]>(NTT_SIZE);
// gpu allocation
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE));
// init inputs
incremental_values(CpuScalars.get(), NTT_SIZE);
CHK_IF_RETURN(cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE, cudaMemcpyHostToDevice));
// inplace
if (INPLACE) {
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}
// run ntt
auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
// NEW
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt)
for (size_t i = 0; i < iterations; i++) {
ntt::NTT(
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
GpuOutputNew);
}
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop));
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
// OLD
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
ntt_config.is_force_radix2 = true;
for (size_t i = 0; i < iterations; i++) {
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld);
}
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop));
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
if (is_print) {
printf("Old Runtime=%0.3f MS\n", icicle_time / iterations);
printf("New Runtime=%0.3f MS\n", new_time / iterations);
}
return CHK_LAST();
};
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
int count = INPLACE ? 1 : 10;
if (INPLACE) {
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}
CHK_IF_RETURN(benchmark(true /*=print*/, count));
// verify
CHK_IF_RETURN(cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
CHK_IF_RETURN(cudaMemcpy(CpuOutputOld.get(), GpuOutputOld, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
bool success = true;
for (int i = 0; i < NTT_SIZE; i++) {
if (CpuOutputNew[i] != CpuOutputOld[i]) {
success = false;
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
break;
} else {
// std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl;
// break;
}
}
const char* success_str = success ? "SUCCESS!" : "FAIL!";
printf("%s\n", success_str);
CHK_IF_RETURN(cudaFree(GpuScalars));
CHK_IF_RETURN(cudaFree(GpuOutputOld));
CHK_IF_RETURN(cudaFree(GpuOutputNew));
return CHK_LAST();
}

View File

@@ -0,0 +1,542 @@
#ifndef T_NTT
#define T_NTT
#pragma once
#include <stdio.h>
#include <stdint.h>
#include "curves/curve_config.cuh"
struct stage_metadata {
uint32_t th_stride;
uint32_t ntt_block_size;
uint32_t ntt_block_id;
uint32_t ntt_inp_id;
};
uint32_t constexpr STAGE_SIZES_HOST[31][5] = {
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = {
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
template <typename E, typename S>
class NTTEngine
{
public:
E X[8];
S WB[3];
S WI[7];
S WE[8];
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles, bool inv)
{
#pragma unroll
for (int i = 0; i < 3; i++) {
WB[i] = basic_twiddles[inv ? i + 3 : i];
}
}
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride, bool inv)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
uint32_t exp = ((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1);
WI[i] = data[(inv && exp) ? 64 - exp : exp]; // if exp = 0 we also take exp and not 64-exp
}
}
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride, bool inv)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
uint32_t exp = 2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1);
WI[i] = data[(inv && exp) ? 64 - exp : exp];
}
}
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride, bool inv)
{
#pragma unroll
for (int i = 0; i < 7; i++) {
uint32_t exp = 4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1);
WI[i] = data[(inv && exp) ? 64 - exp : exp];
}
}
__device__ __forceinline__ void loadExternalTwiddlesGeneric64(
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
{
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
uint32_t exp = (s_meta.ntt_inp_id + 8 * i) * (s_meta.ntt_block_id & (tw_order - 1))
<< (tw_log_size - tw_log_order - 6);
WE[i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
}
}
__device__ __forceinline__ void loadExternalTwiddlesGeneric32(
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
{
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
uint32_t exp = (s_meta.ntt_inp_id * 2 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1))
<< (tw_log_size - tw_log_order - 5);
WE[4 * j + i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
}
}
}
__device__ __forceinline__ void loadExternalTwiddlesGeneric16(
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
{
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
uint32_t exp = (s_meta.ntt_inp_id * 4 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1))
<< (tw_log_size - tw_log_order - 4);
WE[2 * j + i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
}
}
}
__device__ __forceinline__ void loadGlobalData(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
}
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[s_meta.th_stride * i * data_stride];
}
}
__device__ __forceinline__ void storeGlobalData(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
}
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
data[s_meta.th_stride * i * data_stride] = X[i];
}
}
__device__ __forceinline__ void loadGlobalData32(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
}
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
X[4 * j + i] = data[(8 * i + j) * data_stride];
}
}
}
__device__ __forceinline__ void storeGlobalData32(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
}
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
data[(8 * i + j) * data_stride] = X[4 * j + i];
}
}
}
__device__ __forceinline__ void loadGlobalData16(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
}
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
X[2 * j + i] = data[(8 * i + j) * data_stride];
}
}
}
__device__ __forceinline__ void storeGlobalData16(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
} else {
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
}
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
data[(8 * i + j) * data_stride] = X[2 * j + i];
}
}
}
__device__ __forceinline__ void ntt4_2()
{
#pragma unroll
for (int i = 0; i < 2; i++) {
ntt4(X[4 * i], X[4 * i + 1], X[4 * i + 2], X[4 * i + 3]);
}
}
__device__ __forceinline__ void ntt2_4()
{
#pragma unroll
for (int i = 0; i < 4; i++) {
ntt2(X[2 * i], X[2 * i + 1]);
}
}
__device__ __forceinline__ void ntt2(E& X0, E& X1)
{
E T;
T = X0 + X1;
X1 = X0 - X1;
X0 = T;
}
__device__ __forceinline__ void ntt4(E& X0, E& X1, E& X2, E& X3)
{
E T;
T = X0 + X2;
X2 = X0 - X2;
X0 = X1 + X3;
X1 = X1 - X3; // T has X0, X0 has X1, X2 has X2, X1 has X3
X1 = X1 * WB[0];
X3 = X2 - X1;
X1 = X2 + X1;
X2 = T - X0;
X0 = T + X0;
}
// rbo version
__device__ __forceinline__ void ntt4rbo(E& X0, E& X1, E& X2, E& X3)
{
E T;
T = X0 - X1;
X0 = X0 + X1;
X1 = X2 + X3;
X3 = X2 - X3; // T has X0, X0 has X1, X2 has X2, X1 has X3
X3 = X3 * WB[0];
X2 = X0 - X1;
X0 = X0 + X1;
X1 = T + X3;
X3 = T - X3;
}
__device__ __forceinline__ void ntt8(E& X0, E& X1, E& X2, E& X3, E& X4, E& X5, E& X6, E& X7)
{
E T;
// out of 56,623,104 possible mappings, we have:
T = X3 - X7;
X7 = X3 + X7;
X3 = X1 - X5;
X5 = X1 + X5;
X1 = X2 + X6;
X2 = X2 - X6;
X6 = X0 + X4;
X0 = X0 - X4;
T = T * WB[1];
X2 = X2 * WB[1];
X4 = X6 + X1;
X6 = X6 - X1;
X1 = X3 + T;
X3 = X3 - T;
T = X5 + X7;
X5 = X5 - X7;
X7 = X0 + X2;
X0 = X0 - X2;
X1 = X1 * WB[0];
X5 = X5 * WB[1];
X3 = X3 * WB[2];
X2 = X6 + X5;
X6 = X6 - X5;
X5 = X7 - X1;
X1 = X7 + X1;
X7 = X0 - X3;
X3 = X0 + X3;
X0 = X4 + T;
X4 = X4 - T;
}
__device__ __forceinline__ void ntt8win()
{
E T;
T = X[3] - X[7];
X[7] = X[3] + X[7];
X[3] = X[1] - X[5];
X[5] = X[1] + X[5];
X[1] = X[2] + X[6];
X[2] = X[2] - X[6];
X[6] = X[0] + X[4];
X[0] = X[0] - X[4];
X[2] = X[2] * WB[0];
X[4] = X[6] + X[1];
X[6] = X[6] - X[1];
X[1] = X[3] + T;
X[3] = X[3] - T;
T = X[5] + X[7];
X[5] = X[5] - X[7];
X[7] = X[0] + X[2];
X[0] = X[0] - X[2];
X[1] = X[1] * WB[1];
X[5] = X[5] * WB[0];
X[3] = X[3] * WB[2];
X[2] = X[6] + X[5];
X[6] = X[6] - X[5];
X[5] = X[1] + X[3];
X[3] = X[1] - X[3];
X[1] = X[7] + X[5];
X[5] = X[7] - X[5];
X[7] = X[0] - X[3];
X[3] = X[0] + X[3];
X[0] = X[4] + T;
X[4] = X[4] - T;
}
__device__ __forceinline__ void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3;
uint32_t column_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
if (store) {
shmem[ntt_id * 64 + i * 8 + column_id] = X[i];
} else {
X[i] = shmem[ntt_id * 64 + i * 8 + column_id];
}
}
}
__device__ __forceinline__ void SharedData64Rows8(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3;
uint32_t row_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
if (store) {
shmem[ntt_id * 64 + row_id * 8 + i] = X[i];
} else {
X[i] = shmem[ntt_id * 64 + row_id * 8 + i];
}
}
}
__device__ __forceinline__ void SharedData32Columns8(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
uint32_t column_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
if (store) {
shmem[ntt_id * 32 + i * 4 + column_id] = X[i];
} else {
X[i] = shmem[ntt_id * 32 + i * 4 + column_id];
}
}
}
__device__ __forceinline__ void SharedData32Rows8(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
uint32_t row_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
if (store) {
shmem[ntt_id * 32 + row_id * 8 + i] = X[i];
} else {
X[i] = shmem[ntt_id * 32 + row_id * 8 + i];
}
}
}
__device__ __forceinline__ void SharedData32Columns4_2(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
uint32_t column_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2;
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
if (store) {
shmem[ntt_id * 32 + i * 8 + column_id + j] = X[4 * j + i];
} else {
X[4 * j + i] = shmem[ntt_id * 32 + i * 8 + column_id + j];
}
}
}
}
__device__ __forceinline__ void SharedData32Rows4_2(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
uint32_t row_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2;
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
if (store) {
shmem[ntt_id * 32 + row_id * 4 + 4 * j + i] = X[4 * j + i];
} else {
X[4 * j + i] = shmem[ntt_id * 32 + row_id * 4 + 4 * j + i];
}
}
}
}
__device__ __forceinline__ void SharedData16Columns8(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
uint32_t column_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
if (store) {
shmem[ntt_id * 16 + i * 2 + column_id] = X[i];
} else {
X[i] = shmem[ntt_id * 16 + i * 2 + column_id];
}
}
}
__device__ __forceinline__ void SharedData16Rows8(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
uint32_t row_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
if (store) {
shmem[ntt_id * 16 + row_id * 8 + i] = X[i];
} else {
X[i] = shmem[ntt_id * 16 + row_id * 8 + i];
}
}
}
__device__ __forceinline__ void SharedData16Columns2_4(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
uint32_t column_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4;
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
if (store) {
shmem[ntt_id * 16 + i * 8 + column_id + j] = X[2 * j + i];
} else {
X[2 * j + i] = shmem[ntt_id * 16 + i * 8 + column_id + j];
}
}
}
}
__device__ __forceinline__ void SharedData16Rows2_4(E* shmem, bool store, bool high_bits, bool stride)
{
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
uint32_t row_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4;
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
if (store) {
shmem[ntt_id * 16 + row_id * 2 + 2 * j + i] = X[2 * j + i];
} else {
X[2 * j + i] = shmem[ntt_id * 16 + row_id * 2 + 2 * j + i];
}
}
}
}
__device__ __forceinline__ void twiddlesInternal()
{
#pragma unroll
for (int i = 1; i < 8; i++) {
X[i] = X[i] * WI[i - 1];
}
}
__device__ __forceinline__ void twiddlesExternal()
{
#pragma unroll
for (int i = 0; i < 8; i++) {
X[i] = X[i] * WE[i];
}
}
};
#endif

View File

@@ -75,11 +75,19 @@ public:
return Field{omega_inv.storages[logn - 1]};
}
static HOST_INLINE Field inv_log_size(uint32_t logn)
static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn)
{
if (logn == 0) { return Field{CONFIG::one}; }
#ifndef __CUDA_ARCH__
if (logn > CONFIG::omegas_count) THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid inv index");
#else
if (logn > CONFIG::omegas_count) {
printf(
"CUDA ERROR: field.cuh: error on inv_log_size(logn): logn(=%u) > omegas_count (=%u)", logn,
CONFIG::omegas_count);
assert(false);
}
#endif // __CUDA_ARCH__
storage_array<CONFIG::omegas_count, TLC> const inv = CONFIG::inv;
return Field{inv.storages[logn - 1]};
}

View File

@@ -143,11 +143,15 @@ public:
return res;
}
friend HOST_DEVICE_INLINE Projective operator*(const Projective& point, SCALAR_FF scalar) { return scalar * point; }
friend HOST_DEVICE_INLINE bool operator==(const Projective& p1, const Projective& p2)
{
return (p1.x * p2.z == p2.x * p1.z) && (p1.y * p2.z == p2.y * p1.z);
}
friend HOST_DEVICE_INLINE bool operator!=(const Projective& p1, const Projective& p2) { return !(p1 == p2); }
friend HOST_INLINE std::ostream& operator<<(std::ostream& os, const Projective& point)
{
os << "Point { x: " << point.x << "; y: " << point.y << "; z: " << point.z << " }";

View File

@@ -57,6 +57,8 @@ pub struct NTTConfig<'a, S> {
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
pub is_async: bool,
/// Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects radix-2 or mixed-radix algorithm based on heuristics).
pub is_force_radix2: bool,
}
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
@@ -70,6 +72,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
are_inputs_on_device: false,
are_outputs_on_device: false,
is_async: false,
is_force_radix2: false,
}
}
}