[FEAT] ReleaseDomain API (#465)

## Describe the changes

This PR adds a NTT ReleaseDomain API in Golang and Rust

## Linked Issues

Resolves #

---------

Co-authored-by: Yuval Shekel <yshekel@gmail.com>
This commit is contained in:
Leon Hibnik
2024-04-09 12:58:19 +03:00
committed by GitHub
parent 4a35eece51
commit a7b0dc40c1
18 changed files with 246 additions and 17 deletions

View File

@@ -523,6 +523,7 @@ namespace ntt {
domain.fast_internal_twiddles_inv = nullptr;
CHK_IF_RETURN(cudaFreeAsync(domain.fast_basic_twiddles_inv, ctx.stream));
domain.fast_basic_twiddles_inv = nullptr;
domain.initialized = false;
return CHK_LAST();
}
@@ -749,6 +750,17 @@ namespace ntt {
return NTT<curve_config::scalar_t, curve_config::scalar_t>(input, size, dir, config, output);
}
/**
* Extern "C" version of [ReleaseDomain](@ref ReleaseDomain) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
* - `S` is the [scalar field](@ref scalar_t) of the curve;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, ReleaseDomain)(device_context::DeviceContext& ctx)
{
return ReleaseDomain<curve_config::scalar_t>(ctx);
}
#if defined(ECNTT_DEFINED)
/**
* Extern "C" version of [NTT](@ref NTT) function with the following values of template parameters

View File

@@ -12,6 +12,7 @@ extern "C" {
cudaError_t bls12_377NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_377ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bls12_377ReleaseDomain(DeviceContext* ctx);
#ifdef __cplusplus
}

View File

@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bls12_377ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

View File

@@ -1,6 +1,7 @@
package bls12377
import (
"os"
"reflect"
"testing"
@@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}
func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}
func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}
func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}
// execute tests
os.Exit(m.Run())
// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}
// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {

View File

@@ -12,6 +12,7 @@ extern "C" {
cudaError_t bls12_381NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_381ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bls12_381ReleaseDomain(DeviceContext* ctx);
#ifdef __cplusplus
}

View File

@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bls12_381ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

View File

@@ -1,6 +1,7 @@
package bls12381
import (
"os"
"reflect"
"testing"
@@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}
func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}
func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}
func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}
// execute tests
os.Exit(m.Run())
// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}
// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {

View File

@@ -12,6 +12,7 @@ extern "C" {
cudaError_t bn254NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bn254ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bn254ReleaseDomain(DeviceContext* ctx);
#ifdef __cplusplus
}

View File

@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bn254ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

View File

@@ -1,6 +1,7 @@
package bn254
import (
"os"
"reflect"
"testing"
@@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}
func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}
func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}
func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}
// execute tests
os.Exit(m.Run())
// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}
// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {

View File

@@ -12,6 +12,7 @@ extern "C" {
cudaError_t bw6_761NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bw6_761ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t bw6_761ReleaseDomain(DeviceContext* ctx);
#ifdef __cplusplus
}

View File

@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.bw6_761ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

View File

@@ -1,6 +1,7 @@
package bw6761
import (
"os"
"reflect"
"testing"
@@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}
func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}
func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}
func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}
// execute tests
os.Exit(m.Run())
// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}
// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {

View File

@@ -12,6 +12,7 @@ extern "C" {
cudaError_t {{.Curve}}NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t {{.Curve}}ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
cudaError_t {{.Curve}}ReleaseDomain(DeviceContext* ctx);
#ifdef __cplusplus
}

View File

@@ -88,3 +88,10 @@ func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bo
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func ReleaseDomain(ctx cr.DeviceContext) core.IcicleError {
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
__ret := C.{{.Curve}}ReleaseDomain(cCtx)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}

View File

@@ -1,6 +1,7 @@
package {{.PackageName}}
import (
"os"
"reflect"
"testing"
@@ -21,14 +22,15 @@ func init() {
initDomain(largestTestSize, cfg)
}
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) {
func initDomain[T any](largestTestSize int, cfg core.NTTConfig[T]) core.IcicleError {
rouMont, _ := fft.Generator(uint64(1 << largestTestSize))
rou := rouMont.Bits()
rouIcicle := ScalarField{}
limbs := core.ConvertUint64ArrToUint32Arr(rou[:])
rouIcicle.FromLimbs(limbs)
InitDomain(rouIcicle, cfg.Ctx, false)
e := InitDomain(rouIcicle, cfg.Ctx, false)
return e
}
func testAgainstGnarkCryptoNtt(size int, scalars core.HostSlice[ScalarField], output core.HostSlice[ScalarField], order core.Ordering, direction core.NTTDir) bool {
@@ -78,7 +80,7 @@ func TestNTTGetDefaultConfig(t *testing.T) {
}
func TestInitDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the init() function")
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
assert.NotPanics(t, func() { initDomain(largestTestSize, cfg) })
}
@@ -201,6 +203,31 @@ func TestNttBatch(t *testing.T) {
}
}
func TestReleaseDomain(t *testing.T) {
t.Skip("Skipped because each test requires the domain to be initialized before running. We ensure this using the TestMain() function")
cfg := GetDefaultNttConfig()
e := ReleaseDomain(cfg.Ctx)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ReleasDomain failed")
}
func TestMain(m *testing.M) {
// setup domain
cfg := GetDefaultNttConfig()
e := initDomain(largestTestSize, cfg)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("initDomain failed")
}
// execute tests
os.Exit(m.Run())
// release domain
e = ReleaseDomain(cfg.Ctx)
if e.IcicleErrorCode != core.IcicleErrorCode(0) {
panic("ReleaseDomain failed")
}
}
// func TestNttArbitraryCoset(t *testing.T) {
// for _, size := range []int{20} {
// for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {

View File

@@ -124,6 +124,7 @@ pub trait NTT<F: FieldImpl> {
fn ntt_inplace_unchecked(inout: &mut HostOrDeviceSlice<F>, dir: NTTDir, cfg: &NTTConfig<F>) -> IcicleResult<()>;
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
fn initialize_domain_fast_twiddles_mode(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
fn release_domain(ctx: &DeviceContext) -> IcicleResult<()>;
}
/// Computes the NTT, or a batch of several NTTs.
@@ -206,6 +207,14 @@ where
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain_fast_twiddles_mode(primitive_root, ctx)
}
pub fn release_domain<F>(ctx: &DeviceContext) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: NTT<F>,
{
<<F as FieldImpl>::Config as NTT<F>>::release_domain(ctx)
}
#[macro_export]
macro_rules! impl_ntt {
(
@@ -233,6 +242,9 @@ macro_rules! impl_ntt {
ctx: &DeviceContext,
fast_twiddles_mode: bool,
) -> CudaError;
#[link_name = concat!($field_prefix, "ReleaseDomain")]
pub(crate) fn release_ntt_domain(ctx: &DeviceContext) -> CudaError;
}
}
@@ -278,6 +290,9 @@ macro_rules! impl_ntt {
fn initialize_domain_fast_twiddles_mode(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
unsafe { $field_prefix_ident::initialize_ntt_domain(&primitive_root, ctx, true).wrap() }
}
fn release_domain(ctx: &DeviceContext) -> IcicleResult<()> {
unsafe { $field_prefix_ident::release_ntt_domain(ctx).wrap() }
}
}
};
}
@@ -289,6 +304,7 @@ macro_rules! impl_ntt_tests {
) => {
const MAX_SIZE: u64 = 1 << 17;
static INIT: OnceLock<()> = OnceLock::new();
static RELEASE: OnceLock<()> = OnceLock::new(); // for release domain test
const FAST_TWIDDLES_MODE: bool = false;
#[test]
@@ -320,5 +336,12 @@ macro_rules! impl_ntt_tests {
// init_domain is in this test is performed per-device
check_ntt_device_async::<$field>()
}
#[test]
fn test_ntt_release_domain() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE));
check_release_domain::<$field>();
*RELEASE.get_or_init(move || init_domain::<$field>(MAX_SIZE, DEFAULT_DEVICE_ID, FAST_TWIDDLES_MODE))
}
};
}

View File

@@ -7,7 +7,10 @@ use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use crate::{
ntt::{initialize_domain, initialize_domain_fast_twiddles_mode, ntt, ntt_inplace, NTTDir, NttAlgorithm, Ordering},
ntt::{
initialize_domain, initialize_domain_fast_twiddles_mode, ntt, ntt_inplace, release_domain, NTTDir,
NttAlgorithm, Ordering,
},
traits::{ArkConvertible, FieldImpl, GenerateRandom},
vec_ops::{transpose_matrix, VecOps},
};
@@ -28,6 +31,13 @@ where
}
}
pub fn rel_domain<F: FieldImpl>(ctx: &DeviceContext)
where
<F as FieldImpl>::Config: NTT<F>,
{
release_domain::<F>(&ctx).unwrap();
}
pub fn reverse_bit_order(n: u32, order: u32) -> u32 {
fn is_power_of_two(n: u32) -> bool {
n != 0 && n & (n - 1) == 0
@@ -333,11 +343,11 @@ where
set_device(device_id).unwrap();
// if have more than one device, it will use fast-twiddles-mode (note that domain is reused per device if not released)
init_domain::<F>(1 << 16, device_id, true /*=fast twiddles mode*/); // init domain per device
let mut config: NTTConfig<'static, F> = NTTConfig::default_for_device(device_id);
let test_sizes = [1 << 4, 1 << 12];
let batch_sizes = [1, 1 << 4, 100];
for test_size in test_sizes {
let coset_generators = [F::one(), F::Config::generate_random(1)[0]];
let mut config = NTTConfig::default_for_device(device_id);
let stream = config
.ctx
.stream;
@@ -386,3 +396,12 @@ where
}
});
}
pub fn check_release_domain<F: FieldImpl + ArkConvertible>()
where
F::ArkEquivalent: FftField,
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
{
let config: NTTConfig<'static, F> = NTTConfig::default();
rel_domain::<F>(&config.ctx);
}