mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-09 07:28:06 -05:00
Fix broadcast ssz (#3423)
* add two types of encoding/decoding ssz * fix tests * lint * lint
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package p2p
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"reflect"
|
||||
"sync"
|
||||
@@ -59,9 +58,8 @@ func TestService_Broadcast(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(incomingMessage.Data)
|
||||
result := &testpb.TestSimpleMessage{}
|
||||
if err := p.Encoding().Decode(buf, result); err != nil {
|
||||
if err := p.Encoding().Decode(incomingMessage.Data, result); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !proto.Equal(result, msg) {
|
||||
|
||||
@@ -14,10 +14,14 @@ const (
|
||||
|
||||
// NetworkEncoding represents an encoder compatible with Ethereum 2.0 p2p.
|
||||
type NetworkEncoding interface {
|
||||
// Decode reads bytes from the reader and decodes it to the provided message.
|
||||
Decode(io.Reader, proto.Message) error
|
||||
// Decodes to the provided message.
|
||||
Decode([]byte, proto.Message) error
|
||||
// DecodeWithLength a bytes from a reader with a varint length prefix.
|
||||
DecodeWithLength(io.Reader, proto.Message) error
|
||||
// Encode an arbitrary message to the provided writer.
|
||||
Encode(io.Writer, proto.Message) (int, error)
|
||||
// EncodeWithLength an arbitrary message to the provided writer with a varint length prefix.
|
||||
EncodeWithLength(io.Writer, proto.Message) (int, error)
|
||||
// ProtocolSuffix returns the last part of the protocol ID to indicate the encoding scheme.
|
||||
ProtocolSuffix() string
|
||||
}
|
||||
|
||||
@@ -16,26 +16,59 @@ type SszNetworkEncoder struct {
|
||||
UseSnappyCompression bool
|
||||
}
|
||||
|
||||
// Encode the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
|
||||
// to indicate the size of the message.
|
||||
func (e SszNetworkEncoder) doEncode(msg proto.Message) ([]byte, error) {
|
||||
b, err := ssz.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if e.UseSnappyCompression {
|
||||
b = snappy.Encode(nil /*dst*/, b)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// Encode the proto message to the io.Writer.
|
||||
func (e SszNetworkEncoder) Encode(w io.Writer, msg proto.Message) (int, error) {
|
||||
if msg == nil {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
b, err := ssz.Marshal(msg)
|
||||
b, err := e.doEncode(msg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if e.UseSnappyCompression {
|
||||
b = snappy.Encode(nil /*dst*/, b)
|
||||
return w.Write(b)
|
||||
}
|
||||
|
||||
// EncodeWithLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
|
||||
// to indicate the size of the message.
|
||||
func (e SszNetworkEncoder) EncodeWithLength(w io.Writer, msg proto.Message) (int, error) {
|
||||
if msg == nil {
|
||||
return 0, nil
|
||||
}
|
||||
b, err := e.doEncode(msg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
b = append(proto.EncodeVarint(uint64(len(b))), b...)
|
||||
return w.Write(b)
|
||||
}
|
||||
|
||||
// Decode the bytes from io.Reader to the protobuf message provided.
|
||||
func (e SszNetworkEncoder) Decode(r io.Reader, to proto.Message) error {
|
||||
// Decode the bytes to the protobuf message provided.
|
||||
func (e SszNetworkEncoder) Decode(b []byte, to proto.Message) error {
|
||||
if e.UseSnappyCompression {
|
||||
var err error
|
||||
b, err = snappy.Decode(nil /*dst*/, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return ssz.Unmarshal(b, to)
|
||||
}
|
||||
|
||||
// DecodeWithLength the bytes from io.Reader to the protobuf message provided.
|
||||
func (e SszNetworkEncoder) DecodeWithLength(r io.Reader, to proto.Message) error {
|
||||
msgLen, err := readVarint(r)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -12,11 +12,13 @@ import (
|
||||
func TestSszNetworkEncoder_RoundTrip(t *testing.T) {
|
||||
e := &encoder.SszNetworkEncoder{UseSnappyCompression: false}
|
||||
testRoundTrip(t, e)
|
||||
testRoundTripWithLength(t, e)
|
||||
}
|
||||
|
||||
func TestSszNetworkEncoder_RoundTrip_Snappy(t *testing.T) {
|
||||
e := &encoder.SszNetworkEncoder{UseSnappyCompression: true}
|
||||
testRoundTrip(t, e)
|
||||
testRoundTripWithLength(t, e)
|
||||
}
|
||||
|
||||
func testRoundTrip(t *testing.T, e *encoder.SszNetworkEncoder) {
|
||||
@@ -30,7 +32,27 @@ func testRoundTrip(t *testing.T, e *encoder.SszNetworkEncoder) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded := &testpb.TestSimpleMessage{}
|
||||
if err := e.Decode(buf, decoded); err != nil {
|
||||
if err := e.Decode(buf.Bytes(), decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !proto.Equal(decoded, msg) {
|
||||
t.Logf("decoded=%+v\n", decoded)
|
||||
t.Error("Decoded message is not the same as original")
|
||||
}
|
||||
}
|
||||
|
||||
func testRoundTripWithLength(t *testing.T, e *encoder.SszNetworkEncoder) {
|
||||
buf := new(bytes.Buffer)
|
||||
msg := &testpb.TestSimpleMessage{
|
||||
Foo: []byte("fooooo"),
|
||||
Bar: 9001,
|
||||
}
|
||||
_, err := e.EncodeWithLength(buf, msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
decoded := &testpb.TestSimpleMessage{}
|
||||
if err := e.DecodeWithLength(buf, decoded); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !proto.Equal(decoded, msg) {
|
||||
|
||||
@@ -25,7 +25,7 @@ func (s *Service) Send(ctx context.Context, message proto.Message, pid peer.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := s.Encoding().Encode(stream, message); err != nil {
|
||||
if _, err := s.Encoding().EncodeWithLength(stream, message); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -38,10 +38,10 @@ func TestService_Send(t *testing.T) {
|
||||
go func() {
|
||||
p2.SetStreamHandler("/testing/1/ssz", func(stream network.Stream) {
|
||||
rcvd := &testpb.TestSimpleMessage{}
|
||||
if err := svc.Encoding().Decode(stream, rcvd); err != nil {
|
||||
if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := svc.Encoding().Encode(stream, rcvd); err != nil {
|
||||
if _, err := svc.Encoding().EncodeWithLength(stream, rcvd); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := stream.Close(); err != nil {
|
||||
@@ -59,7 +59,7 @@ func TestService_Send(t *testing.T) {
|
||||
testutil.WaitTimeout(&wg, 1*time.Second)
|
||||
|
||||
rcvd := &testpb.TestSimpleMessage{}
|
||||
if err := svc.Encoding().Decode(stream, rcvd); err != nil {
|
||||
if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !proto.Equal(rcvd, msg) {
|
||||
|
||||
@@ -81,7 +81,7 @@ func (p *TestP2P) ReceiveRPC(topic string, msg proto.Message) {
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
n, err := p.Encoding().Encode(s, msg)
|
||||
n, err := p.Encoding().EncodeWithLength(s, msg)
|
||||
if err != nil {
|
||||
p.t.Fatalf("Failed to encode message: %v", err)
|
||||
}
|
||||
@@ -185,7 +185,7 @@ func (p *TestP2P) Send(ctx context.Context, msg proto.Message, pid peer.ID) (net
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := p.Encoding().Encode(stream, msg); err != nil {
|
||||
if _, err := p.Encoding().EncodeWithLength(stream, msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ var responseCodeServerError = byte(0x02)
|
||||
|
||||
func (r *RegularSync) generateErrorResponse(code byte, reason string) ([]byte, error) {
|
||||
buf := bytes.NewBuffer([]byte{code})
|
||||
if _, err := r.p2p.Encoding().Encode(buf, &pb.ErrorMessage{ErrorMessage: reason}); err != nil {
|
||||
if _, err := r.p2p.Encoding().EncodeWithLength(buf, &pb.ErrorMessage{ErrorMessage: reason}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func ReadStatusCode(stream io.Reader, encoding encoder.NetworkEncoding) (uint8,
|
||||
}
|
||||
|
||||
msg := &pb.ErrorMessage{}
|
||||
if err := encoding.Decode(stream, msg); err != nil {
|
||||
if err := encoding.DecodeWithLength(stream, msg); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ func TestRegularSync_generateErrorResponse(t *testing.T) {
|
||||
t.Errorf("The first byte was not the status code. Got %#x wanted %#x", b, responseCodeServerError)
|
||||
}
|
||||
msg := &pb.ErrorMessage{}
|
||||
if err := r.p2p.Encoding().Decode(buf, msg); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(buf, msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg.ErrorMessage != "something bad happened" {
|
||||
|
||||
@@ -128,7 +128,7 @@ func (s *InitialSync) Start() {
|
||||
}
|
||||
|
||||
resp := &pb.BeaconBlocksResponse{}
|
||||
if err := s.p2p.Encoding().Decode(strm, resp); err != nil {
|
||||
if err := s.p2p.Encoding().DecodeWithLength(strm, resp); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ func (r *RegularSync) registerRPC(topic string, base proto.Message, handle rpcHa
|
||||
// Clone the base message type so we have a newly initialized message as the decoding
|
||||
// destination.
|
||||
msg := proto.Clone(base)
|
||||
if err := r.p2p.Encoding().Decode(stream, msg); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil {
|
||||
log.WithError(err).Error("Failed to decode stream message")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -60,6 +60,6 @@ func (r *RegularSync) beaconBlocksRPCHandler(ctx context.Context, msg proto.Mess
|
||||
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
|
||||
log.WithError(err).Error("Failed to write to stream")
|
||||
}
|
||||
_, err = r.p2p.Encoding().Encode(stream, ret)
|
||||
_, err = r.p2p.Encoding().EncodeWithLength(stream, ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) {
|
||||
defer wg.Done()
|
||||
expectSuccess(t, r, stream)
|
||||
res := &pb.BeaconBlocksResponse{}
|
||||
if err := r.p2p.Encoding().Decode(stream, res); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if uint64(len(res.Blocks)) != req.Count {
|
||||
|
||||
@@ -50,7 +50,7 @@ func (r *RegularSync) sendRPCHelloRequest(ctx context.Context, id peer.ID) error
|
||||
}
|
||||
|
||||
msg := &pb.Hello{}
|
||||
if err := r.p2p.Encoding().Decode(stream, msg); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
r.helloTrackerLock.Lock()
|
||||
@@ -121,7 +121,7 @@ func (r *RegularSync) helloRPCHandler(ctx context.Context, msg proto.Message, st
|
||||
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
|
||||
log.WithError(err).Error("Failed to write to stream")
|
||||
}
|
||||
_, err := r.p2p.Encoding().Encode(stream, resp)
|
||||
_, err := r.p2p.Encoding().EncodeWithLength(stream, resp)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func TestHelloRPCHandler_ReturnsHelloMessage(t *testing.T) {
|
||||
defer wg.Done()
|
||||
expectSuccess(t, r, stream)
|
||||
out := &pb.Hello{}
|
||||
if err := r.p2p.Encoding().Decode(stream, out); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := &pb.Hello{
|
||||
@@ -187,7 +187,7 @@ func TestHelloRPCRequest_RequestSent(t *testing.T) {
|
||||
p2.Host.SetStreamHandler(pcl, func(stream network.Stream) {
|
||||
defer wg.Done()
|
||||
out := &pb.Hello{}
|
||||
if err := r.p2p.Encoding().Decode(stream, out); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expected := &pb.Hello{
|
||||
|
||||
@@ -54,6 +54,6 @@ func (r *RegularSync) recentBeaconBlocksRPCHandler(ctx context.Context, msg prot
|
||||
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
|
||||
log.WithError(err).Error("Failed to write to stream")
|
||||
}
|
||||
_, err := r.p2p.Encoding().Encode(stream, ret)
|
||||
_, err := r.p2p.Encoding().EncodeWithLength(stream, ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ func TestRecentBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) {
|
||||
defer wg.Done()
|
||||
expectSuccess(t, r, stream)
|
||||
res := &pb.BeaconBlocksResponse{}
|
||||
if err := r.p2p.Encoding().Decode(stream, res); err != nil {
|
||||
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if len(res.Blocks) != len(req.BlockRoots) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package sync
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -105,7 +104,7 @@ func (r *RegularSync) subscribe(topic string, validate validator, handle subHand
|
||||
}
|
||||
|
||||
msg := proto.Clone(base)
|
||||
if err := r.p2p.Encoding().Decode(bytes.NewBuffer(data), msg); err != nil {
|
||||
if err := r.p2p.Encoding().Decode(data, msg); err != nil {
|
||||
log.WithError(err).Warn("Failed to decode pubsub message")
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user