Files
prysm/encoding/ssz/query/generalized_index.go
Bastin 92bd211e4d upgrade v6 to v7 (#15989)
* upgrade v6 to v7

* changelog

* update-go-ssz
2025-11-06 16:16:23 +00:00

322 lines
11 KiB
Go

package query
import (
"errors"
"fmt"
"github.com/OffchainLabs/prysm/v7/encoding/ssz"
)
const listBaseIndex = 2
// GetGeneralizedIndexFromPath calculates the generalized index for a given path.
// To calculate the generalized index, two inputs are needed:
// 1. The sszInfo of the root object, to be able to navigate the SSZ structure
// 2. The path to the field (e.g., "field_a.field_b[3].field_c")
// It walks the path step by step, updating the generalized index at each step.
func GetGeneralizedIndexFromPath(info *SszInfo, path Path) (uint64, error) {
if info == nil {
return 0, errors.New("SszInfo is nil")
}
// If path is empty, no generalized index can be computed.
if len(path.Elements) == 0 {
return 0, errors.New("cannot compute generalized index for an empty path")
}
// Starting from the root generalized index
currentIndex := uint64(1)
currentInfo := info
for index, pathElement := range path.Elements {
element := pathElement
// Check that we are in a container to access fields
if currentInfo.sszType != Container {
return 0, fmt.Errorf("indexing requires a container field step first, got %s", currentInfo.sszType)
}
// Retrieve the field position and SSZInfo for the field in the current container
fieldPos, fieldSsz, err := getContainerFieldByName(currentInfo, element.Name)
if err != nil {
return 0, fmt.Errorf("container field %s not found: %w", element.Name, err)
}
// Get the chunk count for the current container
chunkCount, err := getChunkCount(currentInfo)
if err != nil {
return 0, fmt.Errorf("chunk count error: %w", err)
}
// Update the generalized index to point to the specified field
currentIndex = currentIndex*nextPowerOfTwo(chunkCount) + fieldPos
currentInfo = fieldSsz
// Check for length access: element is the last in the path and requests length
if path.Length && index == len(path.Elements)-1 {
currentInfo, currentIndex, err = calculateLengthGeneralizedIndex(fieldSsz, element, currentIndex)
if err != nil {
return 0, fmt.Errorf("length calculation error: %w", err)
}
continue
}
if element.Index == nil {
continue
}
switch fieldSsz.sszType {
case List:
currentInfo, currentIndex, err = calculateListGeneralizedIndex(fieldSsz, element, currentIndex)
if err != nil {
return 0, fmt.Errorf("list calculation error: %w", err)
}
case Vector:
currentInfo, currentIndex, err = calculateVectorGeneralizedIndex(fieldSsz, element, currentIndex)
if err != nil {
return 0, fmt.Errorf("vector calculation error: %w", err)
}
case Bitlist:
currentInfo, currentIndex, err = calculateBitlistGeneralizedIndex(fieldSsz, element, currentIndex)
if err != nil {
return 0, fmt.Errorf("bitlist calculation error: %w", err)
}
case Bitvector:
currentInfo, currentIndex, err = calculateBitvectorGeneralizedIndex(fieldSsz, element, currentIndex)
if err != nil {
return 0, fmt.Errorf("bitvector calculation error: %w", err)
}
default:
return 0, fmt.Errorf("indexing not supported for type %s", fieldSsz.sszType)
}
}
return currentIndex, nil
}
// getContainerFieldByName finds a container field by its name
// and returns its index and SSZInfo.
func getContainerFieldByName(info *SszInfo, fieldName string) (uint64, *SszInfo, error) {
containerInfo, err := info.ContainerInfo()
if err != nil {
return 0, nil, err
}
for index, name := range containerInfo.order {
if name == fieldName {
fieldInfo := containerInfo.fields[name]
if fieldInfo == nil || fieldInfo.sszInfo == nil {
return 0, nil, fmt.Errorf("field %s has no ssz info", name)
}
return uint64(index), fieldInfo.sszInfo, nil
}
}
return 0, nil, fmt.Errorf("field %s not found", fieldName)
}
// Helpers for Generalized Index calculation per type
// calculateLengthGeneralizedIndex calculates the generalized index for a length field.
// note: length fields are only valid for List and Bitlist types. Multi-dimensional arrays are not supported.
// Returns:
// - its descendant SSZInfo (length field i.e. uint64)
// - its generalized index.
func calculateLengthGeneralizedIndex(fieldSsz *SszInfo, element PathElement, parentIndex uint64) (*SszInfo, uint64, error) {
if element.Index != nil {
return nil, 0, fmt.Errorf("len() is not supported for multi-dimensional arrays")
}
// Length field is only valid for List and Bitlist types
if fieldSsz.sszType != List && fieldSsz.sszType != Bitlist {
return nil, 0, fmt.Errorf("len() is only supported for List and Bitlist types, got %s", fieldSsz.sszType)
}
// Length is a uint64 per SSZ spec
currentInfo := &SszInfo{sszType: Uint64}
lengthIndex := parentIndex*2 + 1
return currentInfo, lengthIndex, nil
}
// calculateListGeneralizedIndex calculates the generalized index for a list element.
// Returns:
// - its descendant SSZInfo (list element)
// - its generalized index.
func calculateListGeneralizedIndex(fieldSsz *SszInfo, element PathElement, parentIndex uint64) (*SszInfo, uint64, error) {
li, err := fieldSsz.ListInfo()
if err != nil {
return nil, 0, fmt.Errorf("list info error: %w", err)
}
elem, err := li.Element()
if err != nil {
return nil, 0, fmt.Errorf("list element error: %w", err)
}
if *element.Index >= li.Limit() {
return nil, 0, fmt.Errorf("index %d out of bounds for list with limit %d", *element.Index, li.Limit())
}
// Compute chunk position for the element
var chunkPos uint64
if elem.sszType.isBasic() {
start := *element.Index * itemLength(elem)
chunkPos = start / ssz.BytesPerChunk
} else {
chunkPos = *element.Index
}
innerChunkCount, err := getChunkCount(fieldSsz)
if err != nil {
return nil, 0, fmt.Errorf("chunk count error: %w", err)
}
// root = root * base_index * pow2ceil(chunk_count(container)) + fieldPos
listIndex := parentIndex*listBaseIndex*nextPowerOfTwo(innerChunkCount) + chunkPos
currentInfo := elem
return currentInfo, listIndex, nil
}
// calculateVectorGeneralizedIndex calculates the generalized index for a vector element.
// Returns:
// - its descendant SSZInfo (vector element)
// - its generalized index.
func calculateVectorGeneralizedIndex(fieldSsz *SszInfo, element PathElement, parentIndex uint64) (*SszInfo, uint64, error) {
vi, err := fieldSsz.VectorInfo()
if err != nil {
return nil, 0, fmt.Errorf("vector info error: %w", err)
}
elem, err := vi.Element()
if err != nil {
return nil, 0, fmt.Errorf("vector element error: %w", err)
}
if *element.Index >= vi.Length() {
return nil, 0, fmt.Errorf("index %d out of bounds for vector with length %d", *element.Index, vi.Length())
}
var chunkPos uint64
if elem.sszType.isBasic() {
start := *element.Index * itemLength(elem)
chunkPos = start / ssz.BytesPerChunk
} else {
chunkPos = *element.Index
}
innerChunkCount, err := getChunkCount(fieldSsz)
if err != nil {
return nil, 0, fmt.Errorf("chunk count error: %w", err)
}
vectorIndex := parentIndex*nextPowerOfTwo(innerChunkCount) + chunkPos
currentInfo := elem
return currentInfo, vectorIndex, nil
}
// calculateBitlistGeneralizedIndex calculates the generalized index for a bitlist element.
// Returns:
// - its descendant SSZInfo (bitlist element i.e. a boolean)
// - its generalized index.
func calculateBitlistGeneralizedIndex(fieldSsz *SszInfo, element PathElement, parentIndex uint64) (*SszInfo, uint64, error) {
// Bits packed into 256-bit chunks; select the chunk containing the bit
chunkPos := *element.Index / ssz.BitsPerChunk
innerChunkCount, err := getChunkCount(fieldSsz)
if err != nil {
return nil, 0, fmt.Errorf("chunk count error: %w", err)
}
bitlistIndex := parentIndex*listBaseIndex*nextPowerOfTwo(innerChunkCount) + chunkPos
// Bits element is not further descendable; set to basic to guard further steps
currentInfo := &SszInfo{sszType: Boolean}
return currentInfo, bitlistIndex, nil
}
// calculateBitvectorGeneralizedIndex calculates the generalized index for a bitvector element.
// Returns:
// - its descendant SSZInfo (bitvector element i.e. a boolean)
// - its generalized index.
func calculateBitvectorGeneralizedIndex(fieldSsz *SszInfo, element PathElement, parentIndex uint64) (*SszInfo, uint64, error) {
chunkPos := *element.Index / ssz.BitsPerChunk
innerChunkCount, err := getChunkCount(fieldSsz)
if err != nil {
return nil, 0, fmt.Errorf("chunk count error: %w", err)
}
bitvectorIndex := parentIndex*nextPowerOfTwo(innerChunkCount) + chunkPos
// Bits element is not further descendable; set to basic to guard further steps
currentInfo := &SszInfo{sszType: Boolean}
return currentInfo, bitvectorIndex, nil
}
// Helper functions from SSZ spec
// itemLength calculates the byte length of an SSZ item based on its type information.
// For basic SSZ types (uint8, uint16, uint32, uint64, bool, etc.), it returns the actual
// size of the type in bytes. For compound types (containers, lists, vectors), it returns
// BytesPerChunk which represents the standard SSZ chunk size (32 bytes) used for
// Merkle tree operations in the SSZ serialization format.
func itemLength(info *SszInfo) uint64 {
if info.sszType.isBasic() {
return info.Size()
}
return ssz.BytesPerChunk
}
// nextPowerOfTwo computes the next power of two greater than or equal to v.
func nextPowerOfTwo(v uint64) uint64 {
v--
v |= v >> 1
v |= v >> 2
v |= v >> 4
v |= v >> 8
v |= v >> 16
v++
return uint64(v)
}
// getChunkCount returns the number of chunks for the given SSZInfo (equivalent to chunk_count in the spec)
func getChunkCount(info *SszInfo) (uint64, error) {
switch info.sszType {
case Uint8, Uint16, Uint32, Uint64, Boolean:
return 1, nil
case Container:
containerInfo, err := info.ContainerInfo()
if err != nil {
return 0, err
}
return uint64(len(containerInfo.fields)), nil
case List:
listInfo, err := info.ListInfo()
if err != nil {
return 0, err
}
elementInfo, err := listInfo.Element()
if err != nil {
return 0, err
}
elemLength := itemLength(elementInfo)
return (listInfo.Limit()*elemLength + 31) / ssz.BytesPerChunk, nil
case Vector:
vectorInfo, err := info.VectorInfo()
if err != nil {
return 0, err
}
elementInfo, err := vectorInfo.Element()
if err != nil {
return 0, err
}
elemLength := itemLength(elementInfo)
return (vectorInfo.Length()*elemLength + 31) / ssz.BytesPerChunk, nil
case Bitlist:
bitlistInfo, err := info.BitlistInfo()
if err != nil {
return 0, err
}
return (bitlistInfo.Limit() + 255) / ssz.BitsPerChunk, nil // Bits are packed into 256-bit chunks
case Bitvector:
bitvectorInfo, err := info.BitvectorInfo()
if err != nil {
return 0, err
}
return (bitvectorInfo.Length() + 255) / ssz.BitsPerChunk, nil // Bits are packed into 256-bit chunks
default:
return 0, errors.New("unsupported SSZ type for chunk count calculation")
}
}