diff --git a/.github/workflows/golang.yml b/.github/workflows/golang.yml index acff2b3e..68bd5deb 100644 --- a/.github/workflows/golang.yml +++ b/.github/workflows/golang.yml @@ -50,7 +50,7 @@ jobs: - name: Build working-directory: ./wrappers/golang if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true' - run: ./build.sh ${{ matrix.curve }} ON # builds a single curve with G2 enabled + run: ./build.sh ${{ matrix.curve }} ON ON # builds a single curve with G2 and ECNTT enabled - name: Upload ICICLE lib artifacts uses: actions/upload-artifact@v4 if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true' diff --git a/icicle/CMakeLists.txt b/icicle/CMakeLists.txt index c824d557..c30956d3 100644 --- a/icicle/CMakeLists.txt +++ b/icicle/CMakeLists.txt @@ -96,6 +96,10 @@ if (G2_DEFINED STREQUAL "ON") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DG2_DEFINED=ON") endif () +if (ECNTT_DEFINED STREQUAL "ON") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON") +endif () + option(BUILD_TESTS "Build tests" OFF) if (NOT BUILD_TESTS) @@ -110,6 +114,9 @@ if (NOT BUILD_TESTS) if (NOT CURVE IN_LIST SUPPORTED_CURVES_WITHOUT_NTT) list(APPEND ICICLE_SOURCES appUtils/ntt/ntt.cu) list(APPEND ICICLE_SOURCES appUtils/ntt/kernel_ntt.cu) + if(ECNTT_DEFINED STREQUAL "ON") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON") + endif() endif() add_library( diff --git a/icicle/appUtils/ntt/ntt.cu b/icicle/appUtils/ntt/ntt.cu index 871185e8..be10309e 100644 --- a/icicle/appUtils/ntt/ntt.cu +++ b/icicle/appUtils/ntt/ntt.cu @@ -644,29 +644,34 @@ namespace ntt { h_coset.clear(); } - const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config); const bool is_inverse = dir == NTTDir::kInverse; - if (is_radix2_algorithm) { + if constexpr (std::is_same_v) { CHK_IF_RETURN(ntt::radix2_ntt( d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset, coset_index, stream)); } else { - const bool is_on_coset = (coset_index != 0) || coset; - const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset; - S* twiddles = is_fast_twiddles_enabled - ? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles) - : domain.twiddles; - S* internal_twiddles = is_fast_twiddles_enabled - ? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles) - : domain.internal_twiddles; - S* basic_twiddles = is_fast_twiddles_enabled - ? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles) - : domain.basic_twiddles; - - CHK_IF_RETURN(ntt::mixed_radix_ntt( - d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size, - config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream)); + const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config); + if (is_radix2_algorithm) { + CHK_IF_RETURN(ntt::radix2_ntt( + d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset, + coset_index, stream)); + } else { + const bool is_on_coset = (coset_index != 0) || coset; + const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset; + S* twiddles = is_fast_twiddles_enabled + ? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles) + : domain.twiddles; + S* internal_twiddles = is_fast_twiddles_enabled + ? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles) + : domain.internal_twiddles; + S* basic_twiddles = is_fast_twiddles_enabled + ? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles) + : domain.basic_twiddles; + CHK_IF_RETURN(ntt::mixed_radix_ntt( + d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size, + config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream)); + } } if (!are_outputs_on_device) diff --git a/wrappers/golang/build.sh b/wrappers/golang/build.sh index 0b51417b..68d2a3c1 100755 --- a/wrappers/golang/build.sh +++ b/wrappers/golang/build.sh @@ -1,12 +1,18 @@ #!/bin/bash G2_DEFINED=OFF +ECNTT_DEFINED=OFF -if [[ $2 ]] +if [[ $2 == "ON" ]] then G2_DEFINED=ON fi +if [[ $3 ]] +then + ECNTT_DEFINED=ON +fi + BUILD_DIR=$(realpath "$PWD/../../icicle/build") SUPPORTED_CURVES=("bn254" "bls12_377" "bls12_381" "bw6_761") @@ -22,6 +28,6 @@ mkdir -p build for CURVE in "${BUILD_CURVES[@]}" do - cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build + cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DECNTT_DEFINED=$ECNTT_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build cmake --build build -j8 done \ No newline at end of file diff --git a/wrappers/golang/core/ntt.go b/wrappers/golang/core/ntt.go index b798f730..6f77db32 100644 --- a/wrappers/golang/core/ntt.go +++ b/wrappers/golang/core/ntt.go @@ -10,26 +10,26 @@ type NTTDir int8 const ( KForward NTTDir = iota - KInverse NTTDir = 1 + KInverse ) type Ordering uint32 const ( KNN Ordering = iota - KNR Ordering = 1 - KRN Ordering = 2 - KRR Ordering = 3 - KNM Ordering = 4 - KMN Ordering = 5 + KNR + KRN + KRR + KNM + KMN ) type NttAlgorithm uint32 const ( - Auto NttAlgorithm = iota - Radix2 NttAlgorithm = 1 - MixedRadix NttAlgorithm = 2 + Auto NttAlgorithm = iota + Radix2 + MixedRadix ) type NTTConfig[T any] struct { @@ -47,10 +47,9 @@ type NTTConfig[T any] struct { areOutputsOnDevice bool /// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize /// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread. - IsAsync bool - /// Explicitly select the NTT algorithm. - /// Default value: Auto (the implementation selects radix-2 or mixed-radix algorithm based on heuristics). - NttAlgorithm NttAlgorithm + IsAsync bool + NttAlgorithm NttAlgorithm /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation + selects radix-2 or mixed-radix algorithm based on heuristics). */ } func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] { @@ -64,7 +63,7 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] { false, // areInputsOnDevice false, // areOutputsOnDevice false, // IsAsync - Auto, // NttAlgorithm + Auto, } } diff --git a/wrappers/golang/curves/bls12377/include/ntt.h b/wrappers/golang/curves/bls12377/include/ntt.h index 3ff70a7c..23c262e8 100644 --- a/wrappers/golang/curves/bls12377/include/ntt.h +++ b/wrappers/golang/curves/bls12377/include/ntt.h @@ -10,10 +10,11 @@ extern "C" { #endif cudaError_t bls12_377NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output); +cudaError_t bls12_377ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output); cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles); #ifdef __cplusplus } #endif -#endif +#endif \ No newline at end of file diff --git a/wrappers/golang/curves/bls12377/ntt.go b/wrappers/golang/curves/bls12377/ntt.go index 7d8f58d1..4e3d19c9 100644 --- a/wrappers/golang/curves/bls12377/ntt.go +++ b/wrappers/golang/curves/bls12377/ntt.go @@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo return core.FromCudaError(err) } +func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError { + core.NttCheck[T](points, cfg, results) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0]) + } + cPoints := (*C.projective_t)(pointsPointer) + cSize := (C.int)(points.Len() / int(cfg.BatchSize)) + cDir := (C.int)(dir) + cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg)) + + var resultsPointer unsafe.Pointer + if results.IsOnDevice() { + resultsPointer = results.(core.DeviceSlice).AsPointer() + } else { + resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) + } + cResults := (*C.projective_t)(resultsPointer) + + __ret := C.bls12_377ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults) + err := (cr.CudaError)(__ret) + return core.FromCudaError(err) +} + func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError { cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer())) cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) diff --git a/wrappers/golang/curves/bls12377/ntt_test.go b/wrappers/golang/curves/bls12377/ntt_test.go index b857e9ce..969eb5e7 100644 --- a/wrappers/golang/curves/bls12377/ntt_test.go +++ b/wrappers/golang/curves/bls12377/ntt_test.go @@ -98,6 +98,26 @@ func TestNtt(t *testing.T) { } } +func TestECNtt(t *testing.T) { + cfg := GetDefaultNttConfig() + initDomain(largestTestSize, cfg) + points := GenerateProjectivePoints(1 << largestTestSize) + + for _, size := range []int{4, 5, 6, 7, 8} { + for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} { + testSize := 1 << size + + pointsCopy := core.HostSliceFromElements[Projective](points[:testSize]) + cfg.Ordering = v + cfg.NttAlgorithm = core.Radix2 + + output := make(core.HostSlice[Projective], testSize) + e := ECNtt(pointsCopy, core.KForward, &cfg, output) + assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed") + } + } +} + func TestNttDeviceAsync(t *testing.T) { cfg := GetDefaultNttConfig() scalars := GenerateScalars(1 << largestTestSize) diff --git a/wrappers/golang/curves/bls12381/include/ntt.h b/wrappers/golang/curves/bls12381/include/ntt.h index 12113666..5017cd45 100644 --- a/wrappers/golang/curves/bls12381/include/ntt.h +++ b/wrappers/golang/curves/bls12381/include/ntt.h @@ -10,10 +10,11 @@ extern "C" { #endif cudaError_t bls12_381NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output); +cudaError_t bls12_381ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output); cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles); #ifdef __cplusplus } #endif -#endif +#endif \ No newline at end of file diff --git a/wrappers/golang/curves/bls12381/ntt.go b/wrappers/golang/curves/bls12381/ntt.go index 320603eb..04e1b9c8 100644 --- a/wrappers/golang/curves/bls12381/ntt.go +++ b/wrappers/golang/curves/bls12381/ntt.go @@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo return core.FromCudaError(err) } +func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError { + core.NttCheck[T](points, cfg, results) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0]) + } + cPoints := (*C.projective_t)(pointsPointer) + cSize := (C.int)(points.Len() / int(cfg.BatchSize)) + cDir := (C.int)(dir) + cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg)) + + var resultsPointer unsafe.Pointer + if results.IsOnDevice() { + resultsPointer = results.(core.DeviceSlice).AsPointer() + } else { + resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) + } + cResults := (*C.projective_t)(resultsPointer) + + __ret := C.bls12_381ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults) + err := (cr.CudaError)(__ret) + return core.FromCudaError(err) +} + func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError { cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer())) cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) diff --git a/wrappers/golang/curves/bls12381/ntt_test.go b/wrappers/golang/curves/bls12381/ntt_test.go index e3fc1d65..f091c0fd 100644 --- a/wrappers/golang/curves/bls12381/ntt_test.go +++ b/wrappers/golang/curves/bls12381/ntt_test.go @@ -98,6 +98,26 @@ func TestNtt(t *testing.T) { } } +func TestECNtt(t *testing.T) { + cfg := GetDefaultNttConfig() + initDomain(largestTestSize, cfg) + points := GenerateProjectivePoints(1 << largestTestSize) + + for _, size := range []int{4, 5, 6, 7, 8} { + for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} { + testSize := 1 << size + + pointsCopy := core.HostSliceFromElements[Projective](points[:testSize]) + cfg.Ordering = v + cfg.NttAlgorithm = core.Radix2 + + output := make(core.HostSlice[Projective], testSize) + e := ECNtt(pointsCopy, core.KForward, &cfg, output) + assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed") + } + } +} + func TestNttDeviceAsync(t *testing.T) { cfg := GetDefaultNttConfig() scalars := GenerateScalars(1 << largestTestSize) diff --git a/wrappers/golang/curves/bn254/include/ntt.h b/wrappers/golang/curves/bn254/include/ntt.h index 3031727d..00eb5877 100644 --- a/wrappers/golang/curves/bn254/include/ntt.h +++ b/wrappers/golang/curves/bn254/include/ntt.h @@ -10,10 +10,11 @@ extern "C" { #endif cudaError_t bn254NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output); +cudaError_t bn254ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output); cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles); #ifdef __cplusplus } #endif -#endif +#endif \ No newline at end of file diff --git a/wrappers/golang/curves/bn254/ntt.go b/wrappers/golang/curves/bn254/ntt.go index 360d3f63..fc7f7590 100644 --- a/wrappers/golang/curves/bn254/ntt.go +++ b/wrappers/golang/curves/bn254/ntt.go @@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo return core.FromCudaError(err) } +func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError { + core.NttCheck[T](points, cfg, results) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0]) + } + cPoints := (*C.projective_t)(pointsPointer) + cSize := (C.int)(points.Len() / int(cfg.BatchSize)) + cDir := (C.int)(dir) + cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg)) + + var resultsPointer unsafe.Pointer + if results.IsOnDevice() { + resultsPointer = results.(core.DeviceSlice).AsPointer() + } else { + resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) + } + cResults := (*C.projective_t)(resultsPointer) + + __ret := C.bn254ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults) + err := (cr.CudaError)(__ret) + return core.FromCudaError(err) +} + func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError { cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer())) cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) diff --git a/wrappers/golang/curves/bn254/ntt_test.go b/wrappers/golang/curves/bn254/ntt_test.go index fd5c8c38..da6aaca2 100644 --- a/wrappers/golang/curves/bn254/ntt_test.go +++ b/wrappers/golang/curves/bn254/ntt_test.go @@ -98,6 +98,26 @@ func TestNtt(t *testing.T) { } } +func TestECNtt(t *testing.T) { + cfg := GetDefaultNttConfig() + initDomain(largestTestSize, cfg) + points := GenerateProjectivePoints(1 << largestTestSize) + + for _, size := range []int{4, 5, 6, 7, 8} { + for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} { + testSize := 1 << size + + pointsCopy := core.HostSliceFromElements[Projective](points[:testSize]) + cfg.Ordering = v + cfg.NttAlgorithm = core.Radix2 + + output := make(core.HostSlice[Projective], testSize) + e := ECNtt(pointsCopy, core.KForward, &cfg, output) + assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed") + } + } +} + func TestNttDeviceAsync(t *testing.T) { cfg := GetDefaultNttConfig() scalars := GenerateScalars(1 << largestTestSize) diff --git a/wrappers/golang/curves/bw6761/include/ntt.h b/wrappers/golang/curves/bw6761/include/ntt.h index 6716834c..4e45ca85 100644 --- a/wrappers/golang/curves/bw6761/include/ntt.h +++ b/wrappers/golang/curves/bw6761/include/ntt.h @@ -10,10 +10,11 @@ extern "C" { #endif cudaError_t bw6_761NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output); +cudaError_t bw6_761ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output); cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles); #ifdef __cplusplus } #endif -#endif +#endif \ No newline at end of file diff --git a/wrappers/golang/curves/bw6761/ntt.go b/wrappers/golang/curves/bw6761/ntt.go index 64bf1bfa..5b4e9d78 100644 --- a/wrappers/golang/curves/bw6761/ntt.go +++ b/wrappers/golang/curves/bw6761/ntt.go @@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo return core.FromCudaError(err) } +func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError { + core.NttCheck[T](points, cfg, results) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0]) + } + cPoints := (*C.projective_t)(pointsPointer) + cSize := (C.int)(points.Len() / int(cfg.BatchSize)) + cDir := (C.int)(dir) + cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg)) + + var resultsPointer unsafe.Pointer + if results.IsOnDevice() { + resultsPointer = results.(core.DeviceSlice).AsPointer() + } else { + resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) + } + cResults := (*C.projective_t)(resultsPointer) + + __ret := C.bw6_761ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults) + err := (cr.CudaError)(__ret) + return core.FromCudaError(err) +} + func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError { cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer())) cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) diff --git a/wrappers/golang/curves/bw6761/ntt_test.go b/wrappers/golang/curves/bw6761/ntt_test.go index 39283636..e145f526 100644 --- a/wrappers/golang/curves/bw6761/ntt_test.go +++ b/wrappers/golang/curves/bw6761/ntt_test.go @@ -98,6 +98,26 @@ func TestNtt(t *testing.T) { } } +func TestECNtt(t *testing.T) { + cfg := GetDefaultNttConfig() + initDomain(largestTestSize, cfg) + points := GenerateProjectivePoints(1 << largestTestSize) + + for _, size := range []int{4, 5, 6, 7, 8} { + for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} { + testSize := 1 << size + + pointsCopy := core.HostSliceFromElements[Projective](points[:testSize]) + cfg.Ordering = v + cfg.NttAlgorithm = core.Radix2 + + output := make(core.HostSlice[Projective], testSize) + e := ECNtt(pointsCopy, core.KForward, &cfg, output) + assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed") + } + } +} + func TestNttDeviceAsync(t *testing.T) { cfg := GetDefaultNttConfig() scalars := GenerateScalars(1 << largestTestSize) diff --git a/wrappers/golang/internal/generator/templates/include/ntt.h.tmpl b/wrappers/golang/internal/generator/templates/include/ntt.h.tmpl index df915f06..16f972ac 100644 --- a/wrappers/golang/internal/generator/templates/include/ntt.h.tmpl +++ b/wrappers/golang/internal/generator/templates/include/ntt.h.tmpl @@ -10,10 +10,11 @@ extern "C" { #endif cudaError_t {{.Curve}}NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output); +cudaError_t {{.Curve}}ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output); cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles); #ifdef __cplusplus } #endif -#endif +#endif \ No newline at end of file diff --git a/wrappers/golang/internal/generator/templates/msm.go.tmpl b/wrappers/golang/internal/generator/templates/msm.go.tmpl index 7c5b6499..e60c691d 100644 --- a/wrappers/golang/internal/generator/templates/msm.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm.go.tmpl @@ -74,7 +74,7 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp cC := (C.int)(c) cPointsIsOnDevice := (C._Bool)(points.IsOnDevice()) cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx)) - + outputBasesPointer := outputBases.AsPointer() cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer) diff --git a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl index ab2284ac..59620f25 100644 --- a/wrappers/golang/internal/generator/templates/msm_test.go.tmpl +++ b/wrappers/golang/internal/generator/templates/msm_test.go.tmpl @@ -204,7 +204,6 @@ func TestPrecomputeBase{{if .IsG2}}G2{{end}}(t *testing.T) { } } - func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) { cfg := GetDefaultMSMConfig() for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} { @@ -241,7 +240,6 @@ func TestMSM{{if .IsG2}}G2{{end}}MultiDevice(t *testing.T) { orig_device, _ := cr.GetDevice() wg := sync.WaitGroup{} - for i := 0; i < numDevices; i++ { wg.Add(1) cr.RunOnDevice(i, func(args ...any) { diff --git a/wrappers/golang/internal/generator/templates/ntt.go.tmpl b/wrappers/golang/internal/generator/templates/ntt.go.tmpl index 2fb6623a..55036da9 100644 --- a/wrappers/golang/internal/generator/templates/ntt.go.tmpl +++ b/wrappers/golang/internal/generator/templates/ntt.go.tmpl @@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo return core.FromCudaError(err) } +func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError { + core.NttCheck[T](points, cfg, results) + + var pointsPointer unsafe.Pointer + if points.IsOnDevice() { + pointsPointer = points.(core.DeviceSlice).AsPointer() + } else { + pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0]) + } + cPoints := (*C.projective_t)(pointsPointer) + cSize := (C.int)(points.Len() / int(cfg.BatchSize)) + cDir := (C.int)(dir) + cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg)) + + var resultsPointer unsafe.Pointer + if results.IsOnDevice() { + resultsPointer = results.(core.DeviceSlice).AsPointer() + } else { + resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0]) + } + cResults := (*C.projective_t)(resultsPointer) + + __ret := C.{{.Curve}}ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults) + err := (cr.CudaError)(__ret) + return core.FromCudaError(err) +} + func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError { cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer())) cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx)) diff --git a/wrappers/golang/internal/generator/templates/ntt_test.go.tmpl b/wrappers/golang/internal/generator/templates/ntt_test.go.tmpl index 03e3f5c3..a2f205a0 100644 --- a/wrappers/golang/internal/generator/templates/ntt_test.go.tmpl +++ b/wrappers/golang/internal/generator/templates/ntt_test.go.tmpl @@ -98,6 +98,26 @@ func TestNtt(t *testing.T) { } } +func TestECNtt(t *testing.T) { + cfg := GetDefaultNttConfig() + initDomain(largestTestSize, cfg) + points := GenerateProjectivePoints(1 << largestTestSize) + + for _, size := range []int{4, 5, 6, 7, 8} { + for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} { + testSize := 1 << size + + pointsCopy := core.HostSliceFromElements[Projective](points[:testSize]) + cfg.Ordering = v + cfg.NttAlgorithm = core.Radix2 + + output := make(core.HostSlice[Projective], testSize) + e := ECNtt(pointsCopy, core.KForward, &cfg, output) + assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed") + } + } +} + func TestNttDeviceAsync(t *testing.T) { cfg := GetDefaultNttConfig() scalars := GenerateScalars(1 << largestTestSize)