mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
Devmode to Reduce compilation time (including G2 and ECNTT) (#395)
devmode to reduce compilation time
This commit is contained in:
@@ -69,7 +69,7 @@ else()
|
||||
endif()
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
||||
set(CMAKE_CUDA_FLAGS_RELEASE "")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -lineinfo")
|
||||
include_directories("${CMAKE_SOURCE_DIR}")
|
||||
|
||||
|
||||
@@ -92,6 +92,10 @@ if (NOT IS_CURVE_SUPPORTED)
|
||||
message( FATAL_ERROR "The value of CURVE variable: ${CURVE} is not one of the supported curves: ${SUPPORTED_CURVES}" )
|
||||
endif ()
|
||||
|
||||
if (DEVMODE STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O0 --ptxas-options=-O0 --ptxas-options=-allow-expensive-optimizations=false -DDEVMODE=ON")
|
||||
endif ()
|
||||
|
||||
if (G2_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DG2_DEFINED=ON")
|
||||
endif ()
|
||||
|
||||
@@ -86,7 +86,7 @@ public:
|
||||
{
|
||||
Dummy_Projective res = zero();
|
||||
#ifdef CUDA_ARCH
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (int i = 0; i < Dummy_Scalar::NBITS; i++) {
|
||||
if (i > 0) { res = res + res; }
|
||||
|
||||
@@ -4,3 +4,10 @@ build_verification:
|
||||
|
||||
test_verification: build_verification
|
||||
work/test_verification
|
||||
|
||||
build_verification_ecntt:
|
||||
mkdir -p work
|
||||
nvcc -o work/test_verification_ecntt -I. -I.. -I../.. -I../ntt tests/verification.cu -std=c++17 -DECNTT_DEFINED
|
||||
|
||||
test_verification_ecntt: build_verification_ecntt
|
||||
work/test_verification_ecntt
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
#include "appUtils/ntt/thread_ntt.cu"
|
||||
#include "curves/curve_config.cuh"
|
||||
#include "utils/sharedmem.cuh"
|
||||
#include "appUtils/ntt/ntt.cuh" // for Ordering
|
||||
#include "appUtils/ntt/ntt.cuh" // for ntt::Ordering
|
||||
|
||||
namespace ntt {
|
||||
namespace mxntt {
|
||||
|
||||
static inline __device__ uint32_t dig_rev(uint32_t num, uint32_t log_size, bool dit, bool fast_tw)
|
||||
{
|
||||
@@ -907,7 +907,7 @@ namespace ntt {
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
ntt::Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
cudaStream_t cuda_stream)
|
||||
@@ -925,30 +925,30 @@ namespace ntt {
|
||||
eRevType reverse_input = None, reverse_output = None, reverse_coset = None;
|
||||
bool dit = false;
|
||||
switch (ordering) {
|
||||
case Ordering::kNN:
|
||||
case ntt::Ordering::kNN:
|
||||
reverse_input = eRevType::NaturalToMixedRev;
|
||||
dit = true;
|
||||
break;
|
||||
case Ordering::kRN:
|
||||
case ntt::Ordering::kRN:
|
||||
reverse_input = eRevType::RevToMixedRev;
|
||||
dit = true;
|
||||
reverse_coset = is_inverse ? eRevType::None : eRevType::NaturalToRev;
|
||||
break;
|
||||
case Ordering::kNR:
|
||||
case ntt::Ordering::kNR:
|
||||
reverse_output = eRevType::MixedRevToRev;
|
||||
reverse_coset = is_inverse ? eRevType::NaturalToRev : eRevType::None;
|
||||
break;
|
||||
case Ordering::kRR:
|
||||
case ntt::Ordering::kRR:
|
||||
reverse_input = eRevType::RevToMixedRev;
|
||||
dit = true;
|
||||
reverse_output = eRevType::NaturalToRev;
|
||||
reverse_coset = eRevType::NaturalToRev;
|
||||
break;
|
||||
case Ordering::kMN:
|
||||
case ntt::Ordering::kMN:
|
||||
dit = true;
|
||||
reverse_coset = is_inverse ? None : eRevType::NaturalToMixedRev;
|
||||
break;
|
||||
case Ordering::kNM:
|
||||
case ntt::Ordering::kNM:
|
||||
reverse_coset = is_inverse ? eRevType::NaturalToMixedRev : eRevType::None;
|
||||
break;
|
||||
}
|
||||
@@ -1025,9 +1025,27 @@ namespace ntt {
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
ntt::Ordering ordering,
|
||||
curve_config::scalar_t* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
} // namespace ntt
|
||||
// TODO: we may reintroduce mixed-radix ECNTT based on upcoming benching PR
|
||||
// #if defined(ECNTT_DEFINED)
|
||||
// template cudaError_t mixed_radix_ntt<curve_config::projective_t, curve_config::scalar_t>(
|
||||
// curve_config::projective_t* d_input,
|
||||
// curve_config::projective_t* d_output,
|
||||
// curve_config::scalar_t* external_twiddles,
|
||||
// curve_config::scalar_t* internal_twiddles,
|
||||
// curve_config::scalar_t* basic_twiddles,
|
||||
// int ntt_size,
|
||||
// int max_logn,
|
||||
// int batch_size,
|
||||
// bool columns_batch,
|
||||
// bool is_inverse,
|
||||
// bool fast_tw,
|
||||
// ntt::Ordering ordering,
|
||||
// curve_config::scalar_t* arbitrary_coset,
|
||||
// int coset_gen_index,
|
||||
// cudaStream_t cuda_stream);
|
||||
// #endif // ECNTT_DEFINED
|
||||
} // namespace mxntt
|
||||
|
||||
@@ -2,12 +2,14 @@
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
|
||||
#include "curves/curve_config.cuh"
|
||||
#include "utils/sharedmem.cuh"
|
||||
#include "utils/utils_kernels.cuh"
|
||||
#include "utils/utils.h"
|
||||
#include "appUtils/ntt/ntt_impl.cuh"
|
||||
#include "appUtils/ntt/ntt.cuh" // for ntt::Ordering
|
||||
|
||||
#include <mutex>
|
||||
|
||||
@@ -112,7 +114,7 @@ namespace ntt {
|
||||
uint32_t l = threadIdx.x;
|
||||
|
||||
if (l < loop_limit) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (; ss < logn; ss++) {
|
||||
int s = logn - ss - 1;
|
||||
bool is_beginning = ss == 0;
|
||||
@@ -184,7 +186,7 @@ namespace ntt {
|
||||
uint32_t l = threadIdx.x;
|
||||
|
||||
if (l < loop_limit) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (; s < logn; s++) // TODO: this loop also can be unrolled
|
||||
{
|
||||
uint32_t ntw_i = task % chunks;
|
||||
@@ -448,7 +450,7 @@ namespace ntt {
|
||||
// Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements
|
||||
// Managed allocation allows host to read the elements (logn) without copying all (n) TFs back to host
|
||||
CHK_IF_RETURN(cudaMallocManaged(&domain.twiddles, (domain.max_size + 1) * sizeof(S)));
|
||||
CHK_IF_RETURN(generate_external_twiddles_generic(
|
||||
CHK_IF_RETURN(mxntt::generate_external_twiddles_generic(
|
||||
primitive_root, domain.twiddles, domain.internal_twiddles, domain.basic_twiddles, domain.max_log_size,
|
||||
ctx.stream));
|
||||
|
||||
@@ -458,7 +460,7 @@ namespace ntt {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&domain.fast_external_twiddles_inv, domain.max_size * sizeof(S) * 2, ctx.stream));
|
||||
|
||||
// fast-twiddles forward NTT
|
||||
CHK_IF_RETURN(generate_external_twiddles_fast_twiddles_mode(
|
||||
CHK_IF_RETURN(mxntt::generate_external_twiddles_fast_twiddles_mode(
|
||||
primitive_root, domain.fast_external_twiddles, domain.fast_internal_twiddles, domain.fast_basic_twiddles,
|
||||
domain.max_log_size, ctx.stream));
|
||||
|
||||
@@ -466,7 +468,7 @@ namespace ntt {
|
||||
S primitive_root_inv;
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(
|
||||
&primitive_root_inv, &domain.twiddles[domain.max_size - 1], sizeof(S), cudaMemcpyDeviceToHost, ctx.stream));
|
||||
CHK_IF_RETURN(generate_external_twiddles_fast_twiddles_mode(
|
||||
CHK_IF_RETURN(mxntt::generate_external_twiddles_fast_twiddles_mode(
|
||||
primitive_root_inv, domain.fast_external_twiddles_inv, domain.fast_internal_twiddles_inv,
|
||||
domain.fast_basic_twiddles_inv, domain.max_log_size, ctx.stream));
|
||||
}
|
||||
@@ -526,7 +528,7 @@ namespace ntt {
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
|
||||
static bool is_choosing_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
|
||||
{
|
||||
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
|
||||
if (!is_mixed_radix_alg_supported && config.columns_batch)
|
||||
@@ -668,7 +670,7 @@ namespace ntt {
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
|
||||
const bool is_radix2_algorithm = is_choosing_radix2_algorithm(logn, batch_size, config);
|
||||
if (is_radix2_algorithm) {
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
@@ -685,7 +687,7 @@ namespace ntt {
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
CHK_IF_RETURN(mxntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
@@ -748,7 +750,6 @@ namespace ntt {
|
||||
}
|
||||
|
||||
#if defined(ECNTT_DEFINED)
|
||||
|
||||
/**
|
||||
* Extern "C" version of [NTT](@ref NTT) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <stdint.h>
|
||||
#include "appUtils/ntt/ntt.cuh" // for enum Ordering
|
||||
|
||||
namespace ntt {
|
||||
namespace mxntt {
|
||||
|
||||
template <typename S>
|
||||
cudaError_t generate_external_twiddles_generic(
|
||||
@@ -38,10 +38,10 @@ namespace ntt {
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
ntt::Ordering ordering,
|
||||
S* arbitrary_coset,
|
||||
int coset_gen_index,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
} // namespace ntt
|
||||
} // namespace mxntt
|
||||
#endif //_NTT_IMPL_H
|
||||
@@ -12,8 +12,14 @@
|
||||
#include "ntt/ntt_impl.cuh"
|
||||
#include <memory>
|
||||
|
||||
#ifdef ECNTT_DEFINED
|
||||
typedef curve_config::scalar_t test_scalar;
|
||||
typedef curve_config::projective_t test_data;
|
||||
#else
|
||||
typedef curve_config::scalar_t test_scalar;
|
||||
typedef curve_config::scalar_t test_data;
|
||||
#endif
|
||||
|
||||
#include "kernel_ntt.cu"
|
||||
|
||||
void random_samples(test_data* res, uint32_t count)
|
||||
|
||||
@@ -51,116 +51,113 @@ public:
|
||||
S WI[7];
|
||||
S WE[8];
|
||||
|
||||
__device__ __forceinline__ void loadBasicTwiddles(S* basic_twiddles)
|
||||
DEVICE_INLINE void loadBasicTwiddles(S* basic_twiddles)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 3; i++) {
|
||||
WB[i] = basic_twiddles[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadBasicTwiddlesGeneric(S* basic_twiddles, bool inv)
|
||||
DEVICE_INLINE void loadBasicTwiddlesGeneric(S* basic_twiddles, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 3; i++) {
|
||||
WB[i] = basic_twiddles[inv ? i + 3 : i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles64(S* data, bool stride)
|
||||
DEVICE_INLINE void loadInternalTwiddles64(S* data, bool stride)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 7; i++) {
|
||||
WI[i] = data[((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles32(S* data, bool stride)
|
||||
DEVICE_INLINE void loadInternalTwiddles32(S* data, bool stride)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 7; i++) {
|
||||
WI[i] = data[2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddles16(S* data, bool stride)
|
||||
DEVICE_INLINE void loadInternalTwiddles16(S* data, bool stride)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 7; i++) {
|
||||
WI[i] = data[4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1)];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddlesGeneric64(S* data, bool stride, bool inv)
|
||||
DEVICE_INLINE void loadInternalTwiddlesGeneric64(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 7; i++) {
|
||||
uint32_t exp = ((stride ? (threadIdx.x >> 3) : (threadIdx.x)) & 0x7) * (i + 1);
|
||||
WI[i] = data[(inv && exp) ? 64 - exp : exp]; // if exp = 0 we also take exp and not 64-exp
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddlesGeneric32(S* data, bool stride, bool inv)
|
||||
DEVICE_INLINE void loadInternalTwiddlesGeneric32(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 7; i++) {
|
||||
uint32_t exp = 2 * ((stride ? (threadIdx.x >> 4) : (threadIdx.x)) & 0x3) * (i + 1);
|
||||
WI[i] = data[(inv && exp) ? 64 - exp : exp];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadInternalTwiddlesGeneric16(S* data, bool stride, bool inv)
|
||||
DEVICE_INLINE void loadInternalTwiddlesGeneric16(S* data, bool stride, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 7; i++) {
|
||||
uint32_t exp = 4 * ((stride ? (threadIdx.x >> 5) : (threadIdx.x)) & 0x1) * (i + 1);
|
||||
WI[i] = data[(inv && exp) ? 64 - exp : exp];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
DEVICE_INLINE void loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
WE[i] = data[8 * i * tw_order + (1 << tw_log_order + 6) - 1];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
DEVICE_INLINE void loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
WE[4 * j + i] = data[(8 * i + j) * tw_order + (1 << tw_log_order + 5) - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
DEVICE_INLINE void loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
WE[2 * j + i] = data[(8 * i + j) * tw_order + (1 << tw_log_order + 4) - 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric64(
|
||||
DEVICE_INLINE void loadExternalTwiddlesGeneric64(
|
||||
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
uint32_t exp = (s_meta.ntt_inp_id + 8 * i) * (s_meta.ntt_block_id & (tw_order - 1))
|
||||
<< (tw_log_size - tw_log_order - 6);
|
||||
@@ -168,12 +165,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric32(
|
||||
DEVICE_INLINE void loadExternalTwiddlesGeneric32(
|
||||
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
uint32_t exp = (s_meta.ntt_inp_id * 2 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1))
|
||||
<< (tw_log_size - tw_log_order - 5);
|
||||
@@ -182,12 +179,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadExternalTwiddlesGeneric16(
|
||||
DEVICE_INLINE void loadExternalTwiddlesGeneric16(
|
||||
S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta, uint32_t tw_log_size, bool inv)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
uint32_t exp = (s_meta.ntt_inp_id * 4 + 8 * i + j) * (s_meta.ntt_block_id & (tw_order - 1))
|
||||
<< (tw_log_size - tw_log_order - 4);
|
||||
@@ -196,7 +193,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
DEVICE_INLINE void
|
||||
loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
@@ -206,13 +203,13 @@ public:
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
X[i] = data[s_meta.th_stride * i * data_stride];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalDataColumnBatch(
|
||||
DEVICE_INLINE void loadGlobalDataColumnBatch(
|
||||
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
@@ -220,13 +217,13 @@ public:
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
X[i] = data[s_meta.th_stride * i * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
DEVICE_INLINE void
|
||||
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
@@ -236,13 +233,13 @@ public:
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
data[s_meta.th_stride * i * data_stride] = X[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalDataColumnBatch(
|
||||
DEVICE_INLINE void storeGlobalDataColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
@@ -250,13 +247,13 @@ public:
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
data[s_meta.th_stride * i * data_stride * batch_size] = X[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
DEVICE_INLINE void
|
||||
loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
@@ -266,16 +263,16 @@ public:
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
X[4 * j + i] = data[(8 * i + j) * data_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData32ColumnBatch(
|
||||
DEVICE_INLINE void loadGlobalData32ColumnBatch(
|
||||
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
@@ -283,16 +280,16 @@ public:
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
DEVICE_INLINE void
|
||||
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
@@ -302,16 +299,16 @@ public:
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 2;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
data[(8 * i + j) * data_stride] = X[4 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData32ColumnBatch(
|
||||
DEVICE_INLINE void storeGlobalData32ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
@@ -319,16 +316,16 @@ public:
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
DEVICE_INLINE void
|
||||
loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
@@ -338,16 +335,16 @@ public:
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
X[2 * j + i] = data[(8 * i + j) * data_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData16ColumnBatch(
|
||||
DEVICE_INLINE void loadGlobalData16ColumnBatch(
|
||||
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
@@ -355,16 +352,16 @@ public:
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
DEVICE_INLINE void
|
||||
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
@@ -374,16 +371,16 @@ public:
|
||||
data += s_meta.ntt_block_id * s_meta.ntt_block_size + s_meta.ntt_inp_id * 4;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
data[(8 * i + j) * data_stride] = X[2 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData16ColumnBatch(
|
||||
DEVICE_INLINE void storeGlobalData16ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
@@ -391,32 +388,32 @@ public:
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt4_2()
|
||||
DEVICE_INLINE void ntt4_2()
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 2; i++) {
|
||||
ntt4(X[4 * i], X[4 * i + 1], X[4 * i + 2], X[4 * i + 3]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt2_4()
|
||||
DEVICE_INLINE void ntt2_4()
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 4; i++) {
|
||||
ntt2(X[2 * i], X[2 * i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt2(E& X0, E& X1)
|
||||
DEVICE_INLINE void ntt2(E& X0, E& X1)
|
||||
{
|
||||
E T;
|
||||
|
||||
@@ -425,7 +422,7 @@ public:
|
||||
X0 = T;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt4(E& X0, E& X1, E& X2, E& X3)
|
||||
DEVICE_INLINE void ntt4(E& X0, E& X1, E& X2, E& X3)
|
||||
{
|
||||
E T;
|
||||
|
||||
@@ -443,7 +440,7 @@ public:
|
||||
}
|
||||
|
||||
// rbo version
|
||||
__device__ __forceinline__ void ntt4rbo(E& X0, E& X1, E& X2, E& X3)
|
||||
DEVICE_INLINE void ntt4rbo(E& X0, E& X1, E& X2, E& X3)
|
||||
{
|
||||
E T;
|
||||
|
||||
@@ -460,7 +457,7 @@ public:
|
||||
X3 = T - X3;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt8(E& X0, E& X1, E& X2, E& X3, E& X4, E& X5, E& X6, E& X7)
|
||||
DEVICE_INLINE void ntt8(E& X0, E& X1, E& X2, E& X3, E& X4, E& X5, E& X6, E& X7)
|
||||
{
|
||||
E T;
|
||||
|
||||
@@ -500,7 +497,7 @@ public:
|
||||
X4 = X4 - T;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt8win()
|
||||
DEVICE_INLINE void ntt8win()
|
||||
{
|
||||
E T;
|
||||
|
||||
@@ -542,12 +539,12 @@ public:
|
||||
X[4] = X[4] - T;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData64Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3;
|
||||
uint32_t column_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 64 + i * 8 + column_id] = X[i];
|
||||
@@ -557,12 +554,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData64Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData64Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x7 : threadIdx.x >> 3;
|
||||
uint32_t row_id = stride ? threadIdx.x >> 3 : threadIdx.x & 0x7;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 64 + row_id * 8 + i] = X[i];
|
||||
@@ -572,12 +569,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData32Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t column_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + i * 4 + column_id] = X[i];
|
||||
@@ -587,12 +584,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData32Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t row_id = stride ? threadIdx.x >> 4 : threadIdx.x & 0x3;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + row_id * 8 + i] = X[i];
|
||||
@@ -602,14 +599,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Columns4_2(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData32Columns4_2(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t column_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + i * 8 + column_id + j] = X[4 * j + i];
|
||||
@@ -620,14 +617,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData32Rows4_2(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData32Rows4_2(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0xf : threadIdx.x >> 2;
|
||||
uint32_t row_id = (stride ? threadIdx.x >> 4 : threadIdx.x & 0x3) * 2;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 32 + row_id * 4 + 4 * j + i] = X[4 * j + i];
|
||||
@@ -638,12 +635,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData16Columns8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t column_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + i * 2 + column_id] = X[i];
|
||||
@@ -653,12 +650,12 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData16Rows8(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t row_id = stride ? threadIdx.x >> 5 : threadIdx.x & 0x1;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + row_id * 8 + i] = X[i];
|
||||
@@ -668,14 +665,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Columns2_4(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData16Columns2_4(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t column_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + i * 8 + column_id + j] = X[2 * j + i];
|
||||
@@ -686,14 +683,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void SharedData16Rows2_4(E* shmem, bool store, bool high_bits, bool stride)
|
||||
DEVICE_INLINE void SharedData16Rows2_4(E* shmem, bool store, bool high_bits, bool stride)
|
||||
{
|
||||
uint32_t ntt_id = stride ? threadIdx.x & 0x1f : threadIdx.x >> 1;
|
||||
uint32_t row_id = (stride ? threadIdx.x >> 5 : threadIdx.x & 0x1) * 4;
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
if (store) {
|
||||
shmem[ntt_id * 16 + row_id * 2 + 2 * j + i] = X[2 * j + i];
|
||||
@@ -704,17 +701,17 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void twiddlesInternal()
|
||||
DEVICE_INLINE void twiddlesInternal()
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 1; i < 8; i++) {
|
||||
X[i] = X[i] * WI[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void twiddlesExternal()
|
||||
DEVICE_INLINE void twiddlesExternal()
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < 8; i++) {
|
||||
X[i] = X[i] * WE[i];
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace poseidon {
|
||||
}
|
||||
|
||||
template <typename S>
|
||||
__device__ __forceinline__ S sbox_alpha_five(S element)
|
||||
DEVICE_INLINE S sbox_alpha_five(S element)
|
||||
{
|
||||
S result = S::sqr(element);
|
||||
result = S::sqr(result);
|
||||
@@ -46,7 +46,7 @@ namespace poseidon {
|
||||
__syncthreads();
|
||||
|
||||
typename S::Wide element_wide = S::mul_wide(shared_states[vec_number * T], matrix[element_number]);
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 1; i < T; i++) {
|
||||
element_wide = element_wide + S::mul_wide(shared_states[vec_number * T + i], matrix[i * T + element_number]);
|
||||
}
|
||||
@@ -117,14 +117,14 @@ namespace poseidon {
|
||||
|
||||
typename S::Wide state_0_wide = S::mul_wide(element, sparse_matrix[0]);
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 1; i < T; i++) {
|
||||
state_0_wide = state_0_wide + S::mul_wide(state[i], sparse_matrix[i]);
|
||||
}
|
||||
|
||||
state[0] = S::reduce(state_0_wide);
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 1; i < T; i++) {
|
||||
state[i] = state[i] + (element * sparse_matrix[T + i - 1]);
|
||||
}
|
||||
@@ -138,7 +138,7 @@ namespace poseidon {
|
||||
if (idx >= number_of_states) { return; }
|
||||
|
||||
S state[T];
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < T; i++) {
|
||||
state[i] = states[idx * T + i];
|
||||
}
|
||||
@@ -148,7 +148,7 @@ namespace poseidon {
|
||||
rc_offset++;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (int i = 0; i < T; i++) {
|
||||
states[idx * T + i] = state[i];
|
||||
}
|
||||
|
||||
11
icicle/common.cuh
Normal file
11
icicle/common.cuh
Normal file
@@ -0,0 +1,11 @@
|
||||
#if defined(DEVMODE) || defined(DEBUG)
|
||||
#define INLINE_MACRO
|
||||
#define UNROLL
|
||||
#else
|
||||
#define INLINE_MACRO __forceinline__
|
||||
#define UNROLL #pragma unroll
|
||||
#endif
|
||||
|
||||
#define HOST_INLINE __host__ INLINE_MACRO
|
||||
#define DEVICE_INLINE __device__ INLINE_MACRO
|
||||
#define HOST_DEVICE_INLINE __host__ __device__ INLINE_MACRO
|
||||
@@ -1,10 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "field.cuh"
|
||||
|
||||
#define HOST_INLINE __host__ __forceinline__
|
||||
#define DEVICE_INLINE __device__ __forceinline__
|
||||
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
|
||||
#include "common.cuh"
|
||||
|
||||
template <typename CONFIG>
|
||||
class ExtensionField
|
||||
|
||||
@@ -28,10 +28,6 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#define HOST_INLINE __host__ __forceinline__
|
||||
#define DEVICE_INLINE __device__ __forceinline__
|
||||
#define HOST_DEVICE_INLINE __host__ __device__ __forceinline__
|
||||
|
||||
template <class CONFIG>
|
||||
class Field
|
||||
{
|
||||
@@ -130,7 +126,7 @@ public:
|
||||
{
|
||||
Field out{};
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (unsigned i = 0; i < TLC; i++)
|
||||
out.limbs_storage.limbs[i] = xs.limbs_storage.limbs[i];
|
||||
@@ -141,7 +137,7 @@ public:
|
||||
{
|
||||
Field out{};
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (unsigned i = 0; i < TLC; i++)
|
||||
out.limbs_storage.limbs[i] = xs.limbs_storage.limbs[i + TLC];
|
||||
@@ -152,7 +148,7 @@ public:
|
||||
{
|
||||
Field out{};
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (unsigned i = 0; i < TLC; i++) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
@@ -244,7 +240,7 @@ public:
|
||||
}
|
||||
|
||||
template <bool SUBTRACT, bool CARRY_OUT>
|
||||
static constexpr __device__ __forceinline__ uint32_t
|
||||
static constexpr DEVICE_INLINE uint32_t
|
||||
add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r, size_t n = (TLC >> 1))
|
||||
{
|
||||
r[0] = SUBTRACT ? ptx::sub_cc(x[0], y[0]) : ptx::add_cc(x[0], y[0]);
|
||||
@@ -327,7 +323,7 @@ public:
|
||||
|
||||
static DEVICE_INLINE void mul_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (size_t i = 0; i < n; i += 2) {
|
||||
acc[i] = ptx::mul_lo(a[i], bi);
|
||||
acc[i + 1] = ptx::mul_hi(a[i], bi);
|
||||
@@ -336,7 +332,7 @@ public:
|
||||
|
||||
static DEVICE_INLINE void mul_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC, size_t start_i = 0)
|
||||
{
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (size_t i = start_i; i < n; i += 2) {
|
||||
acc[i] = ptx::mul_lo(a[i], bi);
|
||||
acc[i + 1] = ptx::mul_hi(a[i], bi);
|
||||
@@ -344,14 +340,14 @@ public:
|
||||
}
|
||||
|
||||
template <bool CARRY_IN = false>
|
||||
static __device__ __forceinline__ void
|
||||
static DEVICE_INLINE void
|
||||
cmad_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC, uint32_t optional_carry = 0)
|
||||
{
|
||||
if (CARRY_IN) ptx::add_cc(UINT32_MAX, optional_carry);
|
||||
acc[0] = CARRY_IN ? ptx::madc_lo_cc(a[0], bi, acc[0]) : ptx::mad_lo_cc(a[0], bi, acc[0]);
|
||||
acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]);
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (size_t i = 2; i < n; i += 2) {
|
||||
acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]);
|
||||
acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]);
|
||||
@@ -359,7 +355,7 @@ public:
|
||||
}
|
||||
|
||||
template <bool EVEN_PHASE>
|
||||
static __device__ __forceinline__ void cmad_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
static DEVICE_INLINE void cmad_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
{
|
||||
if (EVEN_PHASE) {
|
||||
acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]);
|
||||
@@ -368,14 +364,14 @@ public:
|
||||
acc[1] = ptx::mad_hi_cc(a[0], bi, acc[1]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (size_t i = 2; i < n; i += 2) {
|
||||
acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]);
|
||||
acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void cmad_n_lsb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
static DEVICE_INLINE void cmad_n_lsb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
{
|
||||
if (n > 1)
|
||||
acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]);
|
||||
@@ -383,7 +379,7 @@ public:
|
||||
acc[0] = ptx::mad_lo(a[0], bi, acc[0]);
|
||||
|
||||
size_t i;
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (i = 1; i < n - 1; i += 2) {
|
||||
acc[i] = ptx::madc_hi_cc(a[i - 1], bi, acc[i]);
|
||||
if (i == n - 2)
|
||||
@@ -395,7 +391,7 @@ public:
|
||||
}
|
||||
|
||||
template <bool CARRY_OUT = false, bool CARRY_IN = false>
|
||||
static __device__ __forceinline__ uint32_t mad_row(
|
||||
static DEVICE_INLINE uint32_t mad_row(
|
||||
uint32_t* odd,
|
||||
uint32_t* even,
|
||||
const uint32_t* a,
|
||||
@@ -420,8 +416,7 @@ public:
|
||||
}
|
||||
|
||||
template <bool EVEN_PHASE>
|
||||
static __device__ __forceinline__ void
|
||||
mad_row_msb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
static DEVICE_INLINE void mad_row_msb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
{
|
||||
cmad_n_msb<!EVEN_PHASE>(odd, EVEN_PHASE ? a : (a + 1), bi, n - 2);
|
||||
odd[EVEN_PHASE ? (n - 1) : (n - 2)] = ptx::madc_lo_cc(a[n - 1], bi, 0);
|
||||
@@ -430,8 +425,7 @@ public:
|
||||
odd[EVEN_PHASE ? n : (n - 1)] = ptx::addc(odd[EVEN_PHASE ? n : (n - 1)], 0);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void
|
||||
mad_row_lsb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
static DEVICE_INLINE void mad_row_lsb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = TLC)
|
||||
{
|
||||
// bi here is constant so we can do a compile-time check for zero (which does happen once for bls12-381 scalar field
|
||||
// modulus)
|
||||
@@ -442,12 +436,12 @@ public:
|
||||
return;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ uint32_t
|
||||
static DEVICE_INLINE uint32_t
|
||||
mul_n_and_add(uint32_t* acc, const uint32_t* a, uint32_t bi, uint32_t* extra, size_t n = (TLC >> 1))
|
||||
{
|
||||
acc[0] = ptx::mad_lo_cc(a[0], bi, extra[0]);
|
||||
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (size_t i = 1; i < n - 1; i += 2) {
|
||||
acc[i] = ptx::madc_hi_cc(a[i - 1], bi, extra[i]);
|
||||
acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, extra[i + 1]);
|
||||
@@ -470,8 +464,7 @@ public:
|
||||
* \cdot b_0}{2^{32}}} + \dots + \floor{\frac{a_0 \cdot b_{TLC - 2}}{2^{32}}}) \leq 2^{64} + 2\cdot 2^{96} + \dots +
|
||||
* (TLC - 2) \cdot 2^{32(TLC - 1)} + (TLC - 1) \cdot 2^{32(TLC - 1)} \leq 2(TLC - 1) \cdot 2^{32(TLC - 1)}\f$.
|
||||
*/
|
||||
static __device__ __forceinline__ void
|
||||
multiply_msb_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs)
|
||||
static DEVICE_INLINE void multiply_msb_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs)
|
||||
{
|
||||
const uint32_t* a = as.limbs;
|
||||
const uint32_t* b = bs.limbs;
|
||||
@@ -482,7 +475,7 @@ public:
|
||||
odd[TLC - 2] = ptx::mul_lo(a[TLC - 1], b[0]);
|
||||
odd[TLC - 1] = ptx::mul_hi(a[TLC - 1], b[0]);
|
||||
size_t i;
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (i = 2; i < TLC - 1; i += 2) {
|
||||
mad_row_msb<true>(&even[TLC - 2], &odd[TLC - 2], &a[TLC - i - 1], b[i - 1], i + 1);
|
||||
mad_row_msb<false>(&odd[TLC - 2], &even[TLC - 2], &a[TLC - i - 2], b[i], i + 2);
|
||||
@@ -503,7 +496,7 @@ public:
|
||||
* is excluded if \f$ i + j > TLC - 1 \f$ and only the lower half is included if \f$ i + j = TLC - 1 \f$. All other
|
||||
* limb products are included.
|
||||
*/
|
||||
static __device__ __forceinline__ void
|
||||
static DEVICE_INLINE void
|
||||
multiply_and_add_lsb_raw_device(const ff_storage& as, const ff_storage& bs, ff_storage& cs, ff_storage& rs)
|
||||
{
|
||||
const uint32_t* a = as.limbs;
|
||||
@@ -523,7 +516,7 @@ public:
|
||||
mul_n(odd, a + 1, b[0], TLC - 1);
|
||||
}
|
||||
mad_row_lsb(&even[2], &odd[0], a, b[1], TLC - 1);
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (i = 2; i < TLC - 1; i += 2) {
|
||||
mad_row_lsb(&odd[i], &even[i], a, b[i], TLC - i);
|
||||
mad_row_lsb(&even[i + 2], &odd[i], a, b[i + 1], TLC - i - 1);
|
||||
@@ -545,7 +538,7 @@ public:
|
||||
* that the top bit of \f$ a_{hi} \f$ and \f$ b_{hi} \f$ are unset. This ensures correctness by allowing to keep the
|
||||
* result inside TLC limbs and ignore the carries from the highest limb.
|
||||
*/
|
||||
static __device__ __forceinline__ void
|
||||
static DEVICE_INLINE void
|
||||
multiply_and_add_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even, uint32_t* in1, uint32_t* in2)
|
||||
{
|
||||
__align__(16) uint32_t odd[TLC - 2];
|
||||
@@ -553,7 +546,7 @@ public:
|
||||
uint32_t carry = mul_n_and_add(odd, a + 1, b[0], &in2[1]);
|
||||
|
||||
size_t i;
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (i = 2; i < ((TLC >> 1) - 1); i += 2) {
|
||||
carry = mad_row<true, false>(
|
||||
&even[i], &odd[i - 2], a, b[i - 1], TLC >> 1, in1[(TLC >> 1) + i - 2], in1[(TLC >> 1) + i - 1], carry);
|
||||
@@ -574,7 +567,7 @@ public:
|
||||
* This method multiplies `a` and `b` and writes the result into `even`. It assumes that `a` and `b` are TLC/2 limbs
|
||||
* long. The usual schoolbook algorithm is used.
|
||||
*/
|
||||
static __device__ __forceinline__ void multiply_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even)
|
||||
static DEVICE_INLINE void multiply_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even)
|
||||
{
|
||||
__align__(16) uint32_t odd[TLC - 2];
|
||||
mul_n(even, a, b[0], TLC >> 1);
|
||||
@@ -582,7 +575,7 @@ public:
|
||||
mad_row(&even[2], &odd[0], a, b[1], TLC >> 1);
|
||||
|
||||
size_t i;
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (i = 2; i < ((TLC >> 1) - 1); i += 2) {
|
||||
mad_row(&odd[i], &even[i], a, b[i], TLC >> 1);
|
||||
mad_row(&even[i + 2], &odd[i], a, b[i + 1], TLC >> 1);
|
||||
@@ -817,7 +810,7 @@ public:
|
||||
const uint32_t* x = xs.limbs_storage.limbs;
|
||||
const uint32_t* y = ys.limbs_storage.limbs;
|
||||
uint32_t limbs_or = x[0] ^ y[0];
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
for (unsigned i = 1; i < TLC; i++)
|
||||
limbs_or |= x[i] ^ y[i];
|
||||
return limbs_or == 0;
|
||||
@@ -836,7 +829,7 @@ public:
|
||||
Field mul = multiplier;
|
||||
static bool is_u32 = true;
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (unsigned i = 1; i < TLC; i++)
|
||||
is_u32 &= (mul.limbs_storage.limbs[i] == 0);
|
||||
@@ -852,7 +845,7 @@ public:
|
||||
T temp = xs;
|
||||
bool is_zero = true;
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (unsigned i = 0; i < 32; i++) {
|
||||
if (multiplier & (1 << i)) {
|
||||
@@ -902,7 +895,7 @@ public:
|
||||
Field rs = {};
|
||||
uint32_t* r = rs.limbs_storage.limbs;
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (unsigned i = 0; i < TLC - 1; i++) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
|
||||
@@ -165,7 +165,7 @@ public:
|
||||
{
|
||||
Projective res = zero();
|
||||
#ifdef __CUDA_ARCH__
|
||||
#pragma unroll
|
||||
UNROLL
|
||||
#endif
|
||||
for (int i = 0; i < SCALAR_FF::NBITS; i++) {
|
||||
if (i > 0) { res = res + res; }
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "common.cuh"
|
||||
namespace host_math {
|
||||
|
||||
// return x + y with uint32_t operands
|
||||
@@ -67,9 +67,9 @@ namespace host_math {
|
||||
struct carry_chain {
|
||||
unsigned index;
|
||||
|
||||
constexpr __host__ __forceinline__ carry_chain() : index(0) {}
|
||||
constexpr HOST_INLINE carry_chain() : index(0) {}
|
||||
|
||||
__host__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y, uint32_t& carry)
|
||||
HOST_INLINE uint32_t add(const uint32_t x, const uint32_t y, uint32_t& carry)
|
||||
{
|
||||
index++;
|
||||
if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT)
|
||||
@@ -82,7 +82,7 @@ namespace host_math {
|
||||
return host_math::addc(x, y, carry);
|
||||
}
|
||||
|
||||
__host__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y, uint32_t& carry)
|
||||
HOST_INLINE uint32_t sub(const uint32_t x, const uint32_t y, uint32_t& carry)
|
||||
{
|
||||
index++;
|
||||
if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT)
|
||||
|
||||
@@ -23,5 +23,5 @@ rayon = "1.8.1"
|
||||
default = []
|
||||
arkworks = ["ark-ff", "ark-ec", "ark-poly", "ark-std"]
|
||||
g2 = []
|
||||
# TODO: impl EC NTT
|
||||
ec_ntt = []
|
||||
devmode = []
|
||||
|
||||
@@ -29,4 +29,6 @@ default = []
|
||||
bw6-761 = []
|
||||
bw6-761-g2 = ["bw6-761"]
|
||||
g2 = ["icicle-core/g2"]
|
||||
ec_ntt = ["icicle-core/ec_ntt"]
|
||||
devmode = ["icicle-core/devmode"]
|
||||
arkworks = ["ark-bls12-377", "icicle-core/arkworks"]
|
||||
|
||||
@@ -15,6 +15,12 @@ fn main() {
|
||||
#[cfg(feature = "g2")]
|
||||
config.define("G2_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "ec_ntt")]
|
||||
config.define("ECNTT_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "devmode")]
|
||||
config.define("DEVMODE", "ON");
|
||||
|
||||
// Build
|
||||
let out_dir = config
|
||||
.build_target("icicle")
|
||||
@@ -35,6 +41,12 @@ fn main() {
|
||||
#[cfg(feature = "bw6-761-g2")]
|
||||
config.define("G2_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "ec_ntt")]
|
||||
config.define("ECNTT_DEFINED", "OFF");
|
||||
|
||||
#[cfg(feature = "devmode")]
|
||||
config.define("DEVMODE", "ON");
|
||||
|
||||
// Build
|
||||
let out_dir = config
|
||||
.build_target("icicle")
|
||||
|
||||
@@ -27,4 +27,6 @@ icicle-bls12-381 = { path = ".", features = ["arkworks"] }
|
||||
[features]
|
||||
default = []
|
||||
g2 = ["icicle-core/g2"]
|
||||
ec_ntt = ["icicle-core/ec_ntt"]
|
||||
devmode = ["icicle-core/devmode"]
|
||||
arkworks = ["ark-bls12-381", "icicle-core/arkworks"]
|
||||
|
||||
@@ -15,6 +15,12 @@ fn main() {
|
||||
#[cfg(feature = "g2")]
|
||||
config.define("G2_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "ec_ntt")]
|
||||
config.define("ECNTT_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "devmode")]
|
||||
config.define("DEVMODE", "ON");
|
||||
|
||||
// Build
|
||||
let out_dir = config
|
||||
.build_target("icicle")
|
||||
|
||||
@@ -27,4 +27,6 @@ icicle-bn254 = { path = ".", features = ["arkworks"] }
|
||||
[features]
|
||||
default = []
|
||||
g2 = ["icicle-core/g2"]
|
||||
ec_ntt = ["icicle-core/ec_ntt"]
|
||||
devmode = ["icicle-core/devmode"]
|
||||
arkworks = ["ark-bn254", "icicle-core/arkworks"]
|
||||
|
||||
@@ -15,6 +15,12 @@ fn main() {
|
||||
#[cfg(feature = "g2")]
|
||||
config.define("G2_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "ec_ntt")]
|
||||
config.define("ECNTT_DEFINED", "ON");
|
||||
|
||||
#[cfg(feature = "devmode")]
|
||||
config.define("DEVMODE", "ON");
|
||||
|
||||
// Build
|
||||
let out_dir = config
|
||||
.build_target("icicle")
|
||||
|
||||
@@ -28,4 +28,5 @@ icicle-bw6-761 = { path = ".", features = ["arkworks"] }
|
||||
[features]
|
||||
default = []
|
||||
g2 = ["icicle-bls12-377/bw6-761-g2"]
|
||||
devmode = ["icicle-core/devmode"]
|
||||
arkworks = ["ark-bw6-761", "icicle-core/arkworks", "icicle-bls12-377/arkworks"]
|
||||
|
||||
@@ -25,4 +25,6 @@ icicle-grumpkin = { path = ".", features = ["arkworks"] }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
ec_ntt = ["icicle-core/ec_ntt"]
|
||||
devmode = ["icicle-core/devmode"]
|
||||
arkworks = ["ark-grumpkin-test", "icicle-core/arkworks"]
|
||||
|
||||
@@ -4,12 +4,20 @@ fn main() {
|
||||
println!("cargo:rerun-if-env-changed=CXXFLAGS");
|
||||
println!("cargo:rerun-if-changed=../../../../icicle");
|
||||
|
||||
let out_dir = Config::new("../../../../icicle")
|
||||
.define("BUILD_TESTS", "OFF") //TODO: feature
|
||||
.define("CURVE", "grumpkin")
|
||||
.define("CMAKE_BUILD_TYPE", "Release")
|
||||
.build_target("icicle")
|
||||
.build();
|
||||
// Base config
|
||||
let mut config = Config::new("../../../../icicle");
|
||||
config
|
||||
.define("BUILD_TESTS", "OFF")
|
||||
.define("CURVE", "grumpkin")
|
||||
.define("CMAKE_BUILD_TYPE", "Release");
|
||||
|
||||
#[cfg(feature = "devmode")]
|
||||
config.define("DEVMODE", "ON");
|
||||
|
||||
// Build
|
||||
let out_dir = config
|
||||
.build_target("icicle")
|
||||
.build();
|
||||
|
||||
println!("cargo:rustc-link-search={}/build", out_dir.display());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user