mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-08 23:18:15 -05:00
Integrate state-diff into State() (#16033)
**What type of PR is this?** Feature **What does this PR do? Why is it needed?** This PR integrates the state diff path into the `State()` function from `db/kv`, which allows reading of states using the state diff db, when the `EnableStateDiff` flag is enabled. **Notes for reviewers:** Files `kv/state_diff_test.go` and `config/features/config.go` only contain renamings: - `kv/state_diff_test.go`: rename `setDefaultExponents()` to `setDefaultStateDiffExponents()` to be less vague. - `config/features/config.go`: rename `enableStateDiff` to `EnableStateDiff` to make it public.
This commit is contained in:
@@ -28,6 +28,17 @@ func (s *Store) State(ctx context.Context, blockRoot [32]byte) (state.BeaconStat
|
||||
ctx, span := trace.StartSpan(ctx, "BeaconDB.State")
|
||||
defer span.End()
|
||||
startTime := time.Now()
|
||||
|
||||
// If state diff is enabled, we get the state from the state-diff db.
|
||||
if features.Get().EnableStateDiff {
|
||||
st, err := s.getStateUsingStateDiff(ctx, blockRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stateReadingTime.Observe(float64(time.Since(startTime).Milliseconds()))
|
||||
return st, nil
|
||||
}
|
||||
|
||||
enc, err := s.stateBytes(ctx, blockRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1031,3 +1042,34 @@ func (s *Store) isStateValidatorMigrationOver() (bool, error) {
|
||||
}
|
||||
return returnFlag, nil
|
||||
}
|
||||
|
||||
func (s *Store) getStateUsingStateDiff(ctx context.Context, blockRoot [32]byte) (state.BeaconState, error) {
|
||||
var slot primitives.Slot
|
||||
|
||||
stateSummary, err := s.StateSummary(ctx, blockRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if stateSummary == nil {
|
||||
blk, err := s.Block(ctx, blockRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if blk == nil || blk.IsNil() {
|
||||
return nil, errors.New("neither state summary nor block found")
|
||||
}
|
||||
slot = blk.Block().Slot()
|
||||
} else {
|
||||
slot = stateSummary.Slot
|
||||
}
|
||||
|
||||
st, err := s.stateByDiff(ctx, slot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if st == nil || st.IsNil() {
|
||||
return nil, errors.New("state not found")
|
||||
}
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStateDiff_LoadOrInitOffset(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
db := setupDB(t)
|
||||
err := setOffsetInDB(db, 10)
|
||||
@@ -36,7 +36,7 @@ func TestStateDiff_LoadOrInitOffset(t *testing.T) {
|
||||
|
||||
func TestStateDiff_ComputeLevel(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
@@ -118,7 +118,7 @@ func TestStateDiff_ComputeLevel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveFullSnapshot(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -151,7 +151,7 @@ func TestStateDiff_SaveFullSnapshot(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -179,7 +179,7 @@ func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveDiff(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -245,7 +245,7 @@ func TestStateDiff_SaveDiff(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadDiff(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -319,7 +319,7 @@ func TestStateDiff_SaveAndReadDiff_WithRepetitiveAnchorSlots(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadDiff_MultipleLevels(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -385,7 +385,7 @@ func TestStateDiff_SaveAndReadDiff_MultipleLevels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_SaveAndReadDiffForkTransition(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All()[:len(version.All())-1] {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -419,7 +419,7 @@ func TestStateDiff_SaveAndReadDiffForkTransition(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_OffsetCache(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
// test for slot numbers 0 and 1 for every version
|
||||
for slotNum := range 2 {
|
||||
@@ -450,7 +450,7 @@ func TestStateDiff_OffsetCache(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestStateDiff_AnchorCache(t *testing.T) {
|
||||
setDefaultExponents()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
@@ -654,7 +654,7 @@ func setOffsetInDB(s *Store, offset uint64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func setDefaultExponents() {
|
||||
func setDefaultStateDiffExponents() {
|
||||
globalFlags := flags.GlobalFlags{
|
||||
StateDiffExponents: []int{21, 18, 16, 13, 11, 9, 5},
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
mathRand "math/rand"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/OffchainLabs/prysm/v7/beacon-chain/state"
|
||||
"github.com/OffchainLabs/prysm/v7/cmd/beacon-chain/flags"
|
||||
"github.com/OffchainLabs/prysm/v7/config/features"
|
||||
fieldparams "github.com/OffchainLabs/prysm/v7/config/fieldparams"
|
||||
"github.com/OffchainLabs/prysm/v7/config/params"
|
||||
@@ -17,8 +19,10 @@ import (
|
||||
"github.com/OffchainLabs/prysm/v7/consensus-types/primitives"
|
||||
"github.com/OffchainLabs/prysm/v7/encoding/bytesutil"
|
||||
"github.com/OffchainLabs/prysm/v7/genesis"
|
||||
"github.com/OffchainLabs/prysm/v7/math"
|
||||
enginev1 "github.com/OffchainLabs/prysm/v7/proto/engine/v1"
|
||||
ethpb "github.com/OffchainLabs/prysm/v7/proto/prysm/v1alpha1"
|
||||
"github.com/OffchainLabs/prysm/v7/runtime/version"
|
||||
"github.com/OffchainLabs/prysm/v7/testing/assert"
|
||||
"github.com/OffchainLabs/prysm/v7/testing/require"
|
||||
"github.com/OffchainLabs/prysm/v7/testing/util"
|
||||
@@ -1329,3 +1333,243 @@ func TestStore_CleanUpDirtyStates_NoOriginRoot(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStore_CanSaveRetrieveStateUsingStateDiff(t *testing.T) {
|
||||
t.Run("No state summary or block", func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), [32]byte{'A'})
|
||||
require.IsNil(t, readSt)
|
||||
require.ErrorContains(t, "neither state summary nor block found", err)
|
||||
})
|
||||
|
||||
t.Run("Slot not in tree", func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bytesutil.ToBytes32([]byte{'A'})
|
||||
ss := ðpb.StateSummary{Slot: 1, Root: r[:]} // slot 1 not in tree
|
||||
err = db.SaveStateSummary(context.Background(), ss)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), r)
|
||||
require.ErrorContains(t, "slot not in tree", err)
|
||||
require.IsNil(t, readSt)
|
||||
|
||||
})
|
||||
|
||||
t.Run("State not found", func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bytesutil.ToBytes32([]byte{'A'})
|
||||
ss := ðpb.StateSummary{Slot: 32, Root: r[:]} // slot 32 is in tree
|
||||
err = db.SaveStateSummary(context.Background(), ss)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), r)
|
||||
require.ErrorContains(t, "state not found", err)
|
||||
require.IsNil(t, readSt)
|
||||
})
|
||||
|
||||
t.Run("Full state snapshot", func(t *testing.T) {
|
||||
t.Run("using state summary", func(t *testing.T) {
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bytesutil.ToBytes32([]byte{'A'})
|
||||
ss := ðpb.StateSummary{Slot: 0, Root: r[:]}
|
||||
err = db.SaveStateSummary(context.Background(), ss)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), r)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("using block", func(t *testing.T) {
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
blk := util.NewBeaconBlock()
|
||||
blk.Block.Slot = 0
|
||||
signedBlk, err := blocks.NewSignedBeaconBlock(blk)
|
||||
require.NoError(t, err)
|
||||
err = db.SaveBlock(context.Background(), signedBlk)
|
||||
require.NoError(t, err)
|
||||
r, err := signedBlk.Block().HashTreeRoot()
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), r)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Diffed state", func(t *testing.T) {
|
||||
t.Run("using state summary", func(t *testing.T) {
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
exponents := flags.Get().StateDiffExponents
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot := primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])))
|
||||
st, _ = createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot = primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])) + math.PowerOf2(uint64(exponents[len(exponents)-1])))
|
||||
st, _ = createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
r := bytesutil.ToBytes32([]byte{'A'})
|
||||
ss := ðpb.StateSummary{Slot: slot, Root: r[:]}
|
||||
err = db.SaveStateSummary(context.Background(), ss)
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), r)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("using block", func(t *testing.T) {
|
||||
for v := range version.All() {
|
||||
t.Run(version.String(v), func(t *testing.T) {
|
||||
db := setupDB(t)
|
||||
featCfg := &features.Flags{}
|
||||
featCfg.EnableStateDiff = true
|
||||
reset := features.InitWithReset(featCfg)
|
||||
defer reset()
|
||||
setDefaultStateDiffExponents()
|
||||
|
||||
exponents := flags.Get().StateDiffExponents
|
||||
|
||||
err := setOffsetInDB(db, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
st, _ := createState(t, 0, v)
|
||||
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot := primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])))
|
||||
st, _ = createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
slot = primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])) + math.PowerOf2(uint64(exponents[len(exponents)-1])))
|
||||
st, _ = createState(t, slot, v)
|
||||
err = db.saveStateByDiff(context.Background(), st)
|
||||
require.NoError(t, err)
|
||||
|
||||
blk := util.NewBeaconBlock()
|
||||
blk.Block.Slot = slot
|
||||
signedBlk, err := blocks.NewSignedBeaconBlock(blk)
|
||||
require.NoError(t, err)
|
||||
err = db.SaveBlock(context.Background(), signedBlk)
|
||||
require.NoError(t, err)
|
||||
r, err := signedBlk.Block().HashTreeRoot()
|
||||
require.NoError(t, err)
|
||||
|
||||
readSt, err := db.State(context.Background(), r)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, readSt)
|
||||
|
||||
stSSZ, err := st.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
readStSSZ, err := readSt.MarshalSSZ()
|
||||
require.NoError(t, err)
|
||||
require.DeepSSZEqual(t, stSSZ, readStSSZ)
|
||||
})
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
3
changelog/bastin_state-read-diff.md
Normal file
3
changelog/bastin_state-read-diff.md
Normal file
@@ -0,0 +1,3 @@
|
||||
### Added
|
||||
|
||||
- Integrate state-diff into `State()`.
|
||||
@@ -286,8 +286,8 @@ func ConfigureBeaconChain(ctx *cli.Context) error {
|
||||
cfg.DisableLastEpochTargets = true
|
||||
}
|
||||
|
||||
if ctx.IsSet(enableStateDiff.Name) {
|
||||
logEnabled(enableStateDiff)
|
||||
if ctx.IsSet(EnableStateDiff.Name) {
|
||||
logEnabled(EnableStateDiff)
|
||||
cfg.EnableStateDiff = true
|
||||
|
||||
if ctx.IsSet(enableHistoricalSpaceRepresentation.Name) {
|
||||
|
||||
@@ -172,7 +172,7 @@ var (
|
||||
Name: "enable-experimental-attestation-pool",
|
||||
Usage: "Enables an experimental attestation pool design.",
|
||||
}
|
||||
enableStateDiff = &cli.BoolFlag{
|
||||
EnableStateDiff = &cli.BoolFlag{
|
||||
Name: "enable-state-diff",
|
||||
Usage: "Enables the experimental state diff feature.",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user