mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-09 13:28:01 -05:00
Set a max limit for decoding ssz objects from p2p (#5295)
* Set a max limit for decoding ssz objects from p2p
This commit is contained in:
@@ -11,6 +11,9 @@ import (
|
||||
|
||||
var _ = NetworkEncoding(&SszNetworkEncoder{})
|
||||
|
||||
// MaxChunkSize allowed for decoding messages.
|
||||
const MaxChunkSize = uint64(1 << 20) // 1Mb
|
||||
|
||||
// SszNetworkEncoder supports p2p networking encoding using SimpleSerialize
|
||||
// with snappy compression (if enabled).
|
||||
type SszNetworkEncoder struct {
|
||||
@@ -87,21 +90,15 @@ func (e SszNetworkEncoder) Decode(b []byte, to interface{}) error {
|
||||
|
||||
// DecodeWithLength the bytes from io.Reader to the protobuf message provided.
|
||||
func (e SszNetworkEncoder) DecodeWithLength(r io.Reader, to interface{}) error {
|
||||
msgLen, err := readVarint(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b := make([]byte, msgLen)
|
||||
_, err = r.Read(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return e.Decode(b, to)
|
||||
return e.DecodeWithMaxLength(r, to, MaxChunkSize)
|
||||
}
|
||||
|
||||
// DecodeWithMaxLength the bytes from io.Reader to the protobuf message provided.
|
||||
// This checks that the decoded message isn't larger than the provided max limit.
|
||||
func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}, maxSize uint64) error {
|
||||
if maxSize > MaxChunkSize {
|
||||
return fmt.Errorf("maxSize %d exceeds max chunk size %d", maxSize, MaxChunkSize)
|
||||
}
|
||||
msgLen, err := readVarint(r)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -103,3 +103,12 @@ func TestSszNetworkEncoder_DecodeWithMaxLength(t *testing.T) {
|
||||
t.Errorf("error did not contain wanted message. Wanted: %s but Got: %s", wanted, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSszNetworkEncoder_DecodeWithMaxLength_TooLarge(t *testing.T) {
|
||||
e := &encoder.SszNetworkEncoder{UseSnappyCompression: false}
|
||||
if err := e.DecodeWithMaxLength(nil, nil, encoder.MaxChunkSize+1); err == nil {
|
||||
t.Fatal("Nil error")
|
||||
} else if !strings.Contains(err.Error(), "exceeds max chunk size") {
|
||||
t.Error("Expected error to contain 'exceeds max chunk size'")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user