diff --git a/beacon-chain/db/kv/state.go b/beacon-chain/db/kv/state.go index a84890fcf6..52b2d9b246 100644 --- a/beacon-chain/db/kv/state.go +++ b/beacon-chain/db/kv/state.go @@ -28,6 +28,17 @@ func (s *Store) State(ctx context.Context, blockRoot [32]byte) (state.BeaconStat ctx, span := trace.StartSpan(ctx, "BeaconDB.State") defer span.End() startTime := time.Now() + + // If state diff is enabled, we get the state from the state-diff db. + if features.Get().EnableStateDiff { + st, err := s.getStateUsingStateDiff(ctx, blockRoot) + if err != nil { + return nil, err + } + stateReadingTime.Observe(float64(time.Since(startTime).Milliseconds())) + return st, nil + } + enc, err := s.stateBytes(ctx, blockRoot) if err != nil { return nil, err @@ -1031,3 +1042,34 @@ func (s *Store) isStateValidatorMigrationOver() (bool, error) { } return returnFlag, nil } + +func (s *Store) getStateUsingStateDiff(ctx context.Context, blockRoot [32]byte) (state.BeaconState, error) { + var slot primitives.Slot + + stateSummary, err := s.StateSummary(ctx, blockRoot) + if err != nil { + return nil, err + } + if stateSummary == nil { + blk, err := s.Block(ctx, blockRoot) + if err != nil { + return nil, err + } + if blk == nil || blk.IsNil() { + return nil, errors.New("neither state summary nor block found") + } + slot = blk.Block().Slot() + } else { + slot = stateSummary.Slot + } + + st, err := s.stateByDiff(ctx, slot) + if err != nil { + return nil, err + } + if st == nil || st.IsNil() { + return nil, errors.New("state not found") + } + + return st, nil +} diff --git a/beacon-chain/db/kv/state_diff_test.go b/beacon-chain/db/kv/state_diff_test.go index 06d1ae5022..fa49f96caa 100644 --- a/beacon-chain/db/kv/state_diff_test.go +++ b/beacon-chain/db/kv/state_diff_test.go @@ -20,7 +20,7 @@ import ( ) func TestStateDiff_LoadOrInitOffset(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() db := setupDB(t) err := setOffsetInDB(db, 10) @@ -36,7 +36,7 @@ func TestStateDiff_LoadOrInitOffset(t *testing.T) { func TestStateDiff_ComputeLevel(t *testing.T) { db := setupDB(t) - setDefaultExponents() + setDefaultStateDiffExponents() err := setOffsetInDB(db, 0) require.NoError(t, err) @@ -118,7 +118,7 @@ func TestStateDiff_ComputeLevel(t *testing.T) { } func TestStateDiff_SaveFullSnapshot(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All() { t.Run(version.String(v), func(t *testing.T) { @@ -151,7 +151,7 @@ func TestStateDiff_SaveFullSnapshot(t *testing.T) { } func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All() { t.Run(version.String(v), func(t *testing.T) { @@ -179,7 +179,7 @@ func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) { } func TestStateDiff_SaveDiff(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All() { t.Run(version.String(v), func(t *testing.T) { @@ -245,7 +245,7 @@ func TestStateDiff_SaveDiff(t *testing.T) { } func TestStateDiff_SaveAndReadDiff(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All() { t.Run(version.String(v), func(t *testing.T) { @@ -319,7 +319,7 @@ func TestStateDiff_SaveAndReadDiff_WithRepetitiveAnchorSlots(t *testing.T) { } func TestStateDiff_SaveAndReadDiff_MultipleLevels(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All() { t.Run(version.String(v), func(t *testing.T) { @@ -385,7 +385,7 @@ func TestStateDiff_SaveAndReadDiff_MultipleLevels(t *testing.T) { } func TestStateDiff_SaveAndReadDiffForkTransition(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All()[:len(version.All())-1] { t.Run(version.String(v), func(t *testing.T) { @@ -419,7 +419,7 @@ func TestStateDiff_SaveAndReadDiffForkTransition(t *testing.T) { } func TestStateDiff_OffsetCache(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() // test for slot numbers 0 and 1 for every version for slotNum := range 2 { @@ -450,7 +450,7 @@ func TestStateDiff_OffsetCache(t *testing.T) { } func TestStateDiff_AnchorCache(t *testing.T) { - setDefaultExponents() + setDefaultStateDiffExponents() for v := range version.All() { t.Run(version.String(v), func(t *testing.T) { @@ -654,7 +654,7 @@ func setOffsetInDB(s *Store, offset uint64) error { return nil } -func setDefaultExponents() { +func setDefaultStateDiffExponents() { globalFlags := flags.GlobalFlags{ StateDiffExponents: []int{21, 18, 16, 13, 11, 9, 5}, } diff --git a/beacon-chain/db/kv/state_test.go b/beacon-chain/db/kv/state_test.go index ca4ec52d27..ed9d2c5377 100644 --- a/beacon-chain/db/kv/state_test.go +++ b/beacon-chain/db/kv/state_test.go @@ -1,6 +1,7 @@ package kv import ( + "context" "crypto/rand" "encoding/binary" mathRand "math/rand" @@ -9,6 +10,7 @@ import ( "time" "github.com/OffchainLabs/prysm/v7/beacon-chain/state" + "github.com/OffchainLabs/prysm/v7/cmd/beacon-chain/flags" "github.com/OffchainLabs/prysm/v7/config/features" fieldparams "github.com/OffchainLabs/prysm/v7/config/fieldparams" "github.com/OffchainLabs/prysm/v7/config/params" @@ -17,8 +19,10 @@ import ( "github.com/OffchainLabs/prysm/v7/consensus-types/primitives" "github.com/OffchainLabs/prysm/v7/encoding/bytesutil" "github.com/OffchainLabs/prysm/v7/genesis" + "github.com/OffchainLabs/prysm/v7/math" enginev1 "github.com/OffchainLabs/prysm/v7/proto/engine/v1" ethpb "github.com/OffchainLabs/prysm/v7/proto/prysm/v1alpha1" + "github.com/OffchainLabs/prysm/v7/runtime/version" "github.com/OffchainLabs/prysm/v7/testing/assert" "github.com/OffchainLabs/prysm/v7/testing/require" "github.com/OffchainLabs/prysm/v7/testing/util" @@ -1329,3 +1333,243 @@ func TestStore_CleanUpDirtyStates_NoOriginRoot(t *testing.T) { } } } + +func TestStore_CanSaveRetrieveStateUsingStateDiff(t *testing.T) { + t.Run("No state summary or block", func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + readSt, err := db.State(context.Background(), [32]byte{'A'}) + require.IsNil(t, readSt) + require.ErrorContains(t, "neither state summary nor block found", err) + }) + + t.Run("Slot not in tree", func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + r := bytesutil.ToBytes32([]byte{'A'}) + ss := ðpb.StateSummary{Slot: 1, Root: r[:]} // slot 1 not in tree + err = db.SaveStateSummary(context.Background(), ss) + require.NoError(t, err) + + readSt, err := db.State(context.Background(), r) + require.ErrorContains(t, "slot not in tree", err) + require.IsNil(t, readSt) + + }) + + t.Run("State not found", func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + r := bytesutil.ToBytes32([]byte{'A'}) + ss := ðpb.StateSummary{Slot: 32, Root: r[:]} // slot 32 is in tree + err = db.SaveStateSummary(context.Background(), ss) + require.NoError(t, err) + + readSt, err := db.State(context.Background(), r) + require.ErrorContains(t, "state not found", err) + require.IsNil(t, readSt) + }) + + t.Run("Full state snapshot", func(t *testing.T) { + t.Run("using state summary", func(t *testing.T) { + for v := range version.All() { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + st, _ := createState(t, 0, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + r := bytesutil.ToBytes32([]byte{'A'}) + ss := ðpb.StateSummary{Slot: 0, Root: r[:]} + err = db.SaveStateSummary(context.Background(), ss) + require.NoError(t, err) + + readSt, err := db.State(context.Background(), r) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } + }) + + t.Run("using block", func(t *testing.T) { + for v := range version.All() { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + st, _ := createState(t, 0, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + blk := util.NewBeaconBlock() + blk.Block.Slot = 0 + signedBlk, err := blocks.NewSignedBeaconBlock(blk) + require.NoError(t, err) + err = db.SaveBlock(context.Background(), signedBlk) + require.NoError(t, err) + r, err := signedBlk.Block().HashTreeRoot() + require.NoError(t, err) + + readSt, err := db.State(context.Background(), r) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } + }) + }) + + t.Run("Diffed state", func(t *testing.T) { + t.Run("using state summary", func(t *testing.T) { + for v := range version.All() { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + exponents := flags.Get().StateDiffExponents + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + st, _ := createState(t, 0, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot := primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2]))) + st, _ = createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot = primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])) + math.PowerOf2(uint64(exponents[len(exponents)-1]))) + st, _ = createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + r := bytesutil.ToBytes32([]byte{'A'}) + ss := ðpb.StateSummary{Slot: slot, Root: r[:]} + err = db.SaveStateSummary(context.Background(), ss) + require.NoError(t, err) + + readSt, err := db.State(context.Background(), r) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } + }) + + t.Run("using block", func(t *testing.T) { + for v := range version.All() { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + featCfg := &features.Flags{} + featCfg.EnableStateDiff = true + reset := features.InitWithReset(featCfg) + defer reset() + setDefaultStateDiffExponents() + + exponents := flags.Get().StateDiffExponents + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + st, _ := createState(t, 0, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot := primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2]))) + st, _ = createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot = primitives.Slot(math.PowerOf2(uint64(exponents[len(exponents)-2])) + math.PowerOf2(uint64(exponents[len(exponents)-1]))) + st, _ = createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + blk := util.NewBeaconBlock() + blk.Block.Slot = slot + signedBlk, err := blocks.NewSignedBeaconBlock(blk) + require.NoError(t, err) + err = db.SaveBlock(context.Background(), signedBlk) + require.NoError(t, err) + r, err := signedBlk.Block().HashTreeRoot() + require.NoError(t, err) + + readSt, err := db.State(context.Background(), r) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } + }) + }) +} diff --git a/changelog/bastin_state-read-diff.md b/changelog/bastin_state-read-diff.md new file mode 100644 index 0000000000..9d5f5262d5 --- /dev/null +++ b/changelog/bastin_state-read-diff.md @@ -0,0 +1,3 @@ +### Added + +- Integrate state-diff into `State()`. \ No newline at end of file diff --git a/config/features/config.go b/config/features/config.go index 282521e973..c641cdda6b 100644 --- a/config/features/config.go +++ b/config/features/config.go @@ -286,8 +286,8 @@ func ConfigureBeaconChain(ctx *cli.Context) error { cfg.DisableLastEpochTargets = true } - if ctx.IsSet(enableStateDiff.Name) { - logEnabled(enableStateDiff) + if ctx.IsSet(EnableStateDiff.Name) { + logEnabled(EnableStateDiff) cfg.EnableStateDiff = true if ctx.IsSet(enableHistoricalSpaceRepresentation.Name) { diff --git a/config/features/flags.go b/config/features/flags.go index e943773acf..d63e7cad71 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -172,7 +172,7 @@ var ( Name: "enable-experimental-attestation-pool", Usage: "Enables an experimental attestation pool design.", } - enableStateDiff = &cli.BoolFlag{ + EnableStateDiff = &cli.BoolFlag{ Name: "enable-state-diff", Usage: "Enables the experimental state diff feature.", }