mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-05-02 03:02:54 -04:00
Optimize Copying of Fields (#4811)
* add new changes * memory pool * add test * final optimization * preston's review
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package blockchain
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
@@ -28,6 +29,14 @@ func (s *Service) getAttPreState(ctx context.Context, c *ethpb.Checkpoint) (*sta
|
||||
return cachedState, nil
|
||||
}
|
||||
|
||||
headRoot, err := s.HeadRoot(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not get head root")
|
||||
}
|
||||
if bytes.Equal(headRoot, c.Root) {
|
||||
return s.HeadState(ctx)
|
||||
}
|
||||
|
||||
baseState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(c.Root))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not get pre state for slot %d", helpers.StartSlot(c.Epoch))
|
||||
|
||||
@@ -73,7 +73,13 @@ func (s *Service) verifyBlkPreState(ctx context.Context, b *ethpb.BeaconBlock) (
|
||||
}
|
||||
return preState.Copy(), nil
|
||||
}
|
||||
|
||||
headRoot, err := s.HeadRoot(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not get head root")
|
||||
}
|
||||
if bytes.Equal(headRoot, b.ParentRoot) {
|
||||
return s.HeadState(ctx)
|
||||
}
|
||||
preState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(b.ParentRoot))
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not get pre state for slot %d", b.Slot)
|
||||
|
||||
@@ -19,6 +19,7 @@ go_library(
|
||||
"//proto/beacon/p2p/v1:go_default_library",
|
||||
"//shared/bytesutil:go_default_library",
|
||||
"//shared/hashutil:go_default_library",
|
||||
"//shared/memorypool:go_default_library",
|
||||
"//shared/params:go_default_library",
|
||||
"//shared/stateutil:go_default_library",
|
||||
"@com_github_gogo_protobuf//proto:go_default_library",
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/prysmaticlabs/go-bitfield"
|
||||
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
|
||||
"github.com/prysmaticlabs/prysm/shared/hashutil"
|
||||
"github.com/prysmaticlabs/prysm/shared/memorypool"
|
||||
)
|
||||
|
||||
type fieldIndex int
|
||||
@@ -39,6 +40,9 @@ const (
|
||||
previousJustifiedCheckpoint
|
||||
currentJustifiedCheckpoint
|
||||
finalizedCheckpoint
|
||||
// validatorIdxMap is not part of the state, but is used so as to be able to keep
|
||||
// track of references to it to allow for efficient copy on write.
|
||||
validatorIdxMap
|
||||
)
|
||||
|
||||
// SetGenesisTime for the beacon state.
|
||||
@@ -308,14 +312,21 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e
|
||||
// SetValidatorIndexByPubkey updates the validator index mapping maintained internally to
|
||||
// a given input 48-byte, public key.
|
||||
func (b *BeaconState) SetValidatorIndexByPubkey(pubKey [48]byte, validatorIdx uint64) {
|
||||
// Copy on write since this is a shared map.
|
||||
m := b.validatorIndexMap()
|
||||
idxMap := b.valIdxMap
|
||||
b.lock.RLock()
|
||||
if b.sharedFieldReferences[validatorIdxMap].refs > 1 {
|
||||
// copy-on-write for idx map
|
||||
idxMap = b.validatorIndexMap()
|
||||
b.sharedFieldReferences[validatorIdxMap].refs--
|
||||
b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1}
|
||||
}
|
||||
b.lock.RUnlock()
|
||||
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
m[pubKey] = validatorIdx
|
||||
b.valIdxMap = m
|
||||
idxMap[pubKey] = validatorIdx
|
||||
b.valIdxMap = idxMap
|
||||
}
|
||||
|
||||
// SetBalances for the beacon state. This PR updates the entire
|
||||
@@ -381,7 +392,9 @@ func (b *BeaconState) UpdateRandaoMixesAtIndex(val []byte, idx uint64) error {
|
||||
b.lock.RLock()
|
||||
mixes := b.state.RandaoMixes
|
||||
if refs := b.sharedFieldReferences[randaoMixes].refs; refs > 1 {
|
||||
mixes = b.RandaoMixes()
|
||||
newMixes := memorypool.GetDoubleByteSlice(len(mixes))
|
||||
copy(newMixes, mixes)
|
||||
mixes = newMixes
|
||||
b.sharedFieldReferences[randaoMixes].refs--
|
||||
b.sharedFieldReferences[randaoMixes] = &reference{refs: 1}
|
||||
}
|
||||
@@ -492,7 +505,9 @@ func (b *BeaconState) AppendCurrentEpochAttestations(val *pbp2p.PendingAttestati
|
||||
|
||||
atts := b.state.CurrentEpochAttestations
|
||||
if b.sharedFieldReferences[currentEpochAttestations].refs > 1 {
|
||||
atts = b.CurrentEpochAttestations()
|
||||
copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1)
|
||||
copy(copiedAtts, atts)
|
||||
atts = copiedAtts
|
||||
b.sharedFieldReferences[currentEpochAttestations].refs--
|
||||
b.sharedFieldReferences[currentEpochAttestations] = &reference{refs: 1}
|
||||
}
|
||||
@@ -512,7 +527,9 @@ func (b *BeaconState) AppendPreviousEpochAttestations(val *pbp2p.PendingAttestat
|
||||
b.lock.RLock()
|
||||
atts := b.state.PreviousEpochAttestations
|
||||
if b.sharedFieldReferences[previousEpochAttestations].refs > 1 {
|
||||
atts = b.PreviousEpochAttestations()
|
||||
copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1)
|
||||
copy(copiedAtts, atts)
|
||||
atts = copiedAtts
|
||||
b.sharedFieldReferences[previousEpochAttestations].refs--
|
||||
b.sharedFieldReferences[previousEpochAttestations] = &reference{refs: 1}
|
||||
}
|
||||
@@ -532,7 +549,9 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error {
|
||||
b.lock.RLock()
|
||||
vals := b.state.Validators
|
||||
if b.sharedFieldReferences[validators].refs > 1 {
|
||||
vals = b.Validators()
|
||||
copiedVals := make([]*ethpb.Validator, len(b.state.Validators), len(b.state.Validators)+1)
|
||||
copy(copiedVals, b.state.Validators)
|
||||
vals = copiedVals
|
||||
b.sharedFieldReferences[validators].refs--
|
||||
b.sharedFieldReferences[validators] = &reference{refs: 1}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
|
||||
"github.com/prysmaticlabs/prysm/shared/bytesutil"
|
||||
"github.com/prysmaticlabs/prysm/shared/hashutil"
|
||||
"github.com/prysmaticlabs/prysm/shared/memorypool"
|
||||
"github.com/prysmaticlabs/prysm/shared/params"
|
||||
"github.com/prysmaticlabs/prysm/shared/stateutil"
|
||||
)
|
||||
@@ -73,6 +74,7 @@ func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) {
|
||||
b.sharedFieldReferences[validators] = &reference{refs: 1}
|
||||
b.sharedFieldReferences[balances] = &reference{refs: 1}
|
||||
b.sharedFieldReferences[historicalRoots] = &reference{refs: 1}
|
||||
b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
@@ -141,8 +143,11 @@ func (b *BeaconState) Copy() *BeaconState {
|
||||
|
||||
// Finalizer runs when dst is being destroyed in garbage collection.
|
||||
runtime.SetFinalizer(dst, func(b *BeaconState) {
|
||||
for _, v := range b.sharedFieldReferences {
|
||||
for i, v := range b.sharedFieldReferences {
|
||||
v.refs--
|
||||
if i == randaoMixes && v.refs == 0 {
|
||||
memorypool.PutDoubleByteSlice(b.state.RandaoMixes)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -166,6 +171,12 @@ func (b *BeaconState) HashTreeRoot() ([32]byte, error) {
|
||||
}
|
||||
|
||||
for field := range b.dirtyFields {
|
||||
// do not compute root for field
|
||||
// thats not part of the state.
|
||||
if field == validatorIdxMap {
|
||||
delete(b.dirtyFields, field)
|
||||
continue
|
||||
}
|
||||
root, err := b.rootSelector(field)
|
||||
if err != nil {
|
||||
return [32]byte{}, err
|
||||
|
||||
14
shared/memorypool/BUILD.bazel
Normal file
14
shared/memorypool/BUILD.bazel
Normal file
@@ -0,0 +1,14 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "go_default_library",
|
||||
srcs = ["memorypool.go"],
|
||||
importpath = "github.com/prysmaticlabs/prysm/shared/memorypool",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "go_default_test",
|
||||
srcs = ["memorypool_test.go"],
|
||||
embed = [":go_default_library"],
|
||||
)
|
||||
27
shared/memorypool/memorypool.go
Normal file
27
shared/memorypool/memorypool.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package memorypool
|
||||
|
||||
import "sync"
|
||||
|
||||
// DoubleByteSlicePool represents the memory pool
|
||||
// for 2d byte slices
|
||||
var DoubleByteSlicePool = new(sync.Pool)
|
||||
|
||||
// GetDoubleByteSlice retrieves the 2d byte slice of
|
||||
// the desired size from the memory pool.
|
||||
func GetDoubleByteSlice(size int) [][]byte {
|
||||
rawObj := DoubleByteSlicePool.Get()
|
||||
if rawObj == nil {
|
||||
return make([][]byte, size)
|
||||
}
|
||||
byteSlice := rawObj.([][]byte)
|
||||
if len(byteSlice) >= size {
|
||||
return byteSlice[:size]
|
||||
}
|
||||
return append(byteSlice, make([][]byte, size-len(byteSlice))...)
|
||||
}
|
||||
|
||||
// PutDoubleByteSlice places the provided 2d byte slice
|
||||
// in the memory pool
|
||||
func PutDoubleByteSlice(data [][]byte) {
|
||||
DoubleByteSlicePool.Put(data)
|
||||
}
|
||||
16
shared/memorypool/memorypool_test.go
Normal file
16
shared/memorypool/memorypool_test.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package memorypool
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRoundTripMemoryRetrieval(t *testing.T) {
|
||||
byteSlice := make([][]byte, 1000)
|
||||
PutDoubleByteSlice(byteSlice)
|
||||
newSlice := GetDoubleByteSlice(1000)
|
||||
|
||||
if len(newSlice) != 1000 {
|
||||
t.Errorf("Wanted same slice object, but got different object. "+
|
||||
"Wanted slice with length %d but got length %d", 1000, len(newSlice))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user