mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-07 22:53:56 -05:00
adds CI checks for building and testing Golang bindings adds CI checks for formatting Rust and Golang files Fixes Golang tests for BN254 Splits Actions checks for PR against main into multiple files Resolves #108 Resolves #107 Resolves #138
131 lines
3.5 KiB
Cheetah
131 lines
3.5 KiB
Cheetah
import (
|
|
"fmt"
|
|
"github.com/stretchr/testify/assert"
|
|
"reflect"
|
|
"testing"
|
|
)
|
|
|
|
func TestNtt{{.CurveNameUpperCase}}Batch(t *testing.T) {
|
|
count := 1 << 20
|
|
scalars := GenerateScalars(count, false)
|
|
|
|
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
|
|
copy(nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, scalars)
|
|
NttBatch(&nttResult, false, count, 0)
|
|
assert.NotEqual(t, nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, nttResult)
|
|
}
|
|
|
|
func TestNtt{{.CurveNameUpperCase}}CompareToGnarkDIF(t *testing.T) {
|
|
count := 1 << 2
|
|
scalars := GenerateScalars(count, false)
|
|
|
|
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
|
|
copy(nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, scalars)
|
|
Ntt(&nttResult, false, 0)
|
|
assert.NotEqual(t, nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, nttResult)
|
|
}
|
|
|
|
func TestINtt{{.CurveNameUpperCase}}CompareToGnarkDIT(t *testing.T) {
|
|
count := 1 << 3
|
|
scalars := GenerateScalars(count, false)
|
|
|
|
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
|
|
copy(nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, scalars)
|
|
Ntt(&nttResult, true, 0)
|
|
assert.NotEqual(t, nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, nttResult)
|
|
}
|
|
|
|
func TestNtt{{.CurveNameUpperCase}}(t *testing.T) {
|
|
count := 1 << 3
|
|
|
|
scalars := GenerateScalars(count, false)
|
|
|
|
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
|
|
copy(nttResult, scalars)
|
|
|
|
assert.Equal(t, nttResult, scalars)
|
|
Ntt(&nttResult, false, 0)
|
|
assert.NotEqual(t, nttResult, scalars)
|
|
|
|
inttResult := make([]G1ScalarField, len(nttResult))
|
|
copy(inttResult, nttResult)
|
|
|
|
assert.Equal(t, inttResult, nttResult)
|
|
Ntt(&inttResult, true, 0)
|
|
assert.Equal(t, inttResult, scalars)
|
|
}
|
|
|
|
func TestNttBatch{{.CurveNameUpperCase}}(t *testing.T) {
|
|
count := 1 << 5
|
|
batches := 4
|
|
|
|
scalars := GenerateScalars(count*batches, false)
|
|
|
|
var scalarVecOfVec [][]G1ScalarField = make([][]G1ScalarField, 0)
|
|
|
|
for i := 0; i < batches; i++ {
|
|
start := i * count
|
|
end := (i + 1) * count
|
|
batch := make([]G1ScalarField, len(scalars[start:end]))
|
|
copy(batch, scalars[start:end])
|
|
scalarVecOfVec = append(scalarVecOfVec, batch)
|
|
}
|
|
|
|
nttBatchResult := make([]G1ScalarField, len(scalars))
|
|
copy(nttBatchResult, scalars)
|
|
|
|
NttBatch(&nttBatchResult, false, count, 0)
|
|
|
|
var nttResultVecOfVec [][]G1ScalarField
|
|
|
|
for i := 0; i < batches; i++ {
|
|
// Clone the slice
|
|
clone := make([]G1ScalarField, len(scalarVecOfVec[i]))
|
|
copy(clone, scalarVecOfVec[i])
|
|
|
|
// Add it to the result vector of vectors
|
|
nttResultVecOfVec = append(nttResultVecOfVec, clone)
|
|
|
|
// Call the ntt_{{.CurveNameLowerCase}} function
|
|
Ntt(&nttResultVecOfVec[i], false, 0)
|
|
}
|
|
|
|
assert.NotEqual(t, nttBatchResult, scalars)
|
|
|
|
// Check that the ntt of each vec of scalars is equal to the intt of the specific batch
|
|
for i := 0; i < batches; i++ {
|
|
if !reflect.DeepEqual(nttResultVecOfVec[i], nttBatchResult[i*count:((i+1)*count)]) {
|
|
t.Errorf("ntt of vec of scalars not equal to intt of specific batch")
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkNTT(b *testing.B) {
|
|
LOG_NTT_SIZES := []int{12, 15, 20, 21, 22, 23, 24, 25, 26}
|
|
|
|
for _, logNTTSize := range LOG_NTT_SIZES {
|
|
nttSize := 1 << logNTTSize
|
|
b.Run(fmt.Sprintf("NTT %d", logNTTSize), func(b *testing.B) {
|
|
scalars := GenerateScalars(nttSize, false)
|
|
|
|
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
|
|
copy(nttResult, scalars)
|
|
for n := 0; n < b.N; n++ {
|
|
Ntt(&nttResult, false, 0)
|
|
}
|
|
})
|
|
}
|
|
}
|