Files
icicle/goicicle/curves/bw6761/ntt_test.go
liuxiao fd62fe5ae8 Support bw6-761 (#188)
Resolves #191 and #113

---------

Co-authored-by: DmytroTym <dmytrotym1@gmail.com>
Co-authored-by: ImmanuelSegol <3ditds@gmail.com>
2023-10-21 18:49:06 +03:00

149 lines
4.0 KiB
Go

// Copyright 2023 Ingonyama
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by Ingonyama DO NOT EDIT
package bw6761
import (
"fmt"
"github.com/stretchr/testify/assert"
"reflect"
"testing"
)
func TestNttBW6761Batch(t *testing.T) {
count := 1 << 20
scalars := GenerateScalars(count, false)
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
copy(nttResult, scalars)
assert.Equal(t, nttResult, scalars)
NttBatch(&nttResult, false, count, 0)
assert.NotEqual(t, nttResult, scalars)
assert.Equal(t, nttResult, nttResult)
}
func TestNttBW6761CompareToGnarkDIF(t *testing.T) {
count := 1 << 2
scalars := GenerateScalars(count, false)
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
copy(nttResult, scalars)
assert.Equal(t, nttResult, scalars)
Ntt(&nttResult, false, 0)
assert.NotEqual(t, nttResult, scalars)
assert.Equal(t, nttResult, nttResult)
}
func TestINttBW6761CompareToGnarkDIT(t *testing.T) {
count := 1 << 3
scalars := GenerateScalars(count, false)
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
copy(nttResult, scalars)
assert.Equal(t, nttResult, scalars)
Ntt(&nttResult, true, 0)
assert.NotEqual(t, nttResult, scalars)
assert.Equal(t, nttResult, nttResult)
}
func TestNttBW6761(t *testing.T) {
count := 1 << 3
scalars := GenerateScalars(count, false)
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
copy(nttResult, scalars)
assert.Equal(t, nttResult, scalars)
Ntt(&nttResult, false, 0)
assert.NotEqual(t, nttResult, scalars)
inttResult := make([]G1ScalarField, len(nttResult))
copy(inttResult, nttResult)
assert.Equal(t, inttResult, nttResult)
Ntt(&inttResult, true, 0)
assert.Equal(t, inttResult, scalars)
}
func TestNttBatchBW6761(t *testing.T) {
count := 1 << 5
batches := 4
scalars := GenerateScalars(count*batches, false)
var scalarVecOfVec [][]G1ScalarField = make([][]G1ScalarField, 0)
for i := 0; i < batches; i++ {
start := i * count
end := (i + 1) * count
batch := make([]G1ScalarField, len(scalars[start:end]))
copy(batch, scalars[start:end])
scalarVecOfVec = append(scalarVecOfVec, batch)
}
nttBatchResult := make([]G1ScalarField, len(scalars))
copy(nttBatchResult, scalars)
NttBatch(&nttBatchResult, false, count, 0)
var nttResultVecOfVec [][]G1ScalarField
for i := 0; i < batches; i++ {
// Clone the slice
clone := make([]G1ScalarField, len(scalarVecOfVec[i]))
copy(clone, scalarVecOfVec[i])
// Add it to the result vector of vectors
nttResultVecOfVec = append(nttResultVecOfVec, clone)
// Call the ntt_bw6_761 function
Ntt(&nttResultVecOfVec[i], false, 0)
}
assert.NotEqual(t, nttBatchResult, scalars)
// Check that the ntt of each vec of scalars is equal to the intt of the specific batch
for i := 0; i < batches; i++ {
if !reflect.DeepEqual(nttResultVecOfVec[i], nttBatchResult[i*count:((i+1)*count)]) {
t.Errorf("ntt of vec of scalars not equal to intt of specific batch")
}
}
}
func BenchmarkNTT(b *testing.B) {
LOG_NTT_SIZES := []int{12, 15, 20, 21, 22, 23, 24, 25, 26}
for _, logNTTSize := range LOG_NTT_SIZES {
nttSize := 1 << logNTTSize
b.Run(fmt.Sprintf("NTT %d", logNTTSize), func(b *testing.B) {
scalars := GenerateScalars(nttSize, false)
nttResult := make([]G1ScalarField, len(scalars)) // Make a new slice with the same length
copy(nttResult, scalars)
for n := 0; n < b.N; n++ {
Ntt(&nttResult, false, 0)
}
})
}
}