State: Return interface{} for pb related methods (#8617)

* Return interface{} instead of *pbp2p.BeaconState

* Comment
This commit is contained in:
terence tsao
2021-03-16 20:26:17 -07:00
committed by GitHub
parent 50e99fb6c1
commit 7b16601399
19 changed files with 105 additions and 42 deletions

View File

@@ -70,8 +70,10 @@ func setupBeaconChain(t *testing.T, beaconDB db.Database) *Service {
var web3Service *powchain.Service
var err error
bState, _ := testutil.DeterministicGenesisState(t, 10)
pbState, err := beaconstate.ProtobufBeaconState(bState.InnerStateUnsafe())
require.NoError(t, err)
err = beaconDB.SavePowchainData(ctx, &protodb.ETH1ChainData{
BeaconState: bState.InnerStateUnsafe(),
BeaconState: pbState,
Trie: &protodb.SparseMerkleTrie{},
CurrentEth1Data: &protodb.LatestETH1Data{
BlockHash: make([]byte, 32),

View File

@@ -34,7 +34,11 @@ func TestCheckpointStateCache_StateByCheckpoint(t *testing.T) {
state, err = cache.StateByCheckpoint(cp1)
require.NoError(t, err)
if !proto.Equal(state.InnerStateUnsafe(), st.InnerStateUnsafe()) {
pbState1, err := stateTrie.ProtobufBeaconState(state.InnerStateUnsafe())
require.NoError(t, err)
pbState2, err := stateTrie.ProtobufBeaconState(st.InnerStateUnsafe())
require.NoError(t, err)
if !proto.Equal(pbState1, pbState2) {
t.Error("incorrectly cached state")
}

View File

@@ -55,7 +55,9 @@ func runBlockHeaderTest(t *testing.T, config string) {
postBeaconState := &pb.BeaconState{}
require.NoError(t, postBeaconState.UnmarshalSSZ(postBeaconStateFile), "Failed to unmarshal")
if !proto.Equal(beaconState.CloneInnerState(), postBeaconState) {
pbState, err := stateTrie.ProtobufBeaconState(beaconState.CloneInnerState())
require.NoError(t, err)
if !proto.Equal(pbState, postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.CloneInnerState(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")

View File

@@ -84,8 +84,9 @@ func runBlockProcessingTest(t *testing.T, config string) {
postBeaconState := &pb.BeaconState{}
require.NoError(t, postBeaconState.UnmarshalSSZ(postBeaconStateFile), "Failed to unmarshal")
if !proto.Equal(beaconState.InnerStateUnsafe(), postBeaconState) {
pbState, err := stateTrie.ProtobufBeaconState(beaconState.InnerStateUnsafe())
require.NoError(t, err)
if !proto.Equal(pbState, postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.InnerStateUnsafe(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")

View File

@@ -6,7 +6,8 @@ import (
"github.com/gogo/protobuf/proto"
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
"github.com/prysmaticlabs/prysm/beacon-chain/core/state"
coreState "github.com/prysmaticlabs/prysm/beacon-chain/core/state"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/benchutil"
@@ -26,7 +27,7 @@ func BenchmarkExecuteStateTransition_FullBlock(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := state.ExecuteStateTransition(context.Background(), cleanStates[i], block)
_, err := coreState.ExecuteStateTransition(context.Background(), cleanStates[i], block)
require.NoError(b, err)
}
}
@@ -47,12 +48,12 @@ func BenchmarkExecuteStateTransition_WithCache(b *testing.B) {
require.NoError(b, helpers.UpdateCommitteeCache(beaconState, helpers.CurrentEpoch(beaconState)))
require.NoError(b, beaconState.SetSlot(currentSlot))
// Run the state transition once to populate the cache.
_, err = state.ExecuteStateTransition(context.Background(), beaconState, block)
_, err = coreState.ExecuteStateTransition(context.Background(), beaconState, block)
require.NoError(b, err, "Failed to process block, benchmarks will fail")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := state.ExecuteStateTransition(context.Background(), cleanStates[i], block)
_, err := coreState.ExecuteStateTransition(context.Background(), cleanStates[i], block)
require.NoError(b, err, "Failed to process block, benchmarks will fail")
}
}
@@ -73,7 +74,7 @@ func BenchmarkProcessEpoch_2FullEpochs(b *testing.B) {
for i := 0; i < b.N; i++ {
// ProcessEpochPrecompute is the optimized version of process epoch. It's enabled by default
// at run time.
_, err := state.ProcessEpochPrecompute(context.Background(), beaconState.Copy())
_, err := coreState.ProcessEpochPrecompute(context.Background(), beaconState.Copy())
require.NoError(b, err)
}
}
@@ -109,8 +110,8 @@ func BenchmarkHashTreeRootState_FullState(b *testing.B) {
func BenchmarkMarshalState_FullState(b *testing.B) {
beaconState, err := benchutil.PreGenState2FullEpochs()
require.NoError(b, err)
natState := beaconState.InnerStateUnsafe()
natState, err := state.ProtobufBeaconState(beaconState.InnerStateUnsafe())
require.NoError(b, err)
b.Run("Proto_Marshal", func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
@@ -133,7 +134,8 @@ func BenchmarkMarshalState_FullState(b *testing.B) {
func BenchmarkUnmarshalState_FullState(b *testing.B) {
beaconState, err := benchutil.PreGenState2FullEpochs()
require.NoError(b, err)
natState := beaconState.InnerStateUnsafe()
natState, err := state.ProtobufBeaconState(beaconState.InnerStateUnsafe())
require.NoError(b, err)
protoObject, err := proto.Marshal(natState)
require.NoError(b, err)
sszObject, err := natState.MarshalSSZ()

View File

@@ -18,7 +18,9 @@ func TestSkipSlotCache_OK(t *testing.T) {
state.SkipSlotCache.Enable()
defer state.SkipSlotCache.Disable()
bState, privs := testutil.DeterministicGenesisState(t, params.MinimalSpecConfig().MinGenesisActiveValidatorCount)
originalState, err := beaconstate.InitializeFromProto(bState.CloneInnerState())
pbState, err := beaconstate.ProtobufBeaconState(bState.CloneInnerState())
require.NoError(t, err)
originalState, err := beaconstate.InitializeFromProto(pbState)
require.NoError(t, err)
blkCfg := testutil.DefaultBlockGenConfig()
@@ -40,7 +42,9 @@ func TestSkipSlotCache_OK(t *testing.T) {
func TestSkipSlotCache_ConcurrentMixup(t *testing.T) {
bState, privs := testutil.DeterministicGenesisState(t, params.MinimalSpecConfig().MinGenesisActiveValidatorCount)
originalState, err := beaconstate.InitializeFromProto(bState.CloneInnerState())
pbState, err := beaconstate.ProtobufBeaconState(bState.CloneInnerState())
require.NoError(t, err)
originalState, err := beaconstate.InitializeFromProto(pbState)
require.NoError(t, err)
blkCfg := testutil.DefaultBlockGenConfig()

View File

@@ -46,7 +46,9 @@ func runSlotProcessingTests(t *testing.T, config string) {
postState, err := state.ProcessSlots(context.Background(), beaconState, beaconState.Slot().Add(uint64(slotsCount)))
require.NoError(t, err)
if !proto.Equal(postState.CloneInnerState(), postBeaconState) {
pbState, err := beaconstate.ProtobufBeaconState(postState.CloneInnerState())
require.NoError(t, err)
if !proto.Equal(pbState, postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState, postBeaconState)
t.Fatalf("Post state does not match expected. Diff between states %s", diff)
}

View File

@@ -7,6 +7,7 @@ import (
types "github.com/prysmaticlabs/eth2-types"
ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/beacon-chain/core/state"
beaconstate "github.com/prysmaticlabs/prysm/beacon-chain/state"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/hashutil"
"github.com/prysmaticlabs/prysm/shared/params"
@@ -93,8 +94,13 @@ func TestGenesisState_HashEquality(t *testing.T) {
state2, err := state.GenesisBeaconState(deposits, 0, &ethpb.Eth1Data{BlockHash: make([]byte, 32)})
require.NoError(t, err)
root1, err1 := hashutil.HashProto(state1.CloneInnerState())
root2, err2 := hashutil.HashProto(state2.CloneInnerState())
pbState1, err := beaconstate.ProtobufBeaconState(state1.CloneInnerState())
require.NoError(t, err)
pbState2, err := beaconstate.ProtobufBeaconState(state2.CloneInnerState())
require.NoError(t, err)
root1, err1 := hashutil.HashProto(pbState1)
root2, err2 := hashutil.HashProto(pbState2)
if err1 != nil || err2 != nil {
t.Fatalf("Failed to marshal state to bytes: %v %v", err1, err2)

View File

@@ -83,10 +83,13 @@ func (s *Store) SaveStates(ctx context.Context, states []iface.ReadOnlyBeaconSta
if states == nil {
return errors.New("nil state")
}
var err error
multipleEncs := make([][]byte, len(states))
for i, st := range states {
multipleEncs[i], err = encode(ctx, st.InnerStateUnsafe())
pbState, err := state.ProtobufBeaconState(st.InnerStateUnsafe())
if err != nil {
return err
}
multipleEncs[i], err = encode(ctx, pbState)
if err != nil {
return err
}

View File

@@ -16,7 +16,8 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/core/feed"
statefeed "github.com/prysmaticlabs/prysm/beacon-chain/core/feed/state"
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
"github.com/prysmaticlabs/prysm/beacon-chain/core/state"
coreState "github.com/prysmaticlabs/prysm/beacon-chain/core/state"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
contracts "github.com/prysmaticlabs/prysm/contracts/deposit-contract"
protodb "github.com/prysmaticlabs/prysm/proto/beacon/db"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
@@ -521,7 +522,7 @@ func (s *Service) checkForChainstart(blockHash [32]byte, blockNumber *big.Int, b
if valCount == 0 {
return
}
triggered := state.IsValidGenesisState(valCount, genesisTime)
triggered := coreState.IsValidGenesisState(valCount, genesisTime)
if triggered {
s.chainStartData.GenesisTime = genesisTime
s.ProcessChainStart(s.chainStartData.GenesisTime, blockHash, blockNumber)
@@ -530,10 +531,14 @@ func (s *Service) checkForChainstart(blockHash [32]byte, blockNumber *big.Int, b
// save all powchain related metadata to disk.
func (s *Service) savePowchainData(ctx context.Context) error {
pbState, err := state.ProtobufBeaconState(s.preGenesisState.InnerStateUnsafe())
if err != nil {
return err
}
eth1Data := &protodb.ETH1ChainData{
CurrentEth1Data: s.latestEth1Data,
ChainstartData: s.chainStartData,
BeaconState: s.preGenesisState.InnerStateUnsafe(), // I promise not to mutate it!
BeaconState: pbState, // I promise not to mutate it!
Trie: s.depositTrie.ToProto(),
DepositContainers: s.depositCache.AllDepositContainers(ctx),
}

View File

@@ -27,6 +27,7 @@ go_library(
"//shared/aggregation:__subpackages__",
"//shared/benchutil:__pkg__",
"//shared/depositutil:__subpackages__",
"//shared/interop:__subpackages__",
"//shared/testutil:__pkg__",
"//slasher/rpc:__subpackages__",
"//tools/benchmark-files-gen:__pkg__",

View File

@@ -16,7 +16,7 @@ import (
// InnerStateUnsafe returns the pointer value of the underlying
// beacon state proto object, bypassing immutability. Use with care.
func (b *BeaconState) InnerStateUnsafe() *pbp2p.BeaconState {
func (b *BeaconState) InnerStateUnsafe() interface{} {
if b == nil {
return nil
}
@@ -24,7 +24,7 @@ func (b *BeaconState) InnerStateUnsafe() *pbp2p.BeaconState {
}
// CloneInnerState the beacon state into a protobuf for usage.
func (b *BeaconState) CloneInnerState() *pbp2p.BeaconState {
func (b *BeaconState) CloneInnerState() interface{} {
if b == nil || b.state == nil {
return nil
}
@@ -1035,3 +1035,13 @@ func (b *BeaconState) MarshalSSZ() ([]byte, error) {
}
return b.state.MarshalSSZ()
}
// ProtobufBeaconState transforms an input into beacon state in the form of protobuf.
// Error is returned if the input is not type protobuf beacon state.
func ProtobufBeaconState(s interface{}) (*pbp2p.BeaconState, error) {
pbState, ok := s.(*pbp2p.BeaconState)
if !ok {
return nil, errors.New("input is not type pb.BeaconState")
}
return pbState, nil
}

View File

@@ -30,8 +30,8 @@ type ReadOnlyBeaconState interface {
ReadOnlyBalances
ReadOnlyCheckpoint
ReadOnlyAttestations
InnerStateUnsafe() *pbp2p.BeaconState
CloneInnerState() *pbp2p.BeaconState
InnerStateUnsafe() interface{}
CloneInnerState() interface{}
GenesisTime() uint64
GenesisValidatorRoot() []byte
Slot() types.Slot

View File

@@ -18,7 +18,8 @@ import (
func TestInitializeFromProto(t *testing.T) {
testState, _ := testutil.DeterministicGenesisState(t, 64)
pbState, err := state.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(t, err)
type test struct {
name string
state *pbp2p.BeaconState
@@ -43,7 +44,7 @@ func TestInitializeFromProto(t *testing.T) {
},
{
name: "full state",
state: testState.InnerStateUnsafe(),
state: pbState,
},
}
for _, tt := range initTests {
@@ -60,7 +61,8 @@ func TestInitializeFromProto(t *testing.T) {
func TestInitializeFromProtoUnsafe(t *testing.T) {
testState, _ := testutil.DeterministicGenesisState(t, 64)
pbState, err := state.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(t, err)
type test struct {
name string
state *pbp2p.BeaconState
@@ -85,7 +87,7 @@ func TestInitializeFromProtoUnsafe(t *testing.T) {
},
{
name: "full state",
state: testState.InnerStateUnsafe(),
state: pbState,
},
}
for _, tt := range initTests {
@@ -153,7 +155,9 @@ func TestBeaconState_HashTreeRoot(t *testing.T) {
if err == nil && tt.error != "" {
t.Errorf("Expected error, expected %v, recevied %v", tt.error, err)
}
genericHTR, err := testState.InnerStateUnsafe().HashTreeRoot()
pbState, err := state.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(t, err)
genericHTR, err := pbState.HashTreeRoot()
if err == nil && tt.error != "" {
t.Errorf("Expected error, expected %v, recevied %v", tt.error, err)
}
@@ -220,7 +224,9 @@ func TestBeaconState_HashTreeRoot_FieldTrie(t *testing.T) {
if err == nil && tt.error != "" {
t.Errorf("Expected error, expected %v, recevied %v", tt.error, err)
}
genericHTR, err := testState.InnerStateUnsafe().HashTreeRoot()
pbState, err := state.ProtobufBeaconState(testState.InnerStateUnsafe())
require.NoError(t, err)
genericHTR, err := pbState.HashTreeRoot()
if err == nil && tt.error != "" {
t.Errorf("Expected error, expected %v, recevied %v", tt.error, err)
}

View File

@@ -221,6 +221,7 @@ func TestForkManualCopy_OK(t *testing.T) {
}
require.NoError(t, a.SetFork(wantedFork))
newState := a.CloneInnerState()
require.DeepEqual(t, newState.Fork, wantedFork)
pbState, err := stateTrie.ProtobufBeaconState(a.InnerStateUnsafe())
require.NoError(t, err)
require.DeepEqual(t, pbState.Fork, wantedFork)
}

View File

@@ -12,6 +12,7 @@ go_library(
deps = [
"//beacon-chain/core/helpers:go_default_library",
"//beacon-chain/core/state:go_default_library",
"//beacon-chain/state:go_default_library",
"//proto/beacon/p2p/v1:go_default_library",
"//shared/bls:go_default_library",
"//shared/hashutil:go_default_library",

View File

@@ -8,7 +8,8 @@ import (
"github.com/pkg/errors"
ethpb "github.com/prysmaticlabs/ethereumapis/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
"github.com/prysmaticlabs/prysm/beacon-chain/core/state"
coreState "github.com/prysmaticlabs/prysm/beacon-chain/core/state"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/bls"
"github.com/prysmaticlabs/prysm/shared/hashutil"
@@ -55,7 +56,7 @@ func GenerateGenesisStateFromDepositData(
if genesisTime == 0 {
genesisTime = uint64(timeutils.Now().Unix())
}
beaconState, err := state.GenesisBeaconState(deposits, genesisTime, &ethpb.Eth1Data{
beaconState, err := coreState.GenesisBeaconState(deposits, genesisTime, &ethpb.Eth1Data{
DepositRoot: root[:],
DepositCount: uint64(len(deposits)),
BlockHash: mockEth1BlockHash,
@@ -63,7 +64,12 @@ func GenerateGenesisStateFromDepositData(
if err != nil {
return nil, nil, errors.Wrap(err, "could not generate genesis state")
}
return beaconState.CloneInnerState(), deposits, nil
pbState, err := state.ProtobufBeaconState(beaconState.CloneInnerState())
if err != nil {
return nil, nil, err
}
return pbState, deposits, nil
}
// GenerateDepositsFromData a list of deposit items by creating proofs for each of them from a sparse Merkle trie.

View File

@@ -60,7 +60,11 @@ func GenerateAttestations(
var err error
// Only calculate head state if its an attestation for the current slot or future slot.
if generateHeadState || slot == bState.Slot() {
genState, err := stateTrie.InitializeFromProtoUnsafe(bState.CloneInnerState())
pbState, err := stateTrie.ProtobufBeaconState(bState.CloneInnerState())
if err != nil {
return nil, err
}
genState, err := stateTrie.InitializeFromProtoUnsafe(pbState)
if err != nil {
return nil, err
}

View File

@@ -119,8 +119,9 @@ func RunBlockOperationTest(
if err := postBeaconState.UnmarshalSSZ(postBeaconStateFile); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
if !proto.Equal(beaconState.InnerStateUnsafe(), postBeaconState) {
pbState, err := beaconstate.ProtobufBeaconState(beaconState.InnerStateUnsafe())
require.NoError(t, err)
if !proto.Equal(pbState, postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.InnerStateUnsafe(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")
@@ -172,7 +173,9 @@ func RunEpochOperationTest(
t.Fatalf("Failed to unmarshal: %v", err)
}
if !proto.Equal(beaconState.InnerStateUnsafe(), postBeaconState) {
pbState, err := beaconstate.ProtobufBeaconState(beaconState.InnerStateUnsafe())
require.NoError(t, err)
if !proto.Equal(pbState, postBeaconState) {
diff, _ := messagediff.PrettyDiff(beaconState.InnerStateUnsafe(), postBeaconState)
t.Log(diff)
t.Fatal("Post state does not match expected")