adding nil checks on attestation interface (#14638)

* adding nil checks on interface

* changelog

* add linting

* adding missed checks

* review feedback

* attestation bits should not be in nil check

* fixing nil checks

* simplifying function

* fixing some missed items

* more missed items

* fixing more tests

* reverting some changes and fixing more tests

* adding in source check back in

* missed test

* sammy's review

* radek feedback
This commit is contained in:
james-prysm
2024-11-18 11:51:17 -06:00
committed by GitHub
parent 00aeea3656
commit a7ba11df37
16 changed files with 129 additions and 70 deletions

View File

@@ -25,6 +25,7 @@ type Att interface {
CommitteeBitsVal() bitfield.Bitfield
GetSignature() []byte
GetCommitteeIndex() (primitives.CommitteeIndex, error)
IsNil() bool
}
// IndexedAtt defines common functionality for all indexed attestation types.
@@ -37,6 +38,7 @@ type IndexedAtt interface {
GetAttestingIndices() []uint64
GetData() *AttestationData
GetSignature() []byte
IsNil() bool
}
// SignedAggregateAttAndProof defines common functionality for all signed aggregate attestation types.
@@ -48,6 +50,7 @@ type SignedAggregateAttAndProof interface {
Version() int
AggregateAttestationAndProof() AggregateAttAndProof
GetSignature() []byte
IsNil() bool
}
// AggregateAttAndProof defines common functionality for all aggregate attestation types.
@@ -60,6 +63,7 @@ type AggregateAttAndProof interface {
GetAggregatorIndex() primitives.ValidatorIndex
AggregateVal() Att
GetSelectionProof() []byte
IsNil() bool
}
// AttSlashing defines common functionality for all attestation slashing types.
@@ -71,6 +75,7 @@ type AttSlashing interface {
Version() int
FirstAttestation() IndexedAtt
SecondAttestation() IndexedAtt
IsNil() bool
}
// Copy --
@@ -103,20 +108,25 @@ func (a *Attestation) Version() int {
return version.Phase0
}
// IsNil --
func (a *Attestation) IsNil() bool {
return a == nil || a.Data == nil
}
// Clone --
func (a *Attestation) Clone() Att {
return a.Copy()
}
// Copy --
func (att *Attestation) Copy() *Attestation {
if att == nil {
func (a *Attestation) Copy() *Attestation {
if a == nil {
return nil
}
return &Attestation{
AggregationBits: bytesutil.SafeCopyBytes(att.AggregationBits),
Data: att.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(att.Signature),
AggregationBits: bytesutil.SafeCopyBytes(a.AggregationBits),
Data: a.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(a.Signature),
}
}
@@ -140,6 +150,11 @@ func (a *PendingAttestation) Version() int {
return version.Phase0
}
// IsNil --
func (a *PendingAttestation) IsNil() bool {
return a == nil || a.Data == nil
}
// Clone --
func (a *PendingAttestation) Clone() Att {
return a.Copy()
@@ -181,21 +196,26 @@ func (a *AttestationElectra) Version() int {
return version.Electra
}
// IsNil --
func (a *AttestationElectra) IsNil() bool {
return a == nil || a.Data == nil
}
// Clone --
func (a *AttestationElectra) Clone() Att {
return a.Copy()
}
// Copy --
func (att *AttestationElectra) Copy() *AttestationElectra {
if att == nil {
func (a *AttestationElectra) Copy() *AttestationElectra {
if a == nil {
return nil
}
return &AttestationElectra{
AggregationBits: bytesutil.SafeCopyBytes(att.AggregationBits),
CommitteeBits: bytesutil.SafeCopyBytes(att.CommitteeBits),
Data: att.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(att.Signature),
AggregationBits: bytesutil.SafeCopyBytes(a.AggregationBits),
CommitteeBits: bytesutil.SafeCopyBytes(a.CommitteeBits),
Data: a.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(a.Signature),
}
}
@@ -227,40 +247,50 @@ func (a *IndexedAttestation) Version() int {
return version.Phase0
}
// IsNil --
func (a *IndexedAttestation) IsNil() bool {
return a == nil || a.Data == nil
}
// Version --
func (a *IndexedAttestationElectra) Version() int {
return version.Electra
}
// IsNil --
func (a *IndexedAttestationElectra) IsNil() bool {
return a == nil || a.Data == nil
}
// Copy --
func (indexedAtt *IndexedAttestation) Copy() *IndexedAttestation {
func (a *IndexedAttestation) Copy() *IndexedAttestation {
var indices []uint64
if indexedAtt == nil {
if a == nil {
return nil
} else if indexedAtt.AttestingIndices != nil {
indices = make([]uint64, len(indexedAtt.AttestingIndices))
copy(indices, indexedAtt.AttestingIndices)
} else if a.AttestingIndices != nil {
indices = make([]uint64, len(a.AttestingIndices))
copy(indices, a.AttestingIndices)
}
return &IndexedAttestation{
AttestingIndices: indices,
Data: indexedAtt.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(indexedAtt.Signature),
Data: a.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(a.Signature),
}
}
// Copy --
func (indexedAtt *IndexedAttestationElectra) Copy() *IndexedAttestationElectra {
func (a *IndexedAttestationElectra) Copy() *IndexedAttestationElectra {
var indices []uint64
if indexedAtt == nil {
if a == nil {
return nil
} else if indexedAtt.AttestingIndices != nil {
indices = make([]uint64, len(indexedAtt.AttestingIndices))
copy(indices, indexedAtt.AttestingIndices)
} else if a.AttestingIndices != nil {
indices = make([]uint64, len(a.AttestingIndices))
copy(indices, a.AttestingIndices)
}
return &IndexedAttestationElectra{
AttestingIndices: indices,
Data: indexedAtt.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(indexedAtt.Signature),
Data: a.Data.Copy(),
Signature: bytesutil.SafeCopyBytes(a.Signature),
}
}
@@ -269,6 +299,13 @@ func (a *AttesterSlashing) Version() int {
return version.Phase0
}
// IsNil --
func (a *AttesterSlashing) IsNil() bool {
return a == nil ||
a.Attestation_1 == nil || a.Attestation_1.IsNil() ||
a.Attestation_2 == nil || a.Attestation_2.IsNil()
}
// FirstAttestation --
func (a *AttesterSlashing) FirstAttestation() IndexedAtt {
return a.Attestation_1
@@ -284,6 +321,13 @@ func (a *AttesterSlashingElectra) Version() int {
return version.Electra
}
// IsNil --
func (a *AttesterSlashingElectra) IsNil() bool {
return a == nil ||
a.Attestation_1 == nil || a.Attestation_1.IsNil() ||
a.Attestation_2 == nil || a.Attestation_2.IsNil()
}
// FirstAttestation --
func (a *AttesterSlashingElectra) FirstAttestation() IndexedAtt {
return a.Attestation_1
@@ -320,6 +364,11 @@ func (a *AggregateAttestationAndProof) Version() int {
return version.Phase0
}
// IsNil --
func (a *AggregateAttestationAndProof) IsNil() bool {
return a == nil || a.Aggregate == nil || a.Aggregate.IsNil()
}
// AggregateVal --
func (a *AggregateAttestationAndProof) AggregateVal() Att {
return a.Aggregate
@@ -330,6 +379,11 @@ func (a *AggregateAttestationAndProofElectra) Version() int {
return version.Electra
}
// IsNil --
func (a *AggregateAttestationAndProofElectra) IsNil() bool {
return a == nil || a.Aggregate == nil || a.Aggregate.IsNil()
}
// AggregateVal --
func (a *AggregateAttestationAndProofElectra) AggregateVal() Att {
return a.Aggregate
@@ -340,6 +394,11 @@ func (a *SignedAggregateAttestationAndProof) Version() int {
return version.Phase0
}
// IsNil --
func (a *SignedAggregateAttestationAndProof) IsNil() bool {
return a == nil || a.Message == nil || a.Message.IsNil()
}
// AggregateAttestationAndProof --
func (a *SignedAggregateAttestationAndProof) AggregateAttestationAndProof() AggregateAttAndProof {
return a.Message
@@ -350,6 +409,11 @@ func (a *SignedAggregateAttestationAndProofElectra) Version() int {
return version.Electra
}
// IsNil --
func (a *SignedAggregateAttestationAndProofElectra) IsNil() bool {
return a == nil || a.Message == nil || a.Message.IsNil()
}
// AggregateAttestationAndProof --
func (a *SignedAggregateAttestationAndProofElectra) AggregateAttestationAndProof() AggregateAttAndProof {
return a.Message

View File

@@ -39,7 +39,7 @@ import (
// data=attestation.data,
// signature=attestation.signature,
// )
func ConvertToIndexed(ctx context.Context, attestation ethpb.Att, committees ...[]primitives.ValidatorIndex) (ethpb.IndexedAtt, error) {
func ConvertToIndexed(_ context.Context, attestation ethpb.Att, committees ...[]primitives.ValidatorIndex) (ethpb.IndexedAtt, error) {
attIndices, err := AttestingIndices(attestation, committees...)
if err != nil {
return nil, err
@@ -185,12 +185,10 @@ func IsValidAttestationIndices(ctx context.Context, indexedAttestation ethpb.Ind
_, span := trace.StartSpan(ctx, "attestationutil.IsValidAttestationIndices")
defer span.End()
if indexedAttestation == nil ||
indexedAttestation.GetData() == nil ||
indexedAttestation.GetData().Target == nil ||
indexedAttestation.GetAttestingIndices() == nil {
if indexedAttestation == nil || indexedAttestation.IsNil() || indexedAttestation.GetData().Target == nil || indexedAttestation.GetData().Source == nil {
return errors.New("nil or missing indexed attestation data")
}
indices := indexedAttestation.GetAttestingIndices()
if len(indices) == 0 {
return errors.New("expected non-empty attesting indices")

View File

@@ -106,10 +106,11 @@ func TestIsValidAttestationIndices(t *testing.T) {
att: &eth.IndexedAttestation{
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
wantedErr: "nil or missing indexed attestation data",
wantedErr: "expected non-empty attesting indices",
},
{
name: "Indices should be non-empty",
@@ -117,6 +118,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: []uint64{},
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -128,6 +130,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: make([]uint64, params.BeaconConfig().MaxValidatorsPerCommittee+1),
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -139,6 +142,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: []uint64{3, 2, 1},
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -150,6 +154,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: []uint64{1, 2, 3},
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -160,6 +165,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: []uint64{1, 2},
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -170,6 +176,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: []uint64{1},
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -180,6 +187,7 @@ func TestIsValidAttestationIndices(t *testing.T) {
AttestingIndices: make([]uint64, params.BeaconConfig().MaxValidatorsPerCommittee*params.BeaconConfig().MaxCommitteesPerSlot+1),
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
},
@@ -218,6 +226,7 @@ func BenchmarkIsValidAttestationIndices(b *testing.B) {
AttestingIndices: indices,
Data: &eth.AttestationData{
Target: &eth.Checkpoint{},
Source: &eth.Checkpoint{},
},
Signature: make([]byte, fieldparams.BLSSignatureLength),
}