From 382bec4ad38599198a7b548192a69759f4611f26 Mon Sep 17 00:00:00 2001 From: yshekel Date: Thu, 8 Feb 2024 13:52:00 +0200 Subject: [PATCH] Mixed-radix NTT algorithm Co-authored-by: hadaringonyama --- .gitignore | 2 + examples/c++/ntt/example.cu | 39 +- .../.devcontainer/Dockerfile | 25 + .../.devcontainer/devcontainer.json | 22 + .../polynomial_multiplication/CMakeLists.txt | 26 + .../c++/polynomial_multiplication/compile.sh | 11 + .../c++/polynomial_multiplication/example.cu | 114 ++++ examples/c++/polynomial_multiplication/run.sh | 3 + examples/rust/msm/src/main.rs | 2 +- icicle/CMakeLists.txt | 1 + icicle/appUtils/ntt/Makefile | 6 + icicle/appUtils/ntt/kernel_ntt.cu | 640 ++++++++++++++++++ icicle/appUtils/ntt/ntt.cu | 173 +++-- icicle/appUtils/ntt/ntt.cuh | 2 + icicle/appUtils/ntt/ntt_impl.cuh | 33 + icicle/appUtils/ntt/tests/verification.cu | 155 +++++ icicle/appUtils/ntt/thread_ntt.cu | 542 +++++++++++++++ icicle/primitives/field.cuh | 12 +- icicle/primitives/projective.cuh | 4 + wrappers/rust/icicle-core/src/ntt/mod.rs | 3 + 20 files changed, 1734 insertions(+), 81 deletions(-) create mode 100644 examples/c++/polynomial_multiplication/.devcontainer/Dockerfile create mode 100644 examples/c++/polynomial_multiplication/.devcontainer/devcontainer.json create mode 100644 examples/c++/polynomial_multiplication/CMakeLists.txt create mode 100755 examples/c++/polynomial_multiplication/compile.sh create mode 100644 examples/c++/polynomial_multiplication/example.cu create mode 100755 examples/c++/polynomial_multiplication/run.sh create mode 100644 icicle/appUtils/ntt/Makefile create mode 100644 icicle/appUtils/ntt/kernel_ntt.cu create mode 100644 icicle/appUtils/ntt/ntt_impl.cuh create mode 100644 icicle/appUtils/ntt/tests/verification.cu create mode 100644 icicle/appUtils/ntt/thread_ntt.cu diff --git a/.gitignore b/.gitignore index 244bd7c0..561bd946 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ **/icicle/build/ **/wrappers/rust/icicle-cuda-runtime/src/bindings.rs **/build +**/icicle/appUtils/large_ntt/work +icicle/appUtils/large_ntt/work/test_ntt diff --git a/examples/c++/ntt/example.cu b/examples/c++/ntt/example.cu index 30e5be50..b72cc555 100644 --- a/examples/c++/ntt/example.cu +++ b/examples/c++/ntt/example.cu @@ -5,28 +5,31 @@ #define CURVE_ID 1 // include NTT template #include "appUtils/ntt/ntt.cu" +#include "appUtils/ntt/kernel_ntt.cu" using namespace curve_config; // Operate on scalars typedef scalar_t S; typedef scalar_t E; -void print_elements(const unsigned n, E * elements ) { +void print_elements(const unsigned n, E* elements) +{ for (unsigned i = 0; i < n; i++) { - std::cout << i << ": " << elements[i] << std::endl; + std::cout << i << ": " << elements[i] << std::endl; } } -void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E * elements ) { +void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E* elements) +{ // Lowest Harmonics - for (unsigned i = 0; i < ntt_size; i=i+1) { + for (unsigned i = 0; i < ntt_size; i = i + 1) { elements[i] = E::one(); } // print_elements(ntt_size, elements ); // Highest Harmonics - for (unsigned i = 1*ntt_size; i < 2*ntt_size; i=i+2) { - elements[i] = E::one(); - elements[i+1] = E::neg(scalar_t::one()); + for (unsigned i = 1 * ntt_size; i < 2 * ntt_size; i = i + 2) { + elements[i] = E::one(); + elements[i + 1] = E::neg(scalar_t::one()); } // print_elements(ntt_size, &elements[1*ntt_size] ); } @@ -34,7 +37,7 @@ void initialize_input(const unsigned ntt_size, const unsigned nof_ntts, E * elem int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* elements) { int nof_errors = 0; - E amplitude = E::from((uint32_t) ntt_size); + E amplitude = E::from((uint32_t)ntt_size); // std::cout << "Amplitude: " << amplitude << std::endl; // Lowest Harmonics if (elements[0] != amplitude) { @@ -44,8 +47,8 @@ int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* element } else { std::cout << "Validated lowest harmonics" << std::endl; } - // Highest Harmonics - if (elements[1*ntt_size+ntt_size/2] != amplitude) { + // Highest Harmonics + if (elements[1 * ntt_size + ntt_size / 2] != amplitude) { ++nof_errors; std::cout << "Error in highest harmonics! " << std::endl; // print_elements(ntt_size, &elements[1*ntt_size] ); @@ -66,24 +69,24 @@ int main(int argc, char* argv[]) const unsigned nof_ntts = 2; std::cout << "Number of NTTs: " << nof_ntts << std::endl; const unsigned batch_size = nof_ntts * ntt_size; - + std::cout << "Generating input data for lowest and highest harmonics" << std::endl; E* input; - input = (E*) malloc(sizeof(E) * batch_size); - initialize_input(ntt_size, nof_ntts, input ); + input = (E*)malloc(sizeof(E) * batch_size); + initialize_input(ntt_size, nof_ntts, input); E* output; - output = (E*) malloc(sizeof(E) * batch_size); - + output = (E*)malloc(sizeof(E) * batch_size); + std::cout << "Running NTT with on-host data" << std::endl; cudaStream_t stream; cudaStreamCreate(&stream); // Create a device context auto ctx = device_context::get_default_device_context(); // the next line is valid only for CURVE_ID 1 (will add support for other curves soon) - S rou = S{ {0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1} }; + S rou = S{{0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1}}; ntt::InitDomain(rou, ctx); // Create an NTTConfig instance - ntt::NTTConfig config=ntt::DefaultNTTConfig(); + ntt::NTTConfig config = ntt::DefaultNTTConfig(); config.batch_size = nof_ntts; config.ctx.stream = stream; auto begin0 = std::chrono::high_resolution_clock::now(); @@ -91,7 +94,7 @@ int main(int argc, char* argv[]) auto end0 = std::chrono::high_resolution_clock::now(); auto elapsed0 = std::chrono::duration_cast(end0 - begin0); printf("On-device runtime: %.3f seconds\n", elapsed0.count() * 1e-9); - validate_output(ntt_size, nof_ntts, output ); + validate_output(ntt_size, nof_ntts, output); cudaStreamDestroy(stream); free(input); free(output); diff --git a/examples/c++/polynomial_multiplication/.devcontainer/Dockerfile b/examples/c++/polynomial_multiplication/.devcontainer/Dockerfile new file mode 100644 index 00000000..64188da9 --- /dev/null +++ b/examples/c++/polynomial_multiplication/.devcontainer/Dockerfile @@ -0,0 +1,25 @@ +# Make sure NVIDIA Container Toolkit is installed on your host + +# Use the specified base image +FROM nvidia/cuda:12.0.0-devel-ubuntu22.04 + +# Update and install dependencies +RUN apt-get update && apt-get install -y \ + cmake \ + curl \ + build-essential \ + git \ + libboost-all-dev \ + && rm -rf /var/lib/apt/lists/* + +# Clone Icicle from a GitHub repository +RUN git clone https://github.com/ingonyama-zk/icicle.git /icicle + +# Set the working directory in the container +WORKDIR /icicle-example + +# Specify the default command for the container +CMD ["/bin/bash"] + + + diff --git a/examples/c++/polynomial_multiplication/.devcontainer/devcontainer.json b/examples/c++/polynomial_multiplication/.devcontainer/devcontainer.json new file mode 100644 index 00000000..490fe90a --- /dev/null +++ b/examples/c++/polynomial_multiplication/.devcontainer/devcontainer.json @@ -0,0 +1,22 @@ +{ + "name": "Icicle Examples: polynomial multiplication", + "build": { + "dockerfile": "Dockerfile" + }, + "runArgs": [ + "--gpus", + "all" + ], + "postCreateCommand": [ + "nvidia-smi" + ], + "customizations": { + "vscode": { + "extensions": [ + "ms-vscode.cmake-tools", + "ms-python.python", + "ms-vscode.cpptools" + ] + } + } +} \ No newline at end of file diff --git a/examples/c++/polynomial_multiplication/CMakeLists.txt b/examples/c++/polynomial_multiplication/CMakeLists.txt new file mode 100644 index 00000000..03873910 --- /dev/null +++ b/examples/c++/polynomial_multiplication/CMakeLists.txt @@ -0,0 +1,26 @@ +cmake_minimum_required(VERSION 3.18) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED TRUE) +set(CMAKE_CXX_STANDARD_REQUIRED TRUE) +if (${CMAKE_VERSION} VERSION_LESS "3.24.0") + set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH}) +else() + set(CMAKE_CUDA_ARCHITECTURES native) # on 3.24+, on earlier it is ignored, and the target is not passed +endif () +project(icicle LANGUAGES CUDA CXX) + +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS_RELEASE "") +set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0") +# change the path to your Icicle location +include_directories("../../../icicle") +add_executable( + example + example.cu +) + +find_library(NVML_LIBRARY nvidia-ml PATHS /usr/local/cuda-12.0/targets/x86_64-linux/lib/stubs/ ) +target_link_libraries(example ${NVML_LIBRARY}) +set_target_properties(example PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + diff --git a/examples/c++/polynomial_multiplication/compile.sh b/examples/c++/polynomial_multiplication/compile.sh new file mode 100755 index 00000000..a7ba1621 --- /dev/null +++ b/examples/c++/polynomial_multiplication/compile.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Exit immediately on error +set -e + +rm -rf build +mkdir -p build +cmake -S . -B build +cmake --build build + + diff --git a/examples/c++/polynomial_multiplication/example.cu b/examples/c++/polynomial_multiplication/example.cu new file mode 100644 index 00000000..7100fee9 --- /dev/null +++ b/examples/c++/polynomial_multiplication/example.cu @@ -0,0 +1,114 @@ +#define CURVE_ID BLS12_381 + +#include +#include +#include + +#include "curves/curve_config.cuh" +#include "appUtils/ntt/ntt.cu" +#include "appUtils/ntt/kernel_ntt.cu" +#include "utils/vec_ops.cu" +#include "utils/error_handler.cuh" +#include + +typedef curve_config::scalar_t test_scalar; +typedef curve_config::scalar_t test_data; + +void random_samples(test_data* res, uint32_t count) +{ + for (int i = 0; i < count; i++) + res[i] = i < 1000 ? test_data::rand_host() : res[i - 1000]; +} + +void incremental_values(test_scalar* res, uint32_t count) +{ + for (int i = 0; i < count; i++) { + res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero(); + } +} + +// calcaulting polynomial multiplication A*B via NTT,pointwise-multiplication and INTT +// (1) allocate A,B on CPU. Randomize first half, zero second half +// (2) allocate NttAGpu, NttBGpu on GPU +// (3) calc NTT for A and for B from cpu to GPU +// (4) multiply MulGpu = NttAGpu * NttBGpu (pointwise) +// (5) INTT MulGpu inplace + +int main(int argc, char** argv) +{ + cudaEvent_t start, stop; + float measured_time; + + int NTT_LOG_SIZE = 23; + int NTT_SIZE = 1 << NTT_LOG_SIZE; + + CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context + + // init domain + auto ntt_config = ntt::DefaultNTTConfig(); + ntt_config.ordering = ntt::Ordering::kNN; // TODO: use NR for forward and RN for backward + ntt_config.is_force_radix2 = (argc > 1) ? atoi(argv[1]) : false; + + const char* ntt_alg_str = ntt_config.is_force_radix2 ? "Radix-2" : "Mixed-Radix"; + std::cout << "Polynomial multiplication with " << ntt_alg_str << " NTT: "; + + CHK_IF_RETURN(cudaEventCreate(&start)); + CHK_IF_RETURN(cudaEventCreate(&stop)); + + const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE); + ntt::InitDomain(basic_root, ntt_config.ctx); + + // (1) cpu allocation + auto CpuA = std::make_unique(NTT_SIZE); + auto CpuB = std::make_unique(NTT_SIZE); + random_samples(CpuA.get(), NTT_SIZE >> 1); // second half zeros + random_samples(CpuB.get(), NTT_SIZE >> 1); // second half zeros + + test_data *GpuA, *GpuB, *MulGpu; + + auto benchmark = [&](bool print, int iterations = 1) { + // start recording + CHK_IF_RETURN(cudaEventRecord(start, ntt_config.ctx.stream)); + + for (int iter = 0; iter < iterations; ++iter) { + // (2) gpu input allocation + CHK_IF_RETURN(cudaMallocAsync(&GpuA, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaMallocAsync(&GpuB, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream)); + + // (3) NTT for A,B from cpu to gpu + ntt_config.are_inputs_on_device = false; + ntt_config.are_outputs_on_device = true; + CHK_IF_RETURN(ntt::NTT(CpuA.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuA)); + CHK_IF_RETURN(ntt::NTT(CpuB.get(), NTT_SIZE, ntt::NTTDir::kForward, ntt_config, GpuB)); + + // (4) multiply A,B + CHK_IF_RETURN(cudaMallocAsync(&MulGpu, sizeof(test_data) * NTT_SIZE, ntt_config.ctx.stream)); + CHK_IF_RETURN( + vec_ops::Mul(GpuA, GpuB, NTT_SIZE, true /*=is_on_device*/, false /*=is_montgomery*/, ntt_config.ctx, MulGpu)); + + // (5) INTT (in place) + ntt_config.are_inputs_on_device = true; + ntt_config.are_outputs_on_device = true; + CHK_IF_RETURN(ntt::NTT(MulGpu, NTT_SIZE, ntt::NTTDir::kInverse, ntt_config, MulGpu)); + + CHK_IF_RETURN(cudaFreeAsync(GpuA, ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaFreeAsync(GpuB, ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaFreeAsync(MulGpu, ntt_config.ctx.stream)); + } + + CHK_IF_RETURN(cudaEventRecord(stop, ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaEventElapsedTime(&measured_time, start, stop)); + + if (print) { std::cout << measured_time / iterations << " MS" << std::endl; } + + return CHK_LAST(); + }; + + benchmark(false); // warmup + benchmark(true, 20); + + CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); + + return 0; +} \ No newline at end of file diff --git a/examples/c++/polynomial_multiplication/run.sh b/examples/c++/polynomial_multiplication/run.sh new file mode 100755 index 00000000..b57c9a02 --- /dev/null +++ b/examples/c++/polynomial_multiplication/run.sh @@ -0,0 +1,3 @@ +#!/bin/bash +./build/example 1 # radix2 +./build/example 0 # mixed-radix diff --git a/examples/rust/msm/src/main.rs b/examples/rust/msm/src/main.rs index 6f142642..219ea69d 100644 --- a/examples/rust/msm/src/main.rs +++ b/examples/rust/msm/src/main.rs @@ -53,7 +53,7 @@ struct Args { lower_bound_log_size: u8, /// Upper bound of MSM sizes to run for - #[arg(short, long, default_value_t = 23)] + #[arg(short, long, default_value_t = 22)] upper_bound_log_size: u8, } diff --git a/icicle/CMakeLists.txt b/icicle/CMakeLists.txt index 56204f48..6e6e3f31 100644 --- a/icicle/CMakeLists.txt +++ b/icicle/CMakeLists.txt @@ -108,6 +108,7 @@ if (NOT BUILD_TESTS) primitives/projective.cu appUtils/msm/msm.cu appUtils/ntt/ntt.cu + appUtils/ntt/kernel_ntt.cu ${ICICLE_SOURCES} ) set_target_properties(icicle PROPERTIES OUTPUT_NAME "ingo_${CURVE}") diff --git a/icicle/appUtils/ntt/Makefile b/icicle/appUtils/ntt/Makefile new file mode 100644 index 00000000..3aadb882 --- /dev/null +++ b/icicle/appUtils/ntt/Makefile @@ -0,0 +1,6 @@ +build_verification: + mkdir -p work + nvcc -o work/test_verification -I. -I.. -I../.. -I../ntt tests/verification.cu -std=c++17 + +test_verification: build_verification + work/test_verification diff --git a/icicle/appUtils/ntt/kernel_ntt.cu b/icicle/appUtils/ntt/kernel_ntt.cu new file mode 100644 index 00000000..abbc8b07 --- /dev/null +++ b/icicle/appUtils/ntt/kernel_ntt.cu @@ -0,0 +1,640 @@ + +#include "appUtils/ntt/thread_ntt.cu" +#include "curves/curve_config.cuh" +#include "utils/sharedmem.cuh" +#include "appUtils/ntt/ntt.cuh" // for Ordering + +namespace ntt { + + static __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit) + { + uint32_t rev_num = 0, temp, dig_len; + if (dit) { + for (int i = 4; i >= 0; i--) { + dig_len = STAGE_SIZES_DEVICE[log_size][i]; + temp = num & ((1 << dig_len) - 1); + num = num >> dig_len; + rev_num = rev_num << dig_len; + rev_num = rev_num | temp; + } + } else { + for (int i = 0; i < 5; i++) { + dig_len = STAGE_SIZES_DEVICE[log_size][i]; + temp = num & ((1 << dig_len) - 1); + num = num >> dig_len; + rev_num = rev_num << dig_len; + rev_num = rev_num | temp; + } + } + return rev_num; + } + + // Note: the following reorder kernels are fused with normalization for INTT + template + static __global__ void + reorder_digits_inplace_kernel(E* arr, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) + { + // launch N threads + // each thread starts from one index and calculates the corresponding group + // if its index is the smallest number in the group -> do the memory transformation + // else --> do nothing + + const uint32_t idx = blockDim.x * blockIdx.x + threadIdx.x; + uint32_t next_element = idx; + uint32_t group[MAX_GROUP_SIZE]; + group[0] = idx; + + uint32_t i = 1; + for (; i < MAX_GROUP_SIZE;) { + next_element = dig_rev(next_element, log_size, dit); + if (next_element < idx) return; // not handling this group + if (next_element == idx) break; // calculated whole group + group[i++] = next_element; + } + + if (i == 1) { // single element in group --> nothing to do (except maybe normalize for INTT) + if (is_normalize) { arr[idx] = arr[idx] * inverse_N; } + return; + } + --i; + // reaching here means I am handling this group + const E last_element_in_group = arr[group[i]]; + for (; i > 0; --i) { + arr[group[i]] = is_normalize ? (arr[group[i - 1]] * inverse_N) : arr[group[i - 1]]; + } + arr[idx] = is_normalize ? (last_element_in_group * inverse_N) : last_element_in_group; + } + + template + __launch_bounds__(64) __global__ + void reorder_digits_kernel(E* arr, E* arr_reordered, uint32_t log_size, bool dit, bool is_normalize, S inverse_N) + { + uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x; + uint32_t rd = tid; + uint32_t wr = dig_rev(tid, log_size, dit); + arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd]; + } + + template + __launch_bounds__(64) __global__ void ntt64( + E* in, + E* out, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + uint32_t log_size, + uint32_t tw_log_size, + uint32_t data_stride, + uint32_t log_data_stride, + uint32_t twiddle_stride, + bool strided, + uint32_t stage_num, + bool inv, + bool dit) + { + NTTEngine engine; + stage_metadata s_meta; + SharedMemory smem; + E* shmem = smem.getPointer(); + + s_meta.th_stride = 8; + s_meta.ntt_block_size = 64; + s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3)); + s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7); + + engine.loadBasicTwiddles(basic_twiddles, inv); + engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta); + if (twiddle_stride && dit) { + engine.loadExternalTwiddlesGeneric64( + external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); + engine.twiddlesExternal(); + } + engine.loadInternalTwiddles64(internal_twiddles, strided, inv); + +#pragma unroll 1 + for (uint32_t phase = 0; phase < 2; phase++) { + engine.ntt8win(); + if (phase == 0) { + engine.SharedData64Columns8(shmem, true, false, strided); // store + __syncthreads(); + engine.SharedData64Rows8(shmem, false, false, strided); // load + engine.twiddlesInternal(); + } + } + + if (twiddle_stride && !dit) { + engine.loadExternalTwiddlesGeneric64( + external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); + engine.twiddlesExternal(); + } + engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta); + } + + template + __launch_bounds__(64) __global__ void ntt32( + E* in, + E* out, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + uint32_t log_size, + uint32_t tw_log_size, + uint32_t data_stride, + uint32_t log_data_stride, + uint32_t twiddle_stride, + bool strided, + uint32_t stage_num, + bool inv, + bool dit) + { + NTTEngine engine; + stage_metadata s_meta; + + SharedMemory smem; + E* shmem = smem.getPointer(); + + s_meta.th_stride = 4; + s_meta.ntt_block_size = 32; + s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2)); + s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3); + + engine.loadBasicTwiddles(basic_twiddles, inv); + engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta); + engine.loadInternalTwiddles32(internal_twiddles, strided, inv); + engine.ntt8win(); + engine.twiddlesInternal(); + engine.SharedData32Columns8(shmem, true, false, strided); // store + __syncthreads(); + engine.SharedData32Rows4_2(shmem, false, false, strided); // load + engine.ntt4_2(); + if (twiddle_stride) { + engine.loadExternalTwiddlesGeneric32( + external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); + engine.twiddlesExternal(); + } + engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta); + } + + template + __launch_bounds__(64) __global__ void ntt32dit( + E* in, + E* out, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + uint32_t log_size, + int32_t tw_log_size, + uint32_t data_stride, + uint32_t log_data_stride, + uint32_t twiddle_stride, + bool strided, + uint32_t stage_num, + bool inv, + bool dit) + { + NTTEngine engine; + stage_metadata s_meta; + + SharedMemory smem; + E* shmem = smem.getPointer(); + + s_meta.th_stride = 4; + s_meta.ntt_block_size = 32; + s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2)); + s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3); + + engine.loadBasicTwiddles(basic_twiddles, inv); + engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta); + if (twiddle_stride) { + engine.loadExternalTwiddlesGeneric32( + external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); + engine.twiddlesExternal(); + } + engine.loadInternalTwiddles32(internal_twiddles, strided, inv); + engine.ntt4_2(); + engine.SharedData32Columns4_2(shmem, true, false, strided); // store + __syncthreads(); + engine.SharedData32Rows8(shmem, false, false, strided); // load + engine.twiddlesInternal(); + engine.ntt8win(); + engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta); + } + + template + __launch_bounds__(64) __global__ void ntt16( + E* in, + E* out, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + uint32_t log_size, + uint32_t tw_log_size, + uint32_t data_stride, + uint32_t log_data_stride, + uint32_t twiddle_stride, + bool strided, + uint32_t stage_num, + bool inv, + bool dit) + { + NTTEngine engine; + stage_metadata s_meta; + + SharedMemory smem; + E* shmem = smem.getPointer(); + + s_meta.th_stride = 2; + s_meta.ntt_block_size = 16; + s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1)); + s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1); + + engine.loadBasicTwiddles(basic_twiddles, inv); + engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta); + engine.loadInternalTwiddles16(internal_twiddles, strided, inv); + engine.ntt8win(); + engine.twiddlesInternal(); + engine.SharedData16Columns8(shmem, true, false, strided); // store + __syncthreads(); + engine.SharedData16Rows2_4(shmem, false, false, strided); // load + engine.ntt2_4(); + if (twiddle_stride) { + engine.loadExternalTwiddlesGeneric16( + external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); + engine.twiddlesExternal(); + } + engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta); + } + + template + __launch_bounds__(64) __global__ void ntt16dit( + E* in, + E* out, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + uint32_t log_size, + uint32_t tw_log_size, + uint32_t data_stride, + uint32_t log_data_stride, + uint32_t twiddle_stride, + bool strided, + uint32_t stage_num, + bool inv, + bool dit) + { + NTTEngine engine; + stage_metadata s_meta; + + SharedMemory smem; + E* shmem = smem.getPointer(); + + s_meta.th_stride = 2; + s_meta.ntt_block_size = 16; + s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1)); + s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1); + + engine.loadBasicTwiddles(basic_twiddles, inv); + engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta); + if (twiddle_stride) { + engine.loadExternalTwiddlesGeneric16( + external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv); + engine.twiddlesExternal(); + } + engine.loadInternalTwiddles16(internal_twiddles, strided, inv); + engine.ntt2_4(); + engine.SharedData16Columns2_4(shmem, true, false, strided); // store + __syncthreads(); + engine.SharedData16Rows8(shmem, false, false, strided); // load + engine.twiddlesInternal(); + engine.ntt8win(); + engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta); + } + + template + __global__ void normalize_kernel(E* data, S norm_factor) + { + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + data[tid] = data[tid] * norm_factor; + } + + template + __global__ void generate_base_table(S basic_root, S* base_table, uint32_t skip) + { + S w = basic_root; + S t = S::one(); + for (int i = 0; i < 64; i += skip) { + base_table[i] = t; + t = t * w; + } + } + + template + __global__ void generate_basic_twiddles(S basic_root, S* w6_table, S* basic_twiddles) + { + S w0 = basic_root * basic_root; + S w1 = (basic_root + w0 * basic_root) * S::inv_log_size(1); + S w2 = (basic_root - w0 * basic_root) * S::inv_log_size(1); + basic_twiddles[0] = w0; + basic_twiddles[1] = w1; + basic_twiddles[2] = w2; + S basic_inv = w6_table[64 - 8]; + w0 = basic_inv * basic_inv; + w1 = (basic_inv + w0 * basic_inv) * S::inv_log_size(1); + w2 = (basic_inv - w0 * basic_inv) * S::inv_log_size(1); + basic_twiddles[3] = w0; + basic_twiddles[4] = w1; + basic_twiddles[5] = w2; + } + + template + __global__ void generate_twiddle_combinations_generic( + S* w6_table, S* w12_table, S* w18_table, S* w24_table, S* w30_table, S* external_twiddles, uint32_t log_size) + { + uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint32_t exp = tid << (30 - log_size); + S w6, w12, w18, w24, w30; + w6 = w6_table[exp >> 24]; + w12 = w12_table[((exp >> 18) & 0x3f)]; + w18 = w18_table[((exp >> 12) & 0x3f)]; + w24 = w24_table[((exp >> 6) & 0x3f)]; + w30 = w30_table[(exp & 0x3f)]; + S t = w6 * w12 * w18 * w24 * w30; + external_twiddles[tid] = t; + } + + template + __global__ void set_value(S* arr, int idx, S val) + { + arr[idx] = val; + } + + template + cudaError_t generate_external_twiddles_generic( + const S& basic_root, + S* external_twiddles, + S*& internal_twiddles, + S*& basic_twiddles, + uint32_t log_size, + cudaStream_t& stream) + { + CHK_INIT_IF_RETURN(); + + const int n = pow(2, log_size); + CHK_IF_RETURN(cudaMallocAsync(&basic_twiddles, 6 * sizeof(S), stream)); + + S* w6_table; + S* w12_table; + S* w18_table; + S* w24_table; + S* w30_table; + CHK_IF_RETURN(cudaMallocAsync(&w6_table, sizeof(S) * 64, stream)); + CHK_IF_RETURN(cudaMallocAsync(&w12_table, sizeof(S) * 64, stream)); + CHK_IF_RETURN(cudaMallocAsync(&w18_table, sizeof(S) * 64, stream)); + CHK_IF_RETURN(cudaMallocAsync(&w24_table, sizeof(S) * 64, stream)); + CHK_IF_RETURN(cudaMallocAsync(&w30_table, sizeof(S) * 64, stream)); + + // Note: for compatibility with radix-2 INTT, need ONE in last element (in addition to first element) + set_value<<<1, 1, 0, stream>>>(external_twiddles, n /*last element idx*/, S::one()); + + cudaStreamSynchronize(stream); + + S temp_root = basic_root; + generate_base_table<<<1, 1, 0, stream>>>(basic_root, w30_table, 1 << (30 - log_size)); + + if (log_size > 24) + for (int i = 0; i < 6 - (30 - log_size); i++) + temp_root = temp_root * temp_root; + generate_base_table<<<1, 1, 0, stream>>>(temp_root, w24_table, 1 << (log_size > 24 ? 0 : 24 - log_size)); + + if (log_size > 18) + for (int i = 0; i < 6 - (log_size > 24 ? 0 : 24 - log_size); i++) + temp_root = temp_root * temp_root; + generate_base_table<<<1, 1, 0, stream>>>(temp_root, w18_table, 1 << (log_size > 18 ? 0 : 18 - log_size)); + + if (log_size > 12) + for (int i = 0; i < 6 - (log_size > 18 ? 0 : 18 - log_size); i++) + temp_root = temp_root * temp_root; + generate_base_table<<<1, 1, 0, stream>>>(temp_root, w12_table, 1 << (log_size > 12 ? 0 : 12 - log_size)); + + if (log_size > 6) + for (int i = 0; i < 6 - (log_size > 12 ? 0 : 12 - log_size); i++) + temp_root = temp_root * temp_root; + generate_base_table<<<1, 1, 0, stream>>>(temp_root, w6_table, 1 << (log_size > 6 ? 0 : 6 - log_size)); + + if (log_size > 2) + for (int i = 0; i < 3 - (log_size > 6 ? 0 : 6 - log_size); i++) + temp_root = temp_root * temp_root; + generate_basic_twiddles<<<1, 1, 0, stream>>>(temp_root, w6_table, basic_twiddles); + + const int NOF_BLOCKS = (log_size >= 8) ? (1 << (log_size - 8)) : 1; + const int NOF_THREADS = (log_size >= 8) ? 256 : (1 << log_size); + generate_twiddle_combinations_generic<<>>( + w6_table, w12_table, w18_table, w24_table, w30_table, external_twiddles, log_size); + + internal_twiddles = w6_table; + + CHK_IF_RETURN(cudaFreeAsync(w12_table, stream)); + CHK_IF_RETURN(cudaFreeAsync(w18_table, stream)); + CHK_IF_RETURN(cudaFreeAsync(w24_table, stream)); + CHK_IF_RETURN(cudaFreeAsync(w30_table, stream)); + + return CHK_LAST(); + } + + template + cudaError_t large_ntt( + E* in, + E* out, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + uint32_t log_size, + uint32_t tw_log_size, + bool inv, + bool normalize, + bool dit, + cudaStream_t cuda_stream) + { + CHK_INIT_IF_RETURN(); + + if (log_size == 1 || log_size == 2 || log_size == 3 || log_size == 7) { + throw IcicleError(IcicleError_t::InvalidArgument, "size not implemented for mixed-radix-NTT"); + } + + if (log_size == 4) { + if (dit) { + ntt16dit<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + } else { // dif + ntt16<<<1, 2, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + } + if (normalize) normalize_kernel<<<1, 16, 0, cuda_stream>>>(out, S::inv_log_size(4)); + return CHK_LAST(); + } + + if (log_size == 5) { + if (dit) { + ntt32dit<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + } else { // dif + ntt32<<<1, 4, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + } + if (normalize) normalize_kernel<<<1, 32, 0, cuda_stream>>>(out, S::inv_log_size(5)); + return CHK_LAST(); + } + + if (log_size == 6) { + ntt64<<<1, 8, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + if (normalize) normalize_kernel<<<1, 64, 0, cuda_stream>>>(out, S::inv_log_size(6)); + return CHK_LAST(); + } + + if (log_size == 8) { + if (dit) { + ntt16dit<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + ntt16dit<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>( + out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1, + inv, + dit); // we need threads 32+ although 16-31 are idle + } else { // dif + ntt16<<<1, 64, 8 * 64 * sizeof(E), cuda_stream>>>( + in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 16, 4, 16, true, 1, inv, + dit); // we need threads 32+ although 16-31 are idle + ntt16<<<1, 32, 8 * 64 * sizeof(E), cuda_stream>>>( + out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, 1, 0, 0, false, 0, inv, + dit); + } + if (normalize) normalize_kernel<<<1, 256, 0, cuda_stream>>>(out, S::inv_log_size(8)); + return CHK_LAST(); + } + + // general case: + if (dit) { + for (int i = 0; i < 5; i++) { + uint32_t stage_size = STAGE_SIZES_HOST[log_size][i]; + uint32_t stride_log = 0; + for (int j = 0; j < i; j++) + stride_log += STAGE_SIZES_HOST[log_size][j]; + if (stage_size == 6) + ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>( + i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit); + else if (stage_size == 5) + ntt32dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>( + i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit); + else if (stage_size == 4) + ntt16dit<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>( + i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit); + } + } else { // dif + bool first_run = false, prev_stage = false; + for (int i = 4; i >= 0; i--) { + uint32_t stage_size = STAGE_SIZES_HOST[log_size][i]; + uint32_t stride_log = 0; + for (int j = 0; j < i; j++) + stride_log += STAGE_SIZES_HOST[log_size][j]; + first_run = stage_size && !prev_stage; + if (stage_size == 6) + ntt64<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>( + first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit); + else if (stage_size == 5) + ntt32<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>( + first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit); + else if (stage_size == 4) + ntt16<<<1 << (log_size - 9), 64, 8 * 64 * sizeof(E), cuda_stream>>>( + first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, + 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit); + prev_stage = stage_size; + } + } + if (normalize) normalize_kernel<<<1 << (log_size - 8), 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size)); + + return CHK_LAST(); + } + + template + cudaError_t mixed_radix_ntt( + E* d_input, + E* d_output, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + int ntt_size, + int max_logn, + bool is_inverse, + Ordering ordering, + cudaStream_t cuda_stream) + { + CHK_INIT_IF_RETURN(); + + // TODO: can we support all orderings? Note that reversal is generally digit reverse (generalization of bit reverse) + if (ordering != Ordering::kNN) { + throw IcicleError(IcicleError_t::InvalidArgument, "Mixed-Radix NTT supports NN ordering only"); + } + + const int logn = int(log2(ntt_size)); + + const int NOF_BLOCKS = (1 << (max(logn, 6) - 6)); + const int NOF_THREADS = min(64, 1 << logn); + + const bool reverse_input = ordering == Ordering::kNN; + const bool is_dit = ordering == Ordering::kNN || ordering == Ordering::kRN; + bool is_normalize = is_inverse; + + if (reverse_input) { + // Note: fusing reorder with normalize for INTT + const bool is_reverse_in_place = (d_input == d_output); + if (is_reverse_in_place) { + reorder_digits_inplace_kernel<<>>( + d_output, logn, is_dit, is_normalize, S::inv_log_size(logn)); + } else { + reorder_digits_kernel<<>>( + d_input, d_output, logn, is_dit, is_normalize, S::inv_log_size(logn)); + } + is_normalize = false; + } + + // inplace ntt + CHK_IF_RETURN(large_ntt( + d_output, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, is_inverse, + is_normalize, is_dit, cuda_stream)); + + return CHK_LAST(); + } + + // Explicit instantiation for scalar type + template cudaError_t generate_external_twiddles_generic( + const curve_config::scalar_t& basic_root, + curve_config::scalar_t* external_twiddles, + curve_config::scalar_t*& internal_twiddles, + curve_config::scalar_t*& basic_twiddles, + uint32_t log_size, + cudaStream_t& stream); + + template cudaError_t mixed_radix_ntt( + curve_config::scalar_t* d_input, + curve_config::scalar_t* d_output, + curve_config::scalar_t* external_twiddles, + curve_config::scalar_t* internal_twiddles, + curve_config::scalar_t* basic_twiddles, + int ntt_size, + int max_logn, + bool is_inverse, + Ordering ordering, + cudaStream_t cuda_stream); + +} // namespace ntt diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index c1581531..d437e71d 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -7,6 +7,7 @@ #include "utils/sharedmem.cuh" #include "utils/utils_kernels.cuh" #include "utils/utils.h" +#include "appUtils/ntt/ntt_impl.cuh" namespace ntt { @@ -25,7 +26,10 @@ namespace ntt { int idx = threadId % n; int batch_idx = threadId / n; int idx_reversed = __brev(idx) >> (32 - logn); - arr_reversed[batch_idx * n + idx_reversed] = arr[batch_idx * n + idx]; + + E val = arr[batch_idx * n + idx]; + if (arr == arr_reversed) { __syncthreads(); } // for in-place (when pointers arr==arr_reversed) + arr_reversed[batch_idx * n + idx_reversed] = val; } } @@ -357,25 +361,24 @@ namespace ntt { template class Domain { - static int max_size; - static S* twiddles; - static std::unordered_map coset_index; + static inline int max_size = 0; + static inline int max_log_size = 0; + static inline S* twiddles = nullptr; + static inline std::unordered_map coset_index = {}; + + static inline S* internal_twiddles = nullptr; // required by mixed-radix NTT + static inline S* basic_twiddles = nullptr; // required by mixed-radix NTT public: template friend cudaError_t InitDomain(U primitive_root, device_context::DeviceContext& ctx); + static cudaError_t ReleaseDomain(device_context::DeviceContext& ctx); + template friend cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig& config, E* output); }; - template - int Domain::max_size = 0; - template - S* Domain::twiddles = nullptr; - template - std::unordered_map Domain::coset_index = {}; - template cudaError_t InitDomain(S primitive_root, device_context::DeviceContext& ctx) { @@ -385,34 +388,70 @@ namespace ntt { // please note that this is not thread-safe at all, // but it's a singleton that is supposed to be initialized once per program lifetime if (!Domain::twiddles) { + bool found_logn = false; S omega = primitive_root; unsigned omegas_count = S::get_omegas_count(); - for (int i = 0; i < omegas_count; i++) + for (int i = 0; i < omegas_count; i++) { omega = S::sqr(omega); + if (!found_logn) { + ++Domain::max_log_size; + found_logn = omega == S::one(); + if (found_logn) break; + } + } + Domain::max_size = (int)pow(2, Domain::max_log_size); if (omega != S::one()) { - std::cerr << "Primitive root provided to the InitDomain function is not in the subgroup" << '\n'; - throw -1; + throw IcicleError( + IcicleError_t::InvalidArgument, "Primitive root provided to the InitDomain function is not in the subgroup"); } - std::vector h_twiddles; - h_twiddles.push_back(S::one()); - int n = 1; - do { - Domain::coset_index[h_twiddles.at(n - 1)] = n - 1; - h_twiddles.push_back(h_twiddles.at(n - 1) * primitive_root); - } while (h_twiddles.at(n++) != S::one()); - - CHK_IF_RETURN(cudaMallocAsync(&Domain::twiddles, n * sizeof(S), ctx.stream)); - CHK_IF_RETURN( - cudaMemcpyAsync(Domain::twiddles, &h_twiddles.front(), n * sizeof(S), cudaMemcpyHostToDevice, ctx.stream)); - - Domain::max_size = n - 1; + // allocate and calculate twiddles on GPU + // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements + // Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host + CHK_IF_RETURN(cudaMallocManaged(&Domain::twiddles, (Domain::max_size + 1) * sizeof(S))); + CHK_IF_RETURN(generate_external_twiddles_generic( + primitive_root, Domain::twiddles, Domain::internal_twiddles, Domain::basic_twiddles, + Domain::max_log_size, ctx.stream)); CHK_IF_RETURN(cudaStreamSynchronize(ctx.stream)); + + const bool is_map_only_powers_of_primitive_root = true; + if (is_map_only_powers_of_primitive_root) { + // populate the coset_index map. Note that only powers of the primitive-root are stored (1, PR, PR^2, PR^4, PR^8 + // etc.) + Domain::coset_index[S::one()] = 0; + for (int i = 0; i < Domain::max_log_size; ++i) { + const int index = (int)pow(2, i); + Domain::coset_index[Domain::twiddles[index]] = index; + } + } else { + // populate all values + for (int i = 0; i < Domain::max_size; ++i) { + Domain::coset_index[Domain::twiddles[i]] = i; + } + } } return CHK_LAST(); } + template + cudaError_t Domain::ReleaseDomain(device_context::DeviceContext& ctx) + { + CHK_INIT_IF_RETURN(); + + max_size = 0; + max_log_size = 0; + cudaFreeAsync(twiddles, ctx.stream); + twiddles = nullptr; + cudaFreeAsync(internal_twiddles, ctx.stream); + internal_twiddles = nullptr; + cudaFreeAsync(basic_twiddles, ctx.stream); + basic_twiddles = nullptr; + coset_index.clear(); + + return CHK_LAST(); + } + template cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig& config, E* output) { @@ -431,6 +470,20 @@ namespace ntt { bool are_inputs_on_device = config.are_inputs_on_device; bool are_outputs_on_device = config.are_outputs_on_device; + E* d_input; + if (are_inputs_on_device) { + d_input = input; + } else { + CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream)); + CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream)); + } + E* d_output; + if (are_outputs_on_device) { + d_output = output; + } else { + CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream)); + } + S* coset = nullptr; int coset_index = 0; try { @@ -448,45 +501,44 @@ namespace ntt { h_coset.clear(); } - E* d_input; - if (are_inputs_on_device) { - d_input = input; - } else { - CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream)); - CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream)); - } - E* d_output; - if (are_outputs_on_device) { - d_output = output; - } else { - CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream)); - } + const bool is_small_ntt = logn < 16; // cutoff point where mixed-radix is faster than radix-2 + const bool is_on_coset = (coset_index != 0) || coset; // coset not supported by mixed-radix algorithm yet + const bool is_batch_ntt = batch_size > 1; // batch not supported by mixed-radidx algorithm yet + const bool is_NN = config.ordering == Ordering::kNN; // TODO Yuval: relax this limitation + const bool is_radix2_algorithm = config.is_force_radix2 || is_batch_ntt || is_small_ntt || is_on_coset || !is_NN; - bool ct_butterfly = true; - bool reverse_input = false; - switch (config.ordering) { - case Ordering::kNN: - reverse_input = true; - break; - case Ordering::kNR: - ct_butterfly = false; - break; - case Ordering::kRR: - reverse_input = true; - ct_butterfly = false; - break; + if (is_radix2_algorithm) { + bool ct_butterfly = true; + bool reverse_input = false; + switch (config.ordering) { + case Ordering::kNN: + reverse_input = true; + break; + case Ordering::kNR: + ct_butterfly = false; + break; + case Ordering::kRR: + reverse_input = true; + ct_butterfly = false; + break; + } + + if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output); + + CHK_IF_RETURN(ntt_inplace_batch_template( + reverse_input ? d_output : d_input, size, Domain::twiddles, Domain::max_size, batch_size, logn, + dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output)); + + if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream)); + } else { // mixed-radix algorithm + CHK_IF_RETURN(ntt::mixed_radix_ntt( + d_input, d_output, Domain::twiddles, Domain::internal_twiddles, Domain::basic_twiddles, size, + Domain::max_log_size, dir == NTTDir::kInverse, config.ordering, stream)); } - if (reverse_input) reverse_order_batch(d_input, size, logn, batch_size, stream, d_output); - - CHK_IF_RETURN(ntt_inplace_batch_template( - reverse_input ? d_output : d_input, size, Domain::twiddles, Domain::max_size, batch_size, logn, - dir == NTTDir::kInverse, ct_butterfly, coset, coset_index, stream, d_output)); - if (!are_outputs_on_device) CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream)); - if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream)); if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream)); if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream)); if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream)); @@ -506,6 +558,7 @@ namespace ntt { false, // are_inputs_on_device false, // are_outputs_on_device false, // is_async + false, // is_force_radix2 }; return config; } diff --git a/icicle/appUtils/ntt/ntt.cuh b/icicle/appUtils/ntt/ntt.cuh index fec1fe4d..e82de4ef 100644 --- a/icicle/appUtils/ntt/ntt.cuh +++ b/icicle/appUtils/ntt/ntt.cuh @@ -80,6 +80,8 @@ namespace ntt { * non-blocking and you'd need to synchronize it explicitly by running * `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, the NTT * function will block the current CPU thread. */ + bool is_force_radix2; /**< Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects + radix-2 or mixed-radix algorithm based on heuristics). */ }; /** diff --git a/icicle/appUtils/ntt/ntt_impl.cuh b/icicle/appUtils/ntt/ntt_impl.cuh new file mode 100644 index 00000000..928ffd31 --- /dev/null +++ b/icicle/appUtils/ntt/ntt_impl.cuh @@ -0,0 +1,33 @@ +#pragma once +#ifndef _NTT_IMPL_H +#define _NTT_IMPL_H + +#include +#include "appUtils/ntt/ntt.cuh" // for enum Ordering + +namespace ntt { + + template + cudaError_t generate_external_twiddles_generic( + const S& basic_root, + S* external_twiddles, + S*& internal_twiddles, + S*& basic_twiddles, + uint32_t log_size, + cudaStream_t& stream); + + template + cudaError_t mixed_radix_ntt( + E* d_input, + E* d_output, + S* external_twiddles, + S* internal_twiddles, + S* basic_twiddles, + int ntt_size, + int max_logn, + bool is_inverse, + Ordering ordering, + cudaStream_t cuda_stream); + +} // namespace ntt +#endif //_NTT_IMPL_H \ No newline at end of file diff --git a/icicle/appUtils/ntt/tests/verification.cu b/icicle/appUtils/ntt/tests/verification.cu new file mode 100644 index 00000000..a957fee5 --- /dev/null +++ b/icicle/appUtils/ntt/tests/verification.cu @@ -0,0 +1,155 @@ + +#define CURVE_ID BLS12_381 + +#include "primitives/field.cuh" +#include "primitives/projective.cuh" +#include "utils/cuda_utils.cuh" +#include +#include +#include + +#include "curves/curve_config.cuh" +#include "ntt/ntt.cu" +#include "ntt/ntt_impl.cuh" +#include + +typedef curve_config::scalar_t test_scalar; +typedef curve_config::scalar_t test_data; +#include "kernel_ntt.cu" + +void random_samples(test_data* res, uint32_t count) +{ + for (int i = 0; i < count; i++) + res[i] = i < 1000 ? test_data::rand_host() : res[i - 1000]; +} + +void incremental_values(test_scalar* res, uint32_t count) +{ + for (int i = 0; i < count; i++) { + res[i] = i ? res[i - 1] + test_scalar::one() * test_scalar::omega(4) : test_scalar::zero(); + } +} + +int main(int argc, char** argv) +{ + cudaEvent_t icicle_start, icicle_stop, new_start, new_stop; + float icicle_time, new_time; + + int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19; // assuming second input is the log-size + int NTT_SIZE = 1 << NTT_LOG_SIZE; + bool INPLACE = (argc > 2) ? atoi(argv[2]) : true; + int INV = (argc > 3) ? atoi(argv[3]) : true; + + const ntt::Ordering ordering = ntt::Ordering::kNN; + const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN" + : ordering == ntt::Ordering::kNR ? "NR" + : ordering == ntt::Ordering::kRN ? "RN" + : "RR"; + + printf("running ntt 2^%d, ordering=%s, inplace=%d, inverse=%d\n", NTT_LOG_SIZE, ordering_str, INPLACE, INV); + + cudaFree(nullptr); // init GPU context (warmup) + + // init domain + auto ntt_config = ntt::DefaultNTTConfig(); + ntt_config.ordering = ordering; + ntt_config.are_inputs_on_device = true; + ntt_config.are_outputs_on_device = true; + + CHK_IF_RETURN(cudaEventCreate(&icicle_start)); + CHK_IF_RETURN(cudaEventCreate(&icicle_stop)); + CHK_IF_RETURN(cudaEventCreate(&new_start)); + CHK_IF_RETURN(cudaEventCreate(&new_stop)); + + auto start = std::chrono::high_resolution_clock::now(); + const test_scalar basic_root = test_scalar::omega(NTT_LOG_SIZE); + ntt::InitDomain(basic_root, ntt_config.ctx); + auto stop = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(stop - start).count(); + std::cout << "initDomain took: " << duration / 1000 << " MS" << std::endl; + + // cpu allocation + auto CpuScalars = std::make_unique(NTT_SIZE); + auto CpuOutputOld = std::make_unique(NTT_SIZE); + auto CpuOutputNew = std::make_unique(NTT_SIZE); + + // gpu allocation + test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew; + CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE)); + CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE)); + CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE)); + + // init inputs + incremental_values(CpuScalars.get(), NTT_SIZE); + CHK_IF_RETURN(cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE, cudaMemcpyHostToDevice)); + + // inplace + if (INPLACE) { + CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice)); + } + + // run ntt + auto benchmark = [&](bool is_print, int iterations) -> cudaError_t { + // NEW + CHK_IF_RETURN(cudaEventRecord(new_start, ntt_config.ctx.stream)); + ntt_config.is_force_radix2 = false; // mixed-radix ntt (a.k.a new ntt) + for (size_t i = 0; i < iterations; i++) { + ntt::NTT( + INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, + GpuOutputNew); + } + CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop)); + if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); } + + // OLD + CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream)); + ntt_config.is_force_radix2 = true; + for (size_t i = 0; i < iterations; i++) { + ntt::NTT(GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputOld); + } + CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream)); + CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop)); + if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); } + + if (is_print) { + printf("Old Runtime=%0.3f MS\n", icicle_time / iterations); + printf("New Runtime=%0.3f MS\n", new_time / iterations); + } + + return CHK_LAST(); + }; + + CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup + int count = INPLACE ? 1 : 10; + if (INPLACE) { + CHK_IF_RETURN(cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice)); + } + CHK_IF_RETURN(benchmark(true /*=print*/, count)); + + // verify + CHK_IF_RETURN(cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost)); + CHK_IF_RETURN(cudaMemcpy(CpuOutputOld.get(), GpuOutputOld, NTT_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost)); + + bool success = true; + for (int i = 0; i < NTT_SIZE; i++) { + if (CpuOutputNew[i] != CpuOutputOld[i]) { + success = false; + // std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl; + break; + } else { + // std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl; + // break; + } + } + const char* success_str = success ? "SUCCESS!" : "FAIL!"; + printf("%s\n", success_str); + + CHK_IF_RETURN(cudaFree(GpuScalars)); + CHK_IF_RETURN(cudaFree(GpuOutputOld)); + CHK_IF_RETURN(cudaFree(GpuOutputNew)); + + return CHK_LAST(); +} \ No newline at end of file diff --git a/icicle/appUtils/ntt/thread_ntt.cu b/icicle/appUtils/ntt/thread_ntt.cu new file mode 100644 index 00000000..f39a9cf7 --- /dev/null +++ b/icicle/appUtils/ntt/thread_ntt.cu @@ -0,0 +1,542 @@ +#ifndef T_NTT +#define T_NTT +#pragma once + +#include +#include +#include "curves/curve_config.cuh" + +struct stage_metadata { + uint32_t th_stride; + uint32_t ntt_block_size; + uint32_t ntt_block_id; + uint32_t ntt_inp_id; +}; + +uint32_t constexpr STAGE_SIZES_HOST[31][5] = { + {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0}, + {4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0}, + {6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6}, + {6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}}; + +__device__ constexpr uint32_t STAGE_SIZES_DEVICE[31][5] = { + {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {4, 0, 0, 0, 0}, {5, 0, 0, 0, 0}, {6, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, {4, 4, 0, 0, 0}, {5, 4, 0, 0, 0}, {5, 5, 0, 0, 0}, {6, 5, 0, 0, 0}, {6, 6, 0, 0, 0}, {4, 5, 4, 0, 0}, + {4, 6, 4, 0, 0}, {5, 5, 5, 0, 0}, {6, 4, 6, 0, 0}, {6, 5, 6, 0, 0}, {6, 6, 6, 0, 0}, {6, 5, 4, 4, 0}, {5, 5, 5, 5, 0}, + {6, 5, 5, 5, 0}, {6, 5, 5, 6, 0}, {6, 6, 6, 5, 0}, {6, 6, 6, 6, 0}, {5, 5, 5, 5, 5}, {6, 5, 4, 5, 6}, {6, 5, 5, 5, 6}, + {6, 5, 6, 5, 6}, {6, 6, 5, 6, 6}, {6, 6, 6, 6, 6}}; + +template +class NTTEngine +{ +public: + E X[8]; + S WB[3]; + S WI[7]; + S WE[8]; + + __device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles, bool inv) + { +#pragma unroll + for (int i = 0; i < 3; i++) { + WB[i] = basic_twiddles[inv ? i + 3 : i]; + } + } + + __device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride, bool inv) + { +#pragma unroll + for (int i = 0; i < 7; i++) { + uint32_t exp = ((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1); + WI[i] = data[(inv && exp) ? 64 - exp : exp]; // if exp = 0 we also take exp and not 64-exp + } + } + + __device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride, bool inv) + { +#pragma unroll + for (int i = 0; i < 7; i++) { + uint32_t exp = 2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1); + WI[i] = data[(inv && exp) ? 64 - exp : exp]; + } + } + + __device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride, bool inv) + { +#pragma unroll + for (int i = 0; i < 7; i++) { + uint32_t exp = 4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1); + WI[i] = data[(inv && exp) ? 64 - exp : exp]; + } + } + + __device__ __forceinline__ void loadExternalTwiddlesGeneric64( + E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv) + { +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + uint32_t exp = (s_meta.ntt_inp_id + 8 * i) * (s_meta.ntt_block_id & (tw_order - 1)) + << (tw_log_size - tw_log_order - 6); + WE[i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp]; + } + } + + __device__ __forceinline__ void loadExternalTwiddlesGeneric32( + E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + uint32_t exp = (s_meta.ntt_inp_id * 2 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1)) + << (tw_log_size - tw_log_order - 5); + WE[4 * j + i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp]; + } + } + } + + __device__ __forceinline__ void loadExternalTwiddlesGeneric16( + E* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv) + { +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + uint32_t exp = (s_meta.ntt_inp_id * 4 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1)) + << (tw_log_size - tw_log_order - 4); + WE[2 * j + i] = data[(inv && exp) ? ((1 << tw_log_size) - exp) : exp]; + } + } + } + + __device__ __forceinline__ void loadGlobalData( + E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + { + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + } else { + data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; + } + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + X[i] = data[s_meta.th_stride * i * data_stride]; + } + } + + __device__ __forceinline__ void storeGlobalData( + E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + { + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + } else { + data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id; + } + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + data[s_meta.th_stride * i * data_stride] = X[i]; + } + } + + __device__ __forceinline__ void loadGlobalData32( + E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + { + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + } else { + data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2; + } + +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + X[4 * j + i] = data[(8 * i + j) * data_stride]; + } + } + } + + __device__ __forceinline__ void storeGlobalData32( + E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + { + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + } else { + data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2; + } + +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + data[(8 * i + j) * data_stride] = X[4 * j + i]; + } + } + } + + __device__ __forceinline__ void loadGlobalData16( + E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + { + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + } else { + data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4; + } + +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + X[2 * j + i] = data[(8 * i + j) * data_stride]; + } + } + } + + __device__ __forceinline__ void storeGlobalData16( + E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta) + { + if (strided) { + data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 + + (s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size; + } else { + data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4; + } + +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + data[(8 * i + j) * data_stride] = X[2 * j + i]; + } + } + } + + __device__ __forceinline__ void ntt4_2() + { +#pragma unroll + for (int i = 0; i < 2; i++) { + ntt4(X[4 * i], X[4 * i + 1], X[4 * i + 2], X[4 * i + 3]); + } + } + + __device__ __forceinline__ void ntt2_4() + { +#pragma unroll + for (int i = 0; i < 4; i++) { + ntt2(X[2 * i], X[2 * i + 1]); + } + } + + __device__ __forceinline__ void ntt2(E& X0, E& X1) + { + E T; + + T = X0 + X1; + X1 = X0 - X1; + X0 = T; + } + + __device__ __forceinline__ void ntt4(E& X0, E& X1, E& X2, E& X3) + { + E T; + + T = X0 + X2; + X2 = X0 - X2; + X0 = X1 + X3; + X1 = X1 - X3; // T has X0, X0 has X1, X2 has X2, X1 has X3 + + X1 = X1 * WB[0]; + + X3 = X2 - X1; + X1 = X2 + X1; + X2 = T - X0; + X0 = T + X0; + } + + // rbo version + __device__ __forceinline__ void ntt4rbo(E& X0, E& X1, E& X2, E& X3) + { + E T; + + T = X0 - X1; + X0 = X0 + X1; + X1 = X2 + X3; + X3 = X2 - X3; // T has X0, X0 has X1, X2 has X2, X1 has X3 + + X3 = X3 * WB[0]; + + X2 = X0 - X1; + X0 = X0 + X1; + X1 = T + X3; + X3 = T - X3; + } + + __device__ __forceinline__ void ntt8(E& X0, E& X1, E& X2, E& X3, E& X4, E& X5, E& X6, E& X7) + { + E T; + + // out of 56,623,104 possible mappings, we have: + T = X3 - X7; + X7 = X3 + X7; + X3 = X1 - X5; + X5 = X1 + X5; + X1 = X2 + X6; + X2 = X2 - X6; + X6 = X0 + X4; + X0 = X0 - X4; + + T = T * WB[1]; + X2 = X2 * WB[1]; + + X4 = X6 + X1; + X6 = X6 - X1; + X1 = X3 + T; + X3 = X3 - T; + T = X5 + X7; + X5 = X5 - X7; + X7 = X0 + X2; + X0 = X0 - X2; + + X1 = X1 * WB[0]; + X5 = X5 * WB[1]; + X3 = X3 * WB[2]; + + X2 = X6 + X5; + X6 = X6 - X5; + X5 = X7 - X1; + X1 = X7 + X1; + X7 = X0 - X3; + X3 = X0 + X3; + X0 = X4 + T; + X4 = X4 - T; + } + + __device__ __forceinline__ void ntt8win() + { + E T; + + T = X[3] - X[7]; + X[7] = X[3] + X[7]; + X[3] = X[1] - X[5]; + X[5] = X[1] + X[5]; + X[1] = X[2] + X[6]; + X[2] = X[2] - X[6]; + X[6] = X[0] + X[4]; + X[0] = X[0] - X[4]; + + X[2] = X[2] * WB[0]; + + X[4] = X[6] + X[1]; + X[6] = X[6] - X[1]; + X[1] = X[3] + T; + X[3] = X[3] - T; + T = X[5] + X[7]; + X[5] = X[5] - X[7]; + X[7] = X[0] + X[2]; + X[0] = X[0] - X[2]; + + X[1] = X[1] * WB[1]; + X[5] = X[5] * WB[0]; + X[3] = X[3] * WB[2]; + + X[2] = X[6] + X[5]; + X[6] = X[6] - X[5]; + + X[5] = X[1] + X[3]; + X[3] = X[1] - X[3]; + + X[1] = X[7] + X[5]; + X[5] = X[7] - X[5]; + X[7] = X[0] - X[3]; + X[3] = X[0] + X[3]; + X[0] = X[4] + T; + X[4] = X[4] - T; + } + + __device__ __forceinline__ void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3; + uint32_t column_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + if (store) { + shmem[ntt_id * 64 + i * 8 + column_id] = X[i]; + } else { + X[i] = shmem[ntt_id * 64 + i * 8 + column_id]; + } + } + } + + __device__ __forceinline__ void SharedData64Rows8(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3; + uint32_t row_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + if (store) { + shmem[ntt_id * 64 + row_id * 8 + i] = X[i]; + } else { + X[i] = shmem[ntt_id * 64 + row_id * 8 + i]; + } + } + } + + __device__ __forceinline__ void SharedData32Columns8(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2; + uint32_t column_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + if (store) { + shmem[ntt_id * 32 + i * 4 + column_id] = X[i]; + } else { + X[i] = shmem[ntt_id * 32 + i * 4 + column_id]; + } + } + } + + __device__ __forceinline__ void SharedData32Rows8(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2; + uint32_t row_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + if (store) { + shmem[ntt_id * 32 + row_id * 8 + i] = X[i]; + } else { + X[i] = shmem[ntt_id * 32 + row_id * 8 + i]; + } + } + } + + __device__ __forceinline__ void SharedData32Columns4_2(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2; + uint32_t column_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2; + +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + if (store) { + shmem[ntt_id * 32 + i * 8 + column_id + j] = X[4 * j + i]; + } else { + X[4 * j + i] = shmem[ntt_id * 32 + i * 8 + column_id + j]; + } + } + } + } + + __device__ __forceinline__ void SharedData32Rows4_2(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2; + uint32_t row_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2; + +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { +#pragma unroll + for (uint32_t i = 0; i < 4; i++) { + if (store) { + shmem[ntt_id * 32 + row_id * 4 + 4 * j + i] = X[4 * j + i]; + } else { + X[4 * j + i] = shmem[ntt_id * 32 + row_id * 4 + 4 * j + i]; + } + } + } + } + + __device__ __forceinline__ void SharedData16Columns8(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1; + uint32_t column_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + if (store) { + shmem[ntt_id * 16 + i * 2 + column_id] = X[i]; + } else { + X[i] = shmem[ntt_id * 16 + i * 2 + column_id]; + } + } + } + + __device__ __forceinline__ void SharedData16Rows8(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1; + uint32_t row_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1; + +#pragma unroll + for (uint32_t i = 0; i < 8; i++) { + if (store) { + shmem[ntt_id * 16 + row_id * 8 + i] = X[i]; + } else { + X[i] = shmem[ntt_id * 16 + row_id * 8 + i]; + } + } + } + + __device__ __forceinline__ void SharedData16Columns2_4(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1; + uint32_t column_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4; + +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + if (store) { + shmem[ntt_id * 16 + i * 8 + column_id + j] = X[2 * j + i]; + } else { + X[2 * j + i] = shmem[ntt_id * 16 + i * 8 + column_id + j]; + } + } + } + } + + __device__ __forceinline__ void SharedData16Rows2_4(E* shmem, bool store, bool high_bits, bool stride) + { + uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1; + uint32_t row_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4; + +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + if (store) { + shmem[ntt_id * 16 + row_id * 2 + 2 * j + i] = X[2 * j + i]; + } else { + X[2 * j + i] = shmem[ntt_id * 16 + row_id * 2 + 2 * j + i]; + } + } + } + } + + __device__ __forceinline__ void twiddlesInternal() + { +#pragma unroll + for (int i = 1; i < 8; i++) { + X[i] = X[i] * WI[i - 1]; + } + } + + __device__ __forceinline__ void twiddlesExternal() + { +#pragma unroll + for (int i = 0; i < 8; i++) { + X[i] = X[i] * WE[i]; + } + } +}; + +#endif \ No newline at end of file diff --git a/icicle/primitives/field.cuh b/icicle/primitives/field.cuh index 5fd535be..ed6e497d 100644 --- a/icicle/primitives/field.cuh +++ b/icicle/primitives/field.cuh @@ -75,11 +75,19 @@ public: return Field{omega_inv.storages[logn - 1]}; } - static HOST_INLINE Field inv_log_size(uint32_t logn) + static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn) { if (logn == 0) { return Field{CONFIG::one}; } - +#ifndef __CUDA_ARCH__ if (logn > CONFIG::omegas_count) THROW_ICICLE_ERR(IcicleError_t::InvalidArgument, "Field: Invalid inv index"); +#else + if (logn > CONFIG::omegas_count) { + printf( + "CUDA ERROR: field.cuh: error on inv_log_size(logn): logn(=%u) > omegas_count (=%u)", logn, + CONFIG::omegas_count); + assert(false); + } +#endif // __CUDA_ARCH__ storage_array const inv = CONFIG::inv; return Field{inv.storages[logn - 1]}; } diff --git a/icicle/primitives/projective.cuh b/icicle/primitives/projective.cuh index 4aa81609..8a385436 100644 --- a/icicle/primitives/projective.cuh +++ b/icicle/primitives/projective.cuh @@ -143,11 +143,15 @@ public: return res; } + friend HOST_DEVICE_INLINE Projective operator*(const Projective& point, SCALAR_FF scalar) { return scalar * point; } + friend HOST_DEVICE_INLINE bool operator==(const Projective& p1, const Projective& p2) { return (p1.x * p2.z == p2.x * p1.z) && (p1.y * p2.z == p2.y * p1.z); } + friend HOST_DEVICE_INLINE bool operator!=(const Projective& p1, const Projective& p2) { return !(p1 == p2); } + friend HOST_INLINE std::ostream& operator<<(std::ostream& os, const Projective& point) { os << "Point { x: " << point.x << "; y: " << point.y << "; z: " << point.z << " }"; diff --git a/wrappers/rust/icicle-core/src/ntt/mod.rs b/wrappers/rust/icicle-core/src/ntt/mod.rs index 04f10ba7..24591413 100644 --- a/wrappers/rust/icicle-core/src/ntt/mod.rs +++ b/wrappers/rust/icicle-core/src/ntt/mod.rs @@ -57,6 +57,8 @@ pub struct NTTConfig<'a, S> { /// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize /// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread. pub is_async: bool, + /// Explicitly select radix-2 NTT algorithm. Default value: false (the implementation selects radix-2 or mixed-radix algorithm based on heuristics). + pub is_force_radix2: bool, } impl<'a, S: FieldImpl> NTTConfig<'a, S> { @@ -70,6 +72,7 @@ impl<'a, S: FieldImpl> NTTConfig<'a, S> { are_inputs_on_device: false, are_outputs_on_device: false, is_async: false, + is_force_radix2: false, } } }