Files
icicle/wrappers/golang/core/ntt_test.go
2024-02-22 20:52:48 +02:00

97 lines
2.5 KiB
Go

package core
import (
// "unsafe"
"testing"
"github.com/ingonyama-zk/icicle/wrappers/golang/core/internal"
"github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
"github.com/stretchr/testify/assert"
)
func TestNTTDefaultConfig(t *testing.T) {
var cosetGenField internal.MockField
cosetGenField.One()
var cosetGen [1]uint32
copy(cosetGen[:], cosetGenField.GetLimbs())
ctx, _ := cuda_runtime.GetDefaultDeviceContext()
expected := NTTConfig[[1]uint32]{
ctx, // Ctx
cosetGen, // CosetGen
1, // BatchSize
KNN, // Ordering
false, // areInputsOnDevice
false, // areOutputsOnDevice
false, // IsAsync
}
actual := GetDefaultNTTConfig(cosetGen)
assert.Equal(t, expected, actual)
}
func TestNTTCheckHostScalars(t *testing.T) {
randLimbs := []uint32{1, 2, 3, 4, 5, 6, 7, 8}
var cosetGen internal.MockField
cosetGen.FromLimbs(randLimbs)
cfg := GetDefaultNTTConfig(&cosetGen)
rawInput := make([]internal.MockField, 10)
var emptyField internal.MockField
emptyField.FromLimbs(randLimbs)
for i := range rawInput {
rawInput[i] = emptyField
}
input := HostSliceFromElements[internal.MockField](rawInput)
output := HostSliceFromElements[internal.MockField](rawInput)
assert.NotPanics(t, func() { NttCheck(input, &cfg, output) })
assert.False(t, cfg.areInputsOnDevice)
assert.False(t, cfg.areOutputsOnDevice)
rawInputLarger := make([]internal.MockField, 11)
for i := range rawInputLarger {
rawInputLarger[i] = emptyField
}
output2 := HostSliceFromElements[internal.MockField](rawInputLarger)
assert.Panics(t, func() { NttCheck(input, &cfg, output2) })
}
func TestNTTCheckDeviceScalars(t *testing.T) {
randLimbs := []uint32{1, 2, 3, 4, 5, 6, 7, 8}
var cosetGen internal.MockField
cosetGen.FromLimbs(randLimbs)
cfg := GetDefaultNTTConfig(cosetGen)
fieldBytesSize := 16
numFields := 10
rawInput := make([]internal.MockField, numFields)
for i := range rawInput {
var emptyField internal.MockField
emptyField.FromLimbs(randLimbs)
rawInput[i] = emptyField
}
hostElements := HostSliceFromElements[internal.MockField](rawInput)
var input DeviceSlice
hostElements.CopyToDevice(&input, true)
var output DeviceSlice
output.Malloc(numFields*fieldBytesSize, fieldBytesSize)
assert.NotPanics(t, func() { NttCheck(input, &cfg, output) })
assert.True(t, cfg.areInputsOnDevice)
assert.True(t, cfg.areOutputsOnDevice)
var output2 DeviceSlice
output2.Malloc((numFields+1)*fieldBytesSize, fieldBytesSize)
assert.Panics(t, func() { NttCheck(input, &cfg, output2) })
}
// TODO add check for batches and batchSize