State/v2: Add InitializeFromProto and InitializeFromProtoUnsafe (#9226)

* Add `InitializeFromProto`

* Update BUILD.bazel
This commit is contained in:
terence tsao
2021-07-19 09:47:42 -07:00
committed by GitHub
parent 2a0c4e0d5f
commit 1beb0071b5
3 changed files with 152 additions and 1 deletions

View File

@@ -4,6 +4,7 @@ go_library(
name = "go_default_library",
srcs = [
"field_trie.go",
"state_trie.go",
"types.go",
],
importpath = "github.com/prysmaticlabs/prysm/beacon-chain/state/v2",
@@ -28,15 +29,20 @@ go_library(
"//proto/eth/v1alpha1:go_default_library",
"//shared/params:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["field_trie_test.go"],
srcs = [
"field_trie_test.go",
"state_trie_test.go",
],
embed = [":go_default_library"],
deps = [
"//beacon-chain/state/v1:go_default_library",
"//proto/beacon/p2p/v1:go_default_library",
"//proto/eth/v1alpha1:go_default_library",
"//shared/params:go_default_library",
"//shared/testutil:go_default_library",

View File

@@ -0,0 +1,61 @@
package v2
import (
"sync"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stateutil"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/params"
"google.golang.org/protobuf/proto"
)
// InitializeFromProto the beacon state from a protobuf representation.
func InitializeFromProto(st *pbp2p.BeaconStateAltair) (*BeaconState, error) {
return InitializeFromProtoUnsafe(proto.Clone(st).(*pbp2p.BeaconStateAltair))
}
// InitializeFromProtoUnsafe directly uses the beacon state protobuf pointer
// and sets it as the inner state of the BeaconState type.
func InitializeFromProtoUnsafe(st *pbp2p.BeaconStateAltair) (*BeaconState, error) {
if st == nil {
return nil, errors.New("received nil state")
}
fieldCount := params.BeaconConfig().BeaconStateAltairFieldCount
b := &BeaconState{
state: st,
dirtyFields: make(map[fieldIndex]interface{}, fieldCount),
dirtyIndices: make(map[fieldIndex][]uint64, fieldCount),
stateFieldLeaves: make(map[fieldIndex]*FieldTrie, fieldCount),
sharedFieldReferences: make(map[fieldIndex]*stateutil.Reference, 11),
rebuildTrie: make(map[fieldIndex]bool, fieldCount),
valMapHandler: stateutil.NewValMapHandler(st.Validators),
}
for i := 0; i < fieldCount; i++ {
b.dirtyFields[fieldIndex(i)] = true
b.rebuildTrie[fieldIndex(i)] = true
b.dirtyIndices[fieldIndex(i)] = []uint64{}
b.stateFieldLeaves[fieldIndex(i)] = &FieldTrie{
field: fieldIndex(i),
reference: stateutil.NewRef(1),
RWMutex: new(sync.RWMutex),
}
}
// Initialize field reference tracking for shared data.
b.sharedFieldReferences[randaoMixes] = stateutil.NewRef(1)
b.sharedFieldReferences[stateRoots] = stateutil.NewRef(1)
b.sharedFieldReferences[blockRoots] = stateutil.NewRef(1)
b.sharedFieldReferences[previousEpochParticipationBits] = stateutil.NewRef(1) // New in Altair.
b.sharedFieldReferences[currentEpochParticipationBits] = stateutil.NewRef(1) // New in Altair.
b.sharedFieldReferences[slashings] = stateutil.NewRef(1)
b.sharedFieldReferences[eth1DataVotes] = stateutil.NewRef(1)
b.sharedFieldReferences[validators] = stateutil.NewRef(1)
b.sharedFieldReferences[balances] = stateutil.NewRef(1)
b.sharedFieldReferences[inactivityScores] = stateutil.NewRef(1) // New in Altair.
b.sharedFieldReferences[historicalRoots] = stateutil.NewRef(1)
return b, nil
}

View File

@@ -0,0 +1,84 @@
package v2_test
import (
"testing"
stateAltair "github.com/prysmaticlabs/prysm/beacon-chain/state/v2"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/testutil/assert"
"github.com/prysmaticlabs/prysm/shared/testutil/require"
)
func TestInitializeFromProto(t *testing.T) {
type test struct {
name string
state *pb.BeaconStateAltair
error string
}
initTests := []test{
{
name: "nil state",
state: nil,
error: "received nil state",
},
{
name: "nil validators",
state: &pb.BeaconStateAltair{
Slot: 4,
Validators: nil,
},
},
{
name: "empty state",
state: &pb.BeaconStateAltair{},
},
// TODO: Add full state. Blocked by testutil migration.
}
for _, tt := range initTests {
t.Run(tt.name, func(t *testing.T) {
_, err := stateAltair.InitializeFromProto(tt.state)
if tt.error != "" {
require.ErrorContains(t, tt.error, err)
} else {
require.NoError(t, err)
}
})
}
}
func TestInitializeFromProtoUnsafe(t *testing.T) {
type test struct {
name string
state *pb.BeaconStateAltair
error string
}
initTests := []test{
{
name: "nil state",
state: nil,
error: "received nil state",
},
{
name: "nil validators",
state: &pb.BeaconStateAltair{
Slot: 4,
Validators: nil,
},
},
{
name: "empty state",
state: &pb.BeaconStateAltair{},
},
// TODO: Add full state. Blocked by testutil migration.
}
for _, tt := range initTests {
t.Run(tt.name, func(t *testing.T) {
_, err := stateAltair.InitializeFromProtoUnsafe(tt.state)
if tt.error != "" {
assert.ErrorContains(t, tt.error, err)
} else {
assert.NoError(t, err)
}
})
}
}