Update go-bitfield (#9162)

This commit is contained in:
Preston Van Loon
2021-07-08 10:31:40 -05:00
committed by GitHub
parent 72886986ea
commit 9dc3dd04c7
30 changed files with 371 additions and 164 deletions

View File

@@ -47,6 +47,7 @@ go_test(
"//shared/testutil/require:go_default_library",
"@com_github_ferranbt_fastssz//:go_default_library",
"@com_github_patrickmn_go_cache//:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
],

View File

@@ -212,7 +212,9 @@ func (c *AttCaches) DeleteAggregatedAttestation(att *ethpb.Attestation) error {
filtered := make([]*ethpb.Attestation, 0)
for _, a := range attList {
if att.AggregationBits.Len() == a.AggregationBits.Len() && !att.AggregationBits.Contains(a.AggregationBits) {
if c, err := att.AggregationBits.Contains(a.AggregationBits); err != nil {
return err
} else if !c {
filtered = append(filtered, a)
}
}
@@ -239,7 +241,9 @@ func (c *AttCaches) HasAggregatedAttestation(att *ethpb.Attestation) (bool, erro
defer c.aggregatedAttLock.RUnlock()
if atts, ok := c.aggregatedAtt[r]; ok {
for _, a := range atts {
if a.AggregationBits.Len() == att.AggregationBits.Len() && a.AggregationBits.Contains(att.AggregationBits) {
if c, err := a.AggregationBits.Contains(att.AggregationBits); err != nil {
return false, err
} else if c {
return true, nil
}
}
@@ -249,7 +253,9 @@ func (c *AttCaches) HasAggregatedAttestation(att *ethpb.Attestation) (bool, erro
defer c.blockAttLock.RUnlock()
if atts, ok := c.blockAtt[r]; ok {
for _, a := range atts {
if a.AggregationBits.Len() == att.AggregationBits.Len() && a.AggregationBits.Contains(att.AggregationBits) {
if c, err := a.AggregationBits.Contains(att.AggregationBits); err != nil {
return false, err
} else if c {
return true, nil
}
}

View File

@@ -7,6 +7,7 @@ import (
fssz "github.com/ferranbt/fastssz"
c "github.com/patrickmn/go-cache"
"github.com/pkg/errors"
types "github.com/prysmaticlabs/eth2-types"
"github.com/prysmaticlabs/go-bitfield"
ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
@@ -297,9 +298,9 @@ func TestKV_Aggregated_DeleteAggregatedAttestation(t *testing.T) {
t.Run("non-filtered deletion", func(t *testing.T) {
cache := NewAttCaches()
att1 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b1101}})
att2 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 2}, AggregationBits: bitfield.Bitlist{0b1101}})
att3 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 3}, AggregationBits: bitfield.Bitlist{0b1101}})
att1 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b11010}})
att2 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 2}, AggregationBits: bitfield.Bitlist{0b11010}})
att3 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 3}, AggregationBits: bitfield.Bitlist{0b11010}})
att4 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 3}, AggregationBits: bitfield.Bitlist{0b10101}})
atts := []*ethpb.Attestation{att1, att2, att3, att4}
require.NoError(t, cache.SaveAggregatedAttestations(atts))
@@ -338,21 +339,21 @@ func TestKV_Aggregated_HasAggregatedAttestation(t *testing.T) {
existing []*ethpb.Attestation
input *ethpb.Attestation
want bool
error bool
err error
}{
{
name: "nil attestation",
input: nil,
want: false,
error: true,
err: errors.New("can't be nil"),
},
{
name: "nil attestation data",
input: &ethpb.Attestation{
AggregationBits: bitfield.Bitlist{0b1111},
},
want: false,
error: true,
want: false,
err: errors.New("can't be nil"),
},
{
name: "empty cache aggregated",
@@ -503,6 +504,7 @@ func TestKV_Aggregated_HasAggregatedAttestation(t *testing.T) {
AggregationBits: bitfield.Bitlist{0b1111},
},
want: false,
err: bitfield.ErrBitlistDifferentLength,
},
}
@@ -515,9 +517,9 @@ func TestKV_Aggregated_HasAggregatedAttestation(t *testing.T) {
tt.input.Signature = make([]byte, 96)
}
if tt.error == true {
if tt.err != nil {
_, err := cache.HasAggregatedAttestation(tt.input)
require.ErrorContains(t, "can't be nil", err)
require.ErrorContains(t, tt.err.Error(), err)
} else {
result, err := cache.HasAggregatedAttestation(tt.input)
require.NoError(t, err)

View File

@@ -25,7 +25,9 @@ func (c *AttCaches) SaveBlockAttestation(att *ethpb.Attestation) error {
// Ensure that this attestation is not already fully contained in an existing attestation.
for _, a := range atts {
if a.AggregationBits.Len() == att.AggregationBits.Len() && a.AggregationBits.Contains(att.AggregationBits) {
if c, err := a.AggregationBits.Contains(att.AggregationBits); err != nil {
return err
} else if c {
return nil
}
}

View File

@@ -17,12 +17,16 @@ func TestKV_BlockAttestation_CanSaveRetrieve(t *testing.T) {
att1 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b1101}})
att2 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 2}, AggregationBits: bitfield.Bitlist{0b1101}})
att3 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 3}, AggregationBits: bitfield.Bitlist{0b1101}})
att4 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 3}, AggregationBits: bitfield.Bitlist{0b11011}}) // Diff bit length should not panic.
atts := []*ethpb.Attestation{att1, att2, att3, att4}
atts := []*ethpb.Attestation{att1, att2, att3}
for _, att := range atts {
require.NoError(t, cache.SaveBlockAttestation(att))
}
// Diff bit length should not panic.
att4 := testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 3}, AggregationBits: bitfield.Bitlist{0b11011}})
if err := cache.SaveBlockAttestation(att4); err != bitfield.ErrBitlistDifferentLength {
t.Errorf("Unexpected error: wanted %v, got %v", bitfield.ErrBitlistDifferentLength, err)
}
returned := cache.BlockAttestations()

View File

@@ -21,7 +21,9 @@ func (c *AttCaches) insertSeenBit(att *ethpb.Attestation) error {
}
alreadyExists := false
for _, bit := range seenBits {
if bit.Len() == att.AggregationBits.Len() && bit.Contains(att.AggregationBits) {
if c, err := bit.Contains(att.AggregationBits); err != nil {
return err
} else if c {
alreadyExists = true
break
}
@@ -50,7 +52,9 @@ func (c *AttCaches) hasSeenBit(att *ethpb.Attestation) (bool, error) {
return false, errors.New("could not convert to bitlist type")
}
for _, bit := range seenBits {
if bit.Len() == att.AggregationBits.Len() && bit.Contains(att.AggregationBits) {
if c, err := bit.Contains(att.AggregationBits); err != nil {
return false, err
} else if c {
return true, nil
}
}

View File

@@ -24,7 +24,7 @@ func TestAttCaches_hasSeenBit(t *testing.T) {
{att: testutil.HydrateAttestation(&ethpb.Attestation{AggregationBits: bitfield.Bitlist{0b10000001}}), want: true},
{att: testutil.HydrateAttestation(&ethpb.Attestation{AggregationBits: bitfield.Bitlist{0b11100000}}), want: true},
{att: testutil.HydrateAttestation(&ethpb.Attestation{AggregationBits: bitfield.Bitlist{0b10000011}}), want: true},
{att: testutil.HydrateAttestation(&ethpb.Attestation{AggregationBits: bitfield.Bitlist{0b00001000}}), want: false},
{att: testutil.HydrateAttestation(&ethpb.Attestation{AggregationBits: bitfield.Bitlist{0b10001000}}), want: false},
{att: testutil.HydrateAttestation(&ethpb.Attestation{AggregationBits: bitfield.Bitlist{0b11110111}}), want: false},
}
for _, tt := range tests {

View File

@@ -115,11 +115,20 @@ func (s *Service) seen(att *ethpb.Attestation) (bool, error) {
}
if savedBitlist.Len() == incomingBits.Len() {
// Returns true if the node has seen all the bits in the new bit field of the incoming attestation.
if bytes.Equal(savedBitlist, incomingBits) || savedBitlist.Contains(incomingBits) {
if bytes.Equal(savedBitlist, incomingBits) {
return true, nil
}
if c, err := savedBitlist.Contains(incomingBits); err != nil {
return false, err
} else if c {
return true, nil
}
var err error
// Update the bit fields by Or'ing them with the new ones.
incomingBits = incomingBits.Or(savedBitlist)
incomingBits, err = incomingBits.Or(savedBitlist)
if err != nil {
return false, err
}
}
}

View File

@@ -15,7 +15,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/operations/attestations"
mockp2p "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stategen"
"github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
mockSync "github.com/prysmaticlabs/prysm/beacon-chain/sync/initial-sync/testing"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"

View File

@@ -591,7 +591,15 @@ func (vs *Server) filterAttestationsForBlockInclusion(ctx context.Context, st if
if err := vs.deleteAttsInPool(ctx, invalidAtts); err != nil {
return nil, err
}
return validAtts.dedup().sortByProfitability().limitToMaxAttestations(), nil
deduped, err := validAtts.dedup()
if err != nil {
return nil, err
}
sorted, err := deduped.sortByProfitability()
if err != nil {
return nil, err
}
return sorted.limitToMaxAttestations(), nil
}
// The input attestations are processed and seen by the node, this deletes them from pool
@@ -669,7 +677,15 @@ func (vs *Server) packAttestations(ctx context.Context, latestState iface.Beacon
}
attsForInclusion = append(attsForInclusion, as...)
}
atts = attsForInclusion.dedup().sortByProfitability().limitToMaxAttestations()
deduped, err := attsForInclusion.dedup()
if err != nil {
return nil, err
}
sorted, err := deduped.sortByProfitability()
if err != nil {
return nil, err
}
atts = sorted.limitToMaxAttestations()
}
return atts, nil
}

View File

@@ -33,9 +33,9 @@ func (a proposerAtts) filter(ctx context.Context, state iface.BeaconState) (prop
}
// sortByProfitability orders attestations by highest slot and by highest aggregation bit count.
func (a proposerAtts) sortByProfitability() proposerAtts {
func (a proposerAtts) sortByProfitability() (proposerAtts, error) {
if len(a) < 2 {
return a
return a, nil
}
if featureconfig.Get().ProposerAttsSelectionUsingMaxCover {
return a.sortByProfitabilityUsingMaxCover()
@@ -46,12 +46,12 @@ func (a proposerAtts) sortByProfitability() proposerAtts {
}
return a[i].Data.Slot > a[j].Data.Slot
})
return a
return a, nil
}
// sortByProfitabilityUsingMaxCover orders attestations by highest slot and by highest aggregation bit count.
// Duplicate bits are counted only once, using max-cover algorithm.
func (a proposerAtts) sortByProfitabilityUsingMaxCover() proposerAtts {
func (a proposerAtts) sortByProfitabilityUsingMaxCover() (proposerAtts, error) {
// Separate attestations by slot, as slot number takes higher precedence when sorting.
var slots []types.Slot
attsBySlot := map[types.Slot]proposerAtts{}
@@ -62,13 +62,17 @@ func (a proposerAtts) sortByProfitabilityUsingMaxCover() proposerAtts {
attsBySlot[att.Data.Slot] = append(attsBySlot[att.Data.Slot], att)
}
selectAtts := func(atts proposerAtts) proposerAtts {
selectAtts := func(atts proposerAtts) (proposerAtts, error) {
if len(atts) < 2 {
return atts
return atts, nil
}
candidates := make([]*bitfield.Bitlist64, len(atts))
for i := 0; i < len(atts); i++ {
candidates[i] = atts[i].AggregationBits.ToBitlist64()
var err error
candidates[i], err = atts[i].AggregationBits.ToBitlist64()
if err != nil {
return nil, err
}
}
// Add selected candidates on top, those that are not selected - append at bottom.
selectedKeys, _, err := aggregation.MaxCover(candidates, len(candidates), true /* allowOverlaps */)
@@ -89,9 +93,9 @@ func (a proposerAtts) sortByProfitabilityUsingMaxCover() proposerAtts {
sort.Slice(leftoverAtts, func(i, j int) bool {
return leftoverAtts[i].AggregationBits.Count() > leftoverAtts[j].AggregationBits.Count()
})
return append(selectedAtts, leftoverAtts...)
return append(selectedAtts, leftoverAtts...), nil
}
return atts
return atts, nil
}
// Select attestations. Slots are sorted from higher to lower at this point. Within slots attestations
@@ -102,10 +106,14 @@ func (a proposerAtts) sortByProfitabilityUsingMaxCover() proposerAtts {
return slots[i] > slots[j]
})
for _, slot := range slots {
sortedAtts = append(sortedAtts, selectAtts(attsBySlot[slot])...)
selected, err := selectAtts(attsBySlot[slot])
if err != nil {
return nil, err
}
sortedAtts = append(sortedAtts, selected...)
}
return sortedAtts
return sortedAtts, nil
}
// limitToMaxAttestations limits attestations to maximum attestations per block.
@@ -119,9 +127,9 @@ func (a proposerAtts) limitToMaxAttestations() proposerAtts {
// dedup removes duplicate attestations (ones with the same bits set on).
// Important: not only exact duplicates are removed, but proper subsets are removed too
// (their known bits are redundant and are already contained in their supersets).
func (a proposerAtts) dedup() proposerAtts {
func (a proposerAtts) dedup() (proposerAtts, error) {
if len(a) < 2 {
return a
return a, nil
}
attsByDataRoot := make(map[[32]byte][]*ethpb.Attestation, len(a))
for _, att := range a {
@@ -138,13 +146,17 @@ func (a proposerAtts) dedup() proposerAtts {
a := atts[i]
for j := i + 1; j < len(atts); j++ {
b := atts[j]
if a.AggregationBits.Contains(b.AggregationBits) {
if c, err := a.AggregationBits.Contains(b.AggregationBits); err != nil {
return nil, err
} else if c {
// a contains b, b is redundant.
atts[j] = atts[len(atts)-1]
atts[len(atts)-1] = nil
atts = atts[:len(atts)-1]
j--
} else if b.AggregationBits.Contains(a.AggregationBits) {
} else if c, err := b.AggregationBits.Contains(a.AggregationBits); err != nil {
return nil, err
} else if c {
// b contains a, a is redundant.
atts[i] = atts[len(atts)-1]
atts[len(atts)-1] = nil
@@ -157,5 +169,5 @@ func (a proposerAtts) dedup() proposerAtts {
uniqAtts = append(uniqAtts, atts...)
}
return uniqAtts
return uniqAtts, nil
}

View File

@@ -31,7 +31,10 @@ func TestProposer_ProposerAtts_sortByProfitability(t *testing.T) {
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b11100000}}),
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b11000000}}),
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
}
@@ -57,7 +60,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
t.Run("no atts", func(t *testing.T) {
atts := getAtts([]testData{})
want := getAtts([]testData{})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
@@ -68,7 +74,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
want := getAtts([]testData{
{4, bitfield.Bitlist{0b11100000, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
@@ -81,7 +90,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
{4, bitfield.Bitlist{0b11100000, 0b1}},
{1, bitfield.Bitlist{0b11000000, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
@@ -96,7 +108,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
{4, bitfield.Bitlist{0b11100000, 0b1}},
{1, bitfield.Bitlist{0b11000000, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
@@ -121,7 +136,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
{1, bitfield.Bitlist{0b11001000, 0b1}},
{1, bitfield.Bitlist{0b00001100, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
t.Run("max-cover", func(t *testing.T) {
@@ -140,7 +158,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
{1, bitfield.Bitlist{0b00001100, 0b1}},
{1, bitfield.Bitlist{0b11001000, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
})
@@ -162,7 +183,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
{1, bitfield.Bitlist{0b11100000, 0b1}},
{1, bitfield.Bitlist{0b11000000, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
@@ -191,7 +215,10 @@ func TestProposer_ProposerAtts_sortByProfitabilityUsingMaxCover(t *testing.T) {
{1, bitfield.Bitlist{0b11100000, 0b1}},
{1, bitfield.Bitlist{0b11000000, 0b1}},
})
atts = atts.sortByProfitability()
atts, err := atts.sortByProfitability()
if err != nil {
t.Error(err)
}
require.DeepEqual(t, want, atts)
})
}
@@ -414,7 +441,10 @@ func TestProposer_ProposerAtts_dedup(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
atts := tt.atts.dedup()
atts, err := tt.atts.dedup()
if err != nil {
t.Error(err)
}
sort.Slice(atts, func(i, j int) bool {
if atts[i].AggregationBits.Count() == atts[j].AggregationBits.Count() {
if atts[i].Data.Slot == atts[j].Data.Slot {

View File

@@ -35,9 +35,9 @@ func (cs proposerSyncContributions) filterBySubIndex(i uint64) proposerSyncContr
// dedup removes duplicate sync contributions (ones with the same bits set on).
// Important: not only exact duplicates are removed, but proper subsets are removed too
// (their known bits are redundant and are already contained in their supersets).
func (cs proposerSyncContributions) dedup() proposerSyncContributions {
func (cs proposerSyncContributions) dedup() (proposerSyncContributions, error) {
if len(cs) < 2 {
return cs
return cs, nil
}
contributionsBySubIdx := make(map[uint64][]*eth.SyncCommitteeContribution, len(cs))
for _, c := range cs {
@@ -50,13 +50,17 @@ func (cs proposerSyncContributions) dedup() proposerSyncContributions {
a := cs[i]
for j := i + 1; j < len(cs); j++ {
b := cs[j]
if a.AggregationBits.Contains(b.AggregationBits) {
if c, err := a.AggregationBits.Contains(b.AggregationBits); err != nil {
return nil, err
} else if c {
// a contains b, b is redundant.
cs[j] = cs[len(cs)-1]
cs[len(cs)-1] = nil
cs = cs[:len(cs)-1]
j--
} else if b.AggregationBits.Contains(a.GetAggregationBits()) {
} else if c, err := b.AggregationBits.Contains(a.GetAggregationBits()); err != nil {
return nil, err
} else if c {
// b contains a, a is redundant.
cs[i] = cs[len(cs)-1]
cs[len(cs)-1] = nil
@@ -68,7 +72,7 @@ func (cs proposerSyncContributions) dedup() proposerSyncContributions {
}
uniqContributions = append(uniqContributions, cs...)
}
return uniqContributions
return uniqContributions, nil
}
// mostProfitable returns the most profitable sync contribution, the one with the most

View File

@@ -326,7 +326,10 @@ func TestProposerSyncContributions_Dedup(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cs := tt.cs.dedup()
cs, err := tt.cs.dedup()
if err != nil {
t.Error(err)
}
sort.Slice(cs, func(i, j int) bool {
if cs[i].AggregationBits.Count() == cs[j].AggregationBits.Count() {
if cs[i].SubcommitteeIndex == cs[j].SubcommitteeIndex {

View File

@@ -18,7 +18,7 @@ import (
mockp2p "github.com/prysmaticlabs/prysm/beacon-chain/p2p/testing"
mockPOW "github.com/prysmaticlabs/prysm/beacon-chain/powchain/testing"
"github.com/prysmaticlabs/prysm/beacon-chain/state/stategen"
"github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
mockSync "github.com/prysmaticlabs/prysm/beacon-chain/sync/initial-sync/testing"
dbpb "github.com/prysmaticlabs/prysm/proto/beacon/db"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
@@ -1831,12 +1831,6 @@ func TestProposer_FilterAttestation(t *testing.T) {
require.NoError(t, db.SaveState(ctx, state, genesisRoot), "Could not save genesis state")
require.NoError(t, db.SaveHeadBlockRoot(ctx, genesisRoot), "Could not save genesis state")
proposerServer := &Server{
BeaconDB: db,
AttPool: attestations.NewPool(),
HeadFetcher: &mock.ChainService{State: state, Root: genesisRoot[:]},
}
tests := []struct {
name string
wantedErr string
@@ -1910,6 +1904,11 @@ func TestProposer_FilterAttestation(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
proposerServer := &Server{
BeaconDB: db,
AttPool: attestations.NewPool(),
HeadFetcher: &mock.ChainService{State: state, Root: genesisRoot[:]},
}
atts := tt.inputAtts()
received, err := proposerServer.filterAttestationsForBlockInclusion(context.Background(), state, atts)
if tt.wantedErr != "" {
@@ -2029,8 +2028,8 @@ func TestProposer_DeleteAttsInPool_Aggregated(t *testing.T) {
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b10101}, Signature: sig}),
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b11010}, Signature: sig})}
unaggregatedAtts := []*ethpb.Attestation{
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b1001}, Signature: sig}),
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b0001}, Signature: sig})}
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b10010}, Signature: sig}),
testutil.HydrateAttestation(&ethpb.Attestation{Data: &ethpb.AttestationData{Slot: 1}, AggregationBits: bitfield.Bitlist{0b10100}, Signature: sig})}
require.NoError(t, s.AttPool.SaveAggregatedAttestations(aggregatedAtts))
require.NoError(t, s.AttPool.SaveUnaggregatedAttestations(unaggregatedAtts))

View File

@@ -14,7 +14,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/core/helpers"
dbutil "github.com/prysmaticlabs/prysm/beacon-chain/db/testing"
mockPOW "github.com/prysmaticlabs/prysm/beacon-chain/powchain/testing"
"github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
v1 "github.com/prysmaticlabs/prysm/beacon-chain/state/v1"
pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
"github.com/prysmaticlabs/prysm/proto/eth/v1alpha1/wrapper"

View File

@@ -2761,8 +2761,8 @@ def prysm_deps():
go_repository(
name = "com_github_prysmaticlabs_go_bitfield",
importpath = "github.com/prysmaticlabs/go-bitfield",
sum = "h1:3feHotPCE8LiZ12vizC8UPPrguf2FgGmiphi3n2xkPc=",
version = "v0.0.0-20210701052645-7ef93542bdc8",
sum = "h1:9rrmgQval2GOmtMAgGLdqcCEZLraNaN3k2mY+07cx64=",
version = "v0.0.0-20210706153858-5cb5ce8bdbfe",
)
go_repository(
name = "com_github_prysmaticlabs_prombbolt",

2
go.mod
View File

@@ -88,7 +88,7 @@ require (
github.com/prometheus/prom2json v1.3.0
github.com/prometheus/tsdb v0.10.0 // indirect
github.com/prysmaticlabs/eth2-types v0.0.0-20210303084904-c9735a06829d
github.com/prysmaticlabs/go-bitfield v0.0.0-20210701052645-7ef93542bdc8
github.com/prysmaticlabs/go-bitfield v0.0.0-20210706153858-5cb5ce8bdbfe
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c
github.com/prysmaticlabs/protoc-gen-go-cast v0.0.0-20210504233148-1e141af6a0a1
github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc

4
go.sum
View File

@@ -1070,8 +1070,8 @@ github.com/prysmaticlabs/bazel-go-ethereum v0.0.0-20210420143944-f4dfc9744288/go
github.com/prysmaticlabs/eth2-types v0.0.0-20210303084904-c9735a06829d h1:1dN7YAqMN3oAJ0LceWcyv/U4jHLh+5urnSnr4br6zg4=
github.com/prysmaticlabs/eth2-types v0.0.0-20210303084904-c9735a06829d/go.mod h1:kOmQ/zdobQf7HUohDTifDNFEZfNaSCIY5fkONPL+dWU=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210701052645-7ef93542bdc8 h1:3feHotPCE8LiZ12vizC8UPPrguf2FgGmiphi3n2xkPc=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210701052645-7ef93542bdc8/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210706153858-5cb5ce8bdbfe h1:9rrmgQval2GOmtMAgGLdqcCEZLraNaN3k2mY+07cx64=
github.com/prysmaticlabs/go-bitfield v0.0.0-20210706153858-5cb5ce8bdbfe/go.mod h1:wmuf/mdK4VMD+jA9ThwcUKjg3a2XWM9cVfFYjDyY4j4=
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210622145107-ca3041e1b380 h1:KzQOksIZB8poBiMk8h5Txzbp/OoBLFhS3H20ZN06hWg=
github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210622145107-ca3041e1b380/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24=
github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU=

View File

@@ -14,9 +14,6 @@ var (
// ErrBitsOverlap is returned when two bitlists overlap with each other.
ErrBitsOverlap = errors.New("overlapping aggregation bits")
// ErrBitsDifferentLen is returned when two bitlists have different lengths.
ErrBitsDifferentLen = errors.New("different bitlist lengths")
// ErrInvalidStrategy is returned when invalid aggregation strategy is selected.
ErrInvalidStrategy = errors.New("invalid aggregation strategy")
)

View File

@@ -21,10 +21,10 @@ go_library(
],
)
# gazelle:exclude attestations_bench_test.go
go_test(
name = "go_default_test",
srcs = [
"attestations_bench_test.go",
"attestations_test.go",
"maxcover_test.go",
],
@@ -34,6 +34,7 @@ go_test(
"//shared/aggregation:go_default_library",
"//shared/aggregation/testing:go_default_library",
"//shared/bls:go_default_library",
"//shared/copyutil:go_default_library",
"//shared/featureconfig:go_default_library",
"//shared/params:go_default_library",
"//shared/sszutil:go_default_library",

View File

@@ -66,10 +66,11 @@ func Aggregate(atts []*ethpb.Attestation) ([]*ethpb.Attestation, error) {
// AggregatePair aggregates pair of attestations a1 and a2 together.
func AggregatePair(a1, a2 *ethpb.Attestation) (*ethpb.Attestation, error) {
if a1.AggregationBits.Len() != a2.AggregationBits.Len() {
return nil, aggregation.ErrBitsDifferentLen
o, err := a1.AggregationBits.Overlaps(a2.AggregationBits)
if err != nil {
return nil, err
}
if a1.AggregationBits.Overlaps(a2.AggregationBits) {
if o {
return nil, aggregation.ErrBitsOverlap
}
@@ -79,11 +80,18 @@ func AggregatePair(a1, a2 *ethpb.Attestation) (*ethpb.Attestation, error) {
baseAtt, newAtt = newAtt, baseAtt
}
if baseAtt.AggregationBits.Contains(newAtt.AggregationBits) {
c, err := baseAtt.AggregationBits.Contains(newAtt.AggregationBits)
if err != nil {
return nil, err
}
if c {
return baseAtt, nil
}
newBits := baseAtt.AggregationBits.Or(newAtt.AggregationBits)
newBits, err := baseAtt.AggregationBits.Or(newAtt.AggregationBits)
if err != nil {
return nil, err
}
newSig, err := signatureFromBytes(newAtt.Signature)
if err != nil {
return nil, err

View File

@@ -10,7 +10,6 @@ import (
ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing"
"github.com/prysmaticlabs/prysm/shared/bls"
"github.com/prysmaticlabs/prysm/shared/bls/common"
"github.com/prysmaticlabs/prysm/shared/featureconfig"
"github.com/prysmaticlabs/prysm/shared/params"
"github.com/prysmaticlabs/prysm/shared/testutil/require"
@@ -19,10 +18,10 @@ import (
func BenchmarkAggregateAttestations_Aggregate(b *testing.B) {
// Override expensive BLS aggregation method with cheap no-op such that this benchmark profiles
// the logic of aggregation selection rather than BLS logic.
aggregateSignatures = func(sigs []common.Signature) common.Signature {
aggregateSignatures = func(sigs []bls.Signature) bls.Signature {
return sigs[0]
}
signatureFromBytes = func(sig []byte) (common.Signature, error) {
signatureFromBytes = func(sig []byte) (bls.Signature, error) {
return bls.NewAggregateSignature(), nil
}
defer func() {

View File

@@ -90,7 +90,7 @@ func TestAggregateAttestations_AggregatePair_DiffLengthFails(t *testing.T) {
}
for _, tt := range tests {
_, err := AggregatePair(tt.a1, tt.a2)
require.ErrorContains(t, aggregation.ErrBitsDifferentLen.Error(), err)
require.ErrorContains(t, bitfield.ErrBitlistDifferentLength.Error(), err)
}
}
@@ -101,6 +101,7 @@ func TestAggregateAttestations_Aggregate(t *testing.T) {
name string
inputs []bitfield.Bitlist
want []bitfield.Bitlist
err error
}{
{
name: "empty list",
@@ -208,12 +209,17 @@ func TestAggregateAttestations_Aggregate(t *testing.T) {
{0b00000111, 0b100},
{0b00000100, 0b1},
},
err: bitfield.ErrBitlistDifferentLength,
},
}
for _, tt := range tests {
runner := func() {
got, err := Aggregate(aggtesting.MakeAttestationsFromBitlists(tt.inputs))
if tt.err != nil {
require.ErrorContains(t, tt.err.Error(), err)
return
}
require.NoError(t, err)
sort.Slice(got, func(i, j int) bool {
return got[i].AggregationBits.Bytes()[0] < got[j].AggregationBits.Bytes()[0]

View File

@@ -23,9 +23,6 @@ func MaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Attesta
unaggregated := attList(atts)
if err := unaggregated.validate(); err != nil {
if errors.Is(err, aggregation.ErrBitsDifferentLen) {
return unaggregated, nil
}
return nil, err
}
@@ -49,7 +46,9 @@ func MaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Attesta
}
// Create aggregated attestation and update solution lists.
if !aggregated.hasCoverage(solution.Coverage) {
if has, err := aggregated.hasCoverage(solution.Coverage); err != nil {
return nil, err
} else if !has {
att, err := unaggregated.selectUsingKeys(solution.Keys).aggregate(solution.Coverage)
if err != nil {
return aggregated.merge(unaggregated), err
@@ -59,7 +58,11 @@ func MaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Attesta
unaggregated = unaggregated.selectComplementUsingKeys(solution.Keys)
}
return aggregated.merge(unaggregated.filterContained()), nil
filtered, err := unaggregated.filterContained()
if err != nil {
return nil, err
}
return aggregated.merge(filtered), nil
}
// optMaxCoverAttestationAggregation relies on Maximum Coverage greedy algorithm for aggregation.
@@ -73,9 +76,6 @@ func optMaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Atte
}
if err := attList(atts).validate(); err != nil {
if errors.Is(err, aggregation.ErrBitsDifferentLen) {
return atts, nil
}
return nil, err
}
@@ -83,7 +83,11 @@ func optMaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Atte
// type, so incoming `atts` parameters can be used as candidates list directly.
candidates := make([]*bitfield.Bitlist64, len(atts))
for i := 0; i < len(atts); i++ {
candidates[i] = atts[i].AggregationBits.ToBitlist64()
var err error
candidates[i], err = atts[i].AggregationBits.ToBitlist64()
if err != nil {
return nil, err
}
}
coveredBitsSoFar := bitfield.NewBitlist64(candidates[0].Len())
@@ -120,7 +124,11 @@ func optMaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Atte
// Create aggregated attestation and update solution lists. Process aggregates only if they
// feature at least one unknown bit i.e. can increase the overall coverage.
if coveredBitsSoFar.XorCount(coverage) > 0 {
xc, err := coveredBitsSoFar.XorCount(coverage)
if err != nil {
return nil, err
}
if xc > 0 {
aggIdx, err := aggregateAttestations(atts, keys, coverage)
if err != nil {
return append(aggregated, unaggregated...), err
@@ -140,7 +148,9 @@ func optMaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Atte
unaggregated = unaggregated[1:]
// Update covered bits map.
coveredBitsSoFar.NoAllocOr(coverage, coveredBitsSoFar)
if err := coveredBitsSoFar.NoAllocOr(coverage, coveredBitsSoFar); err != nil {
return nil, err
}
keys = keys[1:]
}
@@ -149,7 +159,11 @@ func optMaxCoverAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Atte
unaggregated = unaggregated[:len(unaggregated)-len(keys)]
}
return append(aggregated, attList(unaggregated).filterContained()...), nil
filtered, err := attList(unaggregated).filterContained()
if err != nil {
return nil, err
}
return append(aggregated, filtered...), nil
}
// NewMaxCover returns initialized Maximum Coverage problem for attestations aggregation.
@@ -286,19 +300,23 @@ func (al attList) selectComplementUsingKeys(keys []int) attList {
}
// hasCoverage returns true if a given coverage is found in attestations list.
func (al attList) hasCoverage(coverage bitfield.Bitlist) bool {
func (al attList) hasCoverage(coverage bitfield.Bitlist) (bool, error) {
for _, att := range al {
if att.AggregationBits.Xor(coverage).Count() == 0 {
return true
x, err := att.AggregationBits.Xor(coverage)
if err != nil {
return false, err
}
if x.Count() == 0 {
return true, nil
}
}
return false
return false, nil
}
// filterContained removes attestations that are contained within other attestations.
func (al attList) filterContained() attList {
func (al attList) filterContained() (attList, error) {
if len(al) < 2 {
return al
return al, nil
}
sort.Slice(al, func(i, j int) bool {
return al[i].AggregationBits.Count() > al[j].AggregationBits.Count()
@@ -306,12 +324,16 @@ func (al attList) filterContained() attList {
filtered := al[:0]
filtered = append(filtered, al[0])
for i := 1; i < len(al); i++ {
if filtered[len(filtered)-1].AggregationBits.Contains(al[i].AggregationBits) {
c, err := filtered[len(filtered)-1].AggregationBits.Contains(al[i].AggregationBits)
if err != nil {
return nil, err
}
if c {
continue
}
filtered = append(filtered, al[i])
}
return filtered
return filtered, nil
}
// validate checks attestation list for validity (equal bitlength, non-nil bitlist etc).
@@ -325,10 +347,9 @@ func (al attList) validate() error {
if al[0].AggregationBits == nil || al[0].AggregationBits.Len() == 0 {
return errors.Wrap(aggregation.ErrInvalidMaxCoverProblem, "bitlist cannot be nil or empty")
}
bitlistLen := al[0].AggregationBits.Len()
for i := 1; i < len(al); i++ {
if al[i].AggregationBits == nil || bitlistLen != al[i].AggregationBits.Len() {
return aggregation.ErrBitsDifferentLen
if al[i].AggregationBits == nil || al[i].AggregationBits.Len() == 0 {
return errors.Wrap(aggregation.ErrInvalidMaxCoverProblem, "bitlist cannot be nil or empty")
}
}
return nil

View File

@@ -102,7 +102,7 @@ func TestAggregateAttestations_MaxCover_AttList_validate(t *testing.T) {
&ethpb.Attestation{AggregationBits: bitfield.NewBitlist(64)},
&ethpb.Attestation{},
},
wantedErr: aggregation.ErrBitsDifferentLen.Error(),
wantedErr: "bitlist cannot be nil or empty",
},
{
name: "first bitlist is empty",
@@ -117,17 +117,7 @@ func TestAggregateAttestations_MaxCover_AttList_validate(t *testing.T) {
&ethpb.Attestation{AggregationBits: bitfield.NewBitlist(64)},
&ethpb.Attestation{AggregationBits: bitfield.Bitlist{}},
},
wantedErr: aggregation.ErrBitsDifferentLen.Error(),
},
{
name: "bitlists of non equal length",
atts: attList{
&ethpb.Attestation{AggregationBits: bitfield.NewBitlist(64)},
&ethpb.Attestation{AggregationBits: bitfield.NewBitlist(64)},
&ethpb.Attestation{AggregationBits: bitfield.NewBitlist(63)},
&ethpb.Attestation{AggregationBits: bitfield.NewBitlist(64)},
},
wantedErr: aggregation.ErrBitsDifferentLen.Error(),
wantedErr: "bitlist cannot be nil or empty",
},
{
name: "valid bitlists",
@@ -295,7 +285,11 @@ func TestAggregateAttestations_rearrangeProcessedAttestations(t *testing.T) {
candidates := make([]*bitfield.Bitlist64, len(tt.atts))
for i := 0; i < len(tt.atts); i++ {
if tt.atts[i] != nil {
candidates[i] = tt.atts[i].AggregationBits.ToBitlist64()
var err error
candidates[i], err = tt.atts[i].AggregationBits.ToBitlist64()
if err != nil {
t.Error(err)
}
}
}
rearrangeProcessedAttestations(tt.atts, candidates, tt.keys)
@@ -381,7 +375,13 @@ func TestAggregateAttestations_aggregateAttestations(t *testing.T) {
},
wantTargetIdx: 0,
keys: []int{0, 1},
coverage: bitfield.NewBitlist64FromBytes(64, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0b00000011}),
coverage: func() *bitfield.Bitlist64 {
b, err := bitfield.NewBitlist64FromBytes(64, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0b00000011})
if err != nil {
t.Fatal(err)
}
return b
}(),
},
{
name: "many attestations, several selected",
@@ -403,7 +403,13 @@ func TestAggregateAttestations_aggregateAttestations(t *testing.T) {
},
wantTargetIdx: 1,
keys: []int{1, 2, 4},
coverage: bitfield.NewBitlist64FromBytes(64, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0b00010110}),
coverage: func() *bitfield.Bitlist64 {
b, err := bitfield.NewBitlist64FromBytes(64, []byte{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0b00010110})
if err != nil {
t.Fatal(err)
}
return b
}(),
},
}
for _, tt := range tests {

View File

@@ -16,7 +16,9 @@ func NaiveAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Attestatio
}
for j := i + 1; j < len(atts); j++ {
b := atts[j]
if a.AggregationBits.Len() == b.AggregationBits.Len() && !a.AggregationBits.Overlaps(b.AggregationBits) {
if o, err := a.AggregationBits.Overlaps(b.AggregationBits); err != nil {
return nil, err
} else if !o {
var err error
a, err = AggregatePair(a, b)
if err != nil {
@@ -39,11 +41,15 @@ func NaiveAttestationAggregation(atts []*ethpb.Attestation) ([]*ethpb.Attestatio
continue
}
if a.AggregationBits.Contains(b.AggregationBits) {
if c, err := a.AggregationBits.Contains(b.AggregationBits); err != nil {
return nil, err
} else if c {
// If b is fully contained in a, then b can be removed.
atts = append(atts[:j], atts[j+1:]...)
j--
} else if b.AggregationBits.Contains(a.AggregationBits) {
} else if c, err := b.AggregationBits.Contains(a.AggregationBits); err != nil {
return nil, err
} else if c {
// if a is fully contained in b, then a can be removed.
atts = append(atts[:i], atts[i+1:]...)
break // Stop the inner loop, advance a.

View File

@@ -62,7 +62,10 @@ func (mc *MaxCoverProblem) Cover(k int, allowOverlaps bool) (*Aggregation, error
Coverage: bitfield.NewBitlist(mc.Candidates[0].bits.Len()),
Keys: make([]int, 0, k),
}
remainingBits := mc.Candidates.union()
remainingBits, err := mc.Candidates.union()
if err != nil {
return nil, err
}
if remainingBits == nil {
return nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists")
}
@@ -71,16 +74,31 @@ func (mc *MaxCoverProblem) Cover(k int, allowOverlaps bool) (*Aggregation, error
// Score candidates against remaining bits.
// Filter out processed and overlapping (when disallowed).
// Sort by score in a descending order.
mc.Candidates.score(remainingBits).filter(solution.Coverage, allowOverlaps).sort()
s, err := mc.Candidates.score(remainingBits)
if err != nil {
return nil, err
}
s, err = s.filter(solution.Coverage, allowOverlaps)
if err != nil {
return nil, err
}
s.sort()
for _, candidate := range mc.Candidates {
if len(solution.Keys) >= k {
break
}
if !candidate.processed {
solution.Coverage = solution.Coverage.Or(*candidate.bits)
var err error
solution.Coverage, err = solution.Coverage.Or(*candidate.bits)
if err != nil {
return nil, err
}
solution.Keys = append(solution.Keys, candidate.key)
remainingBits = remainingBits.And(candidate.bits.Not())
remainingBits, err = remainingBits.And(candidate.bits.Not())
if err != nil {
return nil, err
}
candidate.processed = true
break
}
@@ -104,7 +122,10 @@ func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (sele
// Track bits covered so far as a bitlist.
coveredBits := bitfield.NewBitlist64(candidates[0].Len())
remainingBits := union(candidates)
remainingBits, err := union(candidates)
if err != nil {
return nil, nil, err
}
if remainingBits == nil {
return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists")
}
@@ -117,7 +138,7 @@ func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (sele
if attempts > k {
break
}
attempts += 1
attempts++
// Greedy select the next best candidate (from usable ones) to cover the remaining bits maximally.
maxScore := uint64(0)
@@ -128,7 +149,11 @@ func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (sele
// Score is calculated by taking into account uncovered bits only.
score := uint64(0)
if candidates[idx].Len() == remainingBits.Len() {
score = candidates[idx].AndCount(remainingBits)
var err error
score, err = candidates[idx].AndCount(remainingBits)
if err != nil {
return nil, nil, err
}
}
// Filter out zero-score candidates.
@@ -139,10 +164,16 @@ func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (sele
// Filter out overlapping candidates (if overlapping is not allowed).
wrongLen := coveredBits.Len() != candidates[idx].Len()
overlaps := func(idx int) bool {
return !allowOverlaps && coveredBits.Overlaps(candidates[idx])
overlaps := func(idx int) (bool, error) {
o, err := coveredBits.Overlaps(candidates[idx])
return !allowOverlaps && o, err
}
if wrongLen || overlaps(idx) {
if wrongLen { // Shortcut for wrong length check
usableCandidates.SetBitAt(uint64(idx), false)
continue
} else if o, err := overlaps(idx); err != nil {
return nil, nil, err
} else if o {
usableCandidates.SetBitAt(uint64(idx), false)
continue
}
@@ -155,10 +186,14 @@ func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (sele
}
// Process greedy selected candidate.
if maxScore > 0 {
coveredBits.NoAllocOr(candidates[bestIdx], coveredBits)
if err := coveredBits.NoAllocOr(candidates[bestIdx], coveredBits); err != nil {
return nil, nil, err
}
selectedCandidates.SetBitAt(bestIdx, true)
candidates[bestIdx].NoAllocNot(tmpBitlist)
remainingBits.NoAllocAnd(tmpBitlist, remainingBits)
if err := remainingBits.NoAllocAnd(tmpBitlist, remainingBits); err != nil {
return nil, nil, err
}
usableCandidates.SetBitAt(bestIdx, false)
}
}
@@ -166,32 +201,46 @@ func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (sele
}
// score updates scores of candidates, taking into account the uncovered elements only.
func (cl *MaxCoverCandidates) score(uncovered bitfield.Bitlist) *MaxCoverCandidates {
func (cl *MaxCoverCandidates) score(uncovered bitfield.Bitlist) (*MaxCoverCandidates, error) {
for i := 0; i < len(*cl); i++ {
if (*cl)[i].bits.Len() == uncovered.Len() {
(*cl)[i].score = (*cl)[i].bits.And(uncovered).Count()
a, err := (*cl)[i].bits.And(uncovered)
if err != nil {
return nil, err
}
(*cl)[i].score = a.Count()
}
}
return cl
return cl, nil
}
// filter removes processed, overlapping and zero-score candidates.
func (cl *MaxCoverCandidates) filter(covered bitfield.Bitlist, allowOverlaps bool) *MaxCoverCandidates {
overlaps := func(e bitfield.Bitlist) bool {
return !allowOverlaps && covered.Len() == e.Len() && covered.Overlaps(e)
func (cl *MaxCoverCandidates) filter(covered bitfield.Bitlist, allowOverlaps bool) (*MaxCoverCandidates, error) {
overlaps := func(e bitfield.Bitlist) (bool, error) {
if !allowOverlaps && covered.Len() == e.Len() {
return covered.Overlaps(e)
}
return false, nil
}
cur, end := 0, len(*cl)
for cur < end {
e := *(*cl)[cur]
if e.processed || overlaps(*e.bits) || e.score == 0 {
if e.processed || e.score == 0 {
(*cl)[cur] = (*cl)[end-1]
end--
continue
} else if o, err := overlaps(*e.bits); err == nil && o {
(*cl)[cur] = (*cl)[end-1]
end--
continue
} else if err != nil {
return nil, err
}
cur++
}
*cl = (*cl)[:end]
return cl
return cl, nil
}
// sort orders candidates by their score, starting from the candidate with the highest score.
@@ -206,31 +255,37 @@ func (cl *MaxCoverCandidates) sort() *MaxCoverCandidates {
}
// union merges all candidate bitlists using logical OR operator.
func (cl *MaxCoverCandidates) union() bitfield.Bitlist {
func (cl *MaxCoverCandidates) union() (bitfield.Bitlist, error) {
if len(*cl) == 0 {
return nil
return nil, nil
}
if (*cl)[0].bits == nil || (*cl)[0].bits.Len() == 0 {
return nil
return nil, nil
}
ret := bitfield.NewBitlist((*cl)[0].bits.Len())
var err error
for i := 0; i < len(*cl); i++ {
if *(*cl)[i].bits != nil && ret.Len() == (*cl)[i].bits.Len() {
ret = ret.Or(*(*cl)[i].bits)
ret, err = ret.Or(*(*cl)[i].bits)
if err != nil {
return nil, err
}
}
}
return ret
return ret, nil
}
func union(candidates []*bitfield.Bitlist64) *bitfield.Bitlist64 {
func union(candidates []*bitfield.Bitlist64) (*bitfield.Bitlist64, error) {
if len(candidates) == 0 || candidates[0].Len() == 0 {
return nil
return nil, nil
}
ret := bitfield.NewBitlist64(candidates[0].Len())
for _, bl := range candidates {
if ret.Len() == bl.Len() {
ret.NoAllocOr(bl, ret)
if err := ret.NoAllocOr(bl, ret); err != nil {
return nil, err
}
}
}
return ret
return ret, nil
}

View File

@@ -132,7 +132,10 @@ func TestMaxCover_MaxCoverCandidates_filter(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.cl.filter(tt.args.covered, tt.args.allowOverlaps)
got, err := tt.cl.filter(tt.args.covered, tt.args.allowOverlaps)
if err != nil {
t.Error(err)
}
sort.Slice(*got, func(i, j int) bool {
return (*got)[i].key < (*got)[j].key
})
@@ -272,8 +275,8 @@ func TestMaxCover_MaxCoverCandidates_union(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.cl.union(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("union(), got: %#b, want: %#b", got, tt.want)
if got, err := tt.cl.union(); !reflect.DeepEqual(got, tt.want) || err != nil {
t.Errorf("union(), got: %#b, %v, want: %#b", got, err, tt.want)
}
})
}
@@ -349,8 +352,8 @@ func TestMaxCover_MaxCoverCandidates_score(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.cl.score(tt.uncovered); !reflect.DeepEqual(got, tt.want) {
t.Errorf("score() = %v, want %v", got, tt.want)
if got, err := tt.cl.score(tt.uncovered); !reflect.DeepEqual(got, tt.want) || err != nil {
t.Errorf("score() = %v, %v, want %v", got, err, tt.want)
}
})
}

View File

@@ -21,7 +21,9 @@ func naiveSyncContributionAggregation(contributions []*v2.SyncCommitteeContribut
}
for j := i + 1; j < len(contributions); j++ {
b := contributions[j]
if a.AggregationBits.Len() == b.AggregationBits.Len() && !a.AggregationBits.Overlaps(b.AggregationBits) {
if o, err := a.AggregationBits.Overlaps(b.AggregationBits); err != nil {
return nil, err
} else if !o {
var err error
a, err = aggregate(a, b)
if err != nil {
@@ -44,11 +46,15 @@ func naiveSyncContributionAggregation(contributions []*v2.SyncCommitteeContribut
continue
}
if a.AggregationBits.Contains(b.AggregationBits) {
if c, err := a.AggregationBits.Contains(b.AggregationBits); err != nil {
return nil, err
} else if c {
// If b is fully contained in a, then b can be removed.
contributions = append(contributions[:j], contributions[j+1:]...)
j--
} else if b.AggregationBits.Contains(a.AggregationBits) {
} else if c, err := b.AggregationBits.Contains(a.AggregationBits); err != nil {
return nil, err
} else if c {
// if a is fully contained in b, then a can be removed.
contributions = append(contributions[:i], contributions[i+1:]...)
break // Stop the inner loop, advance a.
@@ -61,7 +67,9 @@ func naiveSyncContributionAggregation(contributions []*v2.SyncCommitteeContribut
// aggregates pair of sync contributions c1 and c2 together.
func aggregate(c1, c2 *v2.SyncCommitteeContribution) (*v2.SyncCommitteeContribution, error) {
if c1.AggregationBits.Overlaps(c2.AggregationBits) {
if o, err := c1.AggregationBits.Overlaps(c2.AggregationBits); err != nil {
return nil, err
} else if o {
return nil, aggregation.ErrBitsOverlap
}
@@ -71,11 +79,16 @@ func aggregate(c1, c2 *v2.SyncCommitteeContribution) (*v2.SyncCommitteeContribut
baseContribution, newContribution = newContribution, baseContribution
}
if baseContribution.AggregationBits.Contains(newContribution.AggregationBits) {
if c, err := baseContribution.AggregationBits.Contains(newContribution.AggregationBits); err != nil {
return nil, err
} else if c {
return baseContribution, nil
}
newBits := baseContribution.AggregationBits.Or(newContribution.AggregationBits)
newBits, err := baseContribution.AggregationBits.Or(newContribution.AggregationBits)
if err != nil {
return nil, err
}
newSig, err := bls.SignatureFromBytes(newContribution.Signature)
if err != nil {
return nil, err