Return state interface from native state constructors (#10208)

Co-authored-by: Raul Jordan <raul@prysmaticlabs.com>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
Radosław Kapka
2022-02-14 10:51:22 +01:00
committed by GitHub
parent 28af5bc601
commit 3c76cc3af5
11 changed files with 66 additions and 47 deletions

View File

@@ -20,24 +20,26 @@ func TestStateReferenceSharing_Finalizer(t *testing.T) {
a, err := InitializeFromProtoUnsafe(&ethpb.BeaconState{RandaoMixes: [][]byte{[]byte("foo")}})
require.NoError(t, err)
assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected a single reference for RANDAO mixes")
s, ok := a.(*BeaconState)
require.Equal(t, true, ok)
assert.Equal(t, uint(1), s.sharedFieldReferences[randaoMixes].Refs(), "Expected a single reference for RANDAO mixes")
func() {
// Create object in a different scope for GC
b := a.Copy()
assert.Equal(t, uint(2), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 references to RANDAO mixes")
assert.Equal(t, uint(2), s.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 references to RANDAO mixes")
_ = b
}()
runtime.GC() // Should run finalizer on object b
assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 1 shared reference to RANDAO mixes!")
assert.Equal(t, uint(1), s.sharedFieldReferences[randaoMixes].Refs(), "Expected 1 shared reference to RANDAO mixes!")
copied := a.Copy()
b, ok := copied.(*BeaconState)
require.Equal(t, true, ok)
assert.Equal(t, uint(2), b.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 shared references to RANDAO mixes")
require.NoError(t, b.UpdateRandaoMixesAtIndex(0, []byte("bar")))
if b.sharedFieldReferences[randaoMixes].Refs() != 1 || a.sharedFieldReferences[randaoMixes].Refs() != 1 {
if b.sharedFieldReferences[randaoMixes].Refs() != 1 || s.sharedFieldReferences[randaoMixes].Refs() != 1 {
t.Error("Expected 1 shared reference to RANDAO mix for both a and b")
}
}
@@ -53,15 +55,17 @@ func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) {
},
})
require.NoError(t, err)
assertRefCount(t, a, blockRoots, 1)
assertRefCount(t, a, stateRoots, 1)
s, ok := a.(*BeaconState)
require.Equal(t, true, ok)
assertRefCount(t, s, blockRoots, 1)
assertRefCount(t, s, stateRoots, 1)
// Copy, increases reference count.
copied := a.Copy()
b, ok := copied.(*BeaconState)
require.Equal(t, true, ok)
assertRefCount(t, a, blockRoots, 2)
assertRefCount(t, a, stateRoots, 2)
assertRefCount(t, s, blockRoots, 2)
assertRefCount(t, s, stateRoots, 2)
assertRefCount(t, b, blockRoots, 2)
assertRefCount(t, b, stateRoots, 2)
assert.Equal(t, 8192, len(b.BlockRoots()), "Wrong number of block roots found")
@@ -106,8 +110,8 @@ func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) {
assert.DeepEqual(t, root1[:], stateRootsB[0], "Unexpected mutation found")
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, blockRoots, 1)
assertRefCount(t, a, stateRoots, 1)
assertRefCount(t, s, blockRoots, 1)
assertRefCount(t, s, stateRoots, 1)
assertRefCount(t, b, blockRoots, 1)
assertRefCount(t, b, stateRoots, 1)
}
@@ -121,13 +125,15 @@ func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) {
},
})
require.NoError(t, err)
assertRefCount(t, a, randaoMixes, 1)
s, ok := a.(*BeaconState)
require.Equal(t, true, ok)
assertRefCount(t, s, randaoMixes, 1)
// Copy, increases reference count.
copied := a.Copy()
b, ok := copied.(*BeaconState)
require.Equal(t, true, ok)
assertRefCount(t, a, randaoMixes, 2)
assertRefCount(t, s, randaoMixes, 2)
assertRefCount(t, b, randaoMixes, 2)
// Assert shared state.
@@ -156,7 +162,7 @@ func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) {
assert.DeepEqual(t, val1, mixesB[0], "Unexpected mutation found")
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, randaoMixes, 1)
assertRefCount(t, s, randaoMixes, 1)
assertRefCount(t, b, randaoMixes, 1)
}
@@ -182,16 +188,18 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
a, err := InitializeFromProtoUnsafe(&ethpb.BeaconState{})
require.NoError(t, err)
assertRefCount(t, a, previousEpochAttestations, 1)
assertRefCount(t, a, currentEpochAttestations, 1)
s, ok := a.(*BeaconState)
require.Equal(t, true, ok)
assertRefCount(t, s, previousEpochAttestations, 1)
assertRefCount(t, s, currentEpochAttestations, 1)
// Update initial state.
atts := []*ethpb.PendingAttestation{
{AggregationBits: bitfield.NewBitlist(1)},
{AggregationBits: bitfield.NewBitlist(2)},
}
a.setPreviousEpochAttestations(atts[:1])
a.setCurrentEpochAttestations(atts[:1])
s.setPreviousEpochAttestations(atts[:1])
s.setCurrentEpochAttestations(atts[:1])
curAtt, err := a.CurrentEpochAttestations()
require.NoError(t, err)
assert.Equal(t, 1, len(curAtt), "Unexpected number of attestations")
@@ -203,8 +211,8 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
copied := a.Copy()
b, ok := copied.(*BeaconState)
require.Equal(t, true, ok)
assertRefCount(t, a, previousEpochAttestations, 2)
assertRefCount(t, a, currentEpochAttestations, 2)
assertRefCount(t, s, previousEpochAttestations, 2)
assertRefCount(t, s, currentEpochAttestations, 2)
assertRefCount(t, b, previousEpochAttestations, 2)
assertRefCount(t, b, currentEpochAttestations, 2)
bPrevEpochAtts, err := b.PreviousEpochAttestations()
@@ -289,7 +297,7 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
state.previousEpochAttestations[i] = att
}
}
applyToEveryAttestation(a)
applyToEveryAttestation(s)
aCurrEpochAtts, err = a.CurrentEpochAttestations()
require.NoError(t, err)
@@ -315,9 +323,9 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
assertAttNotFound(bPrevEpochAtts, 2)
// Copy on write happened, reference counters are reset.
assertRefCount(t, a, currentEpochAttestations, 1)
assertRefCount(t, s, currentEpochAttestations, 1)
assertRefCount(t, b, currentEpochAttestations, 1)
assertRefCount(t, a, previousEpochAttestations, 1)
assertRefCount(t, s, previousEpochAttestations, 1)
assertRefCount(t, b, previousEpochAttestations, 1)
}

View File

@@ -17,6 +17,10 @@ func TestBeaconState_RotateAttestations(t *testing.T) {
require.NoError(t, err)
require.NoError(t, st.RotateAttestations())
require.Equal(t, 0, len(st.currentEpochAttestationsVal()))
require.Equal(t, types.Slot(456), st.previousEpochAttestationsVal()[0].Data.Slot)
currEpochAtts, err := st.CurrentEpochAttestations()
require.NoError(t, err)
require.Equal(t, 0, len(currEpochAtts))
prevEpochAtts, err := st.PreviousEpochAttestations()
require.NoError(t, err)
require.Equal(t, types.Slot(456), prevEpochAtts[0].Data.Slot)
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
"github.com/prysmaticlabs/prysm/testing/require"
)
func TestValidatorMap_DistinctCopy(t *testing.T) {
@@ -67,6 +68,8 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
Validators: vals,
})
assert.NoError(t, err)
s, ok := st.(*BeaconState)
require.Equal(t, true, ok)
wg := new(sync.WaitGroup)
@@ -75,7 +78,7 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
// Continuously lock and unlock the state
// by acquiring the lock.
for i := 0; i < 1000; i++ {
for _, f := range st.stateFieldLeaves {
for _, f := range s.stateFieldLeaves {
f.Lock()
if f.Empty() {
f.InsertFieldLayer(make([][]*[32]byte, 10))
@@ -180,8 +183,10 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
}
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.balances)
s, ok := st.(*BeaconState)
require.Equal(t, true, ok)
newRt := bytesutil.ToBytes32(s.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(s.Balances())
assert.NoError(t, err)
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
}

View File

@@ -25,13 +25,13 @@ import (
)
// InitializeFromProto the beacon state from a protobuf representation.
func InitializeFromProto(st *ethpb.BeaconState) (*BeaconState, error) {
func InitializeFromProto(st *ethpb.BeaconState) (state.BeaconState, error) {
return InitializeFromProtoUnsafe(proto.Clone(st).(*ethpb.BeaconState))
}
// InitializeFromProtoUnsafe directly uses the beacon state protobuf fields
// and sets them as fields of the BeaconState type.
func InitializeFromProtoUnsafe(st *ethpb.BeaconState) (*BeaconState, error) {
func InitializeFromProtoUnsafe(st *ethpb.BeaconState) (state.BeaconState, error) {
if st == nil {
return nil, errors.New("received nil state")
}

View File

@@ -31,7 +31,7 @@ func TestBeaconState_ProtoBeaconStateCompatibility(t *testing.T) {
require.NoError(t, err)
cloned, ok := proto.Clone(genesis).(*ethpb.BeaconState)
assert.Equal(t, true, ok, "Object is not of type *ethpb.BeaconState")
custom := customState.ToProto()
custom := customState.CloneInnerState()
assert.DeepSSZEqual(t, cloned, custom)
r1, err := customState.HashTreeRoot(ctx)
@@ -146,7 +146,7 @@ func BenchmarkStateClone_Manual(b *testing.B) {
require.NoError(b, err)
b.StartTimer()
for i := 0; i < b.N; i++ {
_ = st.ToProto()
_ = st.CloneInnerState()
}
}
@@ -230,7 +230,7 @@ func TestForkManualCopy_OK(t *testing.T) {
}
require.NoError(t, a.SetFork(wantedFork))
pbState, err := v1.ProtobufBeaconState(a.ToProtoUnsafe())
pbState, err := v1.ProtobufBeaconState(a.InnerStateUnsafe())
require.NoError(t, err)
require.DeepEqual(t, pbState.Fork, wantedFork)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
"github.com/prysmaticlabs/prysm/testing/require"
)
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
@@ -106,8 +107,10 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
}
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.balances)
s, ok := st.(*BeaconState)
require.Equal(t, true, ok)
newRt := bytesutil.ToBytes32(s.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.Balances())
assert.NoError(t, err)
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
}

View File

@@ -26,14 +26,14 @@ import (
)
// InitializeFromProto the beacon state from a protobuf representation.
func InitializeFromProto(st *ethpb.BeaconStateAltair) (*BeaconState, error) {
func InitializeFromProto(st *ethpb.BeaconStateAltair) (state.BeaconStateAltair, error) {
return InitializeFromProtoUnsafe(proto.Clone(st).(*ethpb.BeaconStateAltair))
}
// InitializeFromSSZReader can be used when the source for a serialized BeaconState object
// is an io.Reader. This allows client code to remain agnostic about whether the data comes
// from the network or a file without needing to read the entire state into mem as a large byte slice.
func InitializeFromSSZReader(r io.Reader) (*BeaconState, error) {
func InitializeFromSSZReader(r io.Reader) (state.BeaconStateAltair, error) {
b, err := ioutil.ReadAll(r)
if err != nil {
return nil, err

View File

@@ -31,9 +31,3 @@ func TestBeaconState_MatchPreviousJustifiedCheckpt(t *testing.T) {
},
)
}
func TestBeaconState_ValidatorByPubkey(t *testing.T) {
testtmpl.VerifyBeaconState_ValidatorByPubkey(t, func() (state.BeaconState, error) {
return InitializeFromProto(&ethpb.BeaconStateBellatrix{})
})
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/testing/assert"
"github.com/prysmaticlabs/prysm/testing/require"
)
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
@@ -118,8 +119,10 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
}
_, err = st.HashTreeRoot(context.Background())
assert.NoError(t, err)
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.balances)
s, ok := st.(*BeaconState)
require.Equal(t, true, ok)
newRt := bytesutil.ToBytes32(s.merkleLayers[0][balances])
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.Balances())
assert.NoError(t, err)
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
}

View File

@@ -25,13 +25,13 @@ import (
)
// InitializeFromProto the beacon state from a protobuf representation.
func InitializeFromProto(st *ethpb.BeaconStateBellatrix) (*BeaconState, error) {
func InitializeFromProto(st *ethpb.BeaconStateBellatrix) (state.BeaconStateBellatrix, error) {
return InitializeFromProtoUnsafe(proto.Clone(st).(*ethpb.BeaconStateBellatrix))
}
// InitializeFromProtoUnsafe directly uses the beacon state protobuf fields
// and sets them as fields of the BeaconState type.
func InitializeFromProtoUnsafe(st *ethpb.BeaconStateBellatrix) (*BeaconState, error) {
func InitializeFromProtoUnsafe(st *ethpb.BeaconStateBellatrix) (state.BeaconStateBellatrix, error) {
if st == nil {
return nil, errors.New("received nil state")
}

View File

@@ -108,6 +108,8 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
Validators: vals,
})
assert.NoError(t, err)
s, ok := st.(*BeaconState)
require.Equal(t, true, ok)
wg := new(sync.WaitGroup)
@@ -116,7 +118,7 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
// Continuously lock and unlock the state
// by acquiring the lock.
for i := 0; i < 1000; i++ {
for _, f := range st.stateFieldLeaves {
for _, f := range s.stateFieldLeaves {
f.Lock()
if f.Empty() {
f.InsertFieldLayer(make([][]*[fieldparams.RootLength]byte, 10))