diff --git a/beacon-chain/db/kv/state_diff.go b/beacon-chain/db/kv/state_diff.go index 7772cdc351..20f02ea129 100644 --- a/beacon-chain/db/kv/state_diff.go +++ b/beacon-chain/db/kv/state_diff.go @@ -132,6 +132,9 @@ func (s *Store) saveHdiff(lvl int, anchor, st state.ReadOnlyBeaconState) error { return err } } + if err := s.stateDiffCache.setLevelHasData(lvl); err != nil { + return err + } return nil } @@ -171,6 +174,9 @@ func (s *Store) saveFullSnapshot(st state.ReadOnlyBeaconState) error { if err != nil { return err } + if err := s.stateDiffCache.setLevelHasData(0); err != nil { + return err + } return nil } diff --git a/beacon-chain/db/kv/state_diff_cache.go b/beacon-chain/db/kv/state_diff_cache.go index fcf390b90f..cd7175311a 100644 --- a/beacon-chain/db/kv/state_diff_cache.go +++ b/beacon-chain/db/kv/state_diff_cache.go @@ -3,6 +3,7 @@ package kv import ( "encoding/binary" "errors" + "fmt" "sync" "github.com/OffchainLabs/prysm/v7/beacon-chain/state" @@ -12,8 +13,44 @@ import ( type stateDiffCache struct { sync.RWMutex - anchors []state.ReadOnlyBeaconState - offset uint64 + anchors []state.ReadOnlyBeaconState + levelsWithData []bool + offset uint64 +} + +func populateStateDiffCacheFromDB(s *Store, offset uint64) (*stateDiffCache, error) { + cache := &stateDiffCache{ + anchors: make([]state.ReadOnlyBeaconState, len(flags.Get().StateDiffExponents)-1), + levelsWithData: make([]bool, len(flags.Get().StateDiffExponents)), + offset: offset, + } + + if err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + for level := range cache.levelsWithData { + cursor := bucket.Cursor() + prefix := []byte{byte(level)} + key, _ := cursor.Seek(prefix) + if key != nil && key[0] == byte(level) { + cache.levelsWithData[level] = true + } + } + return nil + }); err != nil { + return nil, err + } + + anchor0, err := s.getFullSnapshot(offset) + if err != nil { + return nil, fmt.Errorf("state diff cache: missing offset snapshot at %d: %w", offset, err) + } + cache.anchors[0] = anchor0 + cache.levelsWithData[0] = true + + return cache, nil } func newStateDiffCache(s *Store) (*stateDiffCache, error) { @@ -37,8 +74,9 @@ func newStateDiffCache(s *Store) (*stateDiffCache, error) { } return &stateDiffCache{ - anchors: make([]state.ReadOnlyBeaconState, len(flags.Get().StateDiffExponents)-1), // -1 because last level doesn't need to be cached - offset: offset, + anchors: make([]state.ReadOnlyBeaconState, len(flags.Get().StateDiffExponents)-1), // -1 because last level doesn't need to be cached + levelsWithData: make([]bool, len(flags.Get().StateDiffExponents)), + offset: offset, }, nil } @@ -58,6 +96,25 @@ func (c *stateDiffCache) setAnchor(level int, anchor state.ReadOnlyBeaconState) return nil } +func (c *stateDiffCache) levelHasData(level int) bool { + c.RLock() + defer c.RUnlock() + if level < 0 || level >= len(c.levelsWithData) { + return false + } + return c.levelsWithData[level] +} + +func (c *stateDiffCache) setLevelHasData(level int) error { + c.Lock() + defer c.Unlock() + if level < 0 || level >= len(c.levelsWithData) { + return errors.New("state diff cache: level data index out of range") + } + c.levelsWithData[level] = true + return nil +} + func (c *stateDiffCache) getOffset() uint64 { c.RLock() defer c.RUnlock() diff --git a/beacon-chain/db/kv/state_diff_helpers.go b/beacon-chain/db/kv/state_diff_helpers.go index 68f095ebb1..6f95da83d1 100644 --- a/beacon-chain/db/kv/state_diff_helpers.go +++ b/beacon-chain/db/kv/state_diff_helpers.go @@ -394,7 +394,12 @@ func (s *Store) getBaseAndDiffChain(offset uint64, slot primitives.Slot) (state. if diffSlot == lastSeenAnchorSlot { continue } - diffChainItems = append(diffChainItems, diffItem{level: i + 1, slot: diffSlot + offset}) + level := i + 1 + if s.stateDiffCache != nil && !s.stateDiffCache.levelHasData(level) { + lastSeenAnchorSlot = diffSlot + continue + } + diffChainItems = append(diffChainItems, diffItem{level: level, slot: diffSlot + offset}) lastSeenAnchorSlot = diffSlot } diff --git a/beacon-chain/db/kv/state_diff_test.go b/beacon-chain/db/kv/state_diff_test.go index 006ee5b10f..9d6421ef5f 100644 --- a/beacon-chain/db/kv/state_diff_test.go +++ b/beacon-chain/db/kv/state_diff_test.go @@ -230,6 +230,73 @@ func TestStateDiff_SaveFullSnapshot(t *testing.T) { } } +func TestStateDiff_PopulateStateDiffCacheFromDB(t *testing.T) { + setDefaultStateDiffExponents() + + db := setupDB(t) + _, err := populateStateDiffCacheFromDB(db, 0) + require.ErrorContains(t, "missing offset snapshot", err) + + st, _ := createState(t, 0, version.Phase0) + require.NoError(t, setOffsetInDB(db, 0)) + require.NoError(t, db.saveStateByDiff(context.Background(), st)) + + err = db.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + key := makeKeyForStateDiffTree(2, 0) + return bucket.Put(append(key, stateSuffix...), []byte{1}) + }) + require.NoError(t, err) + + cache, err := populateStateDiffCacheFromDB(db, 0) + require.NoError(t, err) + require.NotNil(t, cache) + require.Equal(t, uint64(0), cache.getOffset()) + require.NotNil(t, cache.getAnchor(0)) + require.Equal(t, true, cache.levelHasData(0)) + require.Equal(t, false, cache.levelHasData(1)) + require.Equal(t, true, cache.levelHasData(2)) +} + +func TestStateDiff_GetBaseAndDiffChainSkipsEmptyLevels(t *testing.T) { + setDefaultStateDiffExponents() + + db := setupDB(t) + require.NoError(t, setOffsetInDB(db, 0)) + st, _ := createState(t, 0, version.Phase0) + require.NoError(t, db.saveFullSnapshot(st)) + + cache, err := populateStateDiffCacheFromDB(db, 0) + require.NoError(t, err) + cache.levelsWithData[0] = true + cache.levelsWithData[1] = false + cache.levelsWithData[2] = true + db.stateDiffCache = cache + + slot := primitives.Slot(math.PowerOf2(18) + math.PowerOf2(16)) + key := makeKeyForStateDiffTree(2, uint64(slot)) + require.NoError(t, db.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + if err := bucket.Put(append(key, stateSuffix...), []byte{1}); err != nil { + return err + } + if err := bucket.Put(append(key, validatorSuffix...), []byte{2}); err != nil { + return err + } + return bucket.Put(append(key, balancesSuffix...), []byte{3}) + })) + + _, diffChain, err := db.getBaseAndDiffChain(0, slot) + require.NoError(t, err) + require.Equal(t, 1, len(diffChain)) +} + func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) { setDefaultStateDiffExponents()