From 4f654d30ac85b3fa84b9eaadd69a76eb19d5ec76 Mon Sep 17 00:00:00 2001 From: Nishant Das Date: Mon, 10 Feb 2020 23:05:58 +0800 Subject: [PATCH] Optimize Copying of Fields (#4811) * add new changes * memory pool * add test * final optimization * preston's review --- .../blockchain/process_attestation_helpers.go | 9 +++++ .../blockchain/process_block_helpers.go | 8 ++++- beacon-chain/state/BUILD.bazel | 1 + beacon-chain/state/setters.go | 35 ++++++++++++++----- beacon-chain/state/types.go | 13 ++++++- shared/memorypool/BUILD.bazel | 14 ++++++++ shared/memorypool/memorypool.go | 27 ++++++++++++++ shared/memorypool/memorypool_test.go | 16 +++++++++ 8 files changed, 113 insertions(+), 10 deletions(-) create mode 100644 shared/memorypool/BUILD.bazel create mode 100644 shared/memorypool/memorypool.go create mode 100644 shared/memorypool/memorypool_test.go diff --git a/beacon-chain/blockchain/process_attestation_helpers.go b/beacon-chain/blockchain/process_attestation_helpers.go index 5a2e7b7ef8..47321804c6 100644 --- a/beacon-chain/blockchain/process_attestation_helpers.go +++ b/beacon-chain/blockchain/process_attestation_helpers.go @@ -1,6 +1,7 @@ package blockchain import ( + "bytes" "context" "fmt" @@ -28,6 +29,14 @@ func (s *Service) getAttPreState(ctx context.Context, c *ethpb.Checkpoint) (*sta return cachedState, nil } + headRoot, err := s.HeadRoot(ctx) + if err != nil { + return nil, errors.Wrapf(err, "could not get head root") + } + if bytes.Equal(headRoot, c.Root) { + return s.HeadState(ctx) + } + baseState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(c.Root)) if err != nil { return nil, errors.Wrapf(err, "could not get pre state for slot %d", helpers.StartSlot(c.Epoch)) diff --git a/beacon-chain/blockchain/process_block_helpers.go b/beacon-chain/blockchain/process_block_helpers.go index da9e57529a..f534696e9b 100644 --- a/beacon-chain/blockchain/process_block_helpers.go +++ b/beacon-chain/blockchain/process_block_helpers.go @@ -73,7 +73,13 @@ func (s *Service) verifyBlkPreState(ctx context.Context, b *ethpb.BeaconBlock) ( } return preState.Copy(), nil } - + headRoot, err := s.HeadRoot(ctx) + if err != nil { + return nil, errors.Wrapf(err, "could not get head root") + } + if bytes.Equal(headRoot, b.ParentRoot) { + return s.HeadState(ctx) + } preState, err := s.beaconDB.State(ctx, bytesutil.ToBytes32(b.ParentRoot)) if err != nil { return nil, errors.Wrapf(err, "could not get pre state for slot %d", b.Slot) diff --git a/beacon-chain/state/BUILD.bazel b/beacon-chain/state/BUILD.bazel index ecdd65b042..ff1f1df89d 100644 --- a/beacon-chain/state/BUILD.bazel +++ b/beacon-chain/state/BUILD.bazel @@ -19,6 +19,7 @@ go_library( "//proto/beacon/p2p/v1:go_default_library", "//shared/bytesutil:go_default_library", "//shared/hashutil:go_default_library", + "//shared/memorypool:go_default_library", "//shared/params:go_default_library", "//shared/stateutil:go_default_library", "@com_github_gogo_protobuf//proto:go_default_library", diff --git a/beacon-chain/state/setters.go b/beacon-chain/state/setters.go index c0b9bada84..3d1eda483c 100644 --- a/beacon-chain/state/setters.go +++ b/beacon-chain/state/setters.go @@ -9,6 +9,7 @@ import ( "github.com/prysmaticlabs/go-bitfield" pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" "github.com/prysmaticlabs/prysm/shared/hashutil" + "github.com/prysmaticlabs/prysm/shared/memorypool" ) type fieldIndex int @@ -39,6 +40,9 @@ const ( previousJustifiedCheckpoint currentJustifiedCheckpoint finalizedCheckpoint + // validatorIdxMap is not part of the state, but is used so as to be able to keep + // track of references to it to allow for efficient copy on write. + validatorIdxMap ) // SetGenesisTime for the beacon state. @@ -308,14 +312,21 @@ func (b *BeaconState) UpdateValidatorAtIndex(idx uint64, val *ethpb.Validator) e // SetValidatorIndexByPubkey updates the validator index mapping maintained internally to // a given input 48-byte, public key. func (b *BeaconState) SetValidatorIndexByPubkey(pubKey [48]byte, validatorIdx uint64) { - // Copy on write since this is a shared map. - m := b.validatorIndexMap() + idxMap := b.valIdxMap + b.lock.RLock() + if b.sharedFieldReferences[validatorIdxMap].refs > 1 { + // copy-on-write for idx map + idxMap = b.validatorIndexMap() + b.sharedFieldReferences[validatorIdxMap].refs-- + b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1} + } + b.lock.RUnlock() b.lock.Lock() defer b.lock.Unlock() - m[pubKey] = validatorIdx - b.valIdxMap = m + idxMap[pubKey] = validatorIdx + b.valIdxMap = idxMap } // SetBalances for the beacon state. This PR updates the entire @@ -381,7 +392,9 @@ func (b *BeaconState) UpdateRandaoMixesAtIndex(val []byte, idx uint64) error { b.lock.RLock() mixes := b.state.RandaoMixes if refs := b.sharedFieldReferences[randaoMixes].refs; refs > 1 { - mixes = b.RandaoMixes() + newMixes := memorypool.GetDoubleByteSlice(len(mixes)) + copy(newMixes, mixes) + mixes = newMixes b.sharedFieldReferences[randaoMixes].refs-- b.sharedFieldReferences[randaoMixes] = &reference{refs: 1} } @@ -492,7 +505,9 @@ func (b *BeaconState) AppendCurrentEpochAttestations(val *pbp2p.PendingAttestati atts := b.state.CurrentEpochAttestations if b.sharedFieldReferences[currentEpochAttestations].refs > 1 { - atts = b.CurrentEpochAttestations() + copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1) + copy(copiedAtts, atts) + atts = copiedAtts b.sharedFieldReferences[currentEpochAttestations].refs-- b.sharedFieldReferences[currentEpochAttestations] = &reference{refs: 1} } @@ -512,7 +527,9 @@ func (b *BeaconState) AppendPreviousEpochAttestations(val *pbp2p.PendingAttestat b.lock.RLock() atts := b.state.PreviousEpochAttestations if b.sharedFieldReferences[previousEpochAttestations].refs > 1 { - atts = b.PreviousEpochAttestations() + copiedAtts := make([]*pbp2p.PendingAttestation, len(atts), len(atts)+1) + copy(copiedAtts, atts) + atts = copiedAtts b.sharedFieldReferences[previousEpochAttestations].refs-- b.sharedFieldReferences[previousEpochAttestations] = &reference{refs: 1} } @@ -532,7 +549,9 @@ func (b *BeaconState) AppendValidator(val *ethpb.Validator) error { b.lock.RLock() vals := b.state.Validators if b.sharedFieldReferences[validators].refs > 1 { - vals = b.Validators() + copiedVals := make([]*ethpb.Validator, len(b.state.Validators), len(b.state.Validators)+1) + copy(copiedVals, b.state.Validators) + vals = copiedVals b.sharedFieldReferences[validators].refs-- b.sharedFieldReferences[validators] = &reference{refs: 1} } diff --git a/beacon-chain/state/types.go b/beacon-chain/state/types.go index 67a41d3bc5..0c1a24dd24 100644 --- a/beacon-chain/state/types.go +++ b/beacon-chain/state/types.go @@ -12,6 +12,7 @@ import ( pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" "github.com/prysmaticlabs/prysm/shared/bytesutil" "github.com/prysmaticlabs/prysm/shared/hashutil" + "github.com/prysmaticlabs/prysm/shared/memorypool" "github.com/prysmaticlabs/prysm/shared/params" "github.com/prysmaticlabs/prysm/shared/stateutil" ) @@ -73,6 +74,7 @@ func InitializeFromProtoUnsafe(st *pbp2p.BeaconState) (*BeaconState, error) { b.sharedFieldReferences[validators] = &reference{refs: 1} b.sharedFieldReferences[balances] = &reference{refs: 1} b.sharedFieldReferences[historicalRoots] = &reference{refs: 1} + b.sharedFieldReferences[validatorIdxMap] = &reference{refs: 1} return b, nil } @@ -141,8 +143,11 @@ func (b *BeaconState) Copy() *BeaconState { // Finalizer runs when dst is being destroyed in garbage collection. runtime.SetFinalizer(dst, func(b *BeaconState) { - for _, v := range b.sharedFieldReferences { + for i, v := range b.sharedFieldReferences { v.refs-- + if i == randaoMixes && v.refs == 0 { + memorypool.PutDoubleByteSlice(b.state.RandaoMixes) + } } }) @@ -166,6 +171,12 @@ func (b *BeaconState) HashTreeRoot() ([32]byte, error) { } for field := range b.dirtyFields { + // do not compute root for field + // thats not part of the state. + if field == validatorIdxMap { + delete(b.dirtyFields, field) + continue + } root, err := b.rootSelector(field) if err != nil { return [32]byte{}, err diff --git a/shared/memorypool/BUILD.bazel b/shared/memorypool/BUILD.bazel new file mode 100644 index 0000000000..10cf669f79 --- /dev/null +++ b/shared/memorypool/BUILD.bazel @@ -0,0 +1,14 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = ["memorypool.go"], + importpath = "github.com/prysmaticlabs/prysm/shared/memorypool", + visibility = ["//visibility:public"], +) + +go_test( + name = "go_default_test", + srcs = ["memorypool_test.go"], + embed = [":go_default_library"], +) diff --git a/shared/memorypool/memorypool.go b/shared/memorypool/memorypool.go new file mode 100644 index 0000000000..f7cfed6c75 --- /dev/null +++ b/shared/memorypool/memorypool.go @@ -0,0 +1,27 @@ +package memorypool + +import "sync" + +// DoubleByteSlicePool represents the memory pool +// for 2d byte slices +var DoubleByteSlicePool = new(sync.Pool) + +// GetDoubleByteSlice retrieves the 2d byte slice of +// the desired size from the memory pool. +func GetDoubleByteSlice(size int) [][]byte { + rawObj := DoubleByteSlicePool.Get() + if rawObj == nil { + return make([][]byte, size) + } + byteSlice := rawObj.([][]byte) + if len(byteSlice) >= size { + return byteSlice[:size] + } + return append(byteSlice, make([][]byte, size-len(byteSlice))...) +} + +// PutDoubleByteSlice places the provided 2d byte slice +// in the memory pool +func PutDoubleByteSlice(data [][]byte) { + DoubleByteSlicePool.Put(data) +} diff --git a/shared/memorypool/memorypool_test.go b/shared/memorypool/memorypool_test.go new file mode 100644 index 0000000000..16519c9e1a --- /dev/null +++ b/shared/memorypool/memorypool_test.go @@ -0,0 +1,16 @@ +package memorypool + +import ( + "testing" +) + +func TestRoundTripMemoryRetrieval(t *testing.T) { + byteSlice := make([][]byte, 1000) + PutDoubleByteSlice(byteSlice) + newSlice := GetDoubleByteSlice(1000) + + if len(newSlice) != 1000 { + t.Errorf("Wanted same slice object, but got different object. "+ + "Wanted slice with length %d but got length %d", 1000, len(newSlice)) + } +}