From 1beb0071b5730f0ec0fb1beb9056c06d55ef562b Mon Sep 17 00:00:00 2001 From: terence tsao Date: Mon, 19 Jul 2021 09:47:42 -0700 Subject: [PATCH] State/v2: Add `InitializeFromProto` and `InitializeFromProtoUnsafe` (#9226) * Add `InitializeFromProto` * Update BUILD.bazel --- beacon-chain/state/v2/BUILD.bazel | 8 ++- beacon-chain/state/v2/state_trie.go | 61 +++++++++++++++++ beacon-chain/state/v2/state_trie_test.go | 84 ++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 beacon-chain/state/v2/state_trie.go create mode 100644 beacon-chain/state/v2/state_trie_test.go diff --git a/beacon-chain/state/v2/BUILD.bazel b/beacon-chain/state/v2/BUILD.bazel index a4b5b4f1da..644e8cdd9c 100644 --- a/beacon-chain/state/v2/BUILD.bazel +++ b/beacon-chain/state/v2/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "go_default_library", srcs = [ "field_trie.go", + "state_trie.go", "types.go", ], importpath = "github.com/prysmaticlabs/prysm/beacon-chain/state/v2", @@ -28,15 +29,20 @@ go_library( "//proto/eth/v1alpha1:go_default_library", "//shared/params:go_default_library", "@com_github_pkg_errors//:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", ], ) go_test( name = "go_default_test", - srcs = ["field_trie_test.go"], + srcs = [ + "field_trie_test.go", + "state_trie_test.go", + ], embed = [":go_default_library"], deps = [ "//beacon-chain/state/v1:go_default_library", + "//proto/beacon/p2p/v1:go_default_library", "//proto/eth/v1alpha1:go_default_library", "//shared/params:go_default_library", "//shared/testutil:go_default_library", diff --git a/beacon-chain/state/v2/state_trie.go b/beacon-chain/state/v2/state_trie.go new file mode 100644 index 0000000000..8f54214a09 --- /dev/null +++ b/beacon-chain/state/v2/state_trie.go @@ -0,0 +1,61 @@ +package v2 + +import ( + "sync" + + "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil" + pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" + "github.com/prysmaticlabs/prysm/shared/params" + "google.golang.org/protobuf/proto" +) + +// InitializeFromProto the beacon state from a protobuf representation. +func InitializeFromProto(st *pbp2p.BeaconStateAltair) (*BeaconState, error) { + return InitializeFromProtoUnsafe(proto.Clone(st).(*pbp2p.BeaconStateAltair)) +} + +// InitializeFromProtoUnsafe directly uses the beacon state protobuf pointer +// and sets it as the inner state of the BeaconState type. +func InitializeFromProtoUnsafe(st *pbp2p.BeaconStateAltair) (*BeaconState, error) { + if st == nil { + return nil, errors.New("received nil state") + } + + fieldCount := params.BeaconConfig().BeaconStateAltairFieldCount + b := &BeaconState{ + state: st, + dirtyFields: make(map[fieldIndex]interface{}, fieldCount), + dirtyIndices: make(map[fieldIndex][]uint64, fieldCount), + stateFieldLeaves: make(map[fieldIndex]*FieldTrie, fieldCount), + sharedFieldReferences: make(map[fieldIndex]*stateutil.Reference, 11), + rebuildTrie: make(map[fieldIndex]bool, fieldCount), + valMapHandler: stateutil.NewValMapHandler(st.Validators), + } + + for i := 0; i < fieldCount; i++ { + b.dirtyFields[fieldIndex(i)] = true + b.rebuildTrie[fieldIndex(i)] = true + b.dirtyIndices[fieldIndex(i)] = []uint64{} + b.stateFieldLeaves[fieldIndex(i)] = &FieldTrie{ + field: fieldIndex(i), + reference: stateutil.NewRef(1), + RWMutex: new(sync.RWMutex), + } + } + + // Initialize field reference tracking for shared data. + b.sharedFieldReferences[randaoMixes] = stateutil.NewRef(1) + b.sharedFieldReferences[stateRoots] = stateutil.NewRef(1) + b.sharedFieldReferences[blockRoots] = stateutil.NewRef(1) + b.sharedFieldReferences[previousEpochParticipationBits] = stateutil.NewRef(1) // New in Altair. + b.sharedFieldReferences[currentEpochParticipationBits] = stateutil.NewRef(1) // New in Altair. + b.sharedFieldReferences[slashings] = stateutil.NewRef(1) + b.sharedFieldReferences[eth1DataVotes] = stateutil.NewRef(1) + b.sharedFieldReferences[validators] = stateutil.NewRef(1) + b.sharedFieldReferences[balances] = stateutil.NewRef(1) + b.sharedFieldReferences[inactivityScores] = stateutil.NewRef(1) // New in Altair. + b.sharedFieldReferences[historicalRoots] = stateutil.NewRef(1) + + return b, nil +} diff --git a/beacon-chain/state/v2/state_trie_test.go b/beacon-chain/state/v2/state_trie_test.go new file mode 100644 index 0000000000..66ef6a54b5 --- /dev/null +++ b/beacon-chain/state/v2/state_trie_test.go @@ -0,0 +1,84 @@ +package v2_test + +import ( + "testing" + + stateAltair "github.com/prysmaticlabs/prysm/beacon-chain/state/v2" + pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" +) + +func TestInitializeFromProto(t *testing.T) { + type test struct { + name string + state *pb.BeaconStateAltair + error string + } + initTests := []test{ + { + name: "nil state", + state: nil, + error: "received nil state", + }, + { + name: "nil validators", + state: &pb.BeaconStateAltair{ + Slot: 4, + Validators: nil, + }, + }, + { + name: "empty state", + state: &pb.BeaconStateAltair{}, + }, + // TODO: Add full state. Blocked by testutil migration. + } + for _, tt := range initTests { + t.Run(tt.name, func(t *testing.T) { + _, err := stateAltair.InitializeFromProto(tt.state) + if tt.error != "" { + require.ErrorContains(t, tt.error, err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestInitializeFromProtoUnsafe(t *testing.T) { + type test struct { + name string + state *pb.BeaconStateAltair + error string + } + initTests := []test{ + { + name: "nil state", + state: nil, + error: "received nil state", + }, + { + name: "nil validators", + state: &pb.BeaconStateAltair{ + Slot: 4, + Validators: nil, + }, + }, + { + name: "empty state", + state: &pb.BeaconStateAltair{}, + }, + // TODO: Add full state. Blocked by testutil migration. + } + for _, tt := range initTests { + t.Run(tt.name, func(t *testing.T) { + _, err := stateAltair.InitializeFromProtoUnsafe(tt.state) + if tt.error != "" { + assert.ErrorContains(t, tt.error, err) + } else { + assert.NoError(t, err) + } + }) + } +}