diff --git a/deps.bzl b/deps.bzl index a7fc3b55d9..de0f55755c 100644 --- a/deps.bzl +++ b/deps.bzl @@ -2779,8 +2779,8 @@ def prysm_deps(): go_repository( name = "com_github_prysmaticlabs_go_bitfield", importpath = "github.com/prysmaticlabs/go-bitfield", - sum = "h1:yALGBNFMp40DeD3qGGRgiC0FWePzy0FIhxWEXoco3ZA=", - version = "v0.0.0-20210628171552-0c86d791fc37", + sum = "h1:nc95NsZcGforJ9a3QbsAdGcG3TbFHaLfHeaHi62CMso=", + version = "v0.0.0-20210628211147-0d89f726b4c2", ) go_repository( name = "com_github_prysmaticlabs_prombbolt", diff --git a/go.mod b/go.mod index e94d4d7b2c..a096084256 100644 --- a/go.mod +++ b/go.mod @@ -88,7 +88,7 @@ require ( github.com/prometheus/prom2json v1.3.0 github.com/prometheus/tsdb v0.10.0 // indirect github.com/prysmaticlabs/eth2-types v0.0.0-20210303084904-c9735a06829d - github.com/prysmaticlabs/go-bitfield v0.0.0-20210628171552-0c86d791fc37 + github.com/prysmaticlabs/go-bitfield v0.0.0-20210628211147-0d89f726b4c2 github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c github.com/prysmaticlabs/protoc-gen-go-cast v0.0.0-20210504233148-1e141af6a0a1 github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc diff --git a/go.sum b/go.sum index 8a4674474c..02510421af 100644 --- a/go.sum +++ b/go.sum @@ -1069,8 +1069,8 @@ github.com/prysmaticlabs/bazel-go-ethereum v0.0.0-20210420143944-f4dfc9744288/go github.com/prysmaticlabs/eth2-types v0.0.0-20210303084904-c9735a06829d h1:1dN7YAqMN3oAJ0LceWcyv/U4jHLh+5urnSnr4br6zg4= github.com/prysmaticlabs/eth2-types v0.0.0-20210303084904-c9735a06829d/go.mod h1:kOmQ/zdobQf7HUohDTifDNFEZfNaSCIY5fkONPL+dWU= github.com/prysmaticlabs/go-bitfield v0.0.0-20210108222456-8e92c3709aa0/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s= -github.com/prysmaticlabs/go-bitfield v0.0.0-20210628171552-0c86d791fc37 h1:yALGBNFMp40DeD3qGGRgiC0FWePzy0FIhxWEXoco3ZA= -github.com/prysmaticlabs/go-bitfield v0.0.0-20210628171552-0c86d791fc37/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s= +github.com/prysmaticlabs/go-bitfield v0.0.0-20210628211147-0d89f726b4c2 h1:nc95NsZcGforJ9a3QbsAdGcG3TbFHaLfHeaHi62CMso= +github.com/prysmaticlabs/go-bitfield v0.0.0-20210628211147-0d89f726b4c2/go.mod h1:hCwmef+4qXWjv0jLDbQdWnL0Ol7cS7/lCSS26WR+u6s= github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210622145107-ca3041e1b380 h1:KzQOksIZB8poBiMk8h5Txzbp/OoBLFhS3H20ZN06hWg= github.com/prysmaticlabs/grpc-gateway/v2 v2.3.1-0.20210622145107-ca3041e1b380/go.mod h1:IOyTYjcIO0rkmnGBfJTL0NJ11exy/Tc2QEuv7hCXp24= github.com/prysmaticlabs/prombbolt v0.0.0-20210126082820-9b7adba6db7c h1:9PHRCuO/VN0s9k+RmLykho7AjDxblNYI5bYKed16NPU= diff --git a/shared/aggregation/sync_contribution/BUILD.bazel b/shared/aggregation/sync_contribution/BUILD.bazel new file mode 100644 index 0000000000..cb5bf0378a --- /dev/null +++ b/shared/aggregation/sync_contribution/BUILD.bazel @@ -0,0 +1,36 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") +load("@prysm//tools/go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = [ + "contribution.go", + "naive.go", + ], + importpath = "github.com/prysmaticlabs/prysm/shared/aggregation/sync_contribution", + visibility = ["//visibility:public"], + deps = [ + "//proto/prysm/v2:go_default_library", + "//shared/aggregation:go_default_library", + "//shared/bls:go_default_library", + "//shared/copyutil:go_default_library", + "@com_github_pkg_errors//:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", + ], +) + +go_test( + name = "go_default_test", + srcs = ["naive_test.go"], + embed = [":go_default_library"], + deps = [ + "//proto/prysm/v2:go_default_library", + "//shared/aggregation:go_default_library", + "//shared/aggregation/testing:go_default_library", + "//shared/bls:go_default_library", + "//shared/featureconfig:go_default_library", + "//shared/testutil/assert:go_default_library", + "//shared/testutil/require:go_default_library", + "@com_github_prysmaticlabs_go_bitfield//:go_default_library", + ], +) diff --git a/shared/aggregation/sync_contribution/contribution.go b/shared/aggregation/sync_contribution/contribution.go new file mode 100644 index 0000000000..c0c6ebae8a --- /dev/null +++ b/shared/aggregation/sync_contribution/contribution.go @@ -0,0 +1,41 @@ +package sync_contribution + +import ( + "github.com/pkg/errors" + v2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" + "github.com/prysmaticlabs/prysm/shared/aggregation" + "github.com/sirupsen/logrus" +) + +const ( + // NaiveAggregation is an aggregation strategy without any optimizations. + NaiveAggregation SyncContributionAggregationStrategy = "naive" + + // MaxCoverAggregation is a strategy based on Maximum Coverage greedy algorithm. + MaxCoverAggregation SyncContributionAggregationStrategy = "max_cover" +) + +// SyncContributionAggregationStrategy defines SyncContribution aggregation strategy. +type SyncContributionAggregationStrategy string + +var _ = logrus.WithField("prefix", "aggregation.sync_contribution") + +// ErrInvalidSyncContributionCount is returned when insufficient number +// of sync contributions is provided for aggregation. +var ErrInvalidSyncContributionCount = errors.New("invalid number of sync contributions") + +// Aggregate aggregates sync contributions. The minimal number of sync contributions is returned. +// Aggregation occurs in-place i.e. contents of input array will be modified. Should you need to +// preserve input sync contributions, clone them before aggregating. +func Aggregate(cs []*v2.SyncCommitteeContribution) ([]*v2.SyncCommitteeContribution, error) { + strategy := NaiveAggregation + switch strategy { + case "", NaiveAggregation: + return naiveSyncContributionAggregation(cs) + case MaxCoverAggregation: + // TODO: Implement max cover aggregation for sync contributions. + return nil, errors.New("no implemented") + default: + return nil, errors.Wrapf(aggregation.ErrInvalidStrategy, "%q", strategy) + } +} diff --git a/shared/aggregation/sync_contribution/naive.go b/shared/aggregation/sync_contribution/naive.go new file mode 100644 index 0000000000..0af7b633a1 --- /dev/null +++ b/shared/aggregation/sync_contribution/naive.go @@ -0,0 +1,93 @@ +package sync_contribution + +import ( + v2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" + "github.com/prysmaticlabs/prysm/shared/aggregation" + "github.com/prysmaticlabs/prysm/shared/bls" + "github.com/prysmaticlabs/prysm/shared/copyutil" +) + +// naiveSyncContributionAggregation aggregates naively, without any complex algorithms or optimizations. +// Note: this is currently a naive implementation to the order of O(mn^2). +func naiveSyncContributionAggregation(contributions []*v2.SyncCommitteeContribution) ([]*v2.SyncCommitteeContribution, error) { + if len(contributions) <= 1 { + return contributions, nil + } + + // Naive aggregation. O(n^2) time. + for i, a := range contributions { + if i >= len(contributions) { + break + } + for j := i + 1; j < len(contributions); j++ { + b := contributions[j] + if a.AggregationBits.Len() == b.AggregationBits.Len() && !a.AggregationBits.Overlaps(b.AggregationBits) { + var err error + a, err = aggregate(a, b) + if err != nil { + return nil, err + } + // Delete b + contributions = append(contributions[:j], contributions[j+1:]...) + j-- + contributions[i] = a + } + } + } + + // Naive deduplication of identical contributions. O(n^2) time. + for i, a := range contributions { + for j := i + 1; j < len(contributions); j++ { + b := contributions[j] + + if a.AggregationBits.Len() != b.AggregationBits.Len() { + continue + } + + if a.AggregationBits.Contains(b.AggregationBits) { + // If b is fully contained in a, then b can be removed. + contributions = append(contributions[:j], contributions[j+1:]...) + j-- + } else if b.AggregationBits.Contains(a.AggregationBits) { + // if a is fully contained in b, then a can be removed. + contributions = append(contributions[:i], contributions[i+1:]...) + break // Stop the inner loop, advance a. + } + } + } + + return contributions, nil +} + +// aggregates pair of sync contributions c1 and c2 together. +func aggregate(c1, c2 *v2.SyncCommitteeContribution) (*v2.SyncCommitteeContribution, error) { + if c1.AggregationBits.Overlaps(c2.AggregationBits) { + return nil, aggregation.ErrBitsOverlap + } + + baseContribution := copyutil.CopySyncCommitteeContribution(c1) + newContribution := copyutil.CopySyncCommitteeContribution(c2) + if newContribution.AggregationBits.Count() > baseContribution.AggregationBits.Count() { + baseContribution, newContribution = newContribution, baseContribution + } + + if baseContribution.AggregationBits.Contains(newContribution.AggregationBits) { + return baseContribution, nil + } + + newBits := baseContribution.AggregationBits.Or(newContribution.AggregationBits) + newSig, err := bls.SignatureFromBytes(newContribution.Signature) + if err != nil { + return nil, err + } + baseSig, err := bls.SignatureFromBytes(baseContribution.Signature) + if err != nil { + return nil, err + } + + aggregatedSig := bls.AggregateSignatures([]bls.Signature{baseSig, newSig}) + baseContribution.Signature = aggregatedSig.Marshal() + baseContribution.AggregationBits = newBits + + return baseContribution, nil +} diff --git a/shared/aggregation/sync_contribution/naive_test.go b/shared/aggregation/sync_contribution/naive_test.go new file mode 100644 index 0000000000..8386e77c82 --- /dev/null +++ b/shared/aggregation/sync_contribution/naive_test.go @@ -0,0 +1,173 @@ +package sync_contribution + +import ( + "fmt" + "sort" + "testing" + + "github.com/prysmaticlabs/go-bitfield" + prysmv2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" + "github.com/prysmaticlabs/prysm/shared/aggregation" + aggtesting "github.com/prysmaticlabs/prysm/shared/aggregation/testing" + "github.com/prysmaticlabs/prysm/shared/bls" + "github.com/prysmaticlabs/prysm/shared/featureconfig" + "github.com/prysmaticlabs/prysm/shared/testutil/assert" + "github.com/prysmaticlabs/prysm/shared/testutil/require" +) + +func TestAggregateAttestations_aggregate(t *testing.T) { + tests := []struct { + a1 *prysmv2.SyncCommitteeContribution + a2 *prysmv2.SyncCommitteeContribution + want *prysmv2.SyncCommitteeContribution + }{ + { + a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x02}, Signature: bls.NewAggregateSignature().Marshal()}, + a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x01}, Signature: bls.NewAggregateSignature().Marshal()}, + want: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x03}}, + }, + { + a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x01}, Signature: bls.NewAggregateSignature().Marshal()}, + a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x02}, Signature: bls.NewAggregateSignature().Marshal()}, + want: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x03}}, + }, + } + for _, tt := range tests { + got, err := aggregate(tt.a1, tt.a2) + require.NoError(t, err) + require.DeepSSZEqual(t, tt.want.AggregationBits, got.AggregationBits) + } +} + +func TestAggregateAttestations_aggregate_OverlapFails(t *testing.T) { + tests := []struct { + a1 *prysmv2.SyncCommitteeContribution + a2 *prysmv2.SyncCommitteeContribution + }{ + { + a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x1F}}, + a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x11}}, + }, + { + a1: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0xFF, 0x85}}, + a2: &prysmv2.SyncCommitteeContribution{AggregationBits: bitfield.Bitvector128{0x13, 0x8F}}, + }, + } + for _, tt := range tests { + _, err := aggregate(tt.a1, tt.a2) + require.ErrorContains(t, aggregation.ErrBitsOverlap.Error(), err) + } +} + +func TestAggregateAttestations_Aggregate(t *testing.T) { + tests := []struct { + name string + inputs []bitfield.Bitvector128 + want []bitfield.Bitvector128 + }{ + { + name: "empty list", + inputs: []bitfield.Bitvector128{}, + want: []bitfield.Bitvector128{}, + }, + { + name: "single attestation", + inputs: []bitfield.Bitvector128{ + {0b00000010}, + }, + want: []bitfield.Bitvector128{ + {0b00000010}, + }, + }, + { + name: "two attestations with no overlap", + inputs: []bitfield.Bitvector128{ + {0b00000001}, + {0b00000010}, + }, + want: []bitfield.Bitvector128{ + {0b00000011}, + }, + }, + { + name: "two attestations with overlap", + inputs: []bitfield.Bitvector128{ + {0b00000101}, + {0b00000110}, + }, + want: []bitfield.Bitvector128{ + {0b00000101}, + {0b00000110}, + }, + }, + { + name: "some attestations overlap", + inputs: []bitfield.Bitvector128{ + {0b00001001}, + {0b00010110}, + {0b00001010}, + {0b00110001}, + }, + want: []bitfield.Bitvector128{ + {0b00111011}, + {0b00011111}, + }, + }, + { + name: "some attestations produce duplicates which are removed", + inputs: []bitfield.Bitvector128{ + {0b00000101}, + {0b00000110}, + {0b00001010}, + {0b00001001}, + }, + want: []bitfield.Bitvector128{ + {0b00001111}, // both 0&1 and 2&3 produce this bitlist + }, + }, + { + name: "two attestations where one is fully contained within the other", + inputs: []bitfield.Bitvector128{ + {0b00000001}, + {0b00000011}, + }, + want: []bitfield.Bitvector128{ + {0b00000011}, + }, + }, + { + name: "two attestations where one is fully contained within the other reversed", + inputs: []bitfield.Bitvector128{ + {0b00000011}, + {0b00000001}, + }, + want: []bitfield.Bitvector128{ + {0b00000011}, + }, + }, + } + + for _, tt := range tests { + runner := func() { + got, err := Aggregate(aggtesting.MakeSyncContributionsFromBitVector(tt.inputs)) + require.NoError(t, err) + sort.Slice(got, func(i, j int) bool { + return got[i].AggregationBits.Bytes()[0] < got[j].AggregationBits.Bytes()[0] + }) + sort.Slice(tt.want, func(i, j int) bool { + return tt.want[i].Bytes()[0] < tt.want[j].Bytes()[0] + }) + assert.Equal(t, len(tt.want), len(got)) + for i, w := range tt.want { + assert.DeepEqual(t, w.Bytes(), got[i].AggregationBits.Bytes()) + } + } + t.Run(fmt.Sprintf("%s/%s", tt.name, NaiveAggregation), func(t *testing.T) { + resetCfg := featureconfig.InitWithReset(&featureconfig.Flags{ + AttestationAggregationStrategy: string(NaiveAggregation), + }) + defer resetCfg() + runner() + }) + } +} diff --git a/shared/aggregation/testing/BUILD.bazel b/shared/aggregation/testing/BUILD.bazel index 8343f95bfc..6f86198660 100644 --- a/shared/aggregation/testing/BUILD.bazel +++ b/shared/aggregation/testing/BUILD.bazel @@ -7,8 +7,10 @@ go_library( visibility = ["//visibility:public"], deps = [ "//proto/eth/v1alpha1:go_default_library", + "//proto/prysm/v2:go_default_library", "//shared/bls:go_default_library", "//shared/timeutils:go_default_library", + "@com_github_prysmaticlabs_eth2_types//:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library", ], ) diff --git a/shared/aggregation/testing/bitlistutils.go b/shared/aggregation/testing/bitlistutils.go index 759a03a716..dbb1abe585 100644 --- a/shared/aggregation/testing/bitlistutils.go +++ b/shared/aggregation/testing/bitlistutils.go @@ -4,8 +4,10 @@ import ( "math/rand" "testing" + types "github.com/prysmaticlabs/eth2-types" "github.com/prysmaticlabs/go-bitfield" ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1" + prysmv2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" "github.com/prysmaticlabs/prysm/shared/bls" "github.com/prysmaticlabs/prysm/shared/timeutils" ) @@ -75,7 +77,7 @@ func Bitlists64WithMultipleBitSet(t testing.TB, n, length, count uint64) []*bitf return lists } -// MakeAttestationsFromBitlists creates list of bitlists from list of attestations. +// MakeAttestationsFromBitlists creates list of attestations from list of bitlist. func MakeAttestationsFromBitlists(bl []bitfield.Bitlist) []*ethpb.Attestation { atts := make([]*ethpb.Attestation, len(bl)) for i, b := range bl { @@ -90,3 +92,17 @@ func MakeAttestationsFromBitlists(bl []bitfield.Bitlist) []*ethpb.Attestation { } return atts } + +// MakeSyncContributionsFromBitVector creates list of sync contributions from list of bitvector. +func MakeSyncContributionsFromBitVector(bl []bitfield.Bitvector128) []*prysmv2.SyncCommitteeContribution { + c := make([]*prysmv2.SyncCommitteeContribution, len(bl)) + for i, b := range bl { + c[i] = &prysmv2.SyncCommitteeContribution{ + Slot: types.Slot(1), + SubcommitteeIndex: 2, + AggregationBits: b, + Signature: bls.NewAggregateSignature().Marshal(), + } + } + return c +} diff --git a/shared/copyutil/BUILD.bazel b/shared/copyutil/BUILD.bazel index 3ea6d914bd..2566b03215 100644 --- a/shared/copyutil/BUILD.bazel +++ b/shared/copyutil/BUILD.bazel @@ -8,6 +8,7 @@ go_library( deps = [ "//proto/beacon/p2p/v1:go_default_library", "//proto/eth/v1alpha1:go_default_library", + "//proto/prysm/v2:go_default_library", "//shared/bytesutil:go_default_library", ], ) diff --git a/shared/copyutil/cloners.go b/shared/copyutil/cloners.go index 3eadc66adb..3b8573a211 100644 --- a/shared/copyutil/cloners.go +++ b/shared/copyutil/cloners.go @@ -3,6 +3,7 @@ package copyutil import ( pbp2p "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1" + prysmv2 "github.com/prysmaticlabs/prysm/proto/prysm/v2" "github.com/prysmaticlabs/prysm/shared/bytesutil" ) @@ -284,3 +285,17 @@ func CopyValidator(val *ethpb.Validator) *ethpb.Validator { WithdrawableEpoch: val.WithdrawableEpoch, } } + +// CopySyncCommitteeContribution copies the provided sync committee contribution object. +func CopySyncCommitteeContribution(c *prysmv2.SyncCommitteeContribution) *prysmv2.SyncCommitteeContribution { + if c == nil { + return nil + } + return &prysmv2.SyncCommitteeContribution{ + Slot: c.Slot, + BlockRoot: bytesutil.SafeCopyBytes(c.BlockRoot), + SubcommitteeIndex: c.SubcommitteeIndex, + AggregationBits: bytesutil.SafeCopyBytes(c.AggregationBits), + Signature: bytesutil.SafeCopyBytes(c.Signature), + } +}