mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 21:08:10 -05:00
Optimize Multivalue Slice For Trie Recomputation (#13238)
Co-authored-by: Radosław Kapka <rkapka@wp.pl>
This commit is contained in:
@@ -12,6 +12,7 @@ go_library(
|
||||
"//beacon-chain/state/state-native/custom-types:go_default_library",
|
||||
"//beacon-chain/state/state-native/types:go_default_library",
|
||||
"//beacon-chain/state/stateutil:go_default_library",
|
||||
"//container/multi-value-slice:go_default_library",
|
||||
"//math:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
@@ -30,9 +31,11 @@ go_test(
|
||||
"//beacon-chain/state/state-native/custom-types:go_default_library",
|
||||
"//beacon-chain/state/state-native/types:go_default_library",
|
||||
"//beacon-chain/state/stateutil:go_default_library",
|
||||
"//config/features:go_default_library",
|
||||
"//config/fieldparams:go_default_library",
|
||||
"//config/params:go_default_library",
|
||||
"//consensus-types/primitives:go_default_library",
|
||||
"//container/multi-value-slice:go_default_library",
|
||||
"//proto/prysm/v1alpha1:go_default_library",
|
||||
"//testing/assert:go_default_library",
|
||||
"//testing/require:go_default_library",
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/state-native/types"
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/stateutil"
|
||||
multi_value_slice "github.com/prysmaticlabs/prysm/v4/container/multi-value-slice"
|
||||
pmath "github.com/prysmaticlabs/prysm/v4/math"
|
||||
)
|
||||
|
||||
@@ -15,6 +16,14 @@ var (
|
||||
ErrEmptyFieldTrie = errors.New("empty field trie")
|
||||
)
|
||||
|
||||
// sliceAccessor describes an interface for a multivalue slice
|
||||
// object that returns information about the multivalue slice along with the
|
||||
// particular state instance we are referencing.
|
||||
type sliceAccessor interface {
|
||||
Len(obj multi_value_slice.Identifiable) int
|
||||
State() multi_value_slice.Identifiable
|
||||
}
|
||||
|
||||
// FieldTrie is the representation of the representative
|
||||
// trie of the particular field.
|
||||
type FieldTrie struct {
|
||||
@@ -51,6 +60,12 @@ func NewFieldTrie(field types.FieldIndex, fieldInfo types.DataType, elements int
|
||||
if err := validateElements(field, fieldInfo, elements, length); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numOfElems := 0
|
||||
if val, ok := elements.(sliceAccessor); ok {
|
||||
numOfElems = val.Len(val.State())
|
||||
} else {
|
||||
numOfElems = reflect.Indirect(reflect.ValueOf(elements)).Len()
|
||||
}
|
||||
switch fieldInfo {
|
||||
case types.BasicArray:
|
||||
fl, err := stateutil.ReturnTrieLayer(fieldRoots, length)
|
||||
@@ -64,7 +79,7 @@ func NewFieldTrie(field types.FieldIndex, fieldInfo types.DataType, elements int
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: length,
|
||||
numOfElems: reflect.Indirect(reflect.ValueOf(elements)).Len(),
|
||||
numOfElems: numOfElems,
|
||||
}, nil
|
||||
case types.CompositeArray, types.CompressedArray:
|
||||
return &FieldTrie{
|
||||
@@ -74,7 +89,7 @@ func NewFieldTrie(field types.FieldIndex, fieldInfo types.DataType, elements int
|
||||
reference: stateutil.NewRef(1),
|
||||
RWMutex: new(sync.RWMutex),
|
||||
length: length,
|
||||
numOfElems: reflect.Indirect(reflect.ValueOf(elements)).Len(),
|
||||
numOfElems: numOfElems,
|
||||
}, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(fieldInfo).Name())
|
||||
@@ -100,20 +115,23 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b
|
||||
if err := f.validateIndices(indices); err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
if val, ok := elements.(sliceAccessor); ok {
|
||||
f.numOfElems = val.Len(val.State())
|
||||
} else {
|
||||
f.numOfElems = reflect.Indirect(reflect.ValueOf(elements)).Len()
|
||||
}
|
||||
switch f.dataType {
|
||||
case types.BasicArray:
|
||||
fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayer(fieldRoots, indices, f.fieldLayers)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
f.numOfElems = reflect.Indirect(reflect.ValueOf(elements)).Len()
|
||||
return fieldRoot, nil
|
||||
case types.CompositeArray:
|
||||
fieldRoot, f.fieldLayers, err = stateutil.RecomputeFromLayerVariable(fieldRoots, indices, f.fieldLayers)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
f.numOfElems = reflect.Indirect(reflect.ValueOf(elements)).Len()
|
||||
return stateutil.AddInMixin(fieldRoot, uint64(len(f.fieldLayers[0])))
|
||||
case types.CompressedArray:
|
||||
numOfElems, err := f.field.ElemsInChunk()
|
||||
@@ -142,7 +160,6 @@ func (f *FieldTrie) RecomputeTrie(indices []uint64, elements interface{}) ([32]b
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
f.numOfElems = reflect.Indirect(reflect.ValueOf(elements)).Len()
|
||||
return stateutil.AddInMixin(fieldRoot, uint64(f.numOfElems))
|
||||
default:
|
||||
return [32]byte{}, errors.Errorf("unrecognized data type in field map: %v", reflect.TypeOf(f.dataType).Name())
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
customtypes "github.com/prysmaticlabs/prysm/v4/beacon-chain/state/state-native/custom-types"
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/state-native/types"
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/stateutil"
|
||||
multi_value_slice "github.com/prysmaticlabs/prysm/v4/container/multi-value-slice"
|
||||
pmath "github.com/prysmaticlabs/prysm/v4/math"
|
||||
ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1"
|
||||
)
|
||||
@@ -53,6 +54,13 @@ func validateElements(field types.FieldIndex, fieldInfo types.DataType, elements
|
||||
}
|
||||
length *= comLength
|
||||
}
|
||||
if val, ok := elements.(sliceAccessor); ok {
|
||||
totalLen := val.Len(val.State())
|
||||
if uint64(totalLen) > length {
|
||||
return errors.Errorf("elements length is larger than expected for field %s: %d > %d", field.String(), totalLen, length)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
val := reflect.Indirect(reflect.ValueOf(elements))
|
||||
if uint64(val.Len()) > length {
|
||||
return errors.Errorf("elements length is larger than expected for field %s: %d > %d", field.String(), val.Len(), length)
|
||||
@@ -63,12 +71,8 @@ func validateElements(field types.FieldIndex, fieldInfo types.DataType, elements
|
||||
// fieldConverters converts the corresponding field and the provided elements to the appropriate roots.
|
||||
func fieldConverters(field types.FieldIndex, indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
switch field {
|
||||
case types.BlockRoots:
|
||||
return convert32ByteArrays[customtypes.BlockRoots](indices, elements, convertAll)
|
||||
case types.StateRoots:
|
||||
return convert32ByteArrays[customtypes.StateRoots](indices, elements, convertAll)
|
||||
case types.RandaoMixes:
|
||||
return convert32ByteArrays[customtypes.RandaoMixes](indices, elements, convertAll)
|
||||
case types.BlockRoots, types.StateRoots, types.RandaoMixes:
|
||||
return convertRoots(indices, elements, convertAll)
|
||||
case types.Eth1DataVotes:
|
||||
return convertEth1DataVotes(indices, elements, convertAll)
|
||||
case types.Validators:
|
||||
@@ -82,13 +86,19 @@ func fieldConverters(field types.FieldIndex, indices []uint64, elements interfac
|
||||
}
|
||||
}
|
||||
|
||||
func convert32ByteArrays[T ~[][32]byte](indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
val, ok := elements.(T)
|
||||
if !ok {
|
||||
var t T
|
||||
return nil, errors.Errorf("Wanted type of %T but got %T", t, elements)
|
||||
func convertRoots(indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
switch castedType := elements.(type) {
|
||||
case customtypes.BlockRoots:
|
||||
return handle32ByteMVslice(multi_value_slice.BuildEmptyCompositeSlice[[32]byte](castedType), indices, convertAll)
|
||||
case customtypes.StateRoots:
|
||||
return handle32ByteMVslice(multi_value_slice.BuildEmptyCompositeSlice[[32]byte](castedType), indices, convertAll)
|
||||
case customtypes.RandaoMixes:
|
||||
return handle32ByteMVslice(multi_value_slice.BuildEmptyCompositeSlice[[32]byte](castedType), indices, convertAll)
|
||||
case multi_value_slice.MultiValueSliceComposite[[32]byte]:
|
||||
return handle32ByteMVslice(castedType, indices, convertAll)
|
||||
default:
|
||||
return nil, errors.Errorf("non-existent type provided %T", castedType)
|
||||
}
|
||||
return handle32ByteArrays(val, indices, convertAll)
|
||||
}
|
||||
|
||||
func convertEth1DataVotes(indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
@@ -100,11 +110,14 @@ func convertEth1DataVotes(indices []uint64, elements interface{}, convertAll boo
|
||||
}
|
||||
|
||||
func convertValidators(indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
val, ok := elements.([]*ethpb.Validator)
|
||||
if !ok {
|
||||
switch casted := elements.(type) {
|
||||
case []*ethpb.Validator:
|
||||
return handleValidatorMVSlice(multi_value_slice.BuildEmptyCompositeSlice[*ethpb.Validator](casted), indices, convertAll)
|
||||
case multi_value_slice.MultiValueSliceComposite[*ethpb.Validator]:
|
||||
return handleValidatorMVSlice(casted, indices, convertAll)
|
||||
default:
|
||||
return nil, errors.Errorf("Wanted type of %T but got %T", []*ethpb.Validator{}, elements)
|
||||
}
|
||||
return handleValidatorSlice(val, indices, convertAll)
|
||||
}
|
||||
|
||||
func convertAttestations(indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
@@ -116,45 +129,56 @@ func convertAttestations(indices []uint64, elements interface{}, convertAll bool
|
||||
}
|
||||
|
||||
func convertBalances(indices []uint64, elements interface{}, convertAll bool) ([][32]byte, error) {
|
||||
val, ok := elements.([]uint64)
|
||||
if !ok {
|
||||
switch casted := elements.(type) {
|
||||
case []uint64:
|
||||
return handleBalanceMVSlice(multi_value_slice.BuildEmptyCompositeSlice[uint64](casted), indices, convertAll)
|
||||
case multi_value_slice.MultiValueSliceComposite[uint64]:
|
||||
return handleBalanceMVSlice(casted, indices, convertAll)
|
||||
default:
|
||||
return nil, errors.Errorf("Wanted type of %T but got %T", []uint64{}, elements)
|
||||
}
|
||||
return handleBalanceSlice(val, indices, convertAll)
|
||||
}
|
||||
|
||||
// handle32ByteArrays computes and returns 32 byte arrays in a slice of root format.
|
||||
func handle32ByteArrays(val [][32]byte, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
// handle32ByteMVslice computes and returns 32 byte arrays in a slice of root format. This is modified
|
||||
// to be used with multivalue slices.
|
||||
func handle32ByteMVslice(mv multi_value_slice.MultiValueSliceComposite[[32]byte],
|
||||
indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
length = len(val)
|
||||
length = mv.Len(mv.State())
|
||||
}
|
||||
roots := make([][32]byte, 0, length)
|
||||
rootCreator := func(input [32]byte) {
|
||||
roots = append(roots, input)
|
||||
}
|
||||
if convertAll {
|
||||
val := mv.Value(mv.State())
|
||||
for i := range val {
|
||||
rootCreator(val[i])
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
totalLen := mv.Len(mv.State())
|
||||
if totalLen > 0 {
|
||||
for _, idx := range indices {
|
||||
if idx > uint64(len(val))-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, len(val))
|
||||
if idx > uint64(totalLen)-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of byte arrays %d", idx, totalLen)
|
||||
}
|
||||
rootCreator(val[idx])
|
||||
val, err := mv.At(mv.State(), idx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rootCreator(val)
|
||||
}
|
||||
}
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
// handleValidatorSlice returns the validator indices in a slice of root format.
|
||||
func handleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
// handleValidatorMVSlice returns the validator indices in a slice of root format.
|
||||
func handleValidatorMVSlice(mv multi_value_slice.MultiValueSliceComposite[*ethpb.Validator], indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
length := len(indices)
|
||||
if convertAll {
|
||||
return stateutil.OptimizedValidatorRoots(val)
|
||||
return stateutil.OptimizedValidatorRoots(mv.Value(mv.State()))
|
||||
}
|
||||
roots := make([][32]byte, 0, length)
|
||||
rootCreator := func(input *ethpb.Validator) error {
|
||||
@@ -165,12 +189,17 @@ func handleValidatorSlice(val []*ethpb.Validator, indices []uint64, convertAll b
|
||||
roots = append(roots, newRoot)
|
||||
return nil
|
||||
}
|
||||
if len(val) > 0 {
|
||||
totalLen := mv.Len(mv.State())
|
||||
if totalLen > 0 {
|
||||
for _, idx := range indices {
|
||||
if idx > uint64(len(val))-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of validators %d", idx, len(val))
|
||||
if idx > uint64(totalLen)-1 {
|
||||
return nil, fmt.Errorf("index %d greater than number of validators %d", idx, totalLen)
|
||||
}
|
||||
err := rootCreator(val[idx])
|
||||
val, err := mv.At(mv.State(), idx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = rootCreator(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -255,12 +284,13 @@ func handlePendingAttestationSlice(val []*ethpb.PendingAttestation, indices []ui
|
||||
return roots, nil
|
||||
}
|
||||
|
||||
// handleBalanceSlice returns the root of a slice of validator balances.
|
||||
func handleBalanceSlice(val, indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
func handleBalanceMVSlice(mv multi_value_slice.MultiValueSliceComposite[uint64], indices []uint64, convertAll bool) ([][32]byte, error) {
|
||||
if convertAll {
|
||||
val := mv.Value(mv.State())
|
||||
return stateutil.PackUint64IntoChunks(val)
|
||||
}
|
||||
if len(val) > 0 {
|
||||
totalLen := mv.Len(mv.State())
|
||||
if totalLen > 0 {
|
||||
numOfElems, err := types.Balances.ElemsInChunk()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -283,8 +313,12 @@ func handleBalanceSlice(val, indices []uint64, convertAll bool) ([][32]byte, err
|
||||
// then you will need to zero out the rest of the chunk. Ex : 41 indexes,
|
||||
// so 41 % 4 = 1 . There are 3 indexes, which do not exist yet but we
|
||||
// have to add in as a root. These 3 indexes are then given a 'zero' value.
|
||||
if j < uint64(len(val)) {
|
||||
wantedVal = val[j]
|
||||
if j < uint64(totalLen) {
|
||||
val, err := mv.At(mv.State(), j)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wantedVal = val
|
||||
}
|
||||
binary.LittleEndian.PutUint64(chunk[i:i+sizeOfElem], wantedVal)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package fieldtrie_test
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/fieldtrie"
|
||||
. "github.com/prysmaticlabs/prysm/v4/beacon-chain/state/fieldtrie"
|
||||
customtypes "github.com/prysmaticlabs/prysm/v4/beacon-chain/state/state-native/custom-types"
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/state-native/types"
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/stateutil"
|
||||
"github.com/prysmaticlabs/prysm/v4/config/features"
|
||||
"github.com/prysmaticlabs/prysm/v4/config/params"
|
||||
"github.com/prysmaticlabs/prysm/v4/consensus-types/primitives"
|
||||
mvslice "github.com/prysmaticlabs/prysm/v4/container/multi-value-slice"
|
||||
ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/v4/testing/assert"
|
||||
"github.com/prysmaticlabs/prysm/v4/testing/require"
|
||||
@@ -16,14 +18,38 @@ import (
|
||||
)
|
||||
|
||||
func TestFieldTrie_NewTrie(t *testing.T) {
|
||||
t.Run("native state", func(t *testing.T) {
|
||||
runNewTrie(t)
|
||||
})
|
||||
t.Run("native state with multivalue slice", func(t *testing.T) {
|
||||
cfg := &features.Flags{}
|
||||
cfg.EnableExperimentalState = true
|
||||
reset := features.InitWithReset(cfg)
|
||||
runNewTrie(t)
|
||||
|
||||
reset()
|
||||
})
|
||||
}
|
||||
|
||||
func runNewTrie(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 40)
|
||||
roots := newState.BlockRoots()
|
||||
var elements interface{}
|
||||
blockRoots := make([][32]byte, len(roots))
|
||||
for i, r := range roots {
|
||||
blockRoots[i] = [32]byte(r)
|
||||
}
|
||||
elements = customtypes.BlockRoots(blockRoots)
|
||||
|
||||
trie, err := fieldtrie.NewFieldTrie(types.BlockRoots, types.BasicArray, customtypes.BlockRoots(blockRoots), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||
if features.Get().EnableExperimentalState {
|
||||
mvRoots := buildTestCompositeSlice[[32]byte](blockRoots)
|
||||
elements = mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: mockIdentifier{},
|
||||
MultiValueSlice: mvRoots,
|
||||
}
|
||||
}
|
||||
|
||||
trie, err := NewFieldTrie(types.BlockRoots, types.BasicArray, elements, uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||
require.NoError(t, err)
|
||||
root, err := stateutil.RootsArrayHashTreeRoot(newState.BlockRoots(), uint64(params.BeaconConfig().SlotsPerHistoricalRoot))
|
||||
require.NoError(t, err)
|
||||
@@ -33,15 +59,40 @@ func TestFieldTrie_NewTrie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFieldTrie_NewTrie_NilElements(t *testing.T) {
|
||||
trie, err := fieldtrie.NewFieldTrie(types.BlockRoots, types.BasicArray, nil, 8234)
|
||||
trie, err := NewFieldTrie(types.BlockRoots, types.BasicArray, nil, 8234)
|
||||
require.NoError(t, err)
|
||||
_, err = trie.TrieRoot()
|
||||
require.ErrorIs(t, err, fieldtrie.ErrEmptyFieldTrie)
|
||||
require.ErrorIs(t, err, ErrEmptyFieldTrie)
|
||||
}
|
||||
|
||||
func TestFieldTrie_RecomputeTrie(t *testing.T) {
|
||||
t.Run("native state", func(t *testing.T) {
|
||||
runRecomputeTrie(t)
|
||||
})
|
||||
t.Run("native state with multivalue slice", func(t *testing.T) {
|
||||
cfg := &features.Flags{}
|
||||
cfg.EnableExperimentalState = true
|
||||
reset := features.InitWithReset(cfg)
|
||||
runRecomputeTrie(t)
|
||||
|
||||
reset()
|
||||
})
|
||||
}
|
||||
|
||||
func runRecomputeTrie(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||
trie, err := fieldtrie.NewFieldTrie(types.Validators, types.CompositeArray, newState.Validators(), params.BeaconConfig().ValidatorRegistryLimit)
|
||||
|
||||
var elements interface{}
|
||||
elements = newState.Validators()
|
||||
if features.Get().EnableExperimentalState {
|
||||
mvRoots := buildTestCompositeSlice[*ethpb.Validator](newState.Validators())
|
||||
elements = mvslice.MultiValueSliceComposite[*ethpb.Validator]{
|
||||
Identifiable: mockIdentifier{},
|
||||
MultiValueSlice: mvRoots,
|
||||
}
|
||||
}
|
||||
|
||||
trie, err := NewFieldTrie(types.Validators, types.CompositeArray, elements, params.BeaconConfig().ValidatorRegistryLimit)
|
||||
require.NoError(t, err)
|
||||
|
||||
oldroot, err := trie.TrieRoot()
|
||||
@@ -71,8 +122,32 @@ func TestFieldTrie_RecomputeTrie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFieldTrie_RecomputeTrie_CompressedArray(t *testing.T) {
|
||||
t.Run("native state", func(t *testing.T) {
|
||||
runRecomputeTrie_CompressedArray(t)
|
||||
})
|
||||
t.Run("native state with multivalue slice", func(t *testing.T) {
|
||||
cfg := &features.Flags{}
|
||||
cfg.EnableExperimentalState = true
|
||||
reset := features.InitWithReset(cfg)
|
||||
runRecomputeTrie_CompressedArray(t)
|
||||
|
||||
reset()
|
||||
})
|
||||
}
|
||||
|
||||
func runRecomputeTrie_CompressedArray(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||
trie, err := fieldtrie.NewFieldTrie(types.Balances, types.CompressedArray, newState.Balances(), stateutil.ValidatorLimitForBalancesChunks())
|
||||
var elements interface{}
|
||||
elements = newState.Balances()
|
||||
if features.Get().EnableExperimentalState {
|
||||
mvBals := buildTestCompositeSlice(newState.Balances())
|
||||
elements = mvslice.MultiValueSliceComposite[uint64]{
|
||||
Identifiable: mockIdentifier{},
|
||||
MultiValueSlice: mvBals,
|
||||
}
|
||||
}
|
||||
|
||||
trie, err := NewFieldTrie(types.Balances, types.CompressedArray, elements, stateutil.ValidatorLimitForBalancesChunks())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, trie.Length(), stateutil.ValidatorLimitForBalancesChunks())
|
||||
changedIdx := []uint64{4, 8}
|
||||
@@ -89,7 +164,7 @@ func TestFieldTrie_RecomputeTrie_CompressedArray(t *testing.T) {
|
||||
|
||||
func TestNewFieldTrie_UnknownType(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||
_, err := fieldtrie.NewFieldTrie(types.Balances, 4, newState.Balances(), 32)
|
||||
_, err := NewFieldTrie(types.Balances, 4, newState.Balances(), 32)
|
||||
require.ErrorContains(t, "unrecognized data type", err)
|
||||
}
|
||||
|
||||
@@ -101,7 +176,7 @@ func TestFieldTrie_CopyTrieImmutable(t *testing.T) {
|
||||
randaoMixes[i] = [32]byte(r)
|
||||
}
|
||||
|
||||
trie, err := fieldtrie.NewFieldTrie(types.RandaoMixes, types.BasicArray, customtypes.RandaoMixes(randaoMixes), uint64(params.BeaconConfig().EpochsPerHistoricalVector))
|
||||
trie, err := NewFieldTrie(types.RandaoMixes, types.BasicArray, customtypes.RandaoMixes(randaoMixes), uint64(params.BeaconConfig().EpochsPerHistoricalVector))
|
||||
require.NoError(t, err)
|
||||
|
||||
newTrie := trie.CopyTrie()
|
||||
@@ -127,7 +202,7 @@ func TestFieldTrie_CopyTrieImmutable(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestFieldTrie_CopyAndTransferEmpty(t *testing.T) {
|
||||
trie, err := fieldtrie.NewFieldTrie(types.RandaoMixes, types.BasicArray, nil, uint64(params.BeaconConfig().EpochsPerHistoricalVector))
|
||||
trie, err := NewFieldTrie(types.RandaoMixes, types.BasicArray, nil, uint64(params.BeaconConfig().EpochsPerHistoricalVector))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.DeepEqual(t, trie, trie.CopyTrie())
|
||||
@@ -137,14 +212,14 @@ func TestFieldTrie_CopyAndTransferEmpty(t *testing.T) {
|
||||
func TestFieldTrie_TransferTrie(t *testing.T) {
|
||||
newState, _ := util.DeterministicGenesisState(t, 32)
|
||||
maxLength := (params.BeaconConfig().ValidatorRegistryLimit*8 + 31) / 32
|
||||
trie, err := fieldtrie.NewFieldTrie(types.Balances, types.CompressedArray, newState.Balances(), maxLength)
|
||||
trie, err := NewFieldTrie(types.Balances, types.CompressedArray, newState.Balances(), maxLength)
|
||||
require.NoError(t, err)
|
||||
oldRoot, err := trie.TrieRoot()
|
||||
require.NoError(t, err)
|
||||
|
||||
newTrie := trie.TransferTrie()
|
||||
root, err := trie.TrieRoot()
|
||||
require.ErrorIs(t, err, fieldtrie.ErrEmptyFieldTrie)
|
||||
require.ErrorIs(t, err, ErrEmptyFieldTrie)
|
||||
require.Equal(t, root, [32]byte{})
|
||||
require.NotNil(t, newTrie)
|
||||
newRoot, err := newTrie.TrieRoot()
|
||||
@@ -165,7 +240,7 @@ func FuzzFieldTrie(f *testing.F) {
|
||||
for i := 32; i < len(data); i += 32 {
|
||||
roots = append(roots, data[i-32:i])
|
||||
}
|
||||
trie, err := fieldtrie.NewFieldTrie(types.FieldIndex(idx), types.DataType(typ), roots, slotsPerHistRoot)
|
||||
trie, err := NewFieldTrie(types.FieldIndex(idx), types.DataType(typ), roots, slotsPerHistRoot)
|
||||
if err != nil {
|
||||
return // invalid inputs
|
||||
}
|
||||
@@ -175,3 +250,18 @@ func FuzzFieldTrie(f *testing.F) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func buildTestCompositeSlice[V comparable](values []V) mvslice.MultiValueSliceComposite[V] {
|
||||
obj := &mvslice.Slice[V]{}
|
||||
obj.Init(values)
|
||||
return mvslice.MultiValueSliceComposite[V]{
|
||||
Identifiable: nil,
|
||||
MultiValueSlice: obj,
|
||||
}
|
||||
}
|
||||
|
||||
type mockIdentifier struct{}
|
||||
|
||||
func (_ mockIdentifier) Id() mvslice.Id {
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/v4/beacon-chain/state/stateutil"
|
||||
fieldparams "github.com/prysmaticlabs/prysm/v4/config/fieldparams"
|
||||
"github.com/prysmaticlabs/prysm/v4/config/params"
|
||||
mvslice "github.com/prysmaticlabs/prysm/v4/container/multi-value-slice"
|
||||
ethpb "github.com/prysmaticlabs/prysm/v4/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/v4/testing/assert"
|
||||
"github.com/prysmaticlabs/prysm/v4/testing/require"
|
||||
@@ -35,13 +36,13 @@ func Test_handleEth1DataSlice_OutOfRange(t *testing.T) {
|
||||
func Test_handleValidatorSlice_OutOfRange(t *testing.T) {
|
||||
vals := make([]*ethpb.Validator, 1)
|
||||
indices := []uint64{3}
|
||||
_, err := handleValidatorSlice(vals, indices, false)
|
||||
_, err := handleValidatorMVSlice(mvslice.BuildEmptyCompositeSlice[*ethpb.Validator](vals), indices, false)
|
||||
assert.ErrorContains(t, "index 3 greater than number of validators 1", err)
|
||||
}
|
||||
|
||||
func TestBalancesSlice_CorrectRoots_All(t *testing.T) {
|
||||
balances := []uint64{5, 2929, 34, 1291, 354305}
|
||||
roots, err := handleBalanceSlice(balances, []uint64{}, true)
|
||||
roots, err := handleBalanceMVSlice(mvslice.BuildEmptyCompositeSlice[uint64](balances), []uint64{}, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var root1 [32]byte
|
||||
@@ -58,7 +59,7 @@ func TestBalancesSlice_CorrectRoots_All(t *testing.T) {
|
||||
|
||||
func TestBalancesSlice_CorrectRoots_Some(t *testing.T) {
|
||||
balances := []uint64{5, 2929, 34, 1291, 354305}
|
||||
roots, err := handleBalanceSlice(balances, []uint64{2, 3}, false)
|
||||
roots, err := handleBalanceMVSlice(mvslice.BuildEmptyCompositeSlice[uint64](balances), []uint64{2, 3}, false)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var root1 [32]byte
|
||||
@@ -123,7 +124,7 @@ func TestFieldTrie_NativeState_fieldConvertersNative(t *testing.T) {
|
||||
convertAll: true,
|
||||
},
|
||||
wantHex: nil,
|
||||
errMsg: "Wanted type of customtypes.BlockRoots",
|
||||
errMsg: "non-existent type provided",
|
||||
},
|
||||
{
|
||||
name: "StateRoots customtypes.StateRoots",
|
||||
@@ -145,7 +146,7 @@ func TestFieldTrie_NativeState_fieldConvertersNative(t *testing.T) {
|
||||
convertAll: true,
|
||||
},
|
||||
wantHex: nil,
|
||||
errMsg: "Wanted type of customtypes.StateRoots",
|
||||
errMsg: "non-existent type provided",
|
||||
},
|
||||
{
|
||||
name: "StateRoots customtypes.StateRoots convert all false",
|
||||
@@ -178,7 +179,7 @@ func TestFieldTrie_NativeState_fieldConvertersNative(t *testing.T) {
|
||||
convertAll: true,
|
||||
},
|
||||
wantHex: nil,
|
||||
errMsg: "Wanted type of customtypes.RandaoMixes",
|
||||
errMsg: "non-existent type provided",
|
||||
},
|
||||
{
|
||||
name: "Eth1DataVotes type not found",
|
||||
|
||||
@@ -33,7 +33,7 @@ var (
|
||||
)
|
||||
|
||||
// MultiValueRandaoMixes is a multi-value slice of randao mixes.
|
||||
type MultiValueRandaoMixes = multi_value_slice.Slice[[32]byte, *BeaconState]
|
||||
type MultiValueRandaoMixes = multi_value_slice.Slice[[32]byte]
|
||||
|
||||
// NewMultiValueRandaoMixes creates a new slice whose shared items will be populated with copies of input values.
|
||||
func NewMultiValueRandaoMixes(mixes [][]byte) *MultiValueRandaoMixes {
|
||||
@@ -49,7 +49,7 @@ func NewMultiValueRandaoMixes(mixes [][]byte) *MultiValueRandaoMixes {
|
||||
}
|
||||
|
||||
// MultiValueBlockRoots is a multi-value slice of block roots.
|
||||
type MultiValueBlockRoots = multi_value_slice.Slice[[32]byte, *BeaconState]
|
||||
type MultiValueBlockRoots = multi_value_slice.Slice[[32]byte]
|
||||
|
||||
// NewMultiValueBlockRoots creates a new slice whose shared items will be populated with copies of input values.
|
||||
func NewMultiValueBlockRoots(roots [][]byte) *MultiValueBlockRoots {
|
||||
@@ -65,7 +65,7 @@ func NewMultiValueBlockRoots(roots [][]byte) *MultiValueBlockRoots {
|
||||
}
|
||||
|
||||
// MultiValueStateRoots is a multi-value slice of state roots.
|
||||
type MultiValueStateRoots = multi_value_slice.Slice[[32]byte, *BeaconState]
|
||||
type MultiValueStateRoots = multi_value_slice.Slice[[32]byte]
|
||||
|
||||
// NewMultiValueStateRoots creates a new slice whose shared items will be populated with copies of input values.
|
||||
func NewMultiValueStateRoots(roots [][]byte) *MultiValueStateRoots {
|
||||
@@ -81,7 +81,7 @@ func NewMultiValueStateRoots(roots [][]byte) *MultiValueStateRoots {
|
||||
}
|
||||
|
||||
// MultiValueBalances is a multi-value slice of balances.
|
||||
type MultiValueBalances = multi_value_slice.Slice[uint64, *BeaconState]
|
||||
type MultiValueBalances = multi_value_slice.Slice[uint64]
|
||||
|
||||
// NewMultiValueBalances creates a new slice whose shared items will be populated with copies of input values.
|
||||
func NewMultiValueBalances(balances []uint64) *MultiValueBalances {
|
||||
@@ -95,7 +95,7 @@ func NewMultiValueBalances(balances []uint64) *MultiValueBalances {
|
||||
}
|
||||
|
||||
// MultiValueInactivityScores is a multi-value slice of inactivity scores.
|
||||
type MultiValueInactivityScores = multi_value_slice.Slice[uint64, *BeaconState]
|
||||
type MultiValueInactivityScores = multi_value_slice.Slice[uint64]
|
||||
|
||||
// NewMultiValueInactivityScores creates a new slice whose shared items will be populated with copies of input values.
|
||||
func NewMultiValueInactivityScores(scores []uint64) *MultiValueInactivityScores {
|
||||
@@ -109,7 +109,7 @@ func NewMultiValueInactivityScores(scores []uint64) *MultiValueInactivityScores
|
||||
}
|
||||
|
||||
// MultiValueValidators is a multi-value slice of validator references.
|
||||
type MultiValueValidators = multi_value_slice.Slice[*ethpb.Validator, *BeaconState]
|
||||
type MultiValueValidators = multi_value_slice.Slice[*ethpb.Validator]
|
||||
|
||||
// NewMultiValueValidators creates a new slice whose shared items will be populated with input values.
|
||||
func NewMultiValueValidators(vals []*ethpb.Validator) *MultiValueValidators {
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/v4/config/features"
|
||||
fieldparams "github.com/prysmaticlabs/prysm/v4/config/fieldparams"
|
||||
"github.com/prysmaticlabs/prysm/v4/config/params"
|
||||
mvslice "github.com/prysmaticlabs/prysm/v4/container/multi-value-slice"
|
||||
"github.com/prysmaticlabs/prysm/v4/container/slice"
|
||||
"github.com/prysmaticlabs/prysm/v4/encoding/bytesutil"
|
||||
"github.com/prysmaticlabs/prysm/v4/encoding/ssz"
|
||||
@@ -1178,7 +1179,10 @@ func finalizerCleanup(b *BeaconState) {
|
||||
func (b *BeaconState) blockRootsRootSelector(field types.FieldIndex) ([32]byte, error) {
|
||||
if b.rebuildTrie[field] {
|
||||
if features.Get().EnableExperimentalState {
|
||||
err := b.resetFieldTrie(field, customtypes.BlockRoots(b.blockRootsMultiValue.Value(b)), fieldparams.BlockRootsLength)
|
||||
err := b.resetFieldTrie(field, mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.blockRootsMultiValue,
|
||||
}, fieldparams.BlockRootsLength)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
@@ -1192,7 +1196,10 @@ func (b *BeaconState) blockRootsRootSelector(field types.FieldIndex) ([32]byte,
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
if features.Get().EnableExperimentalState {
|
||||
return b.recomputeFieldTrie(field, customtypes.BlockRoots(b.blockRootsMultiValue.Value(b)))
|
||||
return b.recomputeFieldTrie(field, mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.blockRootsMultiValue,
|
||||
})
|
||||
} else {
|
||||
return b.recomputeFieldTrie(field, b.blockRoots)
|
||||
}
|
||||
@@ -1201,7 +1208,10 @@ func (b *BeaconState) blockRootsRootSelector(field types.FieldIndex) ([32]byte,
|
||||
func (b *BeaconState) stateRootsRootSelector(field types.FieldIndex) ([32]byte, error) {
|
||||
if b.rebuildTrie[field] {
|
||||
if features.Get().EnableExperimentalState {
|
||||
err := b.resetFieldTrie(field, customtypes.StateRoots(b.stateRootsMultiValue.Value(b)), fieldparams.StateRootsLength)
|
||||
err := b.resetFieldTrie(field, mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.stateRootsMultiValue,
|
||||
}, fieldparams.StateRootsLength)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
@@ -1215,7 +1225,10 @@ func (b *BeaconState) stateRootsRootSelector(field types.FieldIndex) ([32]byte,
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
if features.Get().EnableExperimentalState {
|
||||
return b.recomputeFieldTrie(field, customtypes.StateRoots(b.stateRootsMultiValue.Value(b)))
|
||||
return b.recomputeFieldTrie(field, mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.stateRootsMultiValue,
|
||||
})
|
||||
} else {
|
||||
return b.recomputeFieldTrie(field, b.stateRoots)
|
||||
}
|
||||
@@ -1224,7 +1237,10 @@ func (b *BeaconState) stateRootsRootSelector(field types.FieldIndex) ([32]byte,
|
||||
func (b *BeaconState) validatorsRootSelector(field types.FieldIndex) ([32]byte, error) {
|
||||
if b.rebuildTrie[field] {
|
||||
if features.Get().EnableExperimentalState {
|
||||
err := b.resetFieldTrie(field, b.validatorsMultiValue.Value(b), fieldparams.ValidatorRegistryLimit)
|
||||
err := b.resetFieldTrie(field, mvslice.MultiValueSliceComposite[*ethpb.Validator]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.validatorsMultiValue,
|
||||
}, fieldparams.ValidatorRegistryLimit)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
@@ -1238,7 +1254,10 @@ func (b *BeaconState) validatorsRootSelector(field types.FieldIndex) ([32]byte,
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
if features.Get().EnableExperimentalState {
|
||||
return b.recomputeFieldTrie(field, b.validatorsMultiValue.Value(b))
|
||||
return b.recomputeFieldTrie(field, mvslice.MultiValueSliceComposite[*ethpb.Validator]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.validatorsMultiValue,
|
||||
})
|
||||
} else {
|
||||
return b.recomputeFieldTrie(field, b.validators)
|
||||
}
|
||||
@@ -1247,7 +1266,10 @@ func (b *BeaconState) validatorsRootSelector(field types.FieldIndex) ([32]byte,
|
||||
func (b *BeaconState) balancesRootSelector(field types.FieldIndex) ([32]byte, error) {
|
||||
if b.rebuildTrie[field] {
|
||||
if features.Get().EnableExperimentalState {
|
||||
err := b.resetFieldTrie(field, b.balancesMultiValue.Value(b), stateutil.ValidatorLimitForBalancesChunks())
|
||||
err := b.resetFieldTrie(field, mvslice.MultiValueSliceComposite[uint64]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.balancesMultiValue,
|
||||
}, stateutil.ValidatorLimitForBalancesChunks())
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
@@ -1261,7 +1283,10 @@ func (b *BeaconState) balancesRootSelector(field types.FieldIndex) ([32]byte, er
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
if features.Get().EnableExperimentalState {
|
||||
return b.recomputeFieldTrie(field, b.balancesMultiValue.Value(b))
|
||||
return b.recomputeFieldTrie(field, mvslice.MultiValueSliceComposite[uint64]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.balancesMultiValue,
|
||||
})
|
||||
} else {
|
||||
return b.recomputeFieldTrie(field, b.balances)
|
||||
}
|
||||
@@ -1270,7 +1295,10 @@ func (b *BeaconState) balancesRootSelector(field types.FieldIndex) ([32]byte, er
|
||||
func (b *BeaconState) randaoMixesRootSelector(field types.FieldIndex) ([32]byte, error) {
|
||||
if b.rebuildTrie[field] {
|
||||
if features.Get().EnableExperimentalState {
|
||||
err := b.resetFieldTrie(field, customtypes.RandaoMixes(b.randaoMixesMultiValue.Value(b)), fieldparams.RandaoMixesLength)
|
||||
err := b.resetFieldTrie(field, mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.randaoMixesMultiValue,
|
||||
}, fieldparams.RandaoMixesLength)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
}
|
||||
@@ -1284,7 +1312,10 @@ func (b *BeaconState) randaoMixesRootSelector(field types.FieldIndex) ([32]byte,
|
||||
return b.stateFieldLeaves[field].TrieRoot()
|
||||
}
|
||||
if features.Get().EnableExperimentalState {
|
||||
return b.recomputeFieldTrie(field, customtypes.RandaoMixes(b.randaoMixesMultiValue.Value(b)))
|
||||
return b.recomputeFieldTrie(field, mvslice.MultiValueSliceComposite[[32]byte]{
|
||||
Identifiable: b,
|
||||
MultiValueSlice: b.randaoMixesMultiValue,
|
||||
})
|
||||
} else {
|
||||
return b.recomputeFieldTrie(field, b.randaoMixes)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ go_library(
|
||||
srcs = ["multi_value_slice.go"],
|
||||
importpath = "github.com/prysmaticlabs/prysm/v4/container/multi-value-slice",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["@com_github_pkg_errors//:go_default_library"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
|
||||
@@ -92,6 +92,8 @@ package mvslice
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Id is an object identifier.
|
||||
@@ -103,8 +105,22 @@ type Identifiable interface {
|
||||
}
|
||||
|
||||
// MultiValueSlice defines an abstraction over all concrete implementations of the generic Slice.
|
||||
type MultiValueSlice[O Identifiable] interface {
|
||||
Len(obj O) int
|
||||
type MultiValueSlice[V comparable] interface {
|
||||
Len(obj Identifiable) int
|
||||
At(obj Identifiable, index uint64) (V, error)
|
||||
Value(obj Identifiable) []V
|
||||
}
|
||||
|
||||
// MultiValueSliceComposite describes a struct for which we have access to a multivalue
|
||||
// slice along with the desired state.
|
||||
type MultiValueSliceComposite[V comparable] struct {
|
||||
Identifiable
|
||||
MultiValueSlice[V]
|
||||
}
|
||||
|
||||
// State returns the referenced state.
|
||||
func (m MultiValueSliceComposite[V]) State() Identifiable {
|
||||
return m.Identifiable
|
||||
}
|
||||
|
||||
// Value defines a single value along with one or more IDs that share this value.
|
||||
@@ -124,7 +140,7 @@ type MultiValueItem[V any] struct {
|
||||
// - O interfaces.Identifiable - the type of objects sharing the slice. The constraint is required
|
||||
// because we need a way to compare objects against each other in order to know which objects
|
||||
// values should be accessed.
|
||||
type Slice[V comparable, O Identifiable] struct {
|
||||
type Slice[V comparable] struct {
|
||||
sharedItems []V
|
||||
individualItems map[uint64]*MultiValueItem[V]
|
||||
appendedItems []*MultiValueItem[V]
|
||||
@@ -133,7 +149,7 @@ type Slice[V comparable, O Identifiable] struct {
|
||||
}
|
||||
|
||||
// Init initializes the slice with sensible defaults. Input values are assigned to shared items.
|
||||
func (s *Slice[V, O]) Init(items []V) {
|
||||
func (s *Slice[V]) Init(items []V) {
|
||||
s.sharedItems = items
|
||||
s.individualItems = map[uint64]*MultiValueItem[V]{}
|
||||
s.appendedItems = []*MultiValueItem[V]{}
|
||||
@@ -141,7 +157,7 @@ func (s *Slice[V, O]) Init(items []V) {
|
||||
}
|
||||
|
||||
// Len returns the number of items for the input object.
|
||||
func (s *Slice[V, O]) Len(obj O) int {
|
||||
func (s *Slice[V]) Len(obj Identifiable) int {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
@@ -153,7 +169,7 @@ func (s *Slice[V, O]) Len(obj O) int {
|
||||
}
|
||||
|
||||
// Copy copies items between the source and destination.
|
||||
func (s *Slice[V, O]) Copy(src O, dst O) {
|
||||
func (s *Slice[V]) Copy(src, dst Identifiable) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@@ -190,7 +206,7 @@ func (s *Slice[V, O]) Copy(src O, dst O) {
|
||||
}
|
||||
|
||||
// Value returns all items for the input object.
|
||||
func (s *Slice[V, O]) Value(obj O) []V {
|
||||
func (s *Slice[V]) Value(obj Identifiable) []V {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
@@ -228,7 +244,7 @@ func (s *Slice[V, O]) Value(obj O) []V {
|
||||
// We first check if the index is within the length of shared items.
|
||||
// If it is, then we return an individual value at that index - if it exists - or a shared value otherwise.
|
||||
// If the index is beyond the length of shared values, it is an appended item and that's what gets returned.
|
||||
func (s *Slice[V, O]) At(obj O, index uint64) (V, error) {
|
||||
func (s *Slice[V]) At(obj Identifiable, index uint64) (V, error) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
@@ -266,7 +282,7 @@ func (s *Slice[V, O]) At(obj O, index uint64) (V, error) {
|
||||
}
|
||||
|
||||
// UpdateAt updates the item at the required index for the input object to the passed in value.
|
||||
func (s *Slice[V, O]) UpdateAt(obj O, index uint64, val V) error {
|
||||
func (s *Slice[V]) UpdateAt(obj Identifiable, index uint64, val V) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@@ -283,7 +299,7 @@ func (s *Slice[V, O]) UpdateAt(obj O, index uint64, val V) error {
|
||||
}
|
||||
|
||||
// Append adds a new item to the input object.
|
||||
func (s *Slice[V, O]) Append(obj O, val V) {
|
||||
func (s *Slice[V]) Append(obj Identifiable, val V) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@@ -332,7 +348,7 @@ func (s *Slice[V, O]) Append(obj O, val V) {
|
||||
|
||||
// Detach removes the input object from the multi-value slice.
|
||||
// What this means in practice is that we remove all individual and appended values for that object and clear the cached length.
|
||||
func (s *Slice[V, O]) Detach(obj O) {
|
||||
func (s *Slice[V]) Detach(obj Identifiable) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@@ -378,7 +394,7 @@ func (s *Slice[V, O]) Detach(obj O) {
|
||||
delete(s.cachedLengths, obj.Id())
|
||||
}
|
||||
|
||||
func (s *Slice[V, O]) fillOriginalItems(obj O, items *[]V) {
|
||||
func (s *Slice[V]) fillOriginalItems(obj Identifiable, items *[]V) {
|
||||
for i, item := range s.sharedItems {
|
||||
ind, ok := s.individualItems[uint64(i)]
|
||||
if !ok {
|
||||
@@ -399,7 +415,7 @@ func (s *Slice[V, O]) fillOriginalItems(obj O, items *[]V) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Slice[V, O]) updateOriginalItem(obj O, index uint64, val V) {
|
||||
func (s *Slice[V]) updateOriginalItem(obj Identifiable, index uint64, val V) {
|
||||
ind, ok := s.individualItems[index]
|
||||
if ok {
|
||||
for mvi, v := range ind.Values {
|
||||
@@ -440,7 +456,7 @@ func (s *Slice[V, O]) updateOriginalItem(obj O, index uint64, val V) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Slice[V, O]) updateAppendedItem(obj O, index uint64, val V) error {
|
||||
func (s *Slice[V]) updateAppendedItem(obj Identifiable, index uint64, val V) error {
|
||||
item := s.appendedItems[index-uint64(len(s.sharedItems))]
|
||||
found := false
|
||||
for vi, v := range item.Values {
|
||||
@@ -491,3 +507,34 @@ func deleteElemFromSlice[T any](s []T, i int) []T {
|
||||
s = s[:len(s)-1] // Truncate slice.
|
||||
return s
|
||||
}
|
||||
|
||||
// EmptyMVSlice specifies a type which allows a normal slice to conform
|
||||
// to the multivalue slice interface.
|
||||
type EmptyMVSlice[V comparable] struct {
|
||||
fullSlice []V
|
||||
}
|
||||
|
||||
func (e EmptyMVSlice[V]) Len(_ Identifiable) int {
|
||||
return len(e.fullSlice)
|
||||
}
|
||||
|
||||
func (e EmptyMVSlice[V]) At(_ Identifiable, index uint64) (V, error) {
|
||||
if index >= uint64(len(e.fullSlice)) {
|
||||
var def V
|
||||
return def, errors.Errorf("index %d out of bounds", index)
|
||||
}
|
||||
return e.fullSlice[index], nil
|
||||
}
|
||||
|
||||
func (e EmptyMVSlice[V]) Value(_ Identifiable) []V {
|
||||
return e.fullSlice
|
||||
}
|
||||
|
||||
// BuildEmptyCompositeSlice builds a composite multivalue object with a native
|
||||
// slice.
|
||||
func BuildEmptyCompositeSlice[V comparable](values []V) MultiValueSliceComposite[V] {
|
||||
return MultiValueSliceComposite[V]{
|
||||
Identifiable: nil,
|
||||
MultiValueSlice: EmptyMVSlice[V]{fullSlice: values},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ func (o *testObject) SetId(id uint64) {
|
||||
}
|
||||
|
||||
func TestLen(t *testing.T) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init([]int{1, 2, 3})
|
||||
s.cachedLengths[1] = 123
|
||||
t.Run("cached", func(t *testing.T) {
|
||||
@@ -93,7 +93,7 @@ func TestValue(t *testing.T) {
|
||||
assert.Equal(t, 3, v[6])
|
||||
assert.Equal(t, 2, v[7])
|
||||
|
||||
s = &Slice[int, *testObject]{}
|
||||
s = &Slice[int]{}
|
||||
s.Init([]int{1, 2, 3})
|
||||
|
||||
v = s.Value(&testObject{id: 999})
|
||||
@@ -246,7 +246,7 @@ func TestAppend(t *testing.T) {
|
||||
// - we also want to check that cached length is properly updated after every append
|
||||
|
||||
// we want to start with the simplest slice possible
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init([]int{0})
|
||||
first := &testObject{id: 1}
|
||||
second := &testObject{id: 2}
|
||||
@@ -335,8 +335,8 @@ func TestDetach(t *testing.T) {
|
||||
// Index 5: Different appended value
|
||||
// Index 6: Same appended value
|
||||
// Index 7: Appended value ONLY for the second object
|
||||
func setup() *Slice[int, *testObject] {
|
||||
s := &Slice[int, *testObject]{}
|
||||
func setup() *Slice[int] {
|
||||
s := &Slice[int]{}
|
||||
s.Init([]int{123, 123, 123, 123, 123})
|
||||
s.individualItems[1] = &MultiValueItem[int]{
|
||||
Values: []*Value[int]{
|
||||
@@ -410,7 +410,7 @@ func setup() *Slice[int, *testObject] {
|
||||
return s
|
||||
}
|
||||
|
||||
func assertIndividualFound(t *testing.T, slice *Slice[int, *testObject], id uint64, itemIndex uint64, expected int) {
|
||||
func assertIndividualFound(t *testing.T, slice *Slice[int], id uint64, itemIndex uint64, expected int) {
|
||||
found := false
|
||||
for _, v := range slice.individualItems[itemIndex].Values {
|
||||
for _, o := range v.ids {
|
||||
@@ -423,7 +423,7 @@ func assertIndividualFound(t *testing.T, slice *Slice[int, *testObject], id uint
|
||||
assert.Equal(t, true, found)
|
||||
}
|
||||
|
||||
func assertIndividualNotFound(t *testing.T, slice *Slice[int, *testObject], id uint64, itemIndex uint64) {
|
||||
func assertIndividualNotFound(t *testing.T, slice *Slice[int], id uint64, itemIndex uint64) {
|
||||
found := false
|
||||
for _, v := range slice.individualItems[itemIndex].Values {
|
||||
for _, o := range v.ids {
|
||||
@@ -435,7 +435,7 @@ func assertIndividualNotFound(t *testing.T, slice *Slice[int, *testObject], id u
|
||||
assert.Equal(t, false, found)
|
||||
}
|
||||
|
||||
func assertAppendedFound(t *testing.T, slice *Slice[int, *testObject], id uint64, itemIndex uint64, expected int) {
|
||||
func assertAppendedFound(t *testing.T, slice *Slice[int], id uint64, itemIndex uint64, expected int) {
|
||||
found := false
|
||||
for _, v := range slice.appendedItems[itemIndex].Values {
|
||||
for _, o := range v.ids {
|
||||
@@ -448,7 +448,7 @@ func assertAppendedFound(t *testing.T, slice *Slice[int, *testObject], id uint64
|
||||
assert.Equal(t, true, found)
|
||||
}
|
||||
|
||||
func assertAppendedNotFound(t *testing.T, slice *Slice[int, *testObject], id uint64, itemIndex uint64) {
|
||||
func assertAppendedNotFound(t *testing.T, slice *Slice[int], id uint64, itemIndex uint64) {
|
||||
found := false
|
||||
for _, v := range slice.appendedItems[itemIndex].Values {
|
||||
for _, o := range v.ids {
|
||||
@@ -466,14 +466,14 @@ func BenchmarkValue(b *testing.B) {
|
||||
const _10m = 10000000
|
||||
|
||||
b.Run("100,000 shared items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _100k))
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Value(&testObject{})
|
||||
}
|
||||
})
|
||||
b.Run("100,000 equal individual items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _100k))
|
||||
s.individualItems[0] = &MultiValueItem[int]{Values: []*Value[int]{{val: 999, ids: []uint64{}}}}
|
||||
objs := make([]*testObject, _100k)
|
||||
@@ -486,7 +486,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("100,000 different individual items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _100k))
|
||||
objs := make([]*testObject, _100k)
|
||||
for i := 0; i < len(objs); i++ {
|
||||
@@ -498,7 +498,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("100,000 shared items and 100,000 equal appended items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _100k))
|
||||
s.appendedItems = []*MultiValueItem[int]{{Values: []*Value[int]{{val: 999, ids: []uint64{}}}}}
|
||||
objs := make([]*testObject, _100k)
|
||||
@@ -511,7 +511,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("100,000 shared items and 100,000 different appended items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _100k))
|
||||
s.appendedItems = []*MultiValueItem[int]{}
|
||||
objs := make([]*testObject, _100k)
|
||||
@@ -524,14 +524,14 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("1,000,000 shared items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _1m))
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Value(&testObject{})
|
||||
}
|
||||
})
|
||||
b.Run("1,000,000 equal individual items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _1m))
|
||||
s.individualItems[0] = &MultiValueItem[int]{Values: []*Value[int]{{val: 999, ids: []uint64{}}}}
|
||||
objs := make([]*testObject, _1m)
|
||||
@@ -544,7 +544,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("1,000,000 different individual items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _1m))
|
||||
objs := make([]*testObject, _1m)
|
||||
for i := 0; i < len(objs); i++ {
|
||||
@@ -556,7 +556,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("1,000,000 shared items and 1,000,000 equal appended items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _1m))
|
||||
s.appendedItems = []*MultiValueItem[int]{{Values: []*Value[int]{{val: 999, ids: []uint64{}}}}}
|
||||
objs := make([]*testObject, _1m)
|
||||
@@ -569,7 +569,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("1,000,000 shared items and 1,000,000 different appended items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _1m))
|
||||
s.appendedItems = []*MultiValueItem[int]{}
|
||||
objs := make([]*testObject, _1m)
|
||||
@@ -582,14 +582,14 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("10,000,000 shared items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _10m))
|
||||
for i := 0; i < b.N; i++ {
|
||||
s.Value(&testObject{})
|
||||
}
|
||||
})
|
||||
b.Run("10,000,000 equal individual items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _10m))
|
||||
s.individualItems[0] = &MultiValueItem[int]{Values: []*Value[int]{{val: 999, ids: []uint64{}}}}
|
||||
objs := make([]*testObject, _10m)
|
||||
@@ -602,7 +602,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("10,000,000 different individual items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _10m))
|
||||
objs := make([]*testObject, _10m)
|
||||
for i := 0; i < len(objs); i++ {
|
||||
@@ -614,7 +614,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("10,000,000 shared items and 10,000,000 equal appended items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _10m))
|
||||
s.appendedItems = []*MultiValueItem[int]{{Values: []*Value[int]{{val: 999, ids: []uint64{}}}}}
|
||||
objs := make([]*testObject, _10m)
|
||||
@@ -627,7 +627,7 @@ func BenchmarkValue(b *testing.B) {
|
||||
}
|
||||
})
|
||||
b.Run("10,000,000 shared items and 10,000,000 different appended items", func(b *testing.B) {
|
||||
s := &Slice[int, *testObject]{}
|
||||
s := &Slice[int]{}
|
||||
s.Init(make([]int, _10m))
|
||||
s.appendedItems = []*MultiValueItem[int]{}
|
||||
objs := make([]*testObject, _10m)
|
||||
|
||||
Reference in New Issue
Block a user