diff --git a/beacon-chain/sync/batch_verifier.go b/beacon-chain/sync/batch_verifier.go index e4f0982345..6528b49469 100644 --- a/beacon-chain/sync/batch_verifier.go +++ b/beacon-chain/sync/batch_verifier.go @@ -6,6 +6,7 @@ import ( pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/config/features" "github.com/prysmaticlabs/prysm/crypto/bls" "github.com/prysmaticlabs/prysm/monitoring/tracing" "go.opencensus.io/trace" @@ -83,19 +84,44 @@ func verifyBatch(verifierBatch []*signatureVerifier) { return } aggSet := verifierBatch[0].set - verificationErr := error(nil) for i := 1; i < len(verifierBatch); i++ { aggSet = aggSet.Join(verifierBatch[i].set) } - verified, err := aggSet.Verify() - switch { - case err != nil: - verificationErr = err - case !verified: - verificationErr = errors.New("batch signature verification failed") + var verificationErr error + + if features.Get().EnableBatchGossipAggregation { + aggSet, verificationErr = performBatchAggregation(aggSet) + } + if verificationErr == nil { + verified, err := aggSet.Verify() + switch { + case err != nil: + verificationErr = err + case !verified: + verificationErr = errors.New("batch signature verification failed") + } } for i := 0; i < len(verifierBatch); i++ { verifierBatch[i].resChan <- verificationErr } } + +func performBatchAggregation(aggSet *bls.SignatureBatch) (*bls.SignatureBatch, error) { + currLen := len(aggSet.Signatures) + num, aggSet, err := aggSet.RemoveDuplicates() + if err != nil { + return nil, err + } + duplicatesRemovedCounter.Add(float64(num)) + // Aggregate batches in the provided signature batch. + aggSet, err = aggSet.AggregateBatch() + if err != nil { + return nil, err + } + // Record number of signature sets successfully batched. + if currLen > len(aggSet.Signatures) { + numberOfSetsAggregated.Observe(float64(currLen - len(aggSet.Signatures))) + } + return aggSet, nil +} diff --git a/beacon-chain/sync/metrics.go b/beacon-chain/sync/metrics.go index 324d99f331..70642625a8 100644 --- a/beacon-chain/sync/metrics.go +++ b/beacon-chain/sync/metrics.go @@ -62,6 +62,19 @@ var ( Help: "Count the number of times a node resyncs.", }, ) + duplicatesRemovedCounter = promauto.NewCounter( + prometheus.CounterOpts{ + Name: "number_of_duplicates_removed", + Help: "Count the number of times a duplicate signature set has been removed.", + }, + ) + numberOfSetsAggregated = promauto.NewHistogram( + prometheus.HistogramOpts{ + Name: "number_of_sets_aggregated", + Help: "Count the number of times different sets have been successfully aggregated in a batch.", + Buckets: []float64{10, 50, 100, 200, 400, 800, 1600, 3200}, + }, + ) arrivalBlockPropagationHistogram = promauto.NewHistogram( prometheus.HistogramOpts{ diff --git a/config/features/config.go b/config/features/config.go index 5d59e3df20..a13dadd4d1 100644 --- a/config/features/config.go +++ b/config/features/config.go @@ -66,6 +66,7 @@ type Flags struct { EnableNativeState bool // EnableNativeState defines whether the beacon state will be represented as a pure Go struct or a Go struct that wraps a proto struct. EnableVectorizedHTR bool // EnableVectorizedHTR specifies whether the beacon state will use the optimized sha256 routines. EnableForkChoiceDoublyLinkedTree bool // EnableForkChoiceDoublyLinkedTree specifies whether fork choice store will use a doubly linked tree. + EnableBatchGossipAggregation bool // EnableBatchGossipAggregation specifies whether to further aggregate our gossip batches before verifying them. // KeystoreImportDebounceInterval specifies the time duration the validator waits to reload new keys if they have // changed on disk. This feature is for advanced use cases only. @@ -186,6 +187,10 @@ func ConfigureBeaconChain(ctx *cli.Context) { logEnabled(enableForkChoiceDoublyLinkedTree) cfg.EnableForkChoiceDoublyLinkedTree = true } + if ctx.Bool(enableGossipBatchAggregation.Name) { + logEnabled(enableGossipBatchAggregation) + cfg.EnableBatchGossipAggregation = true + } Init(cfg) } diff --git a/config/features/flags.go b/config/features/flags.go index 30626c98a7..8bbf439f5f 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -113,6 +113,10 @@ var ( Name: "enable-forkchoice-doubly-linked-tree", Usage: "Enables new forkchoice store structure that uses doubly linked trees", } + enableGossipBatchAggregation = &cli.BoolFlag{ + Name: "enable-gossip-batch-aggregation", + Usage: "Enables new methods to further aggregate our gossip batches before verifying them.", + } ) // devModeFlags holds list of flags that are set when development mode is on. @@ -120,6 +124,7 @@ var devModeFlags = []cli.Flag{ enablePeerScorer, enableVecHTR, enableForkChoiceDoublyLinkedTree, + enableGossipBatchAggregation, } // ValidatorFlags contains a list of all the feature flags that apply to the validator client. @@ -158,6 +163,7 @@ var BeaconChainFlags = append(deprecatedFlags, []cli.Flag{ enableNativeState, enableVecHTR, enableForkChoiceDoublyLinkedTree, + enableGossipBatchAggregation, }...) // E2EBeaconChainFlags contains a list of the beacon chain feature flags to be tested in E2E. diff --git a/crypto/bls/BUILD.bazel b/crypto/bls/BUILD.bazel index 0dec737726..68b7f7822b 100644 --- a/crypto/bls/BUILD.bazel +++ b/crypto/bls/BUILD.bazel @@ -15,6 +15,7 @@ go_library( "//crypto/bls/blst:go_default_library", "//crypto/bls/common:go_default_library", "//crypto/bls/herumi:go_default_library", + "@com_github_pkg_errors//:go_default_library", ], ) diff --git a/crypto/bls/bls.go b/crypto/bls/bls.go index b0832610d4..02854c8b0c 100644 --- a/crypto/bls/bls.go +++ b/crypto/bls/bls.go @@ -39,11 +39,21 @@ func AggregatePublicKeys(pubs [][]byte) (PublicKey, error) { return blst.AggregatePublicKeys(pubs) } +// AggregateMultiplePubkeys aggregates the provided decompressed keys into a single key. +func AggregateMultiplePubkeys(pubs []PublicKey) PublicKey { + return blst.AggregateMultiplePubkeys(pubs) +} + // AggregateSignatures converts a list of signatures into a single, aggregated sig. func AggregateSignatures(sigs []common.Signature) common.Signature { return blst.AggregateSignatures(sigs) } +// AggregateCompressedSignatures converts a list of compressed signatures into a single, aggregated sig. +func AggregateCompressedSignatures(multiSigs [][]byte) (common.Signature, error) { + return blst.AggregateCompressedSignatures(multiSigs) +} + // VerifyMultipleSignatures verifies multiple signatures for distinct messages securely. func VerifyMultipleSignatures(sigs [][]byte, msgs [][32]byte, pubKeys []common.PublicKey) (bool, error) { return blst.VerifyMultipleSignatures(sigs, msgs, pubKeys) diff --git a/crypto/bls/blst/public_key.go b/crypto/bls/blst/public_key.go index df41b96d32..529e9157d4 100644 --- a/crypto/bls/blst/public_key.go +++ b/crypto/bls/blst/public_key.go @@ -93,6 +93,12 @@ func (p *PublicKey) IsInfinite() bool { return p.p.Equals(zeroKey) } +// Equals checks if the provided public key is equal to +// the current one. +func (p *PublicKey) Equals(p2 common.PublicKey) bool { + return p.p.Equals(p2.(*PublicKey).p) +} + // Aggregate two public keys. func (p *PublicKey) Aggregate(p2 common.PublicKey) common.PublicKey { if features.Get().SkipBLSVerify { @@ -107,3 +113,20 @@ func (p *PublicKey) Aggregate(p2 common.PublicKey) common.PublicKey { return p } + +// AggregateMultiplePubkeys aggregates the provided decompressed keys into a single key. +func AggregateMultiplePubkeys(pubkeys []common.PublicKey) common.PublicKey { + if features.Get().SkipBLSVerify { + return &PublicKey{} + } + mulP1 := make([]*blstPublicKey, 0, len(pubkeys)) + for _, pubkey := range pubkeys { + mulP1 = append(mulP1, pubkey.(*PublicKey).p) + } + agg := new(blstAggregatePublicKey) + // No group check needed here since it is done in PublicKeyFromBytes + // Note the checks could be moved from PublicKeyFromBytes into Aggregate + // and take advantage of multi-threading. + agg.Aggregate(mulP1, false) + return &PublicKey{p: agg.ToAffine()} +} diff --git a/crypto/bls/blst/public_key_test.go b/crypto/bls/blst/public_key_test.go index 0c9657c94c..ea51cf635d 100644 --- a/crypto/bls/blst/public_key_test.go +++ b/crypto/bls/blst/public_key_test.go @@ -10,6 +10,7 @@ import ( "testing" "github.com/prysmaticlabs/prysm/crypto/bls/blst" + "github.com/prysmaticlabs/prysm/crypto/bls/common" "github.com/prysmaticlabs/prysm/testing/assert" "github.com/prysmaticlabs/prysm/testing/require" ) @@ -78,6 +79,21 @@ func TestPublicKey_Copy(t *testing.T) { require.DeepEqual(t, pubkeyA.Marshal(), pubkeyBytes, "Pubkey was mutated after copy") } +func TestPublicKey_Aggregate(t *testing.T) { + priv, err := blst.RandKey() + require.NoError(t, err) + pubkeyA := priv.PublicKey() + + pubkeyB := pubkeyA.Copy() + priv2, err := blst.RandKey() + require.NoError(t, err) + resKey := pubkeyB.Aggregate(priv2.PublicKey()) + + aggKey := blst.AggregateMultiplePubkeys([]common.PublicKey{priv.PublicKey(), priv2.PublicKey()}) + + require.DeepEqual(t, resKey.Marshal(), aggKey.Marshal(), "Pubkey does not match up") +} + func TestPublicKeysEmpty(t *testing.T) { var pubs [][]byte _, err := blst.AggregatePublicKeys(pubs) diff --git a/crypto/bls/blst/signature.go b/crypto/bls/blst/signature.go index 93380f7c93..0e9d51248a 100644 --- a/crypto/bls/blst/signature.go +++ b/crypto/bls/blst/signature.go @@ -47,6 +47,16 @@ func SignatureFromBytes(sig []byte) (common.Signature, error) { return &Signature{s: signature}, nil } +// AggregateCompressedSignatures converts a list of compressed signatures into a single, aggregated sig. +func AggregateCompressedSignatures(multiSigs [][]byte) (common.Signature, error) { + signature := new(blstAggregateSignature) + valid := signature.AggregateCompressed(multiSigs, true) + if !valid { + return nil, errors.New("provided signatures fail the group check and cannot be compressed") + } + return &Signature{s: signature.ToAffine()}, nil +} + // MultipleSignaturesFromBytes creates a group of BLS signatures from a LittleEndian 2d-byte slice. func MultipleSignaturesFromBytes(multiSigs [][]byte) ([]common.Signature, error) { if features.Get().SkipBLSVerify { diff --git a/crypto/bls/blst/signature_test.go b/crypto/bls/blst/signature_test.go index 4bf3a8a1ab..37f23020fc 100644 --- a/crypto/bls/blst/signature_test.go +++ b/crypto/bls/blst/signature_test.go @@ -41,6 +41,30 @@ func TestAggregateVerify(t *testing.T) { assert.Equal(t, true, aggSig.AggregateVerify(pubkeys, msgs), "Signature did not verify") } +func TestAggregateVerify_CompressedSignatures(t *testing.T) { + pubkeys := make([]common.PublicKey, 0, 100) + sigs := make([]common.Signature, 0, 100) + sigBytes := [][]byte{} + var msgs [][32]byte + for i := 0; i < 100; i++ { + msg := [32]byte{'h', 'e', 'l', 'l', 'o', byte(i)} + priv, err := RandKey() + require.NoError(t, err) + pub := priv.PublicKey() + sig := priv.Sign(msg[:]) + pubkeys = append(pubkeys, pub) + sigs = append(sigs, sig) + sigBytes = append(sigBytes, sig.Marshal()) + msgs = append(msgs, msg) + } + aggSig := AggregateSignatures(sigs) + assert.Equal(t, true, aggSig.AggregateVerify(pubkeys, msgs), "Signature did not verify") + + aggSig2, err := AggregateCompressedSignatures(sigBytes) + assert.NoError(t, err) + assert.DeepEqual(t, aggSig.Marshal(), aggSig2.Marshal(), "Signature did not match up") +} + func TestFastAggregateVerify(t *testing.T) { pubkeys := make([]common.PublicKey, 0, 100) sigs := make([]common.Signature, 0, 100) diff --git a/crypto/bls/blst/stub.go b/crypto/bls/blst/stub.go index 4c56584769..b6c3323a22 100644 --- a/crypto/bls/blst/stub.go +++ b/crypto/bls/blst/stub.go @@ -56,6 +56,11 @@ func (p PublicKey) IsInfinite() bool { panic(err) } +// Equals -- stub +func (p PublicKey) Equals(_ common.PublicKey) bool { + panic(err) +} + // Signature -- stub type Signature struct{} @@ -119,6 +124,16 @@ func AggregateSignatures(_ []common.Signature) common.Signature { panic(err) } +// AggregateMultiplePubkeys -- stub +func AggregateMultiplePubkeys(pubs []common.PublicKey) common.PublicKey { + panic(err) +} + +// AggregateCompressedSignatures -- stub +func AggregateCompressedSignatures(multiSigs [][]byte) (common.Signature, error) { + panic(err) +} + // VerifyMultipleSignatures -- stub func VerifyMultipleSignatures(_ [][]byte, _ [][32]byte, _ []common.PublicKey) (bool, error) { panic(err) diff --git a/crypto/bls/common/interface.go b/crypto/bls/common/interface.go index f04a7de82b..87317272fb 100644 --- a/crypto/bls/common/interface.go +++ b/crypto/bls/common/interface.go @@ -18,6 +18,7 @@ type PublicKey interface { Copy() PublicKey Aggregate(p2 PublicKey) PublicKey IsInfinite() bool + Equals(p2 PublicKey) bool } // Signature represents a BLS signature. diff --git a/crypto/bls/signature_batch.go b/crypto/bls/signature_batch.go index 8c2975cea7..be4d046fb9 100644 --- a/crypto/bls/signature_batch.go +++ b/crypto/bls/signature_batch.go @@ -1,5 +1,7 @@ package bls +import "github.com/pkg/errors" + // SignatureBatch refers to the defined set of // signatures and its respective public keys and // messages required to verify it. @@ -54,3 +56,93 @@ func (s *SignatureBatch) Copy() *SignatureBatch { Messages: messages, } } + +// RemoveDuplicates removes duplicate signature sets from the signature batch. +func (s *SignatureBatch) RemoveDuplicates() (int, *SignatureBatch, error) { + if len(s.Signatures) == 0 || len(s.PublicKeys) == 0 || len(s.Messages) == 0 { + return 0, s, nil + } + if len(s.Signatures) != len(s.PublicKeys) || len(s.Signatures) != len(s.Messages) { + return 0, s, errors.Errorf("mismatch number of signatures, publickeys and messages in signature batch. "+ + "Signatures %d, Public Keys %d , Messages %d", s.Signatures, s.PublicKeys, s.Messages) + } + sigMap := make(map[string]int) + duplicateSet := make(map[int]bool) + for i := 0; i < len(s.Signatures); i++ { + if sigIdx, ok := sigMap[string(s.Signatures[i])]; ok { + if s.PublicKeys[sigIdx].Equals(s.PublicKeys[i]) && + s.Messages[sigIdx] == s.Messages[i] { + duplicateSet[i] = true + continue + } + } + sigMap[string(s.Signatures[i])] = i + } + + sigs := s.Signatures[:0] + pubs := s.PublicKeys[:0] + msgs := s.Messages[:0] + + for i := 0; i < len(s.Signatures); i++ { + if duplicateSet[i] { + continue + } + sigs = append(sigs, s.Signatures[i]) + pubs = append(pubs, s.PublicKeys[i]) + msgs = append(msgs, s.Messages[i]) + } + + s.Signatures = sigs + s.PublicKeys = pubs + s.Messages = msgs + + return len(duplicateSet), s, nil +} + +// AggregateBatch aggregates common messages in the provided batch to +// reduce the number of pairings required when we finally verify the +// whole batch. +func (s *SignatureBatch) AggregateBatch() (*SignatureBatch, error) { + if len(s.Signatures) == 0 || len(s.PublicKeys) == 0 || len(s.Messages) == 0 { + return s, nil + } + if len(s.Signatures) != len(s.PublicKeys) || len(s.Signatures) != len(s.Messages) { + return s, errors.Errorf("mismatch number of signatures, publickeys and messages in signature batch. "+ + "Signatures %d, Public Keys %d , Messages %d", s.Signatures, s.PublicKeys, s.Messages) + } + msgMap := make(map[[32]byte]*SignatureBatch) + + for i := 0; i < len(s.Messages); i++ { + currMsg := s.Messages[i] + currBatch, ok := msgMap[currMsg] + if ok { + currBatch.Signatures = append(currBatch.Signatures, s.Signatures[i]) + currBatch.Messages = append(currBatch.Messages, s.Messages[i]) + currBatch.PublicKeys = append(currBatch.PublicKeys, s.PublicKeys[i]) + continue + } + currBatch = &SignatureBatch{ + Signatures: [][]byte{s.Signatures[i]}, + Messages: [][32]byte{s.Messages[i]}, + PublicKeys: []PublicKey{s.PublicKeys[i]}, + } + msgMap[currMsg] = currBatch + } + newSt := NewSet() + for rt, b := range msgMap { + if len(b.PublicKeys) > 1 { + aggPub := AggregateMultiplePubkeys(b.PublicKeys) + aggSig, err := AggregateCompressedSignatures(b.Signatures) + if err != nil { + return nil, err + } + copiedRt := rt + b.PublicKeys = []PublicKey{aggPub} + b.Signatures = [][]byte{aggSig.Marshal()} + b.Messages = [][32]byte{copiedRt} + } + newObj := *b + newSt = newSt.Join(&newObj) + } + return newSt, nil +} diff --git a/crypto/bls/signature_batch_test.go b/crypto/bls/signature_batch_test.go index 207c2c1d16..07795bcd62 100644 --- a/crypto/bls/signature_batch_test.go +++ b/crypto/bls/signature_batch_test.go @@ -1,6 +1,9 @@ package bls import ( + "bytes" + "reflect" + "sort" "testing" "github.com/prysmaticlabs/prysm/testing/assert" @@ -44,3 +47,538 @@ func TestCopySignatureSet(t *testing.T) { assert.DeepEqual(t, aggSet, aggSet2) }) } + +func TestSignatureBatch_RemoveDuplicates(t *testing.T) { + keys := []SecretKey{} + for i := 0; i < 100; i++ { + key, err := RandKey() + assert.NoError(t, err) + keys = append(keys, key) + } + tests := []struct { + name string + batchCreator func() (input *SignatureBatch, output *SignatureBatch) + want int + }{ + { + name: "empty batch", + batchCreator: func() (*SignatureBatch, *SignatureBatch) { + return &SignatureBatch{}, &SignatureBatch{} + }, + want: 0, + }, + { + name: "valid duplicates in batch", + batchCreator: func() (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:20] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + allSigs := append(signatures, signatures...) + allPubs := append(pubs, pubs...) + allMsgs := append(messages, messages...) + return &SignatureBatch{ + Signatures: allSigs, + PublicKeys: allPubs, + Messages: allMsgs, + }, &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + } + }, + want: 20, + }, + { + name: "valid duplicates in batch with multiple messages", + batchCreator: func() (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:30] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + msg1 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '1'} + msg2 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '2'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys[:10] { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[10:20] { + s := k.Sign(msg1[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg1) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[20:30] { + s := k.Sign(msg2[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg2) + pubs = append(pubs, k.PublicKey()) + } + allSigs := append(signatures, signatures...) + allPubs := append(pubs, pubs...) + allMsgs := append(messages, messages...) + return &SignatureBatch{ + Signatures: allSigs, + PublicKeys: allPubs, + Messages: allMsgs, + }, &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + } + }, + want: 30, + }, + { + name: "no duplicates in batch with multiple messages", + batchCreator: func() (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:30] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + msg1 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '1'} + msg2 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '2'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys[:10] { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[10:20] { + s := k.Sign(msg1[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg1) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[20:30] { + s := k.Sign(msg2[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg2) + pubs = append(pubs, k.PublicKey()) + } + return &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + }, &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + } + }, + want: 0, + }, + { + name: "valid duplicates and invalid duplicates in batch with multiple messages", + batchCreator: func() (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:30] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + msg1 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '1'} + msg2 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '2'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys[:10] { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[10:20] { + s := k.Sign(msg1[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg1) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[20:30] { + s := k.Sign(msg2[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg2) + pubs = append(pubs, k.PublicKey()) + } + allSigs := append(signatures, signatures...) + // Make it a non-unique entry + allSigs[10] = make([]byte, 96) + allPubs := append(pubs, pubs...) + allMsgs := append(messages, messages...) + // Insert it back at the end + signatures = append(signatures, signatures[10]) + pubs = append(pubs, pubs[10]) + messages = append(messages, messages[10]) + // Zero out to expected result + signatures[10] = make([]byte, 96) + return &SignatureBatch{ + Signatures: allSigs, + PublicKeys: allPubs, + Messages: allMsgs, + }, &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + } + }, + want: 29, + }, + { + name: "valid duplicates and invalid duplicates with signature,pubkey,message in batch with multiple messages", + batchCreator: func() (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:30] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + msg1 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '1'} + msg2 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '2'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys[:10] { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[10:20] { + s := k.Sign(msg1[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg1) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[20:30] { + s := k.Sign(msg2[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg2) + pubs = append(pubs, k.PublicKey()) + } + allSigs := append(signatures, signatures...) + // Make it a non-unique entry + allSigs[10] = make([]byte, 96) + + allPubs := append(pubs, pubs...) + allPubs[20] = keys[len(keys)-1].PublicKey() + + allMsgs := append(messages, messages...) + allMsgs[29] = [32]byte{'j', 'u', 'n', 'k'} + + // Insert it back at the end + signatures = append(signatures, signatures[10]) + pubs = append(pubs, pubs[10]) + messages = append(messages, messages[10]) + // Zero out to expected result + signatures[10] = make([]byte, 96) + + // Insert it back at the end + signatures = append(signatures, signatures[20]) + pubs = append(pubs, pubs[20]) + messages = append(messages, messages[20]) + // Zero out to expected result + pubs[20] = keys[len(keys)-1].PublicKey() + + // Insert it back at the end + signatures = append(signatures, signatures[29]) + pubs = append(pubs, pubs[29]) + messages = append(messages, messages[29]) + messages[29] = [32]byte{'j', 'u', 'n', 'k'} + + return &SignatureBatch{ + Signatures: allSigs, + PublicKeys: allPubs, + Messages: allMsgs, + }, &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + } + }, + want: 27, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input, output := tt.batchCreator() + num, res, err := input.RemoveDuplicates() + assert.NoError(t, err) + if num != tt.want { + t.Errorf("RemoveDuplicates() got = %v, want %v", num, tt.want) + } + if !reflect.DeepEqual(res.Signatures, output.Signatures) { + t.Errorf("RemoveDuplicates() Signatures output = %v, want %v", res.Signatures, output.Signatures) + } + if !reflect.DeepEqual(res.PublicKeys, output.PublicKeys) { + t.Errorf("RemoveDuplicates() Publickeys output = %v, want %v", res.PublicKeys, output.PublicKeys) + } + if !reflect.DeepEqual(res.Messages, output.Messages) { + t.Errorf("RemoveDuplicates() Messages output = %v, want %v", res.Messages, output.Messages) + } + }) + } +} + +func TestSignatureBatch_AggregateBatch(t *testing.T) { + keys := []SecretKey{} + for i := 0; i < 100; i++ { + key, err := RandKey() + assert.NoError(t, err) + keys = append(keys, key) + } + tests := []struct { + name string + batchCreator func(t *testing.T) (input *SignatureBatch, output *SignatureBatch) + wantErr bool + }{ + { + name: "empty batch", + batchCreator: func(t *testing.T) (*SignatureBatch, *SignatureBatch) { + return &SignatureBatch{Signatures: nil, Messages: nil, PublicKeys: nil}, + &SignatureBatch{Signatures: nil, Messages: nil, PublicKeys: nil} + }, + wantErr: false, + }, + { + name: "valid signatures in batch", + batchCreator: func(t *testing.T) (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:20] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + aggSig, err := AggregateCompressedSignatures(signatures) + assert.NoError(t, err) + aggPub := AggregateMultiplePubkeys(pubs) + return &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + }, &SignatureBatch{ + Signatures: [][]byte{aggSig.Marshal()}, + PublicKeys: []PublicKey{aggPub}, + Messages: [][32]byte{msg}, + } + }, + wantErr: false, + }, + { + name: "invalid signatures in batch", + batchCreator: func(t *testing.T) (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:20] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + signatures[10] = make([]byte, 96) + return &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + }, nil + }, + wantErr: true, + }, + { + name: "valid aggregates in batch with multiple messages", + batchCreator: func(t *testing.T) (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:30] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + msg1 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '1'} + msg2 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '2'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys[:10] { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[10:20] { + s := k.Sign(msg1[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg1) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[20:30] { + s := k.Sign(msg2[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg2) + pubs = append(pubs, k.PublicKey()) + } + aggSig1, err := AggregateCompressedSignatures(signatures[:10]) + assert.NoError(t, err) + aggSig2, err := AggregateCompressedSignatures(signatures[10:20]) + assert.NoError(t, err) + aggSig3, err := AggregateCompressedSignatures(signatures[20:30]) + assert.NoError(t, err) + aggPub1 := AggregateMultiplePubkeys(pubs[:10]) + aggPub2 := AggregateMultiplePubkeys(pubs[10:20]) + aggPub3 := AggregateMultiplePubkeys(pubs[20:30]) + return &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + }, &SignatureBatch{ + Signatures: [][]byte{aggSig1.Marshal(), aggSig2.Marshal(), aggSig3.Marshal()}, + PublicKeys: []PublicKey{aggPub1, aggPub2, aggPub3}, + Messages: [][32]byte{msg, msg1, msg2}, + } + }, + wantErr: false, + }, + { + name: "common and uncommon messages in batch with multiple messages", + batchCreator: func(t *testing.T) (*SignatureBatch, *SignatureBatch) { + chosenKeys := keys[:30] + + msg := [32]byte{'r', 'a', 'n', 'd', 'o', 'm'} + msg1 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '1'} + msg2 := [32]byte{'r', 'a', 'n', 'd', 'o', 'm', '2'} + signatures := [][]byte{} + messages := [][32]byte{} + pubs := []PublicKey{} + for _, k := range chosenKeys[:10] { + s := k.Sign(msg[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[10:20] { + s := k.Sign(msg1[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg1) + pubs = append(pubs, k.PublicKey()) + } + for _, k := range chosenKeys[20:30] { + s := k.Sign(msg2[:]) + signatures = append(signatures, s.Marshal()) + messages = append(messages, msg2) + pubs = append(pubs, k.PublicKey()) + } + // Set a custom message + messages[5][31] ^= byte(100) + messages[15][31] ^= byte(100) + messages[25][31] ^= byte(100) + + newSigs := [][]byte{} + newSigs = append(newSigs, signatures[:5]...) + newSigs = append(newSigs, signatures[6:10]...) + + aggSig1, err := AggregateCompressedSignatures(newSigs) + assert.NoError(t, err) + + newSigs = [][]byte{} + newSigs = append(newSigs, signatures[10:15]...) + newSigs = append(newSigs, signatures[16:20]...) + aggSig2, err := AggregateCompressedSignatures(newSigs) + assert.NoError(t, err) + + newSigs = [][]byte{} + newSigs = append(newSigs, signatures[20:25]...) + newSigs = append(newSigs, signatures[26:30]...) + aggSig3, err := AggregateCompressedSignatures(newSigs) + assert.NoError(t, err) + + newPubs := []PublicKey{} + newPubs = append(newPubs, pubs[:5]...) + newPubs = append(newPubs, pubs[6:10]...) + + aggPub1 := AggregateMultiplePubkeys(newPubs) + + newPubs = []PublicKey{} + newPubs = append(newPubs, pubs[10:15]...) + newPubs = append(newPubs, pubs[16:20]...) + aggPub2 := AggregateMultiplePubkeys(newPubs) + + newPubs = []PublicKey{} + newPubs = append(newPubs, pubs[20:25]...) + newPubs = append(newPubs, pubs[26:30]...) + aggPub3 := AggregateMultiplePubkeys(newPubs) + + return &SignatureBatch{ + Signatures: signatures, + PublicKeys: pubs, + Messages: messages, + }, &SignatureBatch{ + Signatures: [][]byte{aggSig1.Marshal(), signatures[5], aggSig2.Marshal(), signatures[15], aggSig3.Marshal(), signatures[25]}, + PublicKeys: []PublicKey{aggPub1, pubs[5], aggPub2, pubs[15], aggPub3, pubs[25]}, + Messages: [][32]byte{msg, messages[5], msg1, messages[15], msg2, messages[25]}, + } + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input, output := tt.batchCreator(t) + got, err := input.AggregateBatch() + if (err != nil) != tt.wantErr { + t.Errorf("AggregateBatch() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr { + return + } + got = sortSet(got) + output = sortSet(output) + + if !reflect.DeepEqual(got.Signatures, output.Signatures) { + t.Errorf("AggregateBatch() Signatures got = %v, want %v", got.Signatures, output.Signatures) + } + if !reflect.DeepEqual(got.PublicKeys, output.PublicKeys) { + t.Errorf("AggregateBatch() PublicKeys got = %v, want %v", got.PublicKeys, output.PublicKeys) + } + if !reflect.DeepEqual(got.Messages, output.Messages) { + t.Errorf("AggregateBatch() Messages got = %v, want %v", got.Messages, output.Messages) + } + }) + } +} + +func sortSet(s *SignatureBatch) *SignatureBatch { + sort.Sort(sorter{set: s}) + return s +} + +type sorter struct { + set *SignatureBatch +} + +func (s sorter) Len() int { + return len(s.set.Messages) +} + +func (s sorter) Swap(i, j int) { + s.set.Signatures[i], s.set.Signatures[j] = s.set.Signatures[j], s.set.Signatures[i] + s.set.PublicKeys[i], s.set.PublicKeys[j] = s.set.PublicKeys[j], s.set.PublicKeys[i] + s.set.Messages[i], s.set.Messages[j] = s.set.Messages[j], s.set.Messages[i] +} + +func (s sorter) Less(i, j int) bool { + return bytes.Compare(s.set.Messages[i][:], s.set.Messages[j][:]) == -1 +}