Optimize process attestations (#9418)

This commit is contained in:
terence tsao
2021-08-18 15:41:09 -07:00
committed by GitHub
parent dfd740414f
commit d874b9140f
4 changed files with 71 additions and 17 deletions

View File

@@ -28,9 +28,12 @@ func ProcessAttestationsNoVerifySignature(
return nil, err
}
body := b.Block().Body()
var err error
totalBalance, err := helpers.TotalActiveBalance(beaconState)
if err != nil {
return nil, err
}
for idx, attestation := range body.Attestations() {
beaconState, err = ProcessAttestationNoVerifySignature(ctx, beaconState, attestation)
beaconState, err = ProcessAttestationNoVerifySignature(ctx, beaconState, attestation, totalBalance)
if err != nil {
return nil, errors.Wrapf(err, "could not verify attestation at index %d in block", idx)
}
@@ -44,6 +47,7 @@ func ProcessAttestationNoVerifySignature(
ctx context.Context,
beaconState state.BeaconStateAltair,
att *ethpb.Attestation,
totalBalance uint64,
) (state.BeaconStateAltair, error) {
ctx, span := trace.StartSpan(ctx, "altair.ProcessAttestationNoVerifySignature")
defer span.End()
@@ -69,7 +73,7 @@ func ProcessAttestationNoVerifySignature(
return nil, err
}
return SetParticipationAndRewardProposer(beaconState, att.Data.Target.Epoch, indices, participatedFlags)
return SetParticipationAndRewardProposer(beaconState, att.Data.Target.Epoch, indices, participatedFlags, totalBalance)
}
// SetParticipationAndRewardProposer retrieves and sets the epoch participation bits in state. Based on the epoch participation, it rewards
@@ -97,7 +101,7 @@ func SetParticipationAndRewardProposer(
beaconState state.BeaconState,
targetEpoch types.Epoch,
indices []uint64,
participatedFlags map[uint8]bool) (state.BeaconState, error) {
participatedFlags map[uint8]bool, totalBalance uint64) (state.BeaconState, error) {
var epochParticipation []byte
currentEpoch := helpers.CurrentEpoch(beaconState)
var err error
@@ -113,7 +117,7 @@ func SetParticipationAndRewardProposer(
}
}
proposerRewardNumerator, epochParticipation, err := EpochParticipation(beaconState, indices, epochParticipation, participatedFlags)
proposerRewardNumerator, epochParticipation, err := EpochParticipation(beaconState, indices, epochParticipation, participatedFlags, totalBalance)
if err != nil {
return nil, err
}
@@ -154,16 +158,12 @@ func AddValidatorFlag(flag, flagPosition uint8) uint8 {
// if flag_index in participation_flag_indices and not has_flag(epoch_participation[index], flag_index):
// epoch_participation[index] = add_flag(epoch_participation[index], flag_index)
// proposer_reward_numerator += get_base_reward(state, index) * weight
func EpochParticipation(beaconState state.BeaconState, indices []uint64, epochParticipation []byte, participatedFlags map[uint8]bool) (uint64, []byte, error) {
func EpochParticipation(beaconState state.BeaconState, indices []uint64, epochParticipation []byte, participatedFlags map[uint8]bool, totalBalance uint64) (uint64, []byte, error) {
cfg := params.BeaconConfig()
sourceFlagIndex := cfg.TimelySourceFlagIndex
targetFlagIndex := cfg.TimelyTargetFlagIndex
headFlagIndex := cfg.TimelyHeadFlagIndex
proposerRewardNumerator := uint64(0)
totalBalance, err := helpers.TotalActiveBalance(beaconState)
if err != nil {
return 0, nil, err
}
for _, index := range indices {
if index >= uint64(len(epochParticipation)) {
return 0, nil, fmt.Errorf("index %d exceeds participation length %d", index, len(epochParticipation))
@@ -281,7 +281,7 @@ func AttestationParticipationFlagIndices(beaconState state.BeaconStateAltair, da
// is_matching_source = data.source == justified_checkpoint
// is_matching_target = is_matching_source and data.target.root == get_block_root(state, data.target.epoch)
// is_matching_head = is_matching_target and data.beacon_block_root == get_block_root_at_slot(state, data.slot)
func MatchingStatus(beaconState state.BeaconState, data *ethpb.AttestationData, cp *ethpb.Checkpoint) (matchedSrc bool, matchedTgt bool, matchedHead bool, err error) {
func MatchingStatus(beaconState state.BeaconState, data *ethpb.AttestationData, cp *ethpb.Checkpoint) (matchedSrc, matchedTgt, matchedHead bool, err error) {
matchedSrc = attestationutil.CheckPointIsEqual(data.Source, cp)
r, err := helpers.BlockRoot(beaconState, data.Target.Epoch)

View File

@@ -5,11 +5,13 @@ import (
"fmt"
"testing"
fuzz "github.com/google/gofuzz"
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/beacon-chain/core/altair"
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
stateAltair "github.com/prysmaticlabs/prysm/beacon-chain/state/v2"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/wrapper"
"github.com/prysmaticlabs/prysm/shared/attestationutil"
@@ -258,7 +260,9 @@ func TestProcessAttestationNoVerify_SourceTargetHead(t *testing.T) {
copy(ckp.Root, make([]byte, 32))
require.NoError(t, beaconState.SetCurrentJustifiedCheckpoint(ckp))
beaconState, err = altair.ProcessAttestationNoVerifySignature(context.Background(), beaconState, att)
b, err := helpers.TotalActiveBalance(beaconState)
require.NoError(t, err)
beaconState, err = altair.ProcessAttestationNoVerifySignature(context.Background(), beaconState, att, b)
require.NoError(t, err)
p, err := beaconState.CurrentEpochParticipation()
@@ -379,6 +383,28 @@ func TestValidatorFlag_Add(t *testing.T) {
}
}
func TestFuzzProcessAttestationsNoVerify_10000(t *testing.T) {
fuzzer := fuzz.NewWithSeed(0)
state := &ethpb.BeaconStateAltair{}
b := &ethpb.SignedBeaconBlockAltair{Block: &ethpb.BeaconBlockAltair{}}
ctx := context.Background()
for i := 0; i < 10000; i++ {
fuzzer.Fuzz(state)
fuzzer.Fuzz(b)
if b.Block == nil {
b.Block = &ethpb.BeaconBlockAltair{}
}
s, err := stateAltair.InitializeFromProtoUnsafe(state)
require.NoError(t, err)
wsb, err := wrapper.WrappedAltairSignedBeaconBlock(b)
require.NoError(t, err)
r, err := altair.ProcessAttestationsNoVerifySignature(ctx, s, wsb)
if err != nil && r != nil {
t.Fatalf("return value should be nil on err. found: %v on error: %v for state: %v and block: %v", r, err, state, b)
}
}
}
func TestSetParticipationAndRewardProposer(t *testing.T) {
cfg := params.BeaconConfig()
sourceFlagIndex := cfg.TimelySourceFlagIndex
@@ -450,12 +476,15 @@ func TestSetParticipationAndRewardProposer(t *testing.T) {
} else {
require.NoError(t, beaconState.SetPreviousParticipationBits(test.epochParticipation))
}
st, err := altair.SetParticipationAndRewardProposer(beaconState, test.epoch, test.indices, test.participatedFlags)
b, err := helpers.TotalActiveBalance(beaconState)
require.NoError(t, err)
st, err := altair.SetParticipationAndRewardProposer(beaconState, test.epoch, test.indices, test.participatedFlags, b)
require.NoError(t, err)
i, err := helpers.BeaconProposerIndex(st)
require.NoError(t, err)
b, err := beaconState.BalanceAtIndex(i)
b, err = beaconState.BalanceAtIndex(i)
require.NoError(t, err)
require.Equal(t, test.wantedBalance, b)
@@ -533,7 +562,9 @@ func TestEpochParticipation(t *testing.T) {
},
}
for _, test := range tests {
n, p, err := altair.EpochParticipation(beaconState, test.indices, test.epochParticipation, test.participatedFlags)
b, err := helpers.TotalActiveBalance(beaconState)
require.NoError(t, err)
n, p, err := altair.EpochParticipation(beaconState, test.indices, test.epochParticipation, test.participatedFlags, b)
require.NoError(t, err)
require.Equal(t, test.wantedNumerator, n)
require.DeepSSZEqual(t, test.wantedEpochParticipation, p)

View File

@@ -19,6 +19,7 @@ go_library(
"//beacon-chain/blockchain:go_default_library",
"//beacon-chain/cache:go_default_library",
"//beacon-chain/cache/depositcache:go_default_library",
"//beacon-chain/core/altair:go_default_library",
"//beacon-chain/core/blocks:go_default_library",
"//beacon-chain/core/feed:go_default_library",
"//beacon-chain/core/feed/block:go_default_library",
@@ -53,6 +54,7 @@ go_library(
"//shared/timeutils:go_default_library",
"//shared/traceutil:go_default_library",
"//shared/trieutil:go_default_library",
"//shared/version:go_default_library",
"@com_github_ferranbt_fastssz//:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",

View File

@@ -6,12 +6,15 @@ import (
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/go-bitfield"
"github.com/prysmaticlabs/prysm/beacon-chain/core/altair"
"github.com/prysmaticlabs/prysm/beacon-chain/core/blocks"
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
"github.com/prysmaticlabs/prysm/beacon-chain/state"
ethpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"github.com/prysmaticlabs/prysm/shared/aggregation"
"github.com/prysmaticlabs/prysm/shared/featureconfig"
"github.com/prysmaticlabs/prysm/shared/params"
"github.com/prysmaticlabs/prysm/shared/version"
)
type proposerAtts []*ethpb.Attestation
@@ -19,11 +22,29 @@ type proposerAtts []*ethpb.Attestation
// filter separates attestation list into two groups: valid and invalid attestations.
// The first group passes the all the required checks for attestation to be considered for proposing.
// And attestations from the second group should be deleted.
func (a proposerAtts) filter(ctx context.Context, state state.BeaconState) (proposerAtts, proposerAtts) {
func (a proposerAtts) filter(ctx context.Context, st state.BeaconState) (proposerAtts, proposerAtts) {
validAtts := make([]*ethpb.Attestation, 0, len(a))
invalidAtts := make([]*ethpb.Attestation, 0, len(a))
var attestationProcessor func(context.Context, state.BeaconState, *ethpb.Attestation) (state.BeaconState, error)
switch st.Version() {
case version.Phase0:
attestationProcessor = blocks.ProcessAttestationNoVerifySignature
case version.Altair:
// Use a wrapper here, as go needs strong typing for the function signature.
attestationProcessor = func(ctx context.Context, st state.BeaconState, attestation *ethpb.Attestation) (state.BeaconState, error) {
totalBalance, err := helpers.TotalActiveBalance(st)
if err != nil {
return nil, err
}
return altair.ProcessAttestationNoVerifySignature(ctx, st, attestation, totalBalance)
}
default:
// Exit early if there is an unknown state type.
return validAtts, invalidAtts
}
for _, att := range a {
if _, err := blocks.ProcessAttestationNoVerifySignature(ctx, state, att); err == nil {
if _, err := attestationProcessor(ctx, st, att); err == nil {
validAtts = append(validAtts, att)
continue
}