Optimize Copying of Fields (#4811)

* add new changes

* memory pool

* add test

* final optimization

* preston's review
This commit is contained in:
Nishant Das
2020-02-10 23:05:58 +08:00
committed by GitHub
parent 18fbdd53b9
commit 4f654d30ac
8 changed files with 113 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
package blockchain
import (
"bytes"
"context"
"fmt"
@@ -28,6 +29,14 @@ func (s *Service) getAttPreState(ctx context.Context, c *ethpb.Checkpoint) (*sta
return cachedState, nil
}
headRoot, err := s.HeadRoot(ctx)
if err != nil {
return nil, errors.Wrapf(err, "could not get head root")
}
if bytes.Equal(headRoot, c.Root) {
return s.HeadState(ctx)
}
baseState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(c.Root))
if err != nil {
return nil, errors.Wrapf(err, "could not get pre state for slot %d", helpers.StartSlot(c.Epoch))

View File

@@ -73,7 +73,13 @@ func (s *Service) verifyBlkPreState(ctx context.Context, b *ethpb.BeaconBlock) (
}
return preState.Copy(), nil
}
headRoot, err := s.HeadRoot(ctx)
if err != nil {
return nil, errors.Wrapf(err, "could not get head root")
}
if bytes.Equal(headRoot, b.ParentRoot) {
return s.HeadState(ctx)
}
preState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(b.ParentRoot))
if err != nil {
return nil, errors.Wrapf(err, "could not get pre state for slot %d", b.Slot)

View File

@@ -19,6 +19,7 @@ go_library(
"//proto/beacon/p2p/v1:go_default_library",
"//shared/bytesutil:go_default_library",
"//shared/hashutil:go_default_library",
"//shared/memorypool:go_default_library",
"//shared/params:go_default_library",
"//shared/stateutil:go_default_library",
"@com_github_gogo_protobuf//proto:go_default_library",

View File

@@ -9,6 +9,7 @@ import (
"github.com/prysmaticlabs/go-bitfield"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/hashutil"
"github.com/prysmaticlabs/prysm/shared/memorypool"
)
type fieldIndex int
@@ -39,6 +40,9 @@ const (
previousJustifiedCheckpoint
currentJustifiedCheckpoint
finalizedCheckpoint
// validatorIdxMap is not part of the state, but is used so as to be able to keep
// track of references to it to allow for efficient copy on write.
validatorIdxMap
)
// SetGenesisTime for the beacon state.
@@ -308,14 +312,21 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e
// SetValidatorIndexByPubkey updates the validator index mapping maintained internally to
// a given input 48-byte, public key.
func (b *BeaconState) SetValidatorIndexByPubkey(pubKey [48]byte, validatorIdx uint64) {
// Copy on write since this is a shared map.
m := b.validatorIndexMap()
idxMap := b.valIdxMap
b.lock.RLock()
if b.sharedFieldReferences[validatorIdxMap].refs > 1 {
// copy-on-write for idx map
idxMap = b.validatorIndexMap()
b.sharedFieldReferences[validatorIdxMap].refs--
b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1}
}
b.lock.RUnlock()
b.lock.Lock()
defer b.lock.Unlock()
m[pubKey] = validatorIdx
b.valIdxMap = m
idxMap[pubKey] = validatorIdx
b.valIdxMap = idxMap
}
// SetBalances for the beacon state. This PR updates the entire
@@ -381,7 +392,9 @@ func (b *BeaconState) UpdateRandaoMixesAtIndex(val []byte, idx uint64) error {
b.lock.RLock()
mixes := b.state.RandaoMixes
if refs := b.sharedFieldReferences[randaoMixes].refs; refs > 1 {
mixes = b.RandaoMixes()
newMixes := memorypool.GetDoubleByteSlice(len(mixes))
copy(newMixes, mixes)
mixes = newMixes
b.sharedFieldReferences[randaoMixes].refs--
b.sharedFieldReferences[randaoMixes] = &reference{refs: 1}
}
@@ -492,7 +505,9 @@ func (b *BeaconState) AppendCurrentEpochAttestations(val *pbp2p.PendingAttestati
atts := b.state.CurrentEpochAttestations
if b.sharedFieldReferences[currentEpochAttestations].refs > 1 {
atts = b.CurrentEpochAttestations()
copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1)
copy(copiedAtts, atts)
atts = copiedAtts
b.sharedFieldReferences[currentEpochAttestations].refs--
b.sharedFieldReferences[currentEpochAttestations] = &reference{refs: 1}
}
@@ -512,7 +527,9 @@ func (b *BeaconState) AppendPreviousEpochAttestations(val *pbp2p.PendingAttestat
b.lock.RLock()
atts := b.state.PreviousEpochAttestations
if b.sharedFieldReferences[previousEpochAttestations].refs > 1 {
atts = b.PreviousEpochAttestations()
copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1)
copy(copiedAtts, atts)
atts = copiedAtts
b.sharedFieldReferences[previousEpochAttestations].refs--
b.sharedFieldReferences[previousEpochAttestations] = &reference{refs: 1}
}
@@ -532,7 +549,9 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error {
b.lock.RLock()
vals := b.state.Validators
if b.sharedFieldReferences[validators].refs > 1 {
vals = b.Validators()
copiedVals := make([]*ethpb.Validator, len(b.state.Validators), len(b.state.Validators)+1)
copy(copiedVals, b.state.Validators)
vals = copiedVals
b.sharedFieldReferences[validators].refs--
b.sharedFieldReferences[validators] = &reference{refs: 1}
}

View File

@@ -12,6 +12,7 @@ import (
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
"github.com/prysmaticlabs/prysm/shared/bytesutil"
"github.com/prysmaticlabs/prysm/shared/hashutil"
"github.com/prysmaticlabs/prysm/shared/memorypool"
"github.com/prysmaticlabs/prysm/shared/params"
"github.com/prysmaticlabs/prysm/shared/stateutil"
)
@@ -73,6 +74,7 @@ func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) {
b.sharedFieldReferences[validators] = &reference{refs: 1}
b.sharedFieldReferences[balances] = &reference{refs: 1}
b.sharedFieldReferences[historicalRoots] = &reference{refs: 1}
b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1}
return b, nil
}
@@ -141,8 +143,11 @@ func (b *BeaconState) Copy() *BeaconState {
// Finalizer runs when dst is being destroyed in garbage collection.
runtime.SetFinalizer(dst, func(b *BeaconState) {
for _, v := range b.sharedFieldReferences {
for i, v := range b.sharedFieldReferences {
v.refs--
if i == randaoMixes && v.refs == 0 {
memorypool.PutDoubleByteSlice(b.state.RandaoMixes)
}
}
})
@@ -166,6 +171,12 @@ func (b *BeaconState) HashTreeRoot() ([32]byte, error) {
}
for field := range b.dirtyFields {
// do not compute root for field
// thats not part of the state.
if field == validatorIdxMap {
delete(b.dirtyFields, field)
continue
}
root, err := b.rootSelector(field)
if err != nil {
return [32]byte{}, err

View File

@@ -0,0 +1,14 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = ["memorypool.go"],
importpath = "github.com/prysmaticlabs/prysm/shared/memorypool",
visibility = ["//visibility:public"],
)
go_test(
name = "go_default_test",
srcs = ["memorypool_test.go"],
embed = [":go_default_library"],
)

View File

@@ -0,0 +1,27 @@
package memorypool
import "sync"
// DoubleByteSlicePool represents the memory pool
// for 2d byte slices
var DoubleByteSlicePool = new(sync.Pool)
// GetDoubleByteSlice retrieves the 2d byte slice of
// the desired size from the memory pool.
func GetDoubleByteSlice(size int) [][]byte {
rawObj := DoubleByteSlicePool.Get()
if rawObj == nil {
return make([][]byte, size)
}
byteSlice := rawObj.([][]byte)
if len(byteSlice) >= size {
return byteSlice[:size]
}
return append(byteSlice, make([][]byte, size-len(byteSlice))...)
}
// PutDoubleByteSlice places the provided 2d byte slice
// in the memory pool
func PutDoubleByteSlice(data [][]byte) {
DoubleByteSlicePool.Put(data)
}

View File

@@ -0,0 +1,16 @@
package memorypool
import (
"testing"
)
func TestRoundTripMemoryRetrieval(t *testing.T) {
byteSlice := make([][]byte, 1000)
PutDoubleByteSlice(byteSlice)
newSlice := GetDoubleByteSlice(1000)
if len(newSlice) != 1000 {
t.Errorf("Wanted same slice object, but got different object. "+
"Wanted slice with length %d but got length %d", 1000, len(newSlice))
}
}