mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-10 05:47:59 -05:00
Return state interface from native state constructors (#10208)
Co-authored-by: Raul Jordan <raul@prysmaticlabs.com> Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
@@ -20,24 +20,26 @@ func TestStateReferenceSharing_Finalizer(t *testing.T) {
|
||||
|
||||
a, err := InitializeFromProtoUnsafe(ðpb.BeaconState{RandaoMixes: [][]byte{[]byte("foo")}})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected a single reference for RANDAO mixes")
|
||||
s, ok := a.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assert.Equal(t, uint(1), s.sharedFieldReferences[randaoMixes].Refs(), "Expected a single reference for RANDAO mixes")
|
||||
|
||||
func() {
|
||||
// Create object in a different scope for GC
|
||||
b := a.Copy()
|
||||
assert.Equal(t, uint(2), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 references to RANDAO mixes")
|
||||
assert.Equal(t, uint(2), s.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 references to RANDAO mixes")
|
||||
_ = b
|
||||
}()
|
||||
|
||||
runtime.GC() // Should run finalizer on object b
|
||||
assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 1 shared reference to RANDAO mixes!")
|
||||
assert.Equal(t, uint(1), s.sharedFieldReferences[randaoMixes].Refs(), "Expected 1 shared reference to RANDAO mixes!")
|
||||
|
||||
copied := a.Copy()
|
||||
b, ok := copied.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assert.Equal(t, uint(2), b.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 shared references to RANDAO mixes")
|
||||
require.NoError(t, b.UpdateRandaoMixesAtIndex(0, []byte("bar")))
|
||||
if b.sharedFieldReferences[randaoMixes].Refs() != 1 || a.sharedFieldReferences[randaoMixes].Refs() != 1 {
|
||||
if b.sharedFieldReferences[randaoMixes].Refs() != 1 || s.sharedFieldReferences[randaoMixes].Refs() != 1 {
|
||||
t.Error("Expected 1 shared reference to RANDAO mix for both a and b")
|
||||
}
|
||||
}
|
||||
@@ -53,15 +55,17 @@ func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assertRefCount(t, a, blockRoots, 1)
|
||||
assertRefCount(t, a, stateRoots, 1)
|
||||
s, ok := a.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assertRefCount(t, s, blockRoots, 1)
|
||||
assertRefCount(t, s, stateRoots, 1)
|
||||
|
||||
// Copy, increases reference count.
|
||||
copied := a.Copy()
|
||||
b, ok := copied.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assertRefCount(t, a, blockRoots, 2)
|
||||
assertRefCount(t, a, stateRoots, 2)
|
||||
assertRefCount(t, s, blockRoots, 2)
|
||||
assertRefCount(t, s, stateRoots, 2)
|
||||
assertRefCount(t, b, blockRoots, 2)
|
||||
assertRefCount(t, b, stateRoots, 2)
|
||||
assert.Equal(t, 8192, len(b.BlockRoots()), "Wrong number of block roots found")
|
||||
@@ -106,8 +110,8 @@ func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) {
|
||||
assert.DeepEqual(t, root1[:], stateRootsB[0], "Unexpected mutation found")
|
||||
|
||||
// Copy on write happened, reference counters are reset.
|
||||
assertRefCount(t, a, blockRoots, 1)
|
||||
assertRefCount(t, a, stateRoots, 1)
|
||||
assertRefCount(t, s, blockRoots, 1)
|
||||
assertRefCount(t, s, stateRoots, 1)
|
||||
assertRefCount(t, b, blockRoots, 1)
|
||||
assertRefCount(t, b, stateRoots, 1)
|
||||
}
|
||||
@@ -121,13 +125,15 @@ func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) {
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assertRefCount(t, a, randaoMixes, 1)
|
||||
s, ok := a.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assertRefCount(t, s, randaoMixes, 1)
|
||||
|
||||
// Copy, increases reference count.
|
||||
copied := a.Copy()
|
||||
b, ok := copied.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assertRefCount(t, a, randaoMixes, 2)
|
||||
assertRefCount(t, s, randaoMixes, 2)
|
||||
assertRefCount(t, b, randaoMixes, 2)
|
||||
|
||||
// Assert shared state.
|
||||
@@ -156,7 +162,7 @@ func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) {
|
||||
assert.DeepEqual(t, val1, mixesB[0], "Unexpected mutation found")
|
||||
|
||||
// Copy on write happened, reference counters are reset.
|
||||
assertRefCount(t, a, randaoMixes, 1)
|
||||
assertRefCount(t, s, randaoMixes, 1)
|
||||
assertRefCount(t, b, randaoMixes, 1)
|
||||
}
|
||||
|
||||
@@ -182,16 +188,18 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
|
||||
|
||||
a, err := InitializeFromProtoUnsafe(ðpb.BeaconState{})
|
||||
require.NoError(t, err)
|
||||
assertRefCount(t, a, previousEpochAttestations, 1)
|
||||
assertRefCount(t, a, currentEpochAttestations, 1)
|
||||
s, ok := a.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assertRefCount(t, s, previousEpochAttestations, 1)
|
||||
assertRefCount(t, s, currentEpochAttestations, 1)
|
||||
|
||||
// Update initial state.
|
||||
atts := []*ethpb.PendingAttestation{
|
||||
{AggregationBits: bitfield.NewBitlist(1)},
|
||||
{AggregationBits: bitfield.NewBitlist(2)},
|
||||
}
|
||||
a.setPreviousEpochAttestations(atts[:1])
|
||||
a.setCurrentEpochAttestations(atts[:1])
|
||||
s.setPreviousEpochAttestations(atts[:1])
|
||||
s.setCurrentEpochAttestations(atts[:1])
|
||||
curAtt, err := a.CurrentEpochAttestations()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, len(curAtt), "Unexpected number of attestations")
|
||||
@@ -203,8 +211,8 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
|
||||
copied := a.Copy()
|
||||
b, ok := copied.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
assertRefCount(t, a, previousEpochAttestations, 2)
|
||||
assertRefCount(t, a, currentEpochAttestations, 2)
|
||||
assertRefCount(t, s, previousEpochAttestations, 2)
|
||||
assertRefCount(t, s, currentEpochAttestations, 2)
|
||||
assertRefCount(t, b, previousEpochAttestations, 2)
|
||||
assertRefCount(t, b, currentEpochAttestations, 2)
|
||||
bPrevEpochAtts, err := b.PreviousEpochAttestations()
|
||||
@@ -289,7 +297,7 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
|
||||
state.previousEpochAttestations[i] = att
|
||||
}
|
||||
}
|
||||
applyToEveryAttestation(a)
|
||||
applyToEveryAttestation(s)
|
||||
|
||||
aCurrEpochAtts, err = a.CurrentEpochAttestations()
|
||||
require.NoError(t, err)
|
||||
@@ -315,9 +323,9 @@ func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
|
||||
assertAttNotFound(bPrevEpochAtts, 2)
|
||||
|
||||
// Copy on write happened, reference counters are reset.
|
||||
assertRefCount(t, a, currentEpochAttestations, 1)
|
||||
assertRefCount(t, s, currentEpochAttestations, 1)
|
||||
assertRefCount(t, b, currentEpochAttestations, 1)
|
||||
assertRefCount(t, a, previousEpochAttestations, 1)
|
||||
assertRefCount(t, s, previousEpochAttestations, 1)
|
||||
assertRefCount(t, b, previousEpochAttestations, 1)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,10 @@ func TestBeaconState_RotateAttestations(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, st.RotateAttestations())
|
||||
require.Equal(t, 0, len(st.currentEpochAttestationsVal()))
|
||||
require.Equal(t, types.Slot(456), st.previousEpochAttestationsVal()[0].Data.Slot)
|
||||
currEpochAtts, err := st.CurrentEpochAttestations()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(currEpochAtts))
|
||||
prevEpochAtts, err := st.PreviousEpochAttestations()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, types.Slot(456), prevEpochAtts[0].Data.Slot)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
"github.com/prysmaticlabs/prysm/testing/require"
|
||||
)
|
||||
|
||||
func TestValidatorMap_DistinctCopy(t *testing.T) {
|
||||
@@ -67,6 +68,8 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
|
||||
Validators: vals,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
s, ok := st.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
|
||||
wg := new(sync.WaitGroup)
|
||||
|
||||
@@ -75,7 +78,7 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
|
||||
// Continuously lock and unlock the state
|
||||
// by acquiring the lock.
|
||||
for i := 0; i < 1000; i++ {
|
||||
for _, f := range st.stateFieldLeaves {
|
||||
for _, f := range s.stateFieldLeaves {
|
||||
f.Lock()
|
||||
if f.Empty() {
|
||||
f.InsertFieldLayer(make([][]*[32]byte, 10))
|
||||
@@ -180,8 +183,10 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
}
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.balances)
|
||||
s, ok := st.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
newRt := bytesutil.ToBytes32(s.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(s.Balances())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
|
||||
}
|
||||
|
||||
@@ -25,13 +25,13 @@ import (
|
||||
)
|
||||
|
||||
// InitializeFromProto the beacon state from a protobuf representation.
|
||||
func InitializeFromProto(st *ethpb.BeaconState) (*BeaconState, error) {
|
||||
func InitializeFromProto(st *ethpb.BeaconState) (state.BeaconState, error) {
|
||||
return InitializeFromProtoUnsafe(proto.Clone(st).(*ethpb.BeaconState))
|
||||
}
|
||||
|
||||
// InitializeFromProtoUnsafe directly uses the beacon state protobuf fields
|
||||
// and sets them as fields of the BeaconState type.
|
||||
func InitializeFromProtoUnsafe(st *ethpb.BeaconState) (*BeaconState, error) {
|
||||
func InitializeFromProtoUnsafe(st *ethpb.BeaconState) (state.BeaconState, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("received nil state")
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestBeaconState_ProtoBeaconStateCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
cloned, ok := proto.Clone(genesis).(*ethpb.BeaconState)
|
||||
assert.Equal(t, true, ok, "Object is not of type *ethpb.BeaconState")
|
||||
custom := customState.ToProto()
|
||||
custom := customState.CloneInnerState()
|
||||
assert.DeepSSZEqual(t, cloned, custom)
|
||||
|
||||
r1, err := customState.HashTreeRoot(ctx)
|
||||
@@ -146,7 +146,7 @@ func BenchmarkStateClone_Manual(b *testing.B) {
|
||||
require.NoError(b, err)
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = st.ToProto()
|
||||
_ = st.CloneInnerState()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,7 +230,7 @@ func TestForkManualCopy_OK(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, a.SetFork(wantedFork))
|
||||
|
||||
pbState, err := v1.ProtobufBeaconState(a.ToProtoUnsafe())
|
||||
pbState, err := v1.ProtobufBeaconState(a.InnerStateUnsafe())
|
||||
require.NoError(t, err)
|
||||
require.DeepEqual(t, pbState.Fork, wantedFork)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
"github.com/prysmaticlabs/prysm/testing/require"
|
||||
)
|
||||
|
||||
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
@@ -106,8 +107,10 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
}
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.balances)
|
||||
s, ok := st.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
newRt := bytesutil.ToBytes32(s.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.Balances())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
|
||||
}
|
||||
|
||||
@@ -26,14 +26,14 @@ import (
|
||||
)
|
||||
|
||||
// InitializeFromProto the beacon state from a protobuf representation.
|
||||
func InitializeFromProto(st *ethpb.BeaconStateAltair) (*BeaconState, error) {
|
||||
func InitializeFromProto(st *ethpb.BeaconStateAltair) (state.BeaconStateAltair, error) {
|
||||
return InitializeFromProtoUnsafe(proto.Clone(st).(*ethpb.BeaconStateAltair))
|
||||
}
|
||||
|
||||
// InitializeFromSSZReader can be used when the source for a serialized BeaconState object
|
||||
// is an io.Reader. This allows client code to remain agnostic about whether the data comes
|
||||
// from the network or a file without needing to read the entire state into mem as a large byte slice.
|
||||
func InitializeFromSSZReader(r io.Reader) (*BeaconState, error) {
|
||||
func InitializeFromSSZReader(r io.Reader) (state.BeaconStateAltair, error) {
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -31,9 +31,3 @@ func TestBeaconState_MatchPreviousJustifiedCheckpt(t *testing.T) {
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func TestBeaconState_ValidatorByPubkey(t *testing.T) {
|
||||
testtmpl.VerifyBeaconState_ValidatorByPubkey(t, func() (state.BeaconState, error) {
|
||||
return InitializeFromProto(ðpb.BeaconStateBellatrix{})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/encoding/bytesutil"
|
||||
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
|
||||
"github.com/prysmaticlabs/prysm/testing/assert"
|
||||
"github.com/prysmaticlabs/prysm/testing/require"
|
||||
)
|
||||
|
||||
func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
@@ -118,8 +119,10 @@ func TestBeaconState_AppendBalanceWithTrie(t *testing.T) {
|
||||
}
|
||||
_, err = st.HashTreeRoot(context.Background())
|
||||
assert.NoError(t, err)
|
||||
newRt := bytesutil.ToBytes32(st.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.balances)
|
||||
s, ok := st.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
newRt := bytesutil.ToBytes32(s.merkleLayers[0][balances])
|
||||
wantedRt, err := stateutil.Uint64ListRootWithRegistryLimit(st.Balances())
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, wantedRt, newRt, "state roots are unequal")
|
||||
}
|
||||
|
||||
@@ -25,13 +25,13 @@ import (
|
||||
)
|
||||
|
||||
// InitializeFromProto the beacon state from a protobuf representation.
|
||||
func InitializeFromProto(st *ethpb.BeaconStateBellatrix) (*BeaconState, error) {
|
||||
func InitializeFromProto(st *ethpb.BeaconStateBellatrix) (state.BeaconStateBellatrix, error) {
|
||||
return InitializeFromProtoUnsafe(proto.Clone(st).(*ethpb.BeaconStateBellatrix))
|
||||
}
|
||||
|
||||
// InitializeFromProtoUnsafe directly uses the beacon state protobuf fields
|
||||
// and sets them as fields of the BeaconState type.
|
||||
func InitializeFromProtoUnsafe(st *ethpb.BeaconStateBellatrix) (*BeaconState, error) {
|
||||
func InitializeFromProtoUnsafe(st *ethpb.BeaconStateBellatrix) (state.BeaconStateBellatrix, error) {
|
||||
if st == nil {
|
||||
return nil, errors.New("received nil state")
|
||||
}
|
||||
|
||||
@@ -108,6 +108,8 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
|
||||
Validators: vals,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
s, ok := st.(*BeaconState)
|
||||
require.Equal(t, true, ok)
|
||||
|
||||
wg := new(sync.WaitGroup)
|
||||
|
||||
@@ -116,7 +118,7 @@ func TestBeaconState_NoDeadlock(t *testing.T) {
|
||||
// Continuously lock and unlock the state
|
||||
// by acquiring the lock.
|
||||
for i := 0; i < 1000; i++ {
|
||||
for _, f := range st.stateFieldLeaves {
|
||||
for _, f := range s.stateFieldLeaves {
|
||||
f.Lock()
|
||||
if f.Empty() {
|
||||
f.InsertFieldLayer(make([][]*[fieldparams.RootLength]byte, 10))
|
||||
|
||||
Reference in New Issue
Block a user