Fix broadcast ssz (#3423)

* add two types of encoding/decoding ssz

* fix tests

* lint

* lint
This commit is contained in:
Preston Van Loon
2019-09-08 19:34:52 -07:00
committed by GitHub
parent 4dad28d1f6
commit 8d234014a4
18 changed files with 90 additions and 34 deletions

View File

@@ -1,7 +1,6 @@
package p2p
import (
"bytes"
"context"
"reflect"
"sync"
@@ -59,9 +58,8 @@ func TestService_Broadcast(t *testing.T) {
t.Fatal(err)
}
buf := bytes.NewBuffer(incomingMessage.Data)
result := &testpb.TestSimpleMessage{}
if err := p.Encoding().Decode(buf, result); err != nil {
if err := p.Encoding().Decode(incomingMessage.Data, result); err != nil {
t.Fatal(err)
}
if !proto.Equal(result, msg) {

View File

@@ -14,10 +14,14 @@ const (
// NetworkEncoding represents an encoder compatible with Ethereum 2.0 p2p.
type NetworkEncoding interface {
// Decode reads bytes from the reader and decodes it to the provided message.
Decode(io.Reader, proto.Message) error
// Decodes to the provided message.
Decode([]byte, proto.Message) error
// DecodeWithLength a bytes from a reader with a varint length prefix.
DecodeWithLength(io.Reader, proto.Message) error
// Encode an arbitrary message to the provided writer.
Encode(io.Writer, proto.Message) (int, error)
// EncodeWithLength an arbitrary message to the provided writer with a varint length prefix.
EncodeWithLength(io.Writer, proto.Message) (int, error)
// ProtocolSuffix returns the last part of the protocol ID to indicate the encoding scheme.
ProtocolSuffix() string
}

View File

@@ -16,26 +16,59 @@ type SszNetworkEncoder struct {
UseSnappyCompression bool
}
// Encode the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
// to indicate the size of the message.
func (e SszNetworkEncoder) doEncode(msg proto.Message) ([]byte, error) {
b, err := ssz.Marshal(msg)
if err != nil {
return nil, err
}
if e.UseSnappyCompression {
b = snappy.Encode(nil /*dst*/, b)
}
return b, nil
}
// Encode the proto message to the io.Writer.
func (e SszNetworkEncoder) Encode(w io.Writer, msg proto.Message) (int, error) {
if msg == nil {
return 0, nil
}
b, err := ssz.Marshal(msg)
b, err := e.doEncode(msg)
if err != nil {
return 0, err
}
if e.UseSnappyCompression {
b = snappy.Encode(nil /*dst*/, b)
return w.Write(b)
}
// EncodeWithLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
// to indicate the size of the message.
func (e SszNetworkEncoder) EncodeWithLength(w io.Writer, msg proto.Message) (int, error) {
if msg == nil {
return 0, nil
}
b, err := e.doEncode(msg)
if err != nil {
return 0, err
}
b = append(proto.EncodeVarint(uint64(len(b))), b...)
return w.Write(b)
}
// Decode the bytes from io.Reader to the protobuf message provided.
func (e SszNetworkEncoder) Decode(r io.Reader, to proto.Message) error {
// Decode the bytes to the protobuf message provided.
func (e SszNetworkEncoder) Decode(b []byte, to proto.Message) error {
if e.UseSnappyCompression {
var err error
b, err = snappy.Decode(nil /*dst*/, b)
if err != nil {
return err
}
}
return ssz.Unmarshal(b, to)
}
// DecodeWithLength the bytes from io.Reader to the protobuf message provided.
func (e SszNetworkEncoder) DecodeWithLength(r io.Reader, to proto.Message) error {
msgLen, err := readVarint(r)
if err != nil {
return err

View File

@@ -12,11 +12,13 @@ import (
func TestSszNetworkEncoder_RoundTrip(t *testing.T) {
e := &encoder.SszNetworkEncoder{UseSnappyCompression: false}
testRoundTrip(t, e)
testRoundTripWithLength(t, e)
}
func TestSszNetworkEncoder_RoundTrip_Snappy(t *testing.T) {
e := &encoder.SszNetworkEncoder{UseSnappyCompression: true}
testRoundTrip(t, e)
testRoundTripWithLength(t, e)
}
func testRoundTrip(t *testing.T, e *encoder.SszNetworkEncoder) {
@@ -30,7 +32,27 @@ func testRoundTrip(t *testing.T, e *encoder.SszNetworkEncoder) {
t.Fatal(err)
}
decoded := &testpb.TestSimpleMessage{}
if err := e.Decode(buf, decoded); err != nil {
if err := e.Decode(buf.Bytes(), decoded); err != nil {
t.Fatal(err)
}
if !proto.Equal(decoded, msg) {
t.Logf("decoded=%+v\n", decoded)
t.Error("Decoded message is not the same as original")
}
}
func testRoundTripWithLength(t *testing.T, e *encoder.SszNetworkEncoder) {
buf := new(bytes.Buffer)
msg := &testpb.TestSimpleMessage{
Foo: []byte("fooooo"),
Bar: 9001,
}
_, err := e.EncodeWithLength(buf, msg)
if err != nil {
t.Fatal(err)
}
decoded := &testpb.TestSimpleMessage{}
if err := e.DecodeWithLength(buf, decoded); err != nil {
t.Fatal(err)
}
if !proto.Equal(decoded, msg) {

View File

@@ -25,7 +25,7 @@ func (s *Service) Send(ctx context.Context, message proto.Message, pid peer.ID)
return nil, err
}
if _, err := s.Encoding().Encode(stream, message); err != nil {
if _, err := s.Encoding().EncodeWithLength(stream, message); err != nil {
return nil, err
}

View File

@@ -38,10 +38,10 @@ func TestService_Send(t *testing.T) {
go func() {
p2.SetStreamHandler("/testing/1/ssz", func(stream network.Stream) {
rcvd := &testpb.TestSimpleMessage{}
if err := svc.Encoding().Decode(stream, rcvd); err != nil {
if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil {
t.Fatal(err)
}
if _, err := svc.Encoding().Encode(stream, rcvd); err != nil {
if _, err := svc.Encoding().EncodeWithLength(stream, rcvd); err != nil {
t.Fatal(err)
}
if err := stream.Close(); err != nil {
@@ -59,7 +59,7 @@ func TestService_Send(t *testing.T) {
testutil.WaitTimeout(&wg, 1*time.Second)
rcvd := &testpb.TestSimpleMessage{}
if err := svc.Encoding().Decode(stream, rcvd); err != nil {
if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil {
t.Fatal(err)
}
if !proto.Equal(rcvd, msg) {

View File

@@ -81,7 +81,7 @@ func (p *TestP2P) ReceiveRPC(topic string, msg proto.Message) {
}
defer s.Close()
n, err := p.Encoding().Encode(s, msg)
n, err := p.Encoding().EncodeWithLength(s, msg)
if err != nil {
p.t.Fatalf("Failed to encode message: %v", err)
}
@@ -185,7 +185,7 @@ func (p *TestP2P) Send(ctx context.Context, msg proto.Message, pid peer.ID) (net
return nil, err
}
if _, err := p.Encoding().Encode(stream, msg); err != nil {
if _, err := p.Encoding().EncodeWithLength(stream, msg); err != nil {
return nil, err
}

View File

@@ -19,7 +19,7 @@ var responseCodeServerError = byte(0x02)
func (r *RegularSync) generateErrorResponse(code byte, reason string) ([]byte, error) {
buf := bytes.NewBuffer([]byte{code})
if _, err := r.p2p.Encoding().Encode(buf, &pb.ErrorMessage{ErrorMessage: reason}); err != nil {
if _, err := r.p2p.Encoding().EncodeWithLength(buf, &pb.ErrorMessage{ErrorMessage: reason}); err != nil {
return nil, err
}
@@ -39,7 +39,7 @@ func ReadStatusCode(stream io.Reader, encoding encoder.NetworkEncoding) (uint8,
}
msg := &pb.ErrorMessage{}
if err := encoding.Decode(stream, msg); err != nil {
if err := encoding.DecodeWithLength(stream, msg); err != nil {
return 0, nil, err
}

View File

@@ -26,7 +26,7 @@ func TestRegularSync_generateErrorResponse(t *testing.T) {
t.Errorf("The first byte was not the status code. Got %#x wanted %#x", b, responseCodeServerError)
}
msg := &pb.ErrorMessage{}
if err := r.p2p.Encoding().Decode(buf, msg); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(buf, msg); err != nil {
t.Fatal(err)
}
if msg.ErrorMessage != "something bad happened" {

View File

@@ -128,7 +128,7 @@ func (s *InitialSync) Start() {
}
resp := &pb.BeaconBlocksResponse{}
if err := s.p2p.Encoding().Decode(strm, resp); err != nil {
if err := s.p2p.Encoding().DecodeWithLength(strm, resp); err != nil {
panic(err)
}

View File

@@ -72,7 +72,7 @@ func (r *RegularSync) registerRPC(topic string, base proto.Message, handle rpcHa
// Clone the base message type so we have a newly initialized message as the decoding
// destination.
msg := proto.Clone(base)
if err := r.p2p.Encoding().Decode(stream, msg); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil {
log.WithError(err).Error("Failed to decode stream message")
return
}

View File

@@ -60,6 +60,6 @@ func (r *RegularSync) beaconBlocksRPCHandler(ctx context.Context, msg proto.Mess
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
log.WithError(err).Error("Failed to write to stream")
}
_, err = r.p2p.Encoding().Encode(stream, ret)
_, err = r.p2p.Encoding().EncodeWithLength(stream, ret)
return err
}

View File

@@ -47,7 +47,7 @@ func TestBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) {
defer wg.Done()
expectSuccess(t, r, stream)
res := &pb.BeaconBlocksResponse{}
if err := r.p2p.Encoding().Decode(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
t.Error(err)
}
if uint64(len(res.Blocks)) != req.Count {

View File

@@ -50,7 +50,7 @@ func (r *RegularSync) sendRPCHelloRequest(ctx context.Context, id peer.ID) error
}
msg := &pb.Hello{}
if err := r.p2p.Encoding().Decode(stream, msg); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil {
return err
}
r.helloTrackerLock.Lock()
@@ -121,7 +121,7 @@ func (r *RegularSync) helloRPCHandler(ctx context.Context, msg proto.Message, st
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
log.WithError(err).Error("Failed to write to stream")
}
_, err := r.p2p.Encoding().Encode(stream, resp)
_, err := r.p2p.Encoding().EncodeWithLength(stream, resp)
return err
}

View File

@@ -112,7 +112,7 @@ func TestHelloRPCHandler_ReturnsHelloMessage(t *testing.T) {
defer wg.Done()
expectSuccess(t, r, stream)
out := &pb.Hello{}
if err := r.p2p.Encoding().Decode(stream, out); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil {
t.Fatal(err)
}
expected := &pb.Hello{
@@ -187,7 +187,7 @@ func TestHelloRPCRequest_RequestSent(t *testing.T) {
p2.Host.SetStreamHandler(pcl, func(stream network.Stream) {
defer wg.Done()
out := &pb.Hello{}
if err := r.p2p.Encoding().Decode(stream, out); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil {
t.Fatal(err)
}
expected := &pb.Hello{

View File

@@ -54,6 +54,6 @@ func (r *RegularSync) recentBeaconBlocksRPCHandler(ctx context.Context, msg prot
if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil {
log.WithError(err).Error("Failed to write to stream")
}
_, err := r.p2p.Encoding().Encode(stream, ret)
_, err := r.p2p.Encoding().EncodeWithLength(stream, ret)
return err
}

View File

@@ -54,7 +54,7 @@ func TestRecentBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) {
defer wg.Done()
expectSuccess(t, r, stream)
res := &pb.BeaconBlocksResponse{}
if err := r.p2p.Encoding().Decode(stream, res); err != nil {
if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil {
t.Error(err)
}
if len(res.Blocks) != len(req.BlockRoots) {

View File

@@ -1,7 +1,6 @@
package sync
import (
"bytes"
"context"
"errors"
"fmt"
@@ -105,7 +104,7 @@ func (r *RegularSync) subscribe(topic string, validate validator, handle subHand
}
msg := proto.Clone(base)
if err := r.p2p.Encoding().Decode(bytes.NewBuffer(data), msg); err != nil {
if err := r.p2p.Encoding().Decode(data, msg); err != nil {
log.WithError(err).Warn("Failed to decode pubsub message")
return
}