mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 13:07:59 -05:00
Golang bindings for ECNTT (#433)
This commit is contained in:
2
.github/workflows/golang.yml
vendored
2
.github/workflows/golang.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
- name: Build
|
||||
working-directory: ./wrappers/golang
|
||||
if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'
|
||||
run: ./build.sh ${{ matrix.curve }} ON # builds a single curve with G2 enabled
|
||||
run: ./build.sh ${{ matrix.curve }} ON ON # builds a single curve with G2 and ECNTT enabled
|
||||
- name: Upload ICICLE lib artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'
|
||||
|
||||
@@ -96,6 +96,10 @@ if (G2_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DG2_DEFINED=ON")
|
||||
endif ()
|
||||
|
||||
if (ECNTT_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON")
|
||||
endif ()
|
||||
|
||||
option(BUILD_TESTS "Build tests" OFF)
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
@@ -110,6 +114,9 @@ if (NOT BUILD_TESTS)
|
||||
if (NOT CURVE IN_LIST SUPPORTED_CURVES_WITHOUT_NTT)
|
||||
list(APPEND ICICLE_SOURCES appUtils/ntt/ntt.cu)
|
||||
list(APPEND ICICLE_SOURCES appUtils/ntt/kernel_ntt.cu)
|
||||
if(ECNTT_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(
|
||||
|
||||
@@ -644,29 +644,34 @@ namespace ntt {
|
||||
h_coset.clear();
|
||||
}
|
||||
|
||||
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
|
||||
const bool is_inverse = dir == NTTDir::kInverse;
|
||||
|
||||
if (is_radix2_algorithm) {
|
||||
if constexpr (std::is_same_v<E, curve_config::projective_t>) {
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_on_coset = (coset_index != 0) || coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
|
||||
S* twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
|
||||
: domain.twiddles;
|
||||
S* internal_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
|
||||
: domain.internal_twiddles;
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
|
||||
if (is_radix2_algorithm) {
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_on_coset = (coset_index != 0) || coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
|
||||
S* twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
|
||||
: domain.twiddles;
|
||||
S* internal_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
|
||||
: domain.internal_twiddles;
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
}
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
G2_DEFINED=OFF
|
||||
ECNTT_DEFINED=OFF
|
||||
|
||||
if [[ $2 ]]
|
||||
if [[ $2 == "ON" ]]
|
||||
then
|
||||
G2_DEFINED=ON
|
||||
fi
|
||||
|
||||
if [[ $3 ]]
|
||||
then
|
||||
ECNTT_DEFINED=ON
|
||||
fi
|
||||
|
||||
BUILD_DIR=$(realpath "$PWD/../../icicle/build")
|
||||
SUPPORTED_CURVES=("bn254" "bls12_377" "bls12_381" "bw6_761")
|
||||
|
||||
@@ -22,6 +28,6 @@ mkdir -p build
|
||||
|
||||
for CURVE in "${BUILD_CURVES[@]}"
|
||||
do
|
||||
cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build
|
||||
cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DECNTT_DEFINED=$ECNTT_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build
|
||||
cmake --build build -j8
|
||||
done
|
||||
@@ -10,26 +10,26 @@ type NTTDir int8
|
||||
|
||||
const (
|
||||
KForward NTTDir = iota
|
||||
KInverse NTTDir = 1
|
||||
KInverse
|
||||
)
|
||||
|
||||
type Ordering uint32
|
||||
|
||||
const (
|
||||
KNN Ordering = iota
|
||||
KNR Ordering = 1
|
||||
KRN Ordering = 2
|
||||
KRR Ordering = 3
|
||||
KNM Ordering = 4
|
||||
KMN Ordering = 5
|
||||
KNR
|
||||
KRN
|
||||
KRR
|
||||
KNM
|
||||
KMN
|
||||
)
|
||||
|
||||
type NttAlgorithm uint32
|
||||
|
||||
const (
|
||||
Auto NttAlgorithm = iota
|
||||
Radix2 NttAlgorithm = 1
|
||||
MixedRadix NttAlgorithm = 2
|
||||
Auto NttAlgorithm = iota
|
||||
Radix2
|
||||
MixedRadix
|
||||
)
|
||||
|
||||
type NTTConfig[T any] struct {
|
||||
@@ -47,10 +47,9 @@ type NTTConfig[T any] struct {
|
||||
areOutputsOnDevice bool
|
||||
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
|
||||
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
|
||||
IsAsync bool
|
||||
/// Explicitly select the NTT algorithm.
|
||||
/// Default value: Auto (the implementation selects radix-2 or mixed-radix algorithm based on heuristics).
|
||||
NttAlgorithm NttAlgorithm
|
||||
IsAsync bool
|
||||
NttAlgorithm NttAlgorithm /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
|
||||
selects radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
}
|
||||
|
||||
func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
||||
@@ -64,7 +63,7 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
||||
false, // areInputsOnDevice
|
||||
false, // areOutputsOnDevice
|
||||
false, // IsAsync
|
||||
Auto, // NttAlgorithm
|
||||
Auto,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,10 +10,11 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_377NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_377ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bls12_377ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -10,10 +10,11 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_381NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_381ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bls12_381ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -10,10 +10,11 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bn254NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bn254ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bn254ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -10,10 +10,11 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bw6_761NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bw6_761ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bw6_761ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -10,10 +10,11 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t {{.Curve}}NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t {{.Curve}}ECNTTCuda(projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -74,7 +74,7 @@ func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomp
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer)
|
||||
|
||||
|
||||
@@ -204,7 +204,6 @@ func TestPrecomputeBase{{if .IsG2}}G2{{end}}(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -241,7 +240,6 @@ func TestMSM{{if .IsG2}}G2{{end}}MultiDevice(t *testing.T) {
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
|
||||
@@ -53,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.{{.Curve}}ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
Reference in New Issue
Block a user