Files
linea-monorepo/prover/circuits/internal/utils.go
Arya Tabaie fdd84f24e5 refactor: Use new GKR API (#1064)
* use new gkr API

* fix Table pointers

* refactor: remove removed test engine option

* chore: don't initialize struct for interface assertion

* refactor: plonk-in-wizard hardcoded over U64 for now

* refactor: use new gnark-crypto stateless RSis API

* test: disable incompatible tests

* chore: go mod update to PR tip

* chore: dependency against gnark master

* chore: cherry-pick 43141fc13d

* test: cherry pick test from 407d2e25ecfc32f5ed702ab42e5b829d7cabd483

* chore: remove magic values

* chore: update go version in Docker builder to match go.mod

---------

Co-authored-by: Ivo Kubjas <ivo.kubjas@consensys.net>
2025-06-09 14:17:34 +02:00

914 lines
26 KiB
Go

package internal
import (
"encoding/binary"
"errors"
"math/big"
"slices"
fr377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
fr381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr"
"github.com/consensys/gnark-crypto/hash"
hint "github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/compress"
snarkHash "github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/hash/mimc"
"github.com/consensys/gnark/std/lookup/logderivlookup"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/linea-monorepo/prover/circuits/internal/plonk"
"golang.org/x/exp/constraints"
)
// @reviewer: let's discuss which of these may be good candidates for gnark/std
// and if so, the ideal API for them
// AssertEqualIf asserts cond ≠ 0 ⇒ (a == b)
func AssertEqualIf(api frontend.API, cond, a, b frontend.Variable) {
// in r1cs it's more efficient to do api.AssertIsEqual(0, api.Mul(cond, api.Sub(a, b))) but we don't care about that
// and the following is better for debugging
api.AssertIsEqual(api.Mul(cond, a), api.Mul(cond, b))
}
// AssertIsLessIf asserts cond ≠ 0 ⇒ (a < b)
func AssertIsLessIf(api frontend.API, cond, a, b frontend.Variable) {
var (
condIsNonZero = api.Sub(1, api.IsZero(cond))
a_ = api.Mul(condIsNonZero, api.Add(a, 1))
b_ = api.Mul(condIsNonZero, b)
)
api.AssertIsLessOrEqual(a_, b_)
}
func SliceToTable(api frontend.API, slice []frontend.Variable) logderivlookup.Table {
table := logderivlookup.New(api)
for i := range slice {
table.Insert(slice[i])
}
return table
}
func AssertSliceEquals(api frontend.API, a, b []frontend.Variable) {
api.AssertIsEqual(len(a), len(b))
for i := range a {
api.AssertIsEqual(a[i], b[i])
}
}
type Range struct {
InRange, IsLast, IsFirstBeyond []frontend.Variable
api frontend.API
}
type NewRangeOption func(*bool)
func NoCheck(b *bool) {
*b = false
}
func NewRange(api frontend.API, n frontend.Variable, max int, opts ...NewRangeOption) *Range {
if max < 0 {
panic("negative maximum not allowed")
}
check := true
for _, o := range opts {
o(&check)
}
if max == 0 {
if check {
api.AssertIsEqual(n, 0)
}
return &Range{api: api}
}
inRange := make([]frontend.Variable, max)
isLast := make([]frontend.Variable, max)
isFirstBeyond := make([]frontend.Variable, max)
prevInRange := frontend.Variable(1)
for i := range isFirstBeyond {
isFirstBeyond[i] = api.IsZero(api.Sub(i, n))
prevInRange = api.Sub(prevInRange, isFirstBeyond[i])
inRange[i] = prevInRange
if i != 0 {
isLast[i-1] = isFirstBeyond[i]
}
}
isLast[max-1] = api.IsZero(api.Sub(max, n))
if check {
// if the last element is still in range, it must be the last, meaning isLast = 1 = inRange, otherwise n > max
// if the last element is not in range, it already means n is in range and we don't need to check anything, but isLast = 0 = inRange will be the case anyway
api.AssertIsEqual(isLast[max-1], inRange[max-1])
}
return &Range{inRange, isLast, isFirstBeyond, api}
}
// AssertEqualI i ∈ [0, n) ⇒ a == b with n as given to the constructor
func (r *Range) AssertEqualI(i int, a, b frontend.Variable) {
AssertEqualIf(r.api, r.InRange[i], a, b)
}
// AssertEqualLastI i == n-1 ⇒ a == b with n as given to the constructor
func (r *Range) AssertEqualLastI(i int, a, b frontend.Variable) {
AssertEqualIf(r.api, r.IsLast[i], a, b)
}
// AddIfLastI returns accumulator + IsLast[i] * atI
func (r *Range) AddIfLastI(i int, accumulator, atI frontend.Variable) frontend.Variable {
toAdd := r.api.Mul(r.IsLast[i], atI)
if accumulator == nil {
return toAdd
}
return r.api.Add(accumulator, toAdd)
}
func (r *Range) AssertArrays32Equal(a, b [][32]frontend.Variable) {
// TODO generify when array length parameters become available
r.api.AssertIsEqual(len(a), len(b))
for i := range a {
for j := range a[i] {
r.AssertEqualI(i, a[i][j], b[i][j])
}
}
}
func (r *Range) LastArray32(slice [][32]frontend.Variable) [32]frontend.Variable {
return r.LastArray32F(func(i int) [32]frontend.Variable { return slice[i] })
}
func (r *Range) LastArray32F(provider func(int) [32]frontend.Variable) [32]frontend.Variable {
// TODO generify when array length parameters become available
var res [32]frontend.Variable
for i := 0; i < len(r.InRange); i++ {
for j := 0; j < 32; j++ {
res[j] = r.AddIfLastI(i, res[j], provider(i)[j])
}
}
return res
}
func RegisterHints() {
hint.RegisterHint(toCrumbsHint, concatHint, checksumSubSlicesHint, partitionSliceHint, divEuclideanHint)
}
func toCrumbsHint(_ *big.Int, ins, outs []*big.Int) error {
if len(ins) != 1 {
return errors.New("expected 1 input")
}
if ins[0].Cmp(big.NewInt(0)) < 0 {
return errors.New("input is negative")
}
if ins[0].Cmp(new(big.Int).Lsh(big.NewInt(1), 2*uint(len(outs)))) >= 0 {
return errors.New("input exceeds the expected number of crumbs")
}
inLen := (len(outs) + 3) / 4
in := ins[0].Bytes()
in = append(make([]byte, inLen-len(in), inLen), in...) // zero pad to the left
slices.Reverse(in) // to little-endian
for i := range outs {
outs[i].SetUint64(uint64(in[0] & 3))
in[0] >>= 2
if i%4 == 3 {
in = in[1:]
}
}
return nil
}
// TODO add to gnark: bits.ToBase
// ToCrumbs decomposes scalar v into nbCrumbs 2-bit digits.
// It uses Little Endian order for compatibility with gnark, even though we use Big Endian order in the circuit
func ToCrumbs(api frontend.API, v frontend.Variable, nbCrumbs int) []frontend.Variable {
res, err := api.Compiler().NewHint(toCrumbsHint, nbCrumbs, v)
if err != nil {
panic(err)
}
for _, c := range res {
api.AssertIsCrumb(c)
}
return res
}
// PackedBytesToCrumbs converts a slice of bytes, padded with zeros on the left to make bitsPerElem bits field elements, into a slice of two-bit crumbs
// panics if bitsPerElem is not a multiple of 2
func PackedBytesToCrumbs(api frontend.API, bytes []frontend.Variable, bitsPerElem int) []frontend.Variable {
crumbsPerElem := bitsPerElem / 2
if bitsPerElem != 2*crumbsPerElem {
panic("packing size must be a multiple of 2")
}
bytesPerElem := (bitsPerElem + 7) / 8
firstByteNbCrumbs := crumbsPerElem % 4
if firstByteNbCrumbs == 0 {
firstByteNbCrumbs = 4
}
nbElems := (len(bytes) + bytesPerElem - 1) / bytesPerElem
if nbElems*bytesPerElem != len(bytes) { // pad with zeros if necessary
tmp := bytes
bytes = make([]frontend.Variable, nbElems*bytesPerElem)
copy(bytes, tmp)
for i := len(tmp); i < len(bytes); i++ {
bytes[i] = 0
}
}
res := make([]frontend.Variable, 0, nbElems*crumbsPerElem)
for i := 0; i < len(bytes); i += bytesPerElem {
// first byte
b := ToCrumbs(api, bytes[i], firstByteNbCrumbs)
slices.Reverse(b)
res = append(res, b...)
// remaining bytes
for j := 1; j < bytesPerElem; j++ {
b = ToCrumbs(api, bytes[i+j], 4)
slices.Reverse(b)
res = append(res, b...)
}
}
return res
}
func (r *Range) StaticLength() int {
return len(r.InRange)
}
func MimcHash(api frontend.API, e ...frontend.Variable) frontend.Variable {
hsh, err := mimc.NewMiMC(api)
if err != nil {
panic(err)
}
hsh.Write(e...)
return hsh.Sum()
}
type Slice[T any] struct {
// @reviewer: better for slice to be non-generic with frontend.Variable type and just duplicate the funcs for the few [32]frontend.Variable applications?
Values []T // Values[:Length] contains the data
Length frontend.Variable
}
func (s VarSlice) Range(api frontend.API) *Range {
return NewRange(api, s.Length, len(s.Values))
}
// Concat does not range check the slice lengths individually, but it does their sum
// all slices have to be nonempty. This is not checked either
// Runtime in the order of maxLinLength + len(slices) * max_i(len(slices[i])
// it does not perform well when one of the slices is statically much longer than the others
func Concat(api frontend.API, maxLinearizedLength int, slices ...VarSlice) Slice[frontend.Variable] {
res := Slice[frontend.Variable]{make([]frontend.Variable, maxLinearizedLength), 0}
var outT logderivlookup.Table
{ // hint
inLen := 2 * len(slices)
for i := range slices {
inLen += len(slices[i].Values)
}
in := make([]frontend.Variable, inLen)
i := 0
for _, s := range slices {
in[i], in[i+1] = len(s.Values), s.Length
copy(in[i+2:], s.Values)
i += 2 + len(s.Values)
}
out, err := api.Compiler().NewHint(concatHint, maxLinearizedLength, in...)
if err != nil {
panic(err)
}
copy(res.Values, out)
outT = SliceToTable(api, out)
outT.Insert(0)
}
for _, s := range slices {
r := NewRange(api, s.Length, len(s.Values))
for j := range s.Values {
AssertEqualIf(api, r.InRange[j], outT.Lookup(res.Length)[0], s.Values[j])
res.Length = api.Add(res.Length, r.InRange[j])
}
}
api.AssertIsDifferent(res.Length, maxLinearizedLength+1) // the only possible overflow without a range error on the lookup side
return res
}
// ins = [maxLen(s[0]), len(s[0]), s[0][0], s[0][1], ..., maxLen(s[1]), len(s[1]), s[1][0], ...]
func concatHint(_ *big.Int, ins, outs []*big.Int) error {
for len(ins) != 0 {
m := ins[0].Uint64()
l := ins[1].Uint64()
if !ins[0].IsUint64() || !ins[1].IsUint64() || l > m {
return errors.New("unacceptable lengths")
}
for i := uint64(0); i < l; i++ {
outs[i].Set(ins[2+i])
}
ins = ins[2+m:]
outs = outs[l:]
}
for i := range outs {
outs[i].SetUint64(0)
}
return nil
}
// SubSlice DOES NOT range check start and length, but it does end := start + length
func SubSlice(api frontend.API, slice Slice[frontend.Variable], start, length frontend.Variable, maxLength int) Slice[frontend.Variable] {
if maxLength < 0 || maxLength > len(slice.Values) {
panic("invalid maxLength")
}
res := make([]frontend.Variable, maxLength)
t := SliceToTable(api, slice.Values)
/*for i := 0; i < maxLength; i++ { TODO make sure padding is unnecessary
t.Insert(0)
}*/
end := api.Add(start, length)
api.AssertIsLessOrEqual(end, slice.Length)
r := NewRange(api, length, maxLength)
for i := range res {
res[i] = api.Mul(r.InRange[i], t.Lookup(api.Add(start, i))[0])
}
return Slice[frontend.Variable]{res, length}
}
// ChecksumSlice returns H(len(slice), partial) where partial is slice[0] if length is 1, and hsh(slice[0], slice[1], ..., slice[len(slice)-1]) otherwise
// Proof of collision resistance: By the collision resistance of H, we may assume that f(a_1, ..., a_n) = f(b_1, ..., b_m) implies
// m=1 and f(a_1,...,a_n) = f(b) = b. If n = 1, then f(a) = a and thus a = b.
// Otherwise, n ≠ 1 implies H(n, f(a_1,...,a_n)) ≠ H(1, b).
func ChecksumSlice(slice [][]byte) []byte {
if len(slice) == 0 {
panic("empty slice not supported")
}
partial := slice[0]
hsh := hash.MIMC_BLS12_377.New()
for i := 1; i < len(slice); i++ {
hsh.Reset()
hsh.Write(partial)
hsh.Write(slice[i])
partial = hsh.Sum(nil)
}
hsh.Reset()
n := Uint64To32Bytes(uint64(len(slice)))
hsh.Write(n[:])
hsh.Write(partial)
return hsh.Sum(nil)
}
// TODO test that ChecksumSliceSnark matches ChecksumSubSlices
type VarSlice Slice[frontend.Variable]
type Var32Slice Slice[[32]frontend.Variable]
// Checksum is the SNARK equivalent of ChecksumSlice
// TODO consider doing (r *Range) f (slice []frontend.Variable)
func (s VarSlice) Checksum(api frontend.API, hsh snarkHash.FieldHasher) frontend.Variable {
if len(s.Values) == 0 {
panic("zero-length input")
}
api.AssertIsDifferent(s.Length, 0)
r := NewRange(api, s.Length, len(s.Values))
hsh.Reset()
sum := s.Values[0]
for i := 1; i < len(s.Values); i++ {
hsh.Reset()
hsh.Write(sum, s.Values[i])
sum = api.Select(r.InRange[i], hsh.Sum(), sum)
}
hsh.Reset()
hsh.Write(s.Length, sum)
return hsh.Sum()
}
// ChecksumSubSlices returns the hash of consecutive sub-slices, compatible with ChecksumSliceSnark
// the endpoints are not range checked and are assumed to be strictly increasing, and that endpoint[0] ≠ 0 and endpoint[endpoints.Length-1] \leq len(slice)
// due to its hint, it only works with MiMC in BLS12-377
func ChecksumSubSlices(api frontend.API, hsh snarkHash.FieldHasher, slice []frontend.Variable, subEndPoints VarSlice) []frontend.Variable {
lastElems := make([]frontend.Variable, len(subEndPoints.Values))
endpointsR := subEndPoints.Range(api)
for i, e := range subEndPoints.Values {
// inRange ? e - 1 : l = l + inRange * (e- 1 -l) = inRange * e - (1+l) * inRange + l
lastElems[i] =
plonk.EvaluateExpression(api, e, endpointsR.InRange[i], 0, -1-len(slice), 1, len(slice))
}
var lastElemsT logderivlookup.Table
if len(slice) > 1 {
lastElemsT = SliceToTable(api, lastElems)
lastElemsT.Insert(len(slice)) // in case subEndPoints is tight
}
hintIn := make([]frontend.Variable, len(lastElems)+len(slice))
copy(hintIn, lastElems)
copy(hintIn[len(lastElems):], slice)
res, err := api.Compiler().NewHint(checksumSubSlicesHint, len(subEndPoints.Values), hintIn...)
if err != nil {
panic(err)
}
resT := SliceToTable(api, res)
resT.Insert(0) // in case subEndPoints is tight
var subI, workingHash, incrementI frontend.Variable
subI, workingHash, incrementI = 0, slice[0], api.IsZero(api.Sub(subEndPoints.Values[0], 1))
for i := 1; i < len(slice); i++ {
// if we are done with this hash, check it
AssertEqualIf(api, incrementI, workingHash, resT.Lookup(subI)[0])
subI = api.Add(subI, incrementI)
hsh.Reset()
hsh.Write(workingHash, slice[i])
workingHash = api.Select(incrementI, slice[i], hsh.Sum()) // if this is a new checksum, we will just take the raw value
incrementI = api.IsZero(api.Sub(i, lastElemsT.Lookup(subI)[0]))
}
// if we are done with this hash, check it (needed in case subEndPoints is tight)
AssertEqualIf(api, incrementI, workingHash, resT.Lookup(subI)[0])
subI = api.Add(subI, incrementI)
api.AssertIsEqual(subI, subEndPoints.Length)
lengths := Differences(api, subEndPoints.Values)
for i := range res {
hsh.Reset()
hsh.Write(lengths[i], res[i])
res[i] = hsh.Sum()
}
return res
}
// the in has the format [subEndPoints..., slice...]
// TODO figure out the expected behavior for outs past the end of subEndPoints
func checksumSubSlicesHint(_ *big.Int, ins, outs []*big.Int) error {
subLastPoints := ins[:len(outs)]
slice := ins[len(outs):]
sliceAt := func(i int64) []byte {
res := slice[i].Bytes()
if len(res) == 0 {
return []byte{0} // the mimc hash impl ignores empty input
}
return res
}
hsh := hash.MIMC_BLS12_377.New()
var (
first int64
i int
)
for ; i < len(outs); i++ {
last := subLastPoints[i].Int64()
if last >= int64(len(slice)) {
break
}
out := sliceAt(first)
for j := first + 1; j <= last; j++ { // TODO just do a loop of "writes" as this is how mimc computes long hashes anyway
hsh.Reset()
hsh.Write(out)
hsh.Write(sliceAt(j))
out = hsh.Sum(nil)
}
outs[i].SetBytes(out)
first = last + 1
}
for ; i < len(outs); i++ {
outs[i].SetUint64(0)
}
return nil
}
// PartialSums returns a slice of the same length as slice, where res[i] = slice[0] + ... + slice[i]. Out of range values are excluded.
func (r *Range) PartialSums(slice []frontend.Variable) []frontend.Variable {
return r.PartialSumsF(func(i int) frontend.Variable { return slice[i] })
}
func (r *Range) PartialSumsF(provider func(int) frontend.Variable) []frontend.Variable {
if len(r.InRange) == 0 {
return nil
}
res := make([]frontend.Variable, len(r.InRange))
res[0] = r.api.Mul(provider(0), r.InRange[0])
for i := 1; i < len(r.InRange); i++ {
res[i] = r.api.Add(r.api.Mul(provider(i), r.InRange[i]), res[i-1])
}
return res
}
func (r *Range) LastF(provider func(i int) frontend.Variable) frontend.Variable {
if len(r.IsLast) == 0 {
return 0
}
res := r.api.Mul(r.IsLast[0], provider(0))
for i := 1; i < len(r.IsLast); i++ {
res = r.api.Add(res, r.api.Mul(r.IsLast[i], provider(i)))
}
return res
}
// PackFull packs as many words as possible into a single field element
// The words are construed in big-endian, and 0 padding is added as needed on the left for every element and on the right for the last element
func PackFull(api frontend.API, words []frontend.Variable, bitsPerWord int) []frontend.Variable {
return Pack(api, words, api.Compiler().FieldBitLen()-1, bitsPerWord)
}
func Pack(api frontend.API, words []frontend.Variable, bitsPerElem, bitsPerWord int) []frontend.Variable {
if bitsPerWord > bitsPerElem {
panic("words don't fit in elements")
}
wordsPerElem := bitsPerElem / bitsPerWord
res := make([]frontend.Variable, (len(words)+wordsPerElem-1)/wordsPerElem)
if len(words) != len(res)*wordsPerElem {
tmp := words
words = make([]frontend.Variable, len(res)*wordsPerElem)
copy(words, tmp)
for i := len(tmp); i < len(words); i++ {
words[i] = 0
}
}
// TODO add this to gnark: bits.FromBase
coeffs := make([]*big.Int, wordsPerElem)
for i := range coeffs {
coeffs[len(coeffs)-1-i] = new(big.Int).Lsh(big.NewInt(1), uint(i*bitsPerWord))
}
// TODO use compress.ReadNum?
for i := range res {
currWords := words[i*wordsPerElem : (i+1)*wordsPerElem]
res[i] = api.Mul(coeffs[0], currWords[0]) // TODO once the "add 0" optimization is implemented in gnark, remove this line
for j := 1; j < len(currWords); j++ {
res[i] = api.MulAcc(res[i], coeffs[j], currWords[j])
}
}
return res
}
func flatten(s [][32]frontend.Variable) []frontend.Variable {
res := make([]frontend.Variable, len(s)*32)
for i := range s {
for j := range s[i] {
res[i*32+j] = s[i][j]
}
}
return res
}
func (s Var32Slice) Checksum(api frontend.API) frontend.Variable {
values := PackFull(api, flatten(s.Values), 8)
valsAndLen := make([]frontend.Variable, 1, len(values)+1)
valsAndLen[0] = s.Length
return MimcHash(api, append(valsAndLen, values...)...)
}
func CombineBytesIntoElements(api frontend.API, b [32]frontend.Variable) [2]frontend.Variable {
r := big.NewInt(256)
return [2]frontend.Variable{
compress.ReadNum(api, b[:16], r),
compress.ReadNum(api, b[16:], r),
}
}
// Bls12381ScalarToBls12377Scalars interprets its input as a BLS12-381 scalar, with a modular reduction if necessary, returning two BLS12-377 scalars
// r[1] is the lower 252 bits. r[0] is the higher 3 bits.
// useful in circuit "assign" functions
func Bls12381ScalarToBls12377Scalars(v interface{}) (r [2][]byte, err error) {
var x fr381.Element
if _, err = x.SetInterface(v); err != nil {
return
}
b := x.Bytes()
r[0] = make([]byte, fr377.Bytes)
r[0][fr381.Bytes-1] = b[0] >> 4
b[0] &= 0x0f
r[1] = b[:]
return
}
// PartialSums returns s[0], s[0]+s[1], ..., s[0]+s[1]+...+s[len(s)-1]
func PartialSums(api frontend.API, s []frontend.Variable) []frontend.Variable {
res := make([]frontend.Variable, len(s))
res[0] = s[0]
for i := 1; i < len(s); i++ {
res[i] = api.Add(res[i-1], s[i])
}
return res
}
func Differences(api frontend.API, s []frontend.Variable) []frontend.Variable {
res := make([]frontend.Variable, len(s))
prev := frontend.Variable(0)
for i := range s {
res[i] = api.Sub(s[i], prev)
prev = s[i]
}
return res
}
func Sum[T constraints.Integer](x ...T) T {
var res T
for _, xI := range x {
res += xI
}
return res
}
func Uint64To32Bytes(i uint64) [32]byte {
var b [32]byte
binary.BigEndian.PutUint64(b[24:], i)
return b
}
func PartialSumsInt[T constraints.Integer](s []T) []T {
res := make([]T, len(s))
prev := T(0)
for i := range res {
prev += s[i]
res[i] = prev
}
return res
}
func MapSlice[X, Y any](f func(X) Y, x ...X) []Y {
y := make([]Y, len(x))
for i := range x {
y[i] = f(x[i])
}
return y
}
// Truncate ensures that the slice is 0 starting from the n-th element
func Truncate(api frontend.API, slice []frontend.Variable, n frontend.Variable) []frontend.Variable {
nYet := frontend.Variable(0)
res := make([]frontend.Variable, len(slice))
for i := range slice {
nYet = api.Add(nYet, api.IsZero(api.Sub(i, n)))
res[i] = api.MulAcc(api.Mul(1, slice[i]), slice[i], api.Neg(nYet))
}
return res
}
// RotateLeft rotates the slice v by n positions to the left, so that res[i] becomes v[(i+n)%len(v)]
func RotateLeft(api frontend.API, v []frontend.Variable, n frontend.Variable) (res []frontend.Variable) {
res = make([]frontend.Variable, len(v))
t := SliceToTable(api, v)
for _, x := range v {
t.Insert(x)
}
for i := range res {
res[i] = t.Lookup(api.Add(i, n))[0]
}
return
}
func CloneSlice[T any](s []T, cap ...int) []T {
res := make([]T, len(s), max(len(s), Sum(cap...)))
copy(res, s)
return res
}
// PartitionSlice populates sub-slices subs[0], ... where subs[i] contains the elements s[j] with selectors[j] = i
// There are no guarantee on the values in the subs past their actual lengths. The hint sets them to zero but PartitionSlice does not check that fact.
// It may produce an incorrect result if selectors are out of range
func PartitionSlice(api frontend.API, s []frontend.Variable, selectors []frontend.Variable, subs ...[]frontend.Variable) {
if len(s) != len(selectors) {
panic("s and selectors must have the same length")
}
hintIn := make([]frontend.Variable, 1+len(subs)+len(s)+len(selectors))
hintIn[0] = len(subs)
hintOutLen := 0
for i := range subs {
hintIn[1+i] = len(subs[i])
hintOutLen += len(subs[i])
}
for i := range s {
hintIn[1+len(subs)+i] = s[i]
hintIn[1+len(subs)+len(s)+i] = selectors[i]
}
subsGlued, err := api.Compiler().NewHint(partitionSliceHint, hintOutLen, hintIn...)
if err != nil {
panic(err)
}
subsT := make([]logderivlookup.Table, len(subs))
for i := range subs {
copy(subs[i], subsGlued[:len(subs[i])])
subsGlued = subsGlued[len(subs[i]):]
subsT[i] = SliceToTable(api, subs[i])
subsT[i].Insert(0)
}
subI := make([]frontend.Variable, len(subs))
for i := range subI {
subI[i] = 0
}
indicators := make([]frontend.Variable, len(subs))
subHeads := make([]frontend.Variable, len(subs))
for i := range s {
for j := range subs[:len(subs)-1] {
indicators[j] = api.IsZero(api.Sub(selectors[i], j))
}
indicators[len(subs)-1] = api.Sub(1, SumSnark(api, indicators[:len(subs)-1]...))
for j := range subs {
subHeads[j] = subsT[j].Lookup(subI[j])[0]
subI[j] = api.Add(subI[j], indicators[j])
}
api.AssertIsEqual(s[i], InnerProd(api, subHeads, indicators))
}
// Check that the dummy trailing values weren't actually picked
for i := range subI {
api.AssertIsDifferent(subI[i], len(subs[i])+1)
}
}
func SumSnark(api frontend.API, x ...frontend.Variable) frontend.Variable {
res := frontend.Variable(0)
for i := range x {
res = api.Add(res, x[i])
}
return res
}
// ins: [nbSubs, maxLen_0, ..., maxLen_{nbSubs-1}, s..., indicators...]
func partitionSliceHint(_ *big.Int, ins, outs []*big.Int) error {
subs := make([][]*big.Int, ins[0].Uint64())
for i := range subs {
subs[i] = outs[:ins[1+i].Uint64()]
outs = outs[len(subs[i]):]
}
if len(outs) != 0 {
return errors.New("the sum of subslice max lengths does not equal output length")
}
ins = ins[1+len(subs):]
s := ins[:len(ins)/2]
indicators := ins[len(s):]
if len(s) != len(indicators) {
return errors.New("s and indicators must be of the same length")
}
for i := range s {
b := indicators[i].Int64()
if b < 0 || b >= int64(len(subs)) || !indicators[i].IsUint64() {
return errors.New("indicator out of range")
}
subs[b][0] = s[i]
subs[b] = subs[b][1:]
}
for i := range subs {
for j := range subs[i] {
subs[i][j].SetInt64(0)
}
}
return nil
}
// PartitionSliceEmulated populates sub-slices subs[0], ... where subs[i] contains the elements s[j] with selectors[j] = i
// There are no guarantee on the values in the subs past their actual lengths. The hint sets them to zero but PartitionSlice does not check that fact.
// It may produce an incorrect result if selectors are out of range
func PartitionSliceEmulated[T emulated.FieldParams](api frontend.API, s []emulated.Element[T], selectors []frontend.Variable, subSliceMaxLens ...int) [][]emulated.Element[T] {
field, err := emulated.NewField[T](api)
if err != nil {
panic(err)
}
// transpose limbs for selection
limbs := make([][]frontend.Variable, len(s[0].Limbs)) // limbs are indexed limb first, element second
for i := range limbs {
limbs[i] = make([]frontend.Variable, len(s))
}
for i := range s {
if len(limbs) != len(s[i].Limbs) {
panic("expected uniform number of limbs")
}
for j := range limbs {
limbs[j][i] = s[i].Limbs[j]
}
}
subLimbs := make([][][]frontend.Variable, len(limbs)) // subLimbs is indexed limb first, sub-slice second, element third
for i := range limbs { // construct the sub-slices limb by limb
subLimbs[i] = make([][]frontend.Variable, len(subSliceMaxLens))
for j := range subSliceMaxLens {
subLimbs[i][j] = make([]frontend.Variable, subSliceMaxLens[j])
}
PartitionSlice(api, limbs[i], selectors, subLimbs[i]...)
}
// put the limbs back together
subSlices := make([][]emulated.Element[T], len(subSliceMaxLens))
for i := range subSlices {
subSlices[i] = make([]emulated.Element[T], subSliceMaxLens[i])
for j := range subSlices[i] {
currLimbs := make([]frontend.Variable, len(limbs))
for k := range currLimbs {
currLimbs[k] = subLimbs[k][i][j]
}
subSlices[i][j] = *field.NewElement(currLimbs) // TODO make sure dereferencing is not problematic
}
}
return subSlices
}
func InnerProd(api frontend.API, x, y []frontend.Variable) frontend.Variable {
if len(x) != len(y) {
panic("mismatched lengths")
}
res := frontend.Variable(0)
for i := range x {
res = api.Add(res, api.Mul(x[i], y[i]))
}
return res
}
func SelectMany(api frontend.API, c frontend.Variable, ifSo, ifNot []frontend.Variable) []frontend.Variable {
if len(ifSo) != len(ifNot) {
panic("incompatible lengths")
}
res := make([]frontend.Variable, len(ifSo))
for i := range res {
res[i] = api.Select(c, ifSo[i], ifNot[i])
}
return res
}
// DivEuclidean conventional integer division with a remainder
// TODO @Tabaie replace all/most special-case divisions with this, barring performance issues
func DivEuclidean(api frontend.API, a, b frontend.Variable) (quotient, remainder frontend.Variable) {
api.AssertIsDifferent(b, 0)
outs, err := api.Compiler().NewHint(divEuclideanHint, 2, a, b)
if err != nil {
panic(err)
}
quotient, remainder = outs[0], outs[1]
api.AssertIsLessOrEqual(remainder, api.Sub(b, 1))
api.AssertIsLessOrEqual(quotient, a)
return
}
func divEuclideanHint(_ *big.Int, ins, outs []*big.Int) error {
if len(ins) != 2 || len(outs) != 2 {
return errors.New("expected two inputs and two outputs")
}
a, b := ins[0], ins[1]
quotient, remainder := outs[0], outs[1]
quotient.Div(a, b)
remainder.Mod(a, b)
return nil
}