mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 15:37:58 -05:00
Mixed-radix NTT algorithm
Co-authored-by: hadaringonyama <hadar@ingonyama.com>
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -17,3 +17,5 @@
|
||||
**/icicle/build/
|
||||
**/wrappers/rust/icicle-cuda-runtime/src/bindings.rs
|
||||
**/build
|
||||
**/icicle/appUtils/large_ntt/work
|
||||
icicle/appUtils/large_ntt/work/test_ntt
|
||||
|
||||
@@ -5,28 +5,31 @@
|
||||
#define CURVE_ID 1
|
||||
// include NTT template
|
||||
#include "appUtils/ntt/ntt.cu"
|
||||
#include "appUtils/ntt/kernel_ntt.cu"
|
||||
using namespace curve_config;
|
||||
|
||||
// Operate on scalars
|
||||
typedef scalar_t S;
|
||||
typedef scalar_t E;
|
||||
|
||||
void print_elements(const unsigned n, E * elements ) {
|
||||
void print_elements(const unsigned n, E* elements)
|
||||
{
|
||||
for (unsigned i = 0; i < n; i++) {
|
||||
std::cout << i << ": " << elements[i] << std::endl;
|
||||
std::cout << i << ": " << elements[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E * elements ) {
|
||||
void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E* elements)
|
||||
{
|
||||
// Lowest Harmonics
|
||||
for (unsigned i = 0; i < ntt_size; i=i+1) {
|
||||
for (unsigned i = 0; i < ntt_size; i = i + 1) {
|
||||
elements[i] = E::one();
|
||||
}
|
||||
// print_elements(ntt_size, elements );
|
||||
// Highest Harmonics
|
||||
for (unsigned i = 1*ntt_size; i < 2*ntt_size; i=i+2) {
|
||||
elements[i] = E::one();
|
||||
elements[i+1] = E::neg(scalar_t::one());
|
||||
for (unsigned i = 1 * ntt_size; i < 2 * ntt_size; i = i + 2) {
|
||||
elements[i] = E::one();
|
||||
elements[i + 1] = E::neg(scalar_t::one());
|
||||
}
|
||||
// print_elements(ntt_size, &elements[1*ntt_size] );
|
||||
}
|
||||
@@ -34,7 +37,7 @@ void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E * elem
|
||||
int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* elements)
|
||||
{
|
||||
int nof_errors = 0;
|
||||
E amplitude = E::from((uint32_t) ntt_size);
|
||||
E amplitude = E::from((uint32_t)ntt_size);
|
||||
// std::cout << "Amplitude: " << amplitude << std::endl;
|
||||
// Lowest Harmonics
|
||||
if (elements[0] != amplitude) {
|
||||
@@ -44,8 +47,8 @@ int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* element
|
||||
} else {
|
||||
std::cout << "Validated lowest harmonics" << std::endl;
|
||||
}
|
||||
// Highest Harmonics
|
||||
if (elements[1*ntt_size+ntt_size/2] != amplitude) {
|
||||
// Highest Harmonics
|
||||
if (elements[1 * ntt_size + ntt_size / 2] != amplitude) {
|
||||
++nof_errors;
|
||||
std::cout << "Error in highest harmonics! " << std::endl;
|
||||
// print_elements(ntt_size, &elements[1*ntt_size] );
|
||||
@@ -66,24 +69,24 @@ int main(int argc, char* argv[])
|
||||
const unsigned nof_ntts = 2;
|
||||
std::cout << "Number of NTTs: " << nof_ntts << std::endl;
|
||||
const unsigned batch_size = nof_ntts * ntt_size;
|
||||
|
||||
|
||||
std::cout << "Generating input data for lowest and highest harmonics" << std::endl;
|
||||
E* input;
|
||||
input = (E*) malloc(sizeof(E) * batch_size);
|
||||
initialize_input(ntt_size, nof_ntts, input );
|
||||
input = (E*)malloc(sizeof(E) * batch_size);
|
||||
initialize_input(ntt_size, nof_ntts, input);
|
||||
E* output;
|
||||
output = (E*) malloc(sizeof(E) * batch_size);
|
||||
|
||||
output = (E*)malloc(sizeof(E) * batch_size);
|
||||
|
||||
std::cout << "Running NTT with on-host data" << std::endl;
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
// Create a device context
|
||||
auto ctx = device_context::get_default_device_context();
|
||||
// the next line is valid only for CURVE_ID 1 (will add support for other curves soon)
|
||||
S rou = S{ {0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1} };
|
||||
S rou = S{{0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1}};
|
||||
ntt::InitDomain(rou, ctx);
|
||||
// Create an NTTConfig instance
|
||||
ntt::NTTConfig<S> config=ntt::DefaultNTTConfig<S>();
|
||||
ntt::NTTConfig<S> config = ntt::DefaultNTTConfig<S>();
|
||||
config.batch_size = nof_ntts;
|
||||
config.ctx.stream = stream;
|
||||
auto begin0 = std::chrono::high_resolution_clock::now();
|
||||
@@ -91,7 +94,7 @@ int main(int argc, char* argv[])
|
||||
auto end0 = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed0 = std::chrono::duration_cast<std::chrono::nanoseconds>(end0 - begin0);
|
||||
printf("On-device runtime: %.3f seconds\n", elapsed0.count() * 1e-9);
|
||||
validate_output(ntt_size, nof_ntts, output );
|
||||
validate_output(ntt_size, nof_ntts, output);
|
||||
cudaStreamDestroy(stream);
|
||||
free(input);
|
||||
free(output);
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# Make sure NVIDIA Container Toolkit is installed on your host
|
||||
|
||||
# Use the specified base image
|
||||
FROM nvidia/cuda:12.0.0-devel-ubuntu22.04
|
||||
|
||||
# Update and install dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
cmake \
|
||||
curl \
|
||||
build-essential \
|
||||
git \
|
||||
libboost-all-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Clone Icicle from a GitHub repository
|
||||
RUN git clone https://github.com/ingonyama-zk/icicle.git /icicle
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /icicle-example
|
||||
|
||||
# Specify the default command for the container
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"name": "Icicle Examples: polynomial multiplication",
|
||||
"build": {
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"runArgs": [
|
||||
"--gpus",
|
||||
"all"
|
||||
],
|
||||
"postCreateCommand": [
|
||||
"nvidia-smi"
|
||||
],
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-vscode.cmake-tools",
|
||||
"ms-python.python",
|
||||
"ms-vscode.cpptools"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
26
examples/c++/polynomial_multiplication/CMakeLists.txt
Normal file
26
examples/c++/polynomial_multiplication/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD_REQUIRED TRUE)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
|
||||
if (${CMAKE_VERSION} VERSION_LESS "3.24.0")
|
||||
set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH})
|
||||
else()
|
||||
set(CMAKE_CUDA_ARCHITECTURES native) # on 3.24+, on earlier it is ignored, and the target is not passed
|
||||
endif ()
|
||||
project(icicle LANGUAGES CUDA CXX)
|
||||
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
||||
set(CMAKE_CUDA_FLAGS_RELEASE "")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
|
||||
# change the path to your Icicle location
|
||||
include_directories("../../../icicle")
|
||||
add_executable(
|
||||
example
|
||||
example.cu
|
||||
)
|
||||
|
||||
find_library(NVML_LIBRARY nvidia-ml PATHS /usr/local/cuda-12.0/targets/x86_64-linux/lib/stubs/ )
|
||||
target_link_libraries(example ${NVML_LIBRARY})
|
||||
set_target_properties(example PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
|
||||
|
||||
11
examples/c++/polynomial_multiplication/compile.sh
Executable file
11
examples/c++/polynomial_multiplication/compile.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Exit immediately on error
|
||||
set -e
|
||||
|
||||
rm -rf build
|
||||
mkdir -p build
|
||||
cmake -S . -B build
|
||||
cmake --build build
|
||||
|
||||
|
||||
114
examples/c++/polynomial_multiplication/example.cu
Normal file
114
examples/c++/polynomial_multiplication/example.cu
Normal file
@@ -0,0 +1,114 @@
|
||||
#define CURVE_ID BLS12_381
|
||||
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "curves/curve_config.cuh"
|
||||
#include "appUtils/ntt/ntt.cu"
|
||||
#include "appUtils/ntt/kernel_ntt.cu"
|
||||
#include "utils/vec_ops.cu"
|
||||
#include "utils/error_handler.cuh"
|
||||
#include <memory>
|
||||
|
||||
typedef curve_config::scalar_t test_scalar;
|
||||
typedef curve_config::scalar_t test_data;
|
||||
|
||||
void random_samples(test_data* res, uint32_t count)
|
||||
{
|
||||
for (int i = 0; i < count; i++)
|
||||
res[i] = i < 1000 ? test_data::rand_host() : res[i - 1000];
|
||||
}
|
||||
|
||||
void incremental_values(test_scalar* res, uint32_t count)
|
||||
{
|
||||
for (int i = 0; i < count; i++) {
|
||||
res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero();
|
||||
}
|
||||
}
|
||||
|
||||
// calcaulting polynomial multiplication A*B via NTT,pointwise-multiplication and INTT
|
||||
// (1) allocate A,B on CPU. Randomize first half, zero second half
|
||||
// (2) allocate NttAGpu, NttBGpu on GPU
|
||||
// (3) calc NTT for A and for B from cpu to GPU
|
||||
// (4) multiply MulGpu = NttAGpu * NttBGpu (pointwise)
|
||||
// (5) INTT MulGpu inplace
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaEvent_t start, stop;
|
||||
float measured_time;
|
||||
|
||||
int NTT_LOG_SIZE = 23;
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
|
||||
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context
|
||||
|
||||
// init domain
|
||||
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
|
||||
ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward
|
||||
ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false;
|
||||
|
||||
const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix";
|
||||
std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: ";
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&start));
|
||||
CHK_IF_RETURN(cudaEventCreate(&stop));
|
||||
|
||||
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
|
||||
ntt::InitDomain(basic_root, ntt_config.ctx);
|
||||
|
||||
// (1) cpu allocation
|
||||
auto CpuA = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
auto CpuB = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
random_samples(CpuA.get(), NTT_SIZE >> 1); // second half zeros
|
||||
random_samples(CpuB.get(), NTT_SIZE >> 1); // second half zeros
|
||||
|
||||
test_data *GpuA, *GpuB, *MulGpu;
|
||||
|
||||
auto benchmark = [&](bool print, int iterations = 1) {
|
||||
// start recording
|
||||
CHK_IF_RETURN(cudaEventRecord(start, ntt_config.ctx.stream));
|
||||
|
||||
for (int iter = 0; iter < iterations; ++iter) {
|
||||
// (2) gpu input allocation
|
||||
CHK_IF_RETURN(cudaMallocAsync(&GpuA, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&GpuB, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
|
||||
|
||||
// (3) NTT for A,B from cpu to gpu
|
||||
ntt_config.are_inputs_on_device = false;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA));
|
||||
CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB));
|
||||
|
||||
// (4) multiply A,B
|
||||
CHK_IF_RETURN(cudaMallocAsync(&MulGpu, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(
|
||||
vec_ops::Mul(GpuA, GpuB, NTT_SIZE, true /*=is_on_device*/, false /*=is_montgomery*/, ntt_config.ctx, MulGpu));
|
||||
|
||||
// (5) INTT (in place)
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu));
|
||||
|
||||
CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(GpuB, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(MulGpu, ntt_config.ctx.stream));
|
||||
}
|
||||
|
||||
CHK_IF_RETURN(cudaEventRecord(stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&measured_time, start, stop));
|
||||
|
||||
if (print) { std::cout << measured_time / iterations << " MS" << std::endl; }
|
||||
|
||||
return CHK_LAST();
|
||||
};
|
||||
|
||||
benchmark(false); // warmup
|
||||
benchmark(true, 20);
|
||||
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
|
||||
return 0;
|
||||
}
|
||||
3
examples/c++/polynomial_multiplication/run.sh
Executable file
3
examples/c++/polynomial_multiplication/run.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
./build/example 1 # radix2
|
||||
./build/example 0 # mixed-radix
|
||||
@@ -53,7 +53,7 @@ struct Args {
|
||||
lower_bound_log_size: u8,
|
||||
|
||||
/// Upper bound of MSM sizes to run for
|
||||
#[arg(short, long, default_value_t = 23)]
|
||||
#[arg(short, long, default_value_t = 22)]
|
||||
upper_bound_log_size: u8,
|
||||
}
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ if (NOT BUILD_TESTS)
|
||||
primitives/projective.cu
|
||||
appUtils/msm/msm.cu
|
||||
appUtils/ntt/ntt.cu
|
||||
appUtils/ntt/kernel_ntt.cu
|
||||
${ICICLE_SOURCES}
|
||||
)
|
||||
set_target_properties(icicle PROPERTIES OUTPUT_NAME "ingo_${CURVE}")
|
||||
|
||||
6
icicle/appUtils/ntt/Makefile
Normal file
6
icicle/appUtils/ntt/Makefile
Normal file
@@ -0,0 +1,6 @@
|
||||
build_verification:
|
||||
mkdir -p work
|
||||
nvcc -o work/test_verification -I. -I.. -I../.. -I../ntt tests/verification.cu -std=c++17
|
||||
|
||||
test_verification: build_verification
|
||||
work/test_verification
|
||||
640
icicle/appUtils/ntt/kernel_ntt.cu
Normal file
640
icicle/appUtils/ntt/kernel_ntt.cu
Normal file
@@ -0,0 +1,640 @@
|
||||
|
||||
#include "appUtils/ntt/thread_ntt.cu"
|
||||
#include "curves/curve_config.cuh"
|
||||
#include "utils/sharedmem.cuh"
|
||||
#include "appUtils/ntt/ntt.cuh" // for Ordering
|
||||
|
||||
namespace ntt {
|
||||
|
||||
static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit)
|
||||
{
|
||||
uint32_t rev_num = 0, temp, dig_len;
|
||||
if (dit) {
|
||||
for (int i = 4; i >= 0; i--) {
|
||||
dig_len = STAGE_SIZES_DEVICE[log_size][i];
|
||||
temp = num & ((1 << dig_len) - 1);
|
||||
num = num >> dig_len;
|
||||
rev_num = rev_num << dig_len;
|
||||
rev_num = rev_num | temp;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
dig_len = STAGE_SIZES_DEVICE[log_size][i];
|
||||
temp = num & ((1 << dig_len) - 1);
|
||||
num = num >> dig_len;
|
||||
rev_num = rev_num << dig_len;
|
||||
rev_num = rev_num | temp;
|
||||
}
|
||||
}
|
||||
return rev_num;
|
||||
}
|
||||
|
||||
// Note: the following reorder kernels are fused with normalization for INTT
|
||||
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
|
||||
static __global__ void
|
||||
reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
|
||||
{
|
||||
// launch N threads
|
||||
// each thread starts from one index and calculates the corresponding group
|
||||
// if its index is the smallest number in the group -> do the memory transformation
|
||||
// else --> do nothing
|
||||
|
||||
const uint32_t idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
uint32_t next_element = idx;
|
||||
uint32_t group[MAX_GROUP_SIZE];
|
||||
group[0] = idx;
|
||||
|
||||
uint32_t i = 1;
|
||||
for (; i < MAX_GROUP_SIZE;) {
|
||||
next_element = dig_rev(next_element, log_size, dit);
|
||||
if (next_element < idx) return; // not handling this group
|
||||
if (next_element == idx) break; // calculated whole group
|
||||
group[i++] = next_element;
|
||||
}
|
||||
|
||||
if (i == 1) { // single element in group --> nothing to do (except maybe normalize for INTT)
|
||||
if (is_normalize) { arr[idx] = arr[idx] * inverse_N; }
|
||||
return;
|
||||
}
|
||||
--i;
|
||||
// reaching here means I am handling this group
|
||||
const E last_element_in_group = arr[group[i]];
|
||||
for (; i > 0; --i) {
|
||||
arr[group[i]] = is_normalize ? (arr[group[i - 1]] * inverse_N) : arr[group[i - 1]];
|
||||
}
|
||||
arr[idx] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group;
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__
|
||||
void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N)
|
||||
{
|
||||
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr = dig_rev(tid, log_size, dit);
|
||||
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt64(
|
||||
E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
SharedMemory<E> smem;
|
||||
E* shmem = smem.getPointer();
|
||||
|
||||
s_meta.th_stride = 8;
|
||||
s_meta.ntt_block_size = 64;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride && dit) {
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.loadInternalTwiddles64(internal_twiddles, strided, inv);
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t phase = 0; phase < 2; phase++) {
|
||||
engine.ntt8win();
|
||||
if (phase == 0) {
|
||||
engine.SharedData64Columns8(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
engine.SharedData64Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
}
|
||||
}
|
||||
|
||||
if (twiddle_stride && !dit) {
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt32(
|
||||
E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
|
||||
SharedMemory<E> smem;
|
||||
E* shmem = smem.getPointer();
|
||||
|
||||
s_meta.th_stride = 4;
|
||||
s_meta.ntt_block_size = 32;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
|
||||
engine.ntt8win();
|
||||
engine.twiddlesInternal();
|
||||
engine.SharedData32Columns8(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
engine.SharedData32Rows4_2(shmem, false, false, strided); // load
|
||||
engine.ntt4_2();
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt32dit(
|
||||
E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
int32_t tw_log_size,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
|
||||
SharedMemory<E> smem;
|
||||
E* shmem = smem.getPointer();
|
||||
|
||||
s_meta.th_stride = 4;
|
||||
s_meta.ntt_block_size = 32;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided, inv);
|
||||
engine.ntt4_2();
|
||||
engine.SharedData32Columns4_2(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
engine.SharedData32Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
engine.ntt8win();
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt16(
|
||||
E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
|
||||
SharedMemory<E> smem;
|
||||
E* shmem = smem.getPointer();
|
||||
|
||||
s_meta.th_stride = 2;
|
||||
s_meta.ntt_block_size = 16;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
|
||||
engine.ntt8win();
|
||||
engine.twiddlesInternal();
|
||||
engine.SharedData16Columns8(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
engine.SharedData16Rows2_4(shmem, false, false, strided); // load
|
||||
engine.ntt2_4();
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt16dit(
|
||||
E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
uint32_t twiddle_stride,
|
||||
bool strided,
|
||||
uint32_t stage_num,
|
||||
bool inv,
|
||||
bool dit)
|
||||
{
|
||||
NTTEngine<E, S> engine;
|
||||
stage_metadata s_meta;
|
||||
|
||||
SharedMemory<E> smem;
|
||||
E* shmem = smem.getPointer();
|
||||
|
||||
s_meta.th_stride = 2;
|
||||
s_meta.ntt_block_size = 16;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
engine.loadBasicTwiddles(basic_twiddles, inv);
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided, inv);
|
||||
engine.ntt2_4();
|
||||
engine.SharedData16Columns2_4(shmem, true, false, strided); // store
|
||||
__syncthreads();
|
||||
engine.SharedData16Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
engine.ntt8win();
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void normalize_kernel(E* data, S norm_factor)
|
||||
{
|
||||
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
data[tid] = data[tid] * norm_factor;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void generate_base_table(S basic_root, S* base_table, uint32_t skip)
|
||||
{
|
||||
S w = basic_root;
|
||||
S t = S::one();
|
||||
for (int i = 0; i < 64; i += skip) {
|
||||
base_table[i] = t;
|
||||
t = t * w;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void generate_basic_twiddles(S basic_root, S* w6_table, S* basic_twiddles)
|
||||
{
|
||||
S w0 = basic_root * basic_root;
|
||||
S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1);
|
||||
S w2 = (basic_root - w0 * basic_root) * S::inv_log_size(1);
|
||||
basic_twiddles[0] = w0;
|
||||
basic_twiddles[1] = w1;
|
||||
basic_twiddles[2] = w2;
|
||||
S basic_inv = w6_table[64 - 8];
|
||||
w0 = basic_inv * basic_inv;
|
||||
w1 = (basic_inv + w0 * basic_inv) * S::inv_log_size(1);
|
||||
w2 = (basic_inv - w0 * basic_inv) * S::inv_log_size(1);
|
||||
basic_twiddles[3] = w0;
|
||||
basic_twiddles[4] = w1;
|
||||
basic_twiddles[5] = w2;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void generate_twiddle_combinations_generic(
|
||||
S* w6_table, S* w12_table, S* w18_table, S* w24_table, S* w30_table, S* external_twiddles, uint32_t log_size)
|
||||
{
|
||||
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
uint32_t exp = tid << (30 - log_size);
|
||||
S w6, w12, w18, w24, w30;
|
||||
w6 = w6_table[exp >> 24];
|
||||
w12 = w12_table[((exp >> 18) & 0x3f)];
|
||||
w18 = w18_table[((exp >> 12) & 0x3f)];
|
||||
w24 = w24_table[((exp >> 6) & 0x3f)];
|
||||
w30 = w30_table[(exp & 0x3f)];
|
||||
S t = w6 * w12 * w18 * w24 * w30;
|
||||
external_twiddles[tid] = t;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__global__ void set_value(S* arr, int idx, S val)
|
||||
{
|
||||
arr[idx] = val;
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
cudaError_t generate_external_twiddles_generic(
|
||||
const S& basic_root,
|
||||
S* external_twiddles,
|
||||
S*& internal_twiddles,
|
||||
S*& basic_twiddles,
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
const int n = pow(2, log_size);
|
||||
CHK_IF_RETURN(cudaMallocAsync(&basic_twiddles, 6 * sizeof(S), stream));
|
||||
|
||||
S* w6_table;
|
||||
S* w12_table;
|
||||
S* w18_table;
|
||||
S* w24_table;
|
||||
S* w30_table;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w6_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w12_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w18_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w24_table, sizeof(S) * 64, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&w30_table, sizeof(S) * 64, stream));
|
||||
|
||||
// Note: for compatibility with radix-2 INTT, need ONE in last element (in addition to first element)
|
||||
set_value<<<1, 1, 0, stream>>>(external_twiddles, n /*last element idx*/, S::one());
|
||||
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
S temp_root = basic_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(basic_root, w30_table, 1 << (30 - log_size));
|
||||
|
||||
if (log_size > 24)
|
||||
for (int i = 0; i < 6 - (30 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w24_table, 1 << (log_size > 24 ? 0 : 24 - log_size));
|
||||
|
||||
if (log_size > 18)
|
||||
for (int i = 0; i < 6 - (log_size > 24 ? 0 : 24 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w18_table, 1 << (log_size > 18 ? 0 : 18 - log_size));
|
||||
|
||||
if (log_size > 12)
|
||||
for (int i = 0; i < 6 - (log_size > 18 ? 0 : 18 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w12_table, 1 << (log_size > 12 ? 0 : 12 - log_size));
|
||||
|
||||
if (log_size > 6)
|
||||
for (int i = 0; i < 6 - (log_size > 12 ? 0 : 12 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_base_table<<<1, 1, 0, stream>>>(temp_root, w6_table, 1 << (log_size > 6 ? 0 : 6 - log_size));
|
||||
|
||||
if (log_size > 2)
|
||||
for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++)
|
||||
temp_root = temp_root * temp_root;
|
||||
generate_basic_twiddles<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles);
|
||||
|
||||
const int NOF_BLOCKS = (log_size >= 8) ? (1 << (log_size - 8)) : 1;
|
||||
const int NOF_THREADS = (log_size >= 8) ? 256 : (1 << log_size);
|
||||
generate_twiddle_combinations_generic<<<NOF_BLOCKS, NOF_THREADS, 0, stream>>>(
|
||||
w6_table, w12_table, w18_table, w24_table, w30_table, external_twiddles, log_size);
|
||||
|
||||
internal_twiddles = w6_table;
|
||||
|
||||
CHK_IF_RETURN(cudaFreeAsync(w12_table, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(w18_table, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(w24_table, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(w30_table, stream));
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t large_ntt(
|
||||
E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
bool inv,
|
||||
bool normalize,
|
||||
bool dit,
|
||||
cudaStream_t cuda_stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
if (log_size == 1 || log_size == 2 || log_size == 3 || log_size == 7) {
|
||||
throw IcicleError(IcicleError_t::InvalidArgument, "size not implemented for mixed-radix-NTT");
|
||||
}
|
||||
|
||||
if (log_size == 4) {
|
||||
if (dit) {
|
||||
ntt16dit<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
} else { // dif
|
||||
ntt16<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 5) {
|
||||
if (dit) {
|
||||
ntt32dit<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
} else { // dif
|
||||
ntt32<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 6) {
|
||||
ntt64<<<1, 8, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
if (normalize) normalize_kernel<<<1, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 8) {
|
||||
if (dit) {
|
||||
ntt16dit<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
ntt16dit<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1,
|
||||
inv,
|
||||
dit); // we need threads 32+ although 16-31 are idle
|
||||
} else { // dif
|
||||
ntt16<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1, inv,
|
||||
dit); // we need threads 32+ although 16-31 are idle
|
||||
ntt16<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv,
|
||||
dit);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
// general case:
|
||||
if (dit) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
|
||||
uint32_t stride_log = 0;
|
||||
for (int j = 0; j < i; j++)
|
||||
stride_log += STAGE_SIZES_HOST[log_size][j];
|
||||
if (stage_size == 6)
|
||||
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 5)
|
||||
ntt32dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 4)
|
||||
ntt16dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
}
|
||||
} else { // dif
|
||||
bool first_run = false, prev_stage = false;
|
||||
for (int i = 4; i >= 0; i--) {
|
||||
uint32_t stage_size = STAGE_SIZES_HOST[log_size][i];
|
||||
uint32_t stride_log = 0;
|
||||
for (int j = 0; j < i; j++)
|
||||
stride_log += STAGE_SIZES_HOST[log_size][j];
|
||||
first_run = stage_size && !prev_stage;
|
||||
if (stage_size == 6)
|
||||
ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 5)
|
||||
ntt32<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
else if (stage_size == 4)
|
||||
ntt16<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit);
|
||||
prev_stage = stage_size;
|
||||
}
|
||||
}
|
||||
if (normalize) normalize_kernel<<<1 << (log_size - 8), 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t mixed_radix_ntt(
|
||||
E* d_input,
|
||||
E* d_output,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
cudaStream_t cuda_stream)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
// TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse)
|
||||
if (ordering != Ordering::kNN) {
|
||||
throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only");
|
||||
}
|
||||
|
||||
const int logn = int(log2(ntt_size));
|
||||
|
||||
const int NOF_BLOCKS = (1 << (max(logn, 6) - 6));
|
||||
const int NOF_THREADS = min(64, 1 << logn);
|
||||
|
||||
const bool reverse_input = ordering == Ordering::kNN;
|
||||
const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN;
|
||||
bool is_normalize = is_inverse;
|
||||
|
||||
if (reverse_input) {
|
||||
// Note: fusing reorder with normalize for INTT
|
||||
const bool is_reverse_in_place = (d_input == d_output);
|
||||
if (is_reverse_in_place) {
|
||||
reorder_digits_inplace_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
|
||||
} else {
|
||||
reorder_digits_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
is_normalize = false;
|
||||
}
|
||||
|
||||
// inplace ntt
|
||||
CHK_IF_RETURN(large_ntt(
|
||||
d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, is_inverse,
|
||||
is_normalize, is_dit, cuda_stream));
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
// Explicit instantiation for scalar type
|
||||
template cudaError_t generate_external_twiddles_generic(
|
||||
const curve_config::scalar_t& basic_root,
|
||||
curve_config::scalar_t* external_twiddles,
|
||||
curve_config::scalar_t*& internal_twiddles,
|
||||
curve_config::scalar_t*& basic_twiddles,
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream);
|
||||
|
||||
template cudaError_t mixed_radix_ntt<curve_config::scalar_t, curve_config::scalar_t>(
|
||||
curve_config::scalar_t* d_input,
|
||||
curve_config::scalar_t* d_output,
|
||||
curve_config::scalar_t* external_twiddles,
|
||||
curve_config::scalar_t* internal_twiddles,
|
||||
curve_config::scalar_t* basic_twiddles,
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
} // namespace ntt
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "utils/sharedmem.cuh"
|
||||
#include "utils/utils_kernels.cuh"
|
||||
#include "utils/utils.h"
|
||||
#include "appUtils/ntt/ntt_impl.cuh"
|
||||
|
||||
namespace ntt {
|
||||
|
||||
@@ -25,7 +26,10 @@ namespace ntt {
|
||||
int idx = threadId % n;
|
||||
int batch_idx = threadId / n;
|
||||
int idx_reversed = __brev(idx) >> (32 - logn);
|
||||
arr_reversed[batch_idx * n + idx_reversed] = arr[batch_idx * n + idx];
|
||||
|
||||
E val = arr[batch_idx * n + idx];
|
||||
if (arr == arr_reversed) { __syncthreads(); } // for in-place (when pointers arr==arr_reversed)
|
||||
arr_reversed[batch_idx * n + idx_reversed] = val;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,25 +361,24 @@ namespace ntt {
|
||||
template <typename S>
|
||||
class Domain
|
||||
{
|
||||
static int max_size;
|
||||
static S* twiddles;
|
||||
static std::unordered_map<S, int> coset_index;
|
||||
static inline int max_size = 0;
|
||||
static inline int max_log_size = 0;
|
||||
static inline S* twiddles = nullptr;
|
||||
static inline std::unordered_map<S, int> coset_index = {};
|
||||
|
||||
static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT
|
||||
static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT
|
||||
|
||||
public:
|
||||
template <typename U>
|
||||
friend cudaError_t InitDomain<U>(U primitive_root, device_context::DeviceContext& ctx);
|
||||
|
||||
static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
|
||||
template <typename U, typename E>
|
||||
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
|
||||
};
|
||||
|
||||
template <typename S>
|
||||
int Domain<S>::max_size = 0;
|
||||
template <typename S>
|
||||
S* Domain<S>::twiddles = nullptr;
|
||||
template <typename S>
|
||||
std::unordered_map<S, int> Domain<S>::coset_index = {};
|
||||
|
||||
template <typename S>
|
||||
cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx)
|
||||
{
|
||||
@@ -385,34 +388,70 @@ namespace ntt {
|
||||
// please note that this is not thread-safe at all,
|
||||
// but it's a singleton that is supposed to be initialized once per program lifetime
|
||||
if (!Domain<S>::twiddles) {
|
||||
bool found_logn = false;
|
||||
S omega = primitive_root;
|
||||
unsigned omegas_count = S::get_omegas_count();
|
||||
for (int i = 0; i < omegas_count; i++)
|
||||
for (int i = 0; i < omegas_count; i++) {
|
||||
omega = S::sqr(omega);
|
||||
if (!found_logn) {
|
||||
++Domain<S>::max_log_size;
|
||||
found_logn = omega == S::one();
|
||||
if (found_logn) break;
|
||||
}
|
||||
}
|
||||
Domain<S>::max_size = (int)pow(2, Domain<S>::max_log_size);
|
||||
if (omega != S::one()) {
|
||||
std::cerr << "Primitive root provided to the InitDomain function is not in the subgroup" << '\n';
|
||||
throw -1;
|
||||
throw IcicleError(
|
||||
IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup");
|
||||
}
|
||||
|
||||
std::vector<S> h_twiddles;
|
||||
h_twiddles.push_back(S::one());
|
||||
int n = 1;
|
||||
do {
|
||||
Domain<S>::coset_index[h_twiddles.at(n - 1)] = n - 1;
|
||||
h_twiddles.push_back(h_twiddles.at(n - 1) * primitive_root);
|
||||
} while (h_twiddles.at(n++) != S::one());
|
||||
|
||||
CHK_IF_RETURN(cudaMallocAsync(&Domain<S>::twiddles, n * sizeof(S), ctx.stream));
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpyAsync(Domain<S>::twiddles, &h_twiddles.front(), n * sizeof(S), cudaMemcpyHostToDevice, ctx.stream));
|
||||
|
||||
Domain<S>::max_size = n - 1;
|
||||
// allocate and calculate twiddles on GPU
|
||||
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
|
||||
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
|
||||
CHK_IF_RETURN(cudaMallocManaged(&Domain<S>::twiddles, (Domain<S>::max_size + 1) * sizeof(S)));
|
||||
CHK_IF_RETURN(generate_external_twiddles_generic(
|
||||
primitive_root, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles,
|
||||
Domain<S>::max_log_size, ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream));
|
||||
|
||||
const bool is_map_only_powers_of_primitive_root = true;
|
||||
if (is_map_only_powers_of_primitive_root) {
|
||||
// populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8
|
||||
// etc.)
|
||||
Domain<S>::coset_index[S::one()] = 0;
|
||||
for (int i = 0; i < Domain<S>::max_log_size; ++i) {
|
||||
const int index = (int)pow(2, i);
|
||||
Domain<S>::coset_index[Domain<S>::twiddles[index]] = index;
|
||||
}
|
||||
} else {
|
||||
// populate all values
|
||||
for (int i = 0; i < Domain<S>::max_size; ++i) {
|
||||
Domain<S>::coset_index[Domain<S>::twiddles[i]] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
cudaError_t Domain<S>::ReleaseDomain(device_context::DeviceContext& ctx)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
max_size = 0;
|
||||
max_log_size = 0;
|
||||
cudaFreeAsync(twiddles, ctx.stream);
|
||||
twiddles = nullptr;
|
||||
cudaFreeAsync(internal_twiddles, ctx.stream);
|
||||
internal_twiddles = nullptr;
|
||||
cudaFreeAsync(basic_twiddles, ctx.stream);
|
||||
basic_twiddles = nullptr;
|
||||
coset_index.clear();
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename S, typename E>
|
||||
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
|
||||
{
|
||||
@@ -431,6 +470,20 @@ namespace ntt {
|
||||
bool are_inputs_on_device = config.are_inputs_on_device;
|
||||
bool are_outputs_on_device = config.are_outputs_on_device;
|
||||
|
||||
E* d_input;
|
||||
if (are_inputs_on_device) {
|
||||
d_input = input;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
E* d_output;
|
||||
if (are_outputs_on_device) {
|
||||
d_output = output;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
|
||||
}
|
||||
|
||||
S* coset = nullptr;
|
||||
int coset_index = 0;
|
||||
try {
|
||||
@@ -448,45 +501,44 @@ namespace ntt {
|
||||
h_coset.clear();
|
||||
}
|
||||
|
||||
E* d_input;
|
||||
if (are_inputs_on_device) {
|
||||
d_input = input;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
E* d_output;
|
||||
if (are_outputs_on_device) {
|
||||
d_output = output;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
|
||||
}
|
||||
const bool is_small_ntt = logn < 16; // cutoff point where mixed-radix is faster than radix-2
|
||||
const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet
|
||||
const bool is_batch_ntt = batch_size > 1; // batch not supported by mixed-radidx algorithm yet
|
||||
const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation
|
||||
const bool is_radix2_algorithm = config.is_force_radix2 || is_batch_ntt || is_small_ntt || is_on_coset || !is_NN;
|
||||
|
||||
bool ct_butterfly = true;
|
||||
bool reverse_input = false;
|
||||
switch (config.ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = true;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
reverse_input = true;
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
if (is_radix2_algorithm) {
|
||||
bool ct_butterfly = true;
|
||||
bool reverse_input = false;
|
||||
switch (config.ordering) {
|
||||
case Ordering::kNN:
|
||||
reverse_input = true;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
reverse_input = true;
|
||||
ct_butterfly = false;
|
||||
break;
|
||||
}
|
||||
|
||||
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
|
||||
|
||||
CHK_IF_RETURN(ntt_inplace_batch_template(
|
||||
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
|
||||
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
|
||||
|
||||
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
|
||||
} else { // mixed-radix algorithm
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, Domain<S>::twiddles, Domain<S>::internal_twiddles, Domain<S>::basic_twiddles, size,
|
||||
Domain<S>::max_log_size, dir == NTTDir::kInverse, config.ordering, stream));
|
||||
}
|
||||
|
||||
if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output);
|
||||
|
||||
CHK_IF_RETURN(ntt_inplace_batch_template(
|
||||
reverse_input ? d_output : d_input, size, Domain<S>::twiddles, Domain<S>::max_size, batch_size, logn,
|
||||
dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output));
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));
|
||||
|
||||
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
|
||||
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
|
||||
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
|
||||
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
|
||||
@@ -506,6 +558,7 @@ namespace ntt {
|
||||
false, // are_inputs_on_device
|
||||
false, // are_outputs_on_device
|
||||
false, // is_async
|
||||
false, // is_force_radix2
|
||||
};
|
||||
return config;
|
||||
}
|
||||
|
||||
@@ -80,6 +80,8 @@ namespace ntt {
|
||||
* non-blocking and you'd need to synchronize it explicitly by running
|
||||
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT
|
||||
* function will block the current CPU thread. */
|
||||
bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects
|
||||
radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
33
icicle/appUtils/ntt/ntt_impl.cuh
Normal file
33
icicle/appUtils/ntt/ntt_impl.cuh
Normal file
@@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
#ifndef _NTT_IMPL_H
|
||||
#define _NTT_IMPL_H
|
||||
|
||||
#include <stdint.h>
|
||||
#include "appUtils/ntt/ntt.cuh" // for enum Ordering
|
||||
|
||||
namespace ntt {
|
||||
|
||||
template <typename S>
|
||||
cudaError_t generate_external_twiddles_generic(
|
||||
const S& basic_root,
|
||||
S* external_twiddles,
|
||||
S*& internal_twiddles,
|
||||
S*& basic_twiddles,
|
||||
uint32_t log_size,
|
||||
cudaStream_t& stream);
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t mixed_radix_ntt(
|
||||
E* d_input,
|
||||
E* d_output,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
bool is_inverse,
|
||||
Ordering ordering,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
} // namespace ntt
|
||||
#endif //_NTT_IMPL_H
|
||||
155
icicle/appUtils/ntt/tests/verification.cu
Normal file
155
icicle/appUtils/ntt/tests/verification.cu
Normal file
@@ -0,0 +1,155 @@
|
||||
|
||||
#define CURVE_ID BLS12_381
|
||||
|
||||
#include "primitives/field.cuh"
|
||||
#include "primitives/projective.cuh"
|
||||
#include "utils/cuda_utils.cuh"
|
||||
#include <chrono>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "curves/curve_config.cuh"
|
||||
#include "ntt/ntt.cu"
|
||||
#include "ntt/ntt_impl.cuh"
|
||||
#include <memory>
|
||||
|
||||
typedef curve_config::scalar_t test_scalar;
|
||||
typedef curve_config::scalar_t test_data;
|
||||
#include "kernel_ntt.cu"
|
||||
|
||||
void random_samples(test_data* res, uint32_t count)
|
||||
{
|
||||
for (int i = 0; i < count; i++)
|
||||
res[i] = i < 1000 ? test_data::rand_host() : res[i - 1000];
|
||||
}
|
||||
|
||||
void incremental_values(test_scalar* res, uint32_t count)
|
||||
{
|
||||
for (int i = 0; i < count; i++) {
|
||||
res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
|
||||
float icicle_time, new_time;
|
||||
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : true;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : true;
|
||||
|
||||
const ntt::Ordering ordering = ntt::Ordering::kNN;
|
||||
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
|
||||
: ordering == ntt::Ordering::kNR ? "NR"
|
||||
: ordering == ntt::Ordering::kRN ? "RN"
|
||||
: "RR";
|
||||
|
||||
printf("running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV);
|
||||
|
||||
cudaFree(nullptr); // init GPU context (warmup)
|
||||
|
||||
// init domain
|
||||
auto ntt_config = ntt::DefaultNTTConfig<test_scalar>();
|
||||
ntt_config.ordering = ordering;
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
|
||||
CHK_IF_RETURN(cudaEventCreate(&new_start));
|
||||
CHK_IF_RETURN(cudaEventCreate(&new_stop));
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE);
|
||||
ntt::InitDomain(basic_root, ntt_config.ctx);
|
||||
auto stop = std::chrono::high_resolution_clock::now();
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count();
|
||||
std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl;
|
||||
|
||||
// cpu allocation
|
||||
auto CpuScalars = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
auto CpuOutputOld = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
auto CpuOutputNew = std::make_unique<test_data[]>(NTT_SIZE);
|
||||
|
||||
// gpu allocation
|
||||
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE));
|
||||
|
||||
// init inputs
|
||||
incremental_values(CpuScalars.get(), NTT_SIZE);
|
||||
CHK_IF_RETURN(cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE, cudaMemcpyHostToDevice));
|
||||
|
||||
// inplace
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// run ntt
|
||||
auto benchmark = [&](bool is_print, int iterations) -> cudaError_t {
|
||||
// NEW
|
||||
CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream));
|
||||
ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt)
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
ntt::NTT(
|
||||
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
|
||||
GpuOutputNew);
|
||||
}
|
||||
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop));
|
||||
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
|
||||
|
||||
// OLD
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
|
||||
ntt_config.is_force_radix2 = true;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld);
|
||||
}
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop));
|
||||
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
|
||||
|
||||
if (is_print) {
|
||||
printf("Old Runtime=%0.3f MS\n", icicle_time / iterations);
|
||||
printf("New Runtime=%0.3f MS\n", new_time / iterations);
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
};
|
||||
|
||||
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
|
||||
int count = INPLACE ? 1 : 10;
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
CHK_IF_RETURN(benchmark(true /*=print*/, count));
|
||||
|
||||
// verify
|
||||
CHK_IF_RETURN(cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
CHK_IF_RETURN(cudaMemcpy(CpuOutputOld.get(), GpuOutputOld, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
|
||||
bool success = true;
|
||||
for (int i = 0; i < NTT_SIZE; i++) {
|
||||
if (CpuOutputNew[i] != CpuOutputOld[i]) {
|
||||
success = false;
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
|
||||
break;
|
||||
} else {
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl;
|
||||
// break;
|
||||
}
|
||||
}
|
||||
const char* success_str = success ? "SUCCESS!" : "FAIL!";
|
||||
printf("%s\n", success_str);
|
||||
|
||||
CHK_IF_RETURN(cudaFree(GpuScalars));
|
||||
CHK_IF_RETURN(cudaFree(GpuOutputOld));
|
||||
CHK_IF_RETURN(cudaFree(GpuOutputNew));
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
542
icicle/appUtils/ntt/thread_ntt.cu
Normal file
542
icicle/appUtils/ntt/thread_ntt.cu
Normal file
@@ -0,0 +1,542 @@
|
||||
#ifndef T_NTT
|
||||
#define T_NTT
|
||||
#pragma once
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include "curves/curve_config.cuh"
|
||||
|
||||
struct stage_metadata {
|
||||
uint32_t th_stride;
|
||||
uint32_t ntt_block_size;
|
||||
uint32_t ntt_block_id;
|
||||
uint32_t ntt_inp_id;
|
||||
};
|
||||
|
||||
uint32_t constexpr STAGE_SIZES_HOST[31][5] = {
|
||||
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
|
||||
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
|
||||
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
|
||||
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
|
||||
|
||||
__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = {
|
||||
{0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0},
|
||||
{4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0},
|
||||
{6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6},
|
||||
{6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}};
|
||||
|
||||
template <typename E, typename S>
|
||||
class NTTEngine
|
||||
{
|
||||
public:
|
||||
E X[8];
|
||||
S WB[3];
|
||||
S WI[7];
|
||||
S WE[8];
|
||||
|
||||
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 3; i++) {
|
||||
WB[i] = basic_twiddles[inv ? i + 3 : i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
uint32_t exp = ((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1);
|
||||
WI[i] = data[(inv && exp) ? 64 - exp : exp]; // if exp = 0 we also take exp and not 64-exp
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
uint32_t exp = 2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1);
|
||||
WI[i] = data[(inv && exp) ? 64 - exp : exp];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 7; i++) {
|
||||
uint32_t exp = 4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1);
|
||||
WI[i] = data[(inv && exp) ? 64 - exp : exp];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric64(
|
||||
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
uint32_t exp = (s_meta.ntt_inp_id + 8 * i) * (s_meta.ntt_block_id & (tw_order - 1))
|
||||
<< (tw_log_size - tw_log_order - 6);
|
||||
WE[i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric32(
|
||||
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
uint32_t exp = (s_meta.ntt_inp_id * 2 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1))
|
||||
<< (tw_log_size - tw_log_order - 5);
|
||||
WE[4 * j + i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric16(
|
||||
E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
uint32_t exp = (s_meta.ntt_inp_id * 4 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1))
|
||||
<< (tw_log_size - tw_log_order - 4);
|
||||
WE[2 * j + i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
X[i] = data[s_meta.th_stride * i * data_stride];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
data[s_meta.th_stride * i * data_stride] = X[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData32(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
X[4 * j + i] = data[(8 * i + j) * data_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData32(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
data[(8 * i + j) * data_stride] = X[4 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData16(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
X[2 * j + i] = data[(8 * i + j) * data_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData16(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size;
|
||||
} else {
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
data[(8 * i + j) * data_stride] = X[2 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt4_2()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; i++) {
|
||||
ntt4(X[4 * i], X[4 * i + 1], X[4 * i + 2], X[4 * i + 3]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt2_4()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
ntt2(X[2 * i], X[2 * i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt2(E& X0, E& X1)
|
||||
{
|
||||
E T;
|
||||
|
||||
T = X0 + X1;
|
||||
X1 = X0 - X1;
|
||||
X0 = T;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt4(E& X0, E& X1, E& X2, E& X3)
|
||||
{
|
||||
E T;
|
||||
|
||||
T = X0 + X2;
|
||||
X2 = X0 - X2;
|
||||
X0 = X1 + X3;
|
||||
X1 = X1 - X3; // T has X0, X0 has X1, X2 has X2, X1 has X3
|
||||
|
||||
X1 = X1 * WB[0];
|
||||
|
||||
X3 = X2 - X1;
|
||||
X1 = X2 + X1;
|
||||
X2 = T - X0;
|
||||
X0 = T + X0;
|
||||
}
|
||||
|
||||
// rbo version
|
||||
__device__ __forceinline__ void ntt4rbo(E& X0, E& X1, E& X2, E& X3)
|
||||
{
|
||||
E T;
|
||||
|
||||
T = X0 - X1;
|
||||
X0 = X0 + X1;
|
||||
X1 = X2 + X3;
|
||||
X3 = X2 - X3; // T has X0, X0 has X1, X2 has X2, X1 has X3
|
||||
|
||||
X3 = X3 * WB[0];
|
||||
|
||||
X2 = X0 - X1;
|
||||
X0 = X0 + X1;
|
||||
X1 = T + X3;
|
||||
X3 = T - X3;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt8(E& X0, E& X1, E& X2, E& X3, E& X4, E& X5, E& X6, E& X7)
|
||||
{
|
||||
E T;
|
||||
|
||||
// out of 56,623,104 possible mappings, we have:
|
||||
T = X3 - X7;
|
||||
X7 = X3 + X7;
|
||||
X3 = X1 - X5;
|
||||
X5 = X1 + X5;
|
||||
X1 = X2 + X6;
|
||||
X2 = X2 - X6;
|
||||
X6 = X0 + X4;
|
||||
X0 = X0 - X4;
|
||||
|
||||
T = T * WB[1];
|
||||
X2 = X2 * WB[1];
|
||||
|
||||
X4 = X6 + X1;
|
||||
X6 = X6 - X1;
|
||||
X1 = X3 + T;
|
||||
X3 = X3 - T;
|
||||
T = X5 + X7;
|
||||
X5 = X5 - X7;
|
||||
X7 = X0 + X2;
|
||||
X0 = X0 - X2;
|
||||
|
||||
X1 = X1 * WB[0];
|
||||
X5 = X5 * WB[1];
|
||||
X3 = X3 * WB[2];
|
||||
|
||||
X2 = X6 + X5;
|
||||
X6 = X6 - X5;
|
||||
X5 = X7 - X1;
|
||||
X1 = X7 + X1;
|
||||
X7 = X0 - X3;
|
||||
X3 = X0 + X3;
|
||||
X0 = X4 + T;
|
||||
X4 = X4 - T;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt8win()
|
||||
{
|
||||
E T;
|
||||
|
||||
T = X[3] - X[7];
|
||||
X[7] = X[3] + X[7];
|
||||
X[3] = X[1] - X[5];
|
||||
X[5] = X[1] + X[5];
|
||||
X[1] = X[2] + X[6];
|
||||
X[2] = X[2] - X[6];
|
||||
X[6] = X[0] + X[4];
|
||||
X[0] = X[0] - X[4];
|
||||
|
||||
X[2] = X[2] * WB[0];
|
||||
|
||||
X[4] = X[6] + X[1];
|
||||
X[6] = X[6] - X[1];
|
||||
X[1] = X[3] + T;
|
||||
X[3] = X[3] - T;
|
||||
T = X[5] + X[7];
|
||||
X[5] = X[5] - X[7];
|
||||
X[7] = X[0] + X[2];
|
||||
X[0] = X[0] - X[2];
|
||||
|
||||
X[1] = X[1] * WB[1];
|
||||
X[5] = X[5] * WB[0];
|
||||
X[3] = X[3] * WB[2];
|
||||
|
||||
X[2] = X[6] + X[5];
|
||||
X[6] = X[6] - X[5];
|
||||
|
||||
X[5] = X[1] + X[3];
|
||||
X[3] = X[1] - X[3];
|
||||
|
||||
X[1] = X[7] + X[5];
|
||||
X[5] = X[7] - X[5];
|
||||
X[7] = X[0] - X[3];
|
||||
X[3] = X[0] + X[3];
|
||||
X[0] = X[4] + T;
|
||||
X[4] = X[4] - T;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3;
|
||||
uint32_t column_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 64 + i * 8 + column_id] = X[i];
|
||||
} else {
|
||||
X[i] = shmem[ntt_id * 64 + i * 8 + column_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData64Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3;
|
||||
uint32_t row_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 64 + row_id * 8 + i] = X[i];
|
||||
} else {
|
||||
X[i] = shmem[ntt_id * 64 + row_id * 8 + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t column_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + i * 4 + column_id] = X[i];
|
||||
} else {
|
||||
X[i] = shmem[ntt_id * 32 + i * 4 + column_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t row_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + row_id * 8 + i] = X[i];
|
||||
} else {
|
||||
X[i] = shmem[ntt_id * 32 + row_id * 8 + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Columns4_2(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t column_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + i * 8 + column_id + j] = X[4 * j + i];
|
||||
} else {
|
||||
X[4 * j + i] = shmem[ntt_id * 32 + i * 8 + column_id + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Rows4_2(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t row_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + row_id * 4 + 4 * j + i] = X[4 * j + i];
|
||||
} else {
|
||||
X[4 * j + i] = shmem[ntt_id * 32 + row_id * 4 + 4 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t column_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + i * 2 + column_id] = X[i];
|
||||
} else {
|
||||
X[i] = shmem[ntt_id * 16 + i * 2 + column_id];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t row_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + row_id * 8 + i] = X[i];
|
||||
} else {
|
||||
X[i] = shmem[ntt_id * 16 + row_id * 8 + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Columns2_4(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t column_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + i * 8 + column_id + j] = X[2 * j + i];
|
||||
} else {
|
||||
X[2 * j + i] = shmem[ntt_id * 16 + i * 8 + column_id + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Rows2_4(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t row_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + row_id * 2 + 2 * j + i] = X[2 * j + i];
|
||||
} else {
|
||||
X[2 * j + i] = shmem[ntt_id * 16 + row_id * 2 + 2 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void twiddlesInternal()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 1; i < 8; i++) {
|
||||
X[i] = X[i] * WI[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void twiddlesExternal()
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
X[i] = X[i] * WE[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -75,11 +75,19 @@ public:
|
||||
return Field{omega_inv.storages[logn - 1]};
|
||||
}
|
||||
|
||||
static HOST_INLINE Field inv_log_size(uint32_t logn)
|
||||
static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn)
|
||||
{
|
||||
if (logn == 0) { return Field{CONFIG::one}; }
|
||||
|
||||
#ifndef __CUDA_ARCH__
|
||||
if (logn > CONFIG::omegas_count) THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid inv index");
|
||||
#else
|
||||
if (logn > CONFIG::omegas_count) {
|
||||
printf(
|
||||
"CUDA ERROR: field.cuh: error on inv_log_size(logn): logn(=%u) > omegas_count (=%u)", logn,
|
||||
CONFIG::omegas_count);
|
||||
assert(false);
|
||||
}
|
||||
#endif // __CUDA_ARCH__
|
||||
storage_array<CONFIG::omegas_count, TLC> const inv = CONFIG::inv;
|
||||
return Field{inv.storages[logn - 1]};
|
||||
}
|
||||
|
||||
@@ -143,11 +143,15 @@ public:
|
||||
return res;
|
||||
}
|
||||
|
||||
friend HOST_DEVICE_INLINE Projective operator*(const Projective& point, SCALAR_FF scalar) { return scalar * point; }
|
||||
|
||||
friend HOST_DEVICE_INLINE bool operator==(const Projective& p1, const Projective& p2)
|
||||
{
|
||||
return (p1.x * p2.z == p2.x * p1.z) && (p1.y * p2.z == p2.y * p1.z);
|
||||
}
|
||||
|
||||
friend HOST_DEVICE_INLINE bool operator!=(const Projective& p1, const Projective& p2) { return !(p1 == p2); }
|
||||
|
||||
friend HOST_INLINE std::ostream& operator<<(std::ostream& os, const Projective& point)
|
||||
{
|
||||
os << "Point { x: " << point.x << "; y: " << point.y << "; z: " << point.z << " }";
|
||||
|
||||
@@ -57,6 +57,8 @@ pub struct NTTConfig<'a, S> {
|
||||
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
|
||||
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
|
||||
pub is_async: bool,
|
||||
/// Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects radix-2 or mixed-radix algorithm based on heuristics).
|
||||
pub is_force_radix2: bool,
|
||||
}
|
||||
|
||||
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
@@ -70,6 +72,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> {
|
||||
are_inputs_on_device: false,
|
||||
are_outputs_on_device: false,
|
||||
is_async: false,
|
||||
is_force_radix2: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user