From 4eab41ea4ca770840bf2a00c7613a5d67c7a54cc Mon Sep 17 00:00:00 2001 From: Jun Song <87601811+syjn99@users.noreply.github.com> Date: Tue, 14 Oct 2025 18:33:52 +0100 Subject: [PATCH] SSZ-QL: use `fastssz`-generated `SizeSSZ` method & clarify `Size` method (#15864) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add SizeSSZ as a member of SSZObject * Temporarily rename dereferencePointer function * Fix analyzeType: use reflect.Value for analyzing * Fix PopulateVariableLengthInfo: change function signature & reset pointer * Remove Container arm for Size function as it'll be handled in the previous branch * Remove OffsetBytes function in listInfo * Refactor and document codes * Remove misleading "fixedSize" concept & Add Uint8...64 SSZTypes * Add size testing * Move TestSSZObject_Batch and rename it as TestHashTreeRoot * Changelog :) * Rename endOffset to fixedOffset --------- Co-authored-by: Radosław Kapka --- changelog/syjn99_ssz-ql-fastssz-sizessz.md | 7 + encoding/ssz/query/BUILD.bazel | 1 - encoding/ssz/query/analyzer.go | 272 ++++++++++++--------- encoding/ssz/query/analyzer_test.go | 17 -- encoding/ssz/query/bitvector.go | 9 + encoding/ssz/query/container.go | 6 +- encoding/ssz/query/list.go | 14 -- encoding/ssz/query/query_test.go | 134 +++++----- encoding/ssz/query/ssz_info.go | 54 ++-- encoding/ssz/query/ssz_object.go | 1 + encoding/ssz/query/ssz_type.go | 18 +- encoding/ssz/query/testutil/runner.go | 8 +- encoding/ssz/query/vector.go | 12 + 13 files changed, 309 insertions(+), 244 deletions(-) create mode 100644 changelog/syjn99_ssz-ql-fastssz-sizessz.md delete mode 100644 encoding/ssz/query/analyzer_test.go diff --git a/changelog/syjn99_ssz-ql-fastssz-sizessz.md b/changelog/syjn99_ssz-ql-fastssz-sizessz.md new file mode 100644 index 0000000000..7a3a539c2a --- /dev/null +++ b/changelog/syjn99_ssz-ql-fastssz-sizessz.md @@ -0,0 +1,7 @@ +### Added + +- SSZ-QL: Use `fastssz`'s `SizeSSZ` method for calculating the size of `Container` type. + +### Changed + +- SSZ-QL: Clarify `Size` method with more sophisticated `SSZType`s. diff --git a/encoding/ssz/query/BUILD.bazel b/encoding/ssz/query/BUILD.bazel index 837d17054c..79b8ad84c1 100644 --- a/encoding/ssz/query/BUILD.bazel +++ b/encoding/ssz/query/BUILD.bazel @@ -24,7 +24,6 @@ go_library( go_test( name = "go_default_test", srcs = [ - "analyzer_test.go", "path_test.go", "query_test.go", "tag_parser_test.go", diff --git a/encoding/ssz/query/analyzer.go b/encoding/ssz/query/analyzer.go index 52ba85ad6d..83a5b964cc 100644 --- a/encoding/ssz/query/analyzer.go +++ b/encoding/ssz/query/analyzer.go @@ -11,20 +11,17 @@ const offsetBytes = 4 // AnalyzeObject analyzes given object and returns its SSZ information. func AnalyzeObject(obj SSZObject) (*sszInfo, error) { - value := dereferencePointer(obj) + value := reflect.ValueOf(obj) - info, err := analyzeType(value.Type(), nil) + info, err := analyzeType(value, nil) if err != nil { return nil, fmt.Errorf("could not analyze type %s: %w", value.Type().Name(), err) } - // Store the original object interface - info.source = obj - // Populate variable-length information using the actual value. - err = PopulateVariableLengthInfo(info, value.Interface()) + err = PopulateVariableLengthInfo(info, value) if err != nil { - return nil, fmt.Errorf("could not populate variable length info: %w", err) + return nil, fmt.Errorf("could not populate variable length info for type %s: %w", value.Type().Name(), err) } return info, nil @@ -33,13 +30,13 @@ func AnalyzeObject(obj SSZObject) (*sszInfo, error) { // PopulateVariableLengthInfo populates runtime information for SSZ fields of variable-sized types. // This function updates the sszInfo structure with actual lengths and offsets that can only // be determined at runtime for variable-sized items like Lists and variable-sized Container fields. -func PopulateVariableLengthInfo(sszInfo *sszInfo, value any) error { +func PopulateVariableLengthInfo(sszInfo *sszInfo, value reflect.Value) error { if sszInfo == nil { return errors.New("sszInfo is nil") } - if value == nil { - return errors.New("value is nil") + if !value.IsValid() { + return errors.New("value is invalid") } // Short circuit: If the type is fixed-sized, we don't need to fill in the info. @@ -59,18 +56,18 @@ func PopulateVariableLengthInfo(sszInfo *sszInfo, value any) error { return errors.New("listInfo is nil") } - val := reflect.ValueOf(value) - if val.Kind() != reflect.Slice { - return fmt.Errorf("expected slice for List type, got %v", val.Kind()) + if value.Kind() != reflect.Slice { + return fmt.Errorf("expected slice for List type, got %v", value.Kind()) } - length := val.Len() + + length := value.Len() if listInfo.element.isVariable { listInfo.elementSizes = make([]uint64, 0, length) // Populate nested variable-sized type element lengths recursively. for i := range length { - if err := PopulateVariableLengthInfo(listInfo.element, val.Index(i).Interface()); err != nil { + if err := PopulateVariableLengthInfo(listInfo.element, value.Index(i)); err != nil { return fmt.Errorf("could not populate nested list element at index %d: %w", i, err) } listInfo.elementSizes = append(listInfo.elementSizes, listInfo.element.Size()) @@ -94,8 +91,7 @@ func PopulateVariableLengthInfo(sszInfo *sszInfo, value any) error { return errors.New("bitlistInfo is nil") } - val := reflect.ValueOf(value) - if err := bitlistInfo.SetLengthFromBytes(val.Bytes()); err != nil { + if err := bitlistInfo.SetLengthFromBytes(value.Bytes()); err != nil { return fmt.Errorf("could not set bitlist length from bytes: %w", err) } @@ -108,11 +104,21 @@ func PopulateVariableLengthInfo(sszInfo *sszInfo, value any) error { return fmt.Errorf("could not get container info: %w", err) } + if containerInfo == nil { + return errors.New("containerInfo is nil") + } + // Dereference first in case value is a pointer. derefValue := dereferencePointer(value) + if derefValue.Kind() != reflect.Struct { + return fmt.Errorf("expected struct for Container type, got %v", derefValue.Kind()) + } - // Start with the fixed size of this Container. - currentOffset := sszInfo.FixedSize() + // Reset the pointer to the new value. + sszInfo.source = castToSSZObject(derefValue) + + // Start with the end offset of this Container. + currentOffset := containerInfo.fixedOffset for _, fieldName := range containerInfo.order { fieldInfo := containerInfo.fields[fieldName] @@ -128,13 +134,15 @@ func PopulateVariableLengthInfo(sszInfo *sszInfo, value any) error { // Recursively populate variable-sized fields. fieldValue := derefValue.FieldByName(fieldInfo.goFieldName) - if err := PopulateVariableLengthInfo(childSszInfo, fieldValue.Interface()); err != nil { + if err := PopulateVariableLengthInfo(childSszInfo, fieldValue); err != nil { return fmt.Errorf("could not populate from value for field %s: %w", fieldName, err) } // Each variable-sized element needs an offset entry. - if childSszInfo.sszType == List { - currentOffset += childSszInfo.listInfo.OffsetBytes() + if listInfo, err := childSszInfo.ListInfo(); err == nil && listInfo != nil { + if listInfo.element.isVariable { + currentOffset += listInfo.Length() * offsetBytes + } } // Set the actual offset for variable-sized fields. @@ -149,66 +157,64 @@ func PopulateVariableLengthInfo(sszInfo *sszInfo, value any) error { } } -// analyzeType is an entry point that inspects a reflect.Type and computes its SSZ layout information. -func analyzeType(typ reflect.Type, tag *reflect.StructTag) (*sszInfo, error) { - switch typ.Kind() { +// analyzeType is an entry point that inspects a reflect.Value and computes its SSZ layout information. +func analyzeType(value reflect.Value, tag *reflect.StructTag) (*sszInfo, error) { + switch value.Kind() { // Basic types (e.g., uintN where N is 8, 16, 32, 64) // NOTE: uint128 and uint256 are represented as []byte in Go, // so we handle them as slices. See `analyzeHomogeneousColType`. case reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Bool: - return analyzeBasicType(typ) + return analyzeBasicType(value) case reflect.Slice: - return analyzeHomogeneousColType(typ, tag) + return analyzeHomogeneousColType(value, tag) case reflect.Struct: - return analyzeContainerType(typ) + return analyzeContainerType(value) - case reflect.Ptr: - // Dereference pointer types. - return analyzeType(typ.Elem(), tag) + case reflect.Pointer: + derefValue := dereferencePointer(value) + return analyzeType(derefValue, tag) default: - return nil, fmt.Errorf("unsupported type %v for SSZ calculation", typ.Kind()) + return nil, fmt.Errorf("unsupported type %v for SSZ calculation", value.Kind()) } } // analyzeBasicType analyzes SSZ basic types (uintN, bool) and returns its info. -func analyzeBasicType(typ reflect.Type) (*sszInfo, error) { +func analyzeBasicType(value reflect.Value) (*sszInfo, error) { + var sszType SSZType + + switch value.Kind() { + case reflect.Uint64: + sszType = Uint64 + case reflect.Uint32: + sszType = Uint32 + case reflect.Uint16: + sszType = Uint16 + case reflect.Uint8: + sszType = Uint8 + case reflect.Bool: + sszType = Boolean + default: + return nil, fmt.Errorf("unsupported basic type %v for SSZ calculation", value.Kind()) + } + sszInfo := &sszInfo{ - typ: typ, + sszType: sszType, + typ: value.Type(), // Every basic type is fixed-size and not variable. isVariable: false, } - switch typ.Kind() { - case reflect.Uint64: - sszInfo.sszType = UintN - sszInfo.fixedSize = 8 - case reflect.Uint32: - sszInfo.sszType = UintN - sszInfo.fixedSize = 4 - case reflect.Uint16: - sszInfo.sszType = UintN - sszInfo.fixedSize = 2 - case reflect.Uint8: - sszInfo.sszType = UintN - sszInfo.fixedSize = 1 - case reflect.Bool: - sszInfo.sszType = Boolean - sszInfo.fixedSize = 1 - default: - return nil, fmt.Errorf("unsupported basic type %v for SSZ calculation", typ.Kind()) - } - return sszInfo, nil } // analyzeHomogeneousColType analyzes homogeneous collection types (e.g., List, Vector, Bitlist, Bitvector) and returns its SSZ info. -func analyzeHomogeneousColType(typ reflect.Type, tag *reflect.StructTag) (*sszInfo, error) { - if typ.Kind() != reflect.Slice { - return nil, fmt.Errorf("can only analyze slice types, got %v", typ.Kind()) +func analyzeHomogeneousColType(value reflect.Value, tag *reflect.StructTag) (*sszInfo, error) { + if value.Kind() != reflect.Slice { + return nil, fmt.Errorf("can only analyze slice types, got %v", value.Kind()) } // Parse the first dimension from the tag and get remaining tag for element @@ -220,8 +226,12 @@ func analyzeHomogeneousColType(typ reflect.Type, tag *reflect.StructTag) (*sszIn return nil, errors.New("ssz tag is required for slice types") } + // NOTE: Elem() won't panic because value is guaranteed to be a slice here. + elementType := value.Type().Elem() // Analyze element type with remaining dimensions - elementInfo, err := analyzeType(typ.Elem(), remainingTag) + // Note that it is enough to analyze by a zero value, + // as the actual value with variable-sized type will be populated later. + elementInfo, err := analyzeType(reflect.New(elementType), remainingTag) if err != nil { return nil, fmt.Errorf("could not analyze element type for homogeneous collection: %w", err) } @@ -233,7 +243,7 @@ func analyzeHomogeneousColType(typ reflect.Type, tag *reflect.StructTag) (*sszIn return nil, fmt.Errorf("could not get list limit: %w", err) } - return analyzeListType(typ, elementInfo, limit, sszDimension.isBitfield) + return analyzeListType(value, elementInfo, limit, sszDimension.isBitfield) } // 2. Handle Vector/Bitvector type @@ -243,7 +253,7 @@ func analyzeHomogeneousColType(typ reflect.Type, tag *reflect.StructTag) (*sszIn return nil, fmt.Errorf("could not get vector length: %w", err) } - return analyzeVectorType(typ, elementInfo, length, sszDimension.isBitfield) + return analyzeVectorType(value, elementInfo, length, sszDimension.isBitfield) } // Parsing ssz tag doesn't provide enough information to determine the collection type, @@ -252,13 +262,12 @@ func analyzeHomogeneousColType(typ reflect.Type, tag *reflect.StructTag) (*sszIn } // analyzeListType analyzes SSZ List/Bitlist type and returns its SSZ info. -func analyzeListType(typ reflect.Type, elementInfo *sszInfo, limit uint64, isBitfield bool) (*sszInfo, error) { +func analyzeListType(value reflect.Value, elementInfo *sszInfo, limit uint64, isBitfield bool) (*sszInfo, error) { if isBitfield { return &sszInfo{ sszType: Bitlist, - typ: typ, + typ: value.Type(), - fixedSize: offsetBytes, isVariable: true, bitlistInfo: &bitlistInfo{ @@ -273,9 +282,8 @@ func analyzeListType(typ reflect.Type, elementInfo *sszInfo, limit uint64, isBit return &sszInfo{ sszType: List, - typ: typ, + typ: value.Type(), - fixedSize: offsetBytes, isVariable: true, listInfo: &listInfo{ @@ -286,14 +294,12 @@ func analyzeListType(typ reflect.Type, elementInfo *sszInfo, limit uint64, isBit } // analyzeVectorType analyzes SSZ Vector/Bitvector type and returns its SSZ info. -func analyzeVectorType(typ reflect.Type, elementInfo *sszInfo, length uint64, isBitfield bool) (*sszInfo, error) { +func analyzeVectorType(value reflect.Value, elementInfo *sszInfo, length uint64, isBitfield bool) (*sszInfo, error) { if isBitfield { return &sszInfo{ sszType: Bitvector, - typ: typ, + typ: value.Type(), - // Size in bytes - fixedSize: length, isVariable: false, bitvectorInfo: &bitvectorInfo{ @@ -314,9 +320,8 @@ func analyzeVectorType(typ reflect.Type, elementInfo *sszInfo, length uint64, is return &sszInfo{ sszType: Vector, - typ: typ, + typ: value.Type(), - fixedSize: length * elementInfo.Size(), isVariable: false, vectorInfo: &vectorInfo{ @@ -327,44 +332,36 @@ func analyzeVectorType(typ reflect.Type, elementInfo *sszInfo, length uint64, is } // analyzeContainerType analyzes SSZ Container type and returns its SSZ info. -func analyzeContainerType(typ reflect.Type) (*sszInfo, error) { - if typ.Kind() != reflect.Struct { - return nil, fmt.Errorf("can only analyze struct types, got %v", typ.Kind()) +func analyzeContainerType(value reflect.Value) (*sszInfo, error) { + if value.Kind() != reflect.Struct { + return nil, fmt.Errorf("can only analyze struct types, got %v", value.Kind()) } + containerTyp := value.Type() fields := make(map[string]*fieldInfo) - order := make([]string, 0, typ.NumField()) + order := make([]string, 0) - sszInfo := &sszInfo{ - sszType: Container, - typ: typ, - } + isVariable := false var currentOffset uint64 - for i := 0; i < typ.NumField(); i++ { - field := typ.Field(i) + for i := 0; i < value.NumField(); i++ { + structFieldInfo := containerTyp.Field(i) // Protobuf-generated structs contain private fields we must skip. // e.g., state, sizeCache, unknownFields, etc. - if !field.IsExported() { + if !structFieldInfo.IsExported() { continue } - // The JSON tag contains the field name in the first part. - // e.g., "attesting_indices,omitempty" -> "attesting_indices". - jsonTag := field.Tag.Get("json") - if jsonTag == "" { - return nil, fmt.Errorf("field %s has no JSON tag", field.Name) - } - - // NOTE: `fieldName` is a string with `snake_case` format (following consensus specs). - fieldName := strings.Split(jsonTag, ",")[0] - if fieldName == "" { - return nil, fmt.Errorf("field %s has an empty JSON tag", field.Name) + tag := structFieldInfo.Tag + goFieldName := structFieldInfo.Name + fieldName, err := parseFieldNameFromTag(tag) + if err != nil { + return nil, fmt.Errorf("could not parse field name from tag for field %s: %w", goFieldName, err) } // Analyze each field so that we can complete full SSZ information. - info, err := analyzeType(field.Type, &field.Tag) + info, err := analyzeType(value.Field(i), &tag) if err != nil { return nil, fmt.Errorf("could not analyze type for field %s: %w", fieldName, err) } @@ -373,7 +370,7 @@ func analyzeContainerType(typ reflect.Type) (*sszInfo, error) { fields[fieldName] = &fieldInfo{ sszInfo: info, offset: currentOffset, - goFieldName: field.Name, + goFieldName: goFieldName, } // Persist order order = append(order, fieldName) @@ -382,34 +379,87 @@ func analyzeContainerType(typ reflect.Type) (*sszInfo, error) { if info.isVariable { // If one of the fields is variable-sized, // the entire struct is considered variable-sized. - sszInfo.isVariable = true + isVariable = true currentOffset += offsetBytes } else { - currentOffset += info.fixedSize + currentOffset += info.Size() } } - sszInfo.fixedSize = currentOffset - sszInfo.containerInfo = &containerInfo{ - fields: fields, - order: order, - } + return &sszInfo{ + sszType: Container, + typ: containerTyp, + source: castToSSZObject(value), - return sszInfo, nil + isVariable: isVariable, + + containerInfo: &containerInfo{ + fields: fields, + order: order, + fixedOffset: currentOffset, + }, + }, nil } // dereferencePointer dereferences a pointer to get the underlying value using reflection. -func dereferencePointer(obj any) reflect.Value { - value := reflect.ValueOf(obj) - if value.Kind() == reflect.Ptr { +func dereferencePointer(value reflect.Value) reflect.Value { + derefValue := value + + if value.IsValid() && value.Kind() == reflect.Pointer { if value.IsNil() { - // If we encounter a nil pointer before the end of the path, we can still proceed - // by analyzing the type, not the value. - value = reflect.New(value.Type().Elem()).Elem() + // Create a zero value if the pointer is nil. + derefValue = reflect.New(value.Type().Elem()).Elem() } else { - value = value.Elem() + derefValue = value.Elem() } } - return value + return derefValue +} + +// castToSSZObject attempts to cast a reflect.Value to the SSZObject interface. +// If failed, it returns nil. +func castToSSZObject(value reflect.Value) SSZObject { + if !value.IsValid() { + return nil + } + + // SSZObject is only implemented by struct types. + if value.Kind() != reflect.Struct { + return nil + } + + // To cast to SSZObject, we need the addressable value. + if !value.CanAddr() { + return nil + } + + if sszObj, ok := value.Addr().Interface().(SSZObject); ok { + return sszObj + } + + return nil +} + +// parseFieldNameFromTag extracts the field name (`snake_case` format) +// from a struct tag by looking for the json tag. +// The JSON tag contains the field name in the first part. +// e.g., "attesting_indices,omitempty" -> "attesting_indices". +func parseFieldNameFromTag(tag reflect.StructTag) (string, error) { + jsonTag := tag.Get("json") + if jsonTag == "" { + return "", errors.New("no JSON tag found") + } + + substrings := strings.Split(jsonTag, ",") + if len(substrings) == 0 { + return "", errors.New("invalid JSON tag format") + } + + fieldName := strings.TrimSpace(substrings[0]) + if fieldName == "" { + return "", errors.New("empty field name") + } + + return fieldName, nil } diff --git a/encoding/ssz/query/analyzer_test.go b/encoding/ssz/query/analyzer_test.go deleted file mode 100644 index 7c8d4666fa..0000000000 --- a/encoding/ssz/query/analyzer_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package query_test - -import ( - "testing" - - "github.com/OffchainLabs/prysm/v6/encoding/ssz/query" - sszquerypb "github.com/OffchainLabs/prysm/v6/proto/ssz_query" - "github.com/OffchainLabs/prysm/v6/testing/require" -) - -func TestAnalyzeSSZInfo(t *testing.T) { - info, err := query.AnalyzeObject(&sszquerypb.FixedTestContainer{}) - require.NoError(t, err) - - require.NotNil(t, info, "Expected non-nil SSZ info") - require.Equal(t, uint64(565), info.FixedSize()) -} diff --git a/encoding/ssz/query/bitvector.go b/encoding/ssz/query/bitvector.go index 046189d232..7ef42686d7 100644 --- a/encoding/ssz/query/bitvector.go +++ b/encoding/ssz/query/bitvector.go @@ -13,3 +13,12 @@ func (v *bitvectorInfo) Length() uint64 { return v.length } + +func (v *bitvectorInfo) Size() uint64 { + if v == nil { + return 0 + } + + // Size in bytes. + return v.length / 8 +} diff --git a/encoding/ssz/query/container.go b/encoding/ssz/query/container.go index 374bd7c4fb..f0d2fc9088 100644 --- a/encoding/ssz/query/container.go +++ b/encoding/ssz/query/container.go @@ -3,9 +3,11 @@ package query // containerInfo has // 1. fields: a field map that maps a field's JSON name to its sszInfo for nested Containers // 2. order: a list of field names in the order they should be serialized +// 3. fixedOffset: the total size of the fixed part of the container type containerInfo struct { - fields map[string]*fieldInfo - order []string + fields map[string]*fieldInfo + order []string + fixedOffset uint64 } type fieldInfo struct { diff --git a/encoding/ssz/query/list.go b/encoding/ssz/query/list.go index 5b797a422f..d09a5fd821 100644 --- a/encoding/ssz/query/list.go +++ b/encoding/ssz/query/list.go @@ -71,17 +71,3 @@ func (l *listInfo) Size() uint64 { } return totalSize } - -// OffsetBytes returns the total number of offset bytes used for the list elements. -// Each variable-sized element uses 4 bytes to store its offset. -func (l *listInfo) OffsetBytes() uint64 { - if l == nil { - return 0 - } - - if !l.element.isVariable { - return 0 - } - - return offsetBytes * l.length -} diff --git a/encoding/ssz/query/query_test.go b/encoding/ssz/query/query_test.go index 934f28a5f9..0b11a57556 100644 --- a/encoding/ssz/query/query_test.go +++ b/encoding/ssz/query/query_test.go @@ -11,6 +11,34 @@ import ( "github.com/prysmaticlabs/go-bitfield" ) +func TestSize(t *testing.T) { + tests := []struct { + name string + obj query.SSZObject + expectedSize uint64 + }{ + { + name: "FixedTestContainer", + obj: &sszquerypb.FixedTestContainer{}, + expectedSize: 565, + }, + { + name: "VariableTestContainer", + obj: &sszquerypb.VariableTestContainer{}, + expectedSize: 128, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info, err := query.AnalyzeObject(tt.obj) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, tt.expectedSize, info.Size()) + }) + } +} + func TestCalculateOffsetAndLength(t *testing.T) { type testCase struct { name string @@ -224,6 +252,56 @@ func TestCalculateOffsetAndLength(t *testing.T) { }) } +func TestHashTreeRoot(t *testing.T) { + tests := []struct { + name string + obj query.SSZObject + }{ + { + name: "FixedNestedContainer", + obj: &sszquerypb.FixedNestedContainer{ + Value1: 42, + Value2: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + }, + }, + { + name: "FixedTestContainer", + obj: createFixedTestContainer(), + }, + { + name: "VariableNestedContainer", + obj: &sszquerypb.VariableNestedContainer{ + Value1: 84, + FieldListUint64: []uint64{1, 2, 3, 4, 5}, + NestedListField: [][]byte{ + {0x0a, 0x0b, 0x0c}, + {0x1a, 0x1b, 0x1c, 0x1d}, + }, + }, + }, + { + name: "VariableTestContainer", + obj: createVariableTestContainer(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Analyze the object to get its sszInfo + info, err := query.AnalyzeObject(tt.obj) + require.NoError(t, err) + require.NotNil(t, info, "Expected non-nil SSZ info") + + // Call HashTreeRoot on the sszInfo and compare results + hashTreeRoot, err := info.HashTreeRoot() + require.NoError(t, err, "HashTreeRoot should not return an error") + expectedHashTreeRoot, err := tt.obj.HashTreeRoot() + require.NoError(t, err, "HashTreeRoot on original object should not return an error") + require.Equal(t, expectedHashTreeRoot, hashTreeRoot, "HashTreeRoot from sszInfo should match original object's HashTreeRoot") + }) + } +} + func TestRoundTripSszInfo(t *testing.T) { specs := []testutil.TestSpec{ getFixedTestContainerSpec(), @@ -364,62 +442,6 @@ func getFixedTestContainerSpec() testutil.TestSpec { } } -func TestSSZObject_batch(t *testing.T) { - tests := []struct { - name string - obj any - }{ - { - name: "FixedNestedContainer", - obj: &sszquerypb.FixedNestedContainer{ - Value1: 42, - Value2: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, - }, - }, - { - name: "FixedTestContainer", - obj: createFixedTestContainer(), - }, - { - name: "VariableNestedContainer", - obj: &sszquerypb.VariableNestedContainer{ - Value1: 84, - FieldListUint64: []uint64{1, 2, 3, 4, 5}, - NestedListField: [][]byte{ - {0x0a, 0x0b, 0x0c}, - {0x1a, 0x1b, 0x1c, 0x1d}, - }, - }, - }, - { - name: "VariableTestContainer", - obj: createVariableTestContainer(), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Analyze the object to get its sszInfo - object, ok := tt.obj.(query.SSZObject) - require.Equal(t, true, ok, "Expected object to implement SSZObject") - info, err := query.AnalyzeObject(object) - require.NoError(t, err) - require.NotNil(t, info, "Expected non-nil SSZ info") - - // Ensure the original object implements SSZObject - originalFunctions, ok := tt.obj.(query.SSZObject) - require.Equal(t, ok, true, "Original object does not implement SSZObject") - - // Call HashTreeRoot on the sszInfo and compare results - hashTreeRoot, err := info.HashTreeRoot() - require.NoError(t, err, "HashTreeRoot should not return an error") - expectedHashTreeRoot, err := originalFunctions.HashTreeRoot() - require.NoError(t, err, "HashTreeRoot on original object should not return an error") - require.Equal(t, expectedHashTreeRoot, hashTreeRoot, "HashTreeRoot from sszInfo should match original object's HashTreeRoot") - }) - } -} - func createVariableTestContainer() *sszquerypb.VariableTestContainer { leadingField := make([]byte, 32) for i := range leadingField { diff --git a/encoding/ssz/query/ssz_info.go b/encoding/ssz/query/ssz_info.go index 0ba3369196..37fe9507c5 100644 --- a/encoding/ssz/query/ssz_info.go +++ b/encoding/ssz/query/ssz_info.go @@ -18,8 +18,6 @@ type sszInfo struct { // isVariable is true if the struct contains any variable-size fields. isVariable bool - // fixedSize is the total size of the struct's fixed part. - fixedSize uint64 // For Container types. containerInfo *containerInfo @@ -37,46 +35,38 @@ type sszInfo struct { bitvectorInfo *bitvectorInfo } -func (info *sszInfo) FixedSize() uint64 { - if info == nil { - return 0 - } - return info.fixedSize -} - func (info *sszInfo) Size() uint64 { if info == nil { return 0 } - // Easy case: if the type is not variable, we can return the fixed size. - if !info.isVariable { - return info.fixedSize - } - switch info.sszType { + case Uint8: + return 1 + case Uint16: + return 2 + case Uint32: + return 4 + case Uint64: + return 8 + case Boolean: + return 1 + case Container: + // Using existing API if the pointer is available. + if info.source != nil { + return uint64(info.source.SizeSSZ()) + } + + return 0 + case Vector: + return info.vectorInfo.Size() case List: return info.listInfo.Size() - + case Bitvector: + return info.bitvectorInfo.Size() case Bitlist: return info.bitlistInfo.Size() - case Container: - size := info.fixedSize - for _, fieldInfo := range info.containerInfo.fields { - if !fieldInfo.sszInfo.isVariable { - continue - } - - // Include offset bytes inside nested lists. - if fieldInfo.sszInfo.sszType == List { - size += fieldInfo.sszInfo.listInfo.OffsetBytes() - } - - size += fieldInfo.sszInfo.Size() - } - return size - default: return 0 } @@ -193,7 +183,7 @@ func printRecursive(info *sszInfo, builder *strings.Builder, prefix string) { switch info.sszType { case Container: - builder.WriteString(fmt.Sprintf("%s (%s / fixed size: %d, total size: %d)\n", info, sizeDesc, info.FixedSize(), info.Size())) + builder.WriteString(fmt.Sprintf("%s (%s / size: %d)\n", info, sizeDesc, info.Size())) for i, key := range info.containerInfo.order { connector := "├─" diff --git a/encoding/ssz/query/ssz_object.go b/encoding/ssz/query/ssz_object.go index a56b15983d..ae60613696 100644 --- a/encoding/ssz/query/ssz_object.go +++ b/encoding/ssz/query/ssz_object.go @@ -4,6 +4,7 @@ import "errors" type SSZObject interface { HashTreeRoot() ([32]byte, error) + SizeSSZ() int } // HashTreeRoot calls the HashTreeRoot method on the stored interface if it implements SSZObject. diff --git a/encoding/ssz/query/ssz_type.go b/encoding/ssz/query/ssz_type.go index fe6195cc8f..a31e5e1b73 100644 --- a/encoding/ssz/query/ssz_type.go +++ b/encoding/ssz/query/ssz_type.go @@ -9,8 +9,10 @@ type SSZType int // SSZ type constants. const ( // Basic types - UintN SSZType = iota - Byte + Uint8 SSZType = iota + Uint16 + Uint32 + Uint64 Boolean // Composite types @@ -27,10 +29,14 @@ const ( func (t SSZType) String() string { switch t { - case UintN: - return "UintN" - case Byte: - return "Byte" + case Uint8: + return "Uint8" + case Uint16: + return "Uint16" + case Uint32: + return "Uint32" + case Uint64: + return "Uint64" case Boolean: return "Boolean" case Container: diff --git a/encoding/ssz/query/testutil/runner.go b/encoding/ssz/query/testutil/runner.go index 610e2d6be0..066d1d4464 100644 --- a/encoding/ssz/query/testutil/runner.go +++ b/encoding/ssz/query/testutil/runner.go @@ -1,6 +1,7 @@ package testutil import ( + "reflect" "testing" "github.com/OffchainLabs/prysm/v6/encoding/ssz/query" @@ -10,14 +11,11 @@ import ( func RunStructTest(t *testing.T, spec TestSpec) { t.Run(spec.Name, func(t *testing.T) { - object, ok := spec.Type.(query.SSZObject) - require.Equal(t, true, ok, "spec.Type must implement SSZObject interface") - require.NotNil(t, object, "spec.Type must not be nil") - info, err := query.AnalyzeObject(object) + info, err := query.AnalyzeObject(spec.Type) require.NoError(t, err) testInstance := spec.Instance - err = query.PopulateVariableLengthInfo(info, testInstance) + err = query.PopulateVariableLengthInfo(info, reflect.ValueOf(testInstance)) require.NoError(t, err) marshaller, ok := testInstance.(ssz.Marshaler) diff --git a/encoding/ssz/query/vector.go b/encoding/ssz/query/vector.go index c0c0f70d38..8e90856952 100644 --- a/encoding/ssz/query/vector.go +++ b/encoding/ssz/query/vector.go @@ -25,3 +25,15 @@ func (v *vectorInfo) Element() (*sszInfo, error) { return v.element, nil } + +func (v *vectorInfo) Size() uint64 { + if v == nil { + return 0 + } + + if v.element == nil { + return 0 + } + + return v.length * v.element.Size() +}