mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
[FEAT] ReleaseDomain API (#465)
## Describe the changes This PR adds a NTT ReleaseDomain API in Golang and Rust ## Linked Issues Resolves # --------- Co-authored-by: Yuval Shekel <yshekel@gmail.com>
This commit is contained in:
@@ -523,6 +523,7 @@ namespace ntt {
|
||||
domain.fast_internal_twiddles_inv = nullptr;
|
||||
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles_inv, ctx.stream));
|
||||
domain.fast_basic_twiddles_inv = nullptr;
|
||||
domain.initialized = false;
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
@@ -749,6 +750,17 @@ namespace ntt {
|
||||
return NTT<curve_config::scalar_t, curve_config::scalar_t>(input, size, dir, config, output);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern "C" version of [ReleaseDomain](@ref ReleaseDomain) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `S` is the [scalar field](@ref scalar_t) of the curve;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, ReleaseDomain)(device_context::DeviceContext& ctx)
|
||||
{
|
||||
return ReleaseDomain<curve_config::scalar_t>(ctx);
|
||||
}
|
||||
|
||||
#if defined(ECNTT_DEFINED)
|
||||
/**
|
||||
* Extern "C" version of [NTT](@ref NTT) function with the following values of template parameters
|
||||
|
||||
@@ -12,6 +12,7 @@ extern "C" {
|
||||
cudaError_t bls12_377NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_377ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
cudaError_t bls12_377ReleaseDomain(DeviceContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
__ret := C.bls12_377ReleaseDomain(cCtx)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bls12377
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -21,14 +22,15 @@ func init() {
|
||||
initDomain(largestTestSize, cfg)
|
||||
}
|
||||
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
|
||||
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
|
||||
rou := rouMont.Bits()
|
||||
rouIcicle := ScalarField{}
|
||||
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
|
||||
|
||||
rouIcicle.FromLimbs(limbs)
|
||||
InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
e := InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
return e
|
||||
}
|
||||
|
||||
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
|
||||
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInitDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
|
||||
}
|
||||
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := ReleaseDomain(cfg.Ctx)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// setup domain
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := initDomain(largestTestSize, cfg)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("initDomain failed")
|
||||
}
|
||||
|
||||
// execute tests
|
||||
os.Exit(m.Run())
|
||||
|
||||
// release domain
|
||||
e = ReleaseDomain(cfg.Ctx)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("ReleaseDomain failed")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestNttArbitraryCoset(t *testing.T) {
|
||||
// for _, size := range []int{20} {
|
||||
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
|
||||
@@ -12,6 +12,7 @@ extern "C" {
|
||||
cudaError_t bls12_381NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_381ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
cudaError_t bls12_381ReleaseDomain(DeviceContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
__ret := C.bls12_381ReleaseDomain(cCtx)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bls12381
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -21,14 +22,15 @@ func init() {
|
||||
initDomain(largestTestSize, cfg)
|
||||
}
|
||||
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
|
||||
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
|
||||
rou := rouMont.Bits()
|
||||
rouIcicle := ScalarField{}
|
||||
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
|
||||
|
||||
rouIcicle.FromLimbs(limbs)
|
||||
InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
e := InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
return e
|
||||
}
|
||||
|
||||
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
|
||||
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInitDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
|
||||
}
|
||||
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := ReleaseDomain(cfg.Ctx)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// setup domain
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := initDomain(largestTestSize, cfg)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("initDomain failed")
|
||||
}
|
||||
|
||||
// execute tests
|
||||
os.Exit(m.Run())
|
||||
|
||||
// release domain
|
||||
e = ReleaseDomain(cfg.Ctx)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("ReleaseDomain failed")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestNttArbitraryCoset(t *testing.T) {
|
||||
// for _, size := range []int{20} {
|
||||
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
|
||||
@@ -12,6 +12,7 @@ extern "C" {
|
||||
cudaError_t bn254NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bn254ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
cudaError_t bn254ReleaseDomain(DeviceContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
__ret := C.bn254ReleaseDomain(cCtx)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bn254
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -21,14 +22,15 @@ func init() {
|
||||
initDomain(largestTestSize, cfg)
|
||||
}
|
||||
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
|
||||
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
|
||||
rou := rouMont.Bits()
|
||||
rouIcicle := ScalarField{}
|
||||
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
|
||||
|
||||
rouIcicle.FromLimbs(limbs)
|
||||
InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
e := InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
return e
|
||||
}
|
||||
|
||||
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
|
||||
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInitDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
|
||||
}
|
||||
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := ReleaseDomain(cfg.Ctx)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// setup domain
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := initDomain(largestTestSize, cfg)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("initDomain failed")
|
||||
}
|
||||
|
||||
// execute tests
|
||||
os.Exit(m.Run())
|
||||
|
||||
// release domain
|
||||
e = ReleaseDomain(cfg.Ctx)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("ReleaseDomain failed")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestNttArbitraryCoset(t *testing.T) {
|
||||
// for _, size := range []int{20} {
|
||||
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
|
||||
@@ -12,6 +12,7 @@ extern "C" {
|
||||
cudaError_t bw6_761NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bw6_761ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
cudaError_t bw6_761ReleaseDomain(DeviceContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
__ret := C.bw6_761ReleaseDomain(cCtx)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package bw6761
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -21,14 +22,15 @@ func init() {
|
||||
initDomain(largestTestSize, cfg)
|
||||
}
|
||||
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
|
||||
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
|
||||
rou := rouMont.Bits()
|
||||
rouIcicle := ScalarField{}
|
||||
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
|
||||
|
||||
rouIcicle.FromLimbs(limbs)
|
||||
InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
e := InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
return e
|
||||
}
|
||||
|
||||
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
|
||||
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInitDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
|
||||
}
|
||||
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := ReleaseDomain(cfg.Ctx)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// setup domain
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := initDomain(largestTestSize, cfg)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("initDomain failed")
|
||||
}
|
||||
|
||||
// execute tests
|
||||
os.Exit(m.Run())
|
||||
|
||||
// release domain
|
||||
e = ReleaseDomain(cfg.Ctx)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("ReleaseDomain failed")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestNttArbitraryCoset(t *testing.T) {
|
||||
// for _, size := range []int{20} {
|
||||
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
|
||||
@@ -12,6 +12,7 @@ extern "C" {
|
||||
cudaError_t {{.Curve}}NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t {{.Curve}}ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
cudaError_t {{.Curve}}ReleaseDomain(DeviceContext* ctx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
__ret := C.{{.Curve}}ReleaseDomain(cCtx)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -21,14 +22,15 @@ func init() {
|
||||
initDomain(largestTestSize, cfg)
|
||||
}
|
||||
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
|
||||
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
|
||||
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
|
||||
rou := rouMont.Bits()
|
||||
rouIcicle := ScalarField{}
|
||||
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
|
||||
|
||||
rouIcicle.FromLimbs(limbs)
|
||||
InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
e := InitDomain(rouIcicle, cfg.Ctx, false)
|
||||
return e
|
||||
}
|
||||
|
||||
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
|
||||
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInitDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
|
||||
}
|
||||
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReleaseDomain(t *testing.T) {
|
||||
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := ReleaseDomain(cfg.Ctx)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// setup domain
|
||||
cfg := GetDefaultNttConfig()
|
||||
e := initDomain(largestTestSize, cfg)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("initDomain failed")
|
||||
}
|
||||
|
||||
// execute tests
|
||||
os.Exit(m.Run())
|
||||
|
||||
// release domain
|
||||
e = ReleaseDomain(cfg.Ctx)
|
||||
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
|
||||
panic("ReleaseDomain failed")
|
||||
}
|
||||
}
|
||||
|
||||
// func TestNttArbitraryCoset(t *testing.T) {
|
||||
// for _, size := range []int{20} {
|
||||
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
|
||||
@@ -124,6 +124,7 @@ pub trait NTT<F: FieldImpl> {
|
||||
fn ntt_inplace_unchecked(inout: &mut HostOrDeviceSlice<F>, dir: NTTDir, cfg: &NTTConfig<F>) -> IcicleResult<()>;
|
||||
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
|
||||
fn initialize_domain_fast_twiddles_mode(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
|
||||
fn release_domain(ctx: &DeviceContext) -> IcicleResult<()>;
|
||||
}
|
||||
|
||||
/// Computes the NTT, or a batch of several NTTs.
|
||||
@@ -206,6 +207,14 @@ where
|
||||
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain_fast_twiddles_mode(primitive_root, ctx)
|
||||
}
|
||||
|
||||
pub fn release_domain<F>(ctx: &DeviceContext) -> IcicleResult<()>
|
||||
where
|
||||
F: FieldImpl,
|
||||
<F as FieldImpl>::Config: NTT<F>,
|
||||
{
|
||||
<<F as FieldImpl>::Config as NTT<F>>::release_domain(ctx)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_ntt {
|
||||
(
|
||||
@@ -233,6 +242,9 @@ macro_rules! impl_ntt {
|
||||
ctx: &DeviceContext,
|
||||
fast_twiddles_mode: bool,
|
||||
) -> CudaError;
|
||||
|
||||
#[link_name = concat!($field_prefix, "ReleaseDomain")]
|
||||
pub(crate) fn release_ntt_domain(ctx: &DeviceContext) -> CudaError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -278,6 +290,9 @@ macro_rules! impl_ntt {
|
||||
fn initialize_domain_fast_twiddles_mode(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
|
||||
unsafe { $field_prefix_ident::initialize_ntt_domain(&primitive_root, ctx, true).wrap() }
|
||||
}
|
||||
fn release_domain(ctx: &DeviceContext) -> IcicleResult<()> {
|
||||
unsafe { $field_prefix_ident::release_ntt_domain(ctx).wrap() }
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -289,6 +304,7 @@ macro_rules! impl_ntt_tests {
|
||||
) => {
|
||||
const MAX_SIZE: u64 = 1 << 17;
|
||||
static INIT: OnceLock<()> = OnceLock::new();
|
||||
static RELEASE: OnceLock<()> = OnceLock::new(); // for release domain test
|
||||
const FAST_TWIDDLES_MODE: bool = false;
|
||||
|
||||
#[test]
|
||||
@@ -320,5 +336,12 @@ macro_rules! impl_ntt_tests {
|
||||
// init_domain is in this test is performed per-device
|
||||
check_ntt_device_async::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntt_release_domain() {
|
||||
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
|
||||
check_release_domain::<$field>();
|
||||
*RELEASE.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@ use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
ntt::{initialize_domain, initialize_domain_fast_twiddles_mode, ntt, ntt_inplace, NTTDir, NttAlgorithm, Ordering},
|
||||
ntt::{
|
||||
initialize_domain, initialize_domain_fast_twiddles_mode, ntt, ntt_inplace, release_domain, NTTDir,
|
||||
NttAlgorithm, Ordering,
|
||||
},
|
||||
traits::{ArkConvertible, FieldImpl, GenerateRandom},
|
||||
vec_ops::{transpose_matrix, VecOps},
|
||||
};
|
||||
@@ -28,6 +31,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rel_domain<F: FieldImpl>(ctx: &DeviceContext)
|
||||
where
|
||||
<F as FieldImpl>::Config: NTT<F>,
|
||||
{
|
||||
release_domain::<F>(&ctx).unwrap();
|
||||
}
|
||||
|
||||
pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
|
||||
fn is_power_of_two(n: u32) -> bool {
|
||||
n != 0 && n & (n - 1) == 0
|
||||
@@ -333,11 +343,11 @@ where
|
||||
set_device(device_id).unwrap();
|
||||
// if have more than one device, it will use fast-twiddles-mode (note that domain is reused per device if not released)
|
||||
init_domain::<F>(1 << 16, device_id, true /*=fast twiddles mode*/); // init domain per device
|
||||
let mut config: NTTConfig<'static, F> = NTTConfig::default_for_device(device_id);
|
||||
let test_sizes = [1 << 4, 1 << 12];
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
for test_size in test_sizes {
|
||||
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
|
||||
let mut config = NTTConfig::default_for_device(device_id);
|
||||
let stream = config
|
||||
.ctx
|
||||
.stream;
|
||||
@@ -386,3 +396,12 @@ where
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub fn check_release_domain<F: FieldImpl + ArkConvertible>()
|
||||
where
|
||||
F::ArkEquivalent: FftField,
|
||||
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
|
||||
{
|
||||
let config: NTTConfig<'static, F> = NTTConfig::default();
|
||||
rel_domain::<F>(&config.ctx);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user