diff --git a/beacon-chain/p2p/broadcaster_test.go b/beacon-chain/p2p/broadcaster_test.go index e7bb4b8f8a..5281912f6f 100644 --- a/beacon-chain/p2p/broadcaster_test.go +++ b/beacon-chain/p2p/broadcaster_test.go @@ -1,7 +1,6 @@ package p2p import ( - "bytes" "context" "reflect" "sync" @@ -59,9 +58,8 @@ func TestService_Broadcast(t *testing.T) { t.Fatal(err) } - buf := bytes.NewBuffer(incomingMessage.Data) result := &testpb.TestSimpleMessage{} - if err := p.Encoding().Decode(buf, result); err != nil { + if err := p.Encoding().Decode(incomingMessage.Data, result); err != nil { t.Fatal(err) } if !proto.Equal(result, msg) { diff --git a/beacon-chain/p2p/encoder/network_encoding.go b/beacon-chain/p2p/encoder/network_encoding.go index 0eb5626215..cf2b973eea 100644 --- a/beacon-chain/p2p/encoder/network_encoding.go +++ b/beacon-chain/p2p/encoder/network_encoding.go @@ -14,10 +14,14 @@ const ( // NetworkEncoding represents an encoder compatible with Ethereum 2.0 p2p. type NetworkEncoding interface { - // Decode reads bytes from the reader and decodes it to the provided message. - Decode(io.Reader, proto.Message) error + // Decodes to the provided message. + Decode([]byte, proto.Message) error + // DecodeWithLength a bytes from a reader with a varint length prefix. + DecodeWithLength(io.Reader, proto.Message) error // Encode an arbitrary message to the provided writer. Encode(io.Writer, proto.Message) (int, error) + // EncodeWithLength an arbitrary message to the provided writer with a varint length prefix. + EncodeWithLength(io.Writer, proto.Message) (int, error) // ProtocolSuffix returns the last part of the protocol ID to indicate the encoding scheme. ProtocolSuffix() string } diff --git a/beacon-chain/p2p/encoder/ssz.go b/beacon-chain/p2p/encoder/ssz.go index 24022e5758..db9d5bc968 100644 --- a/beacon-chain/p2p/encoder/ssz.go +++ b/beacon-chain/p2p/encoder/ssz.go @@ -16,26 +16,59 @@ type SszNetworkEncoder struct { UseSnappyCompression bool } -// Encode the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint -// to indicate the size of the message. +func (e SszNetworkEncoder) doEncode(msg proto.Message) ([]byte, error) { + b, err := ssz.Marshal(msg) + if err != nil { + return nil, err + } + if e.UseSnappyCompression { + b = snappy.Encode(nil /*dst*/, b) + } + return b, nil +} + +// Encode the proto message to the io.Writer. func (e SszNetworkEncoder) Encode(w io.Writer, msg proto.Message) (int, error) { if msg == nil { return 0, nil } - b, err := ssz.Marshal(msg) + b, err := e.doEncode(msg) if err != nil { return 0, err } - if e.UseSnappyCompression { - b = snappy.Encode(nil /*dst*/, b) + return w.Write(b) +} + +// EncodeWithLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint +// to indicate the size of the message. +func (e SszNetworkEncoder) EncodeWithLength(w io.Writer, msg proto.Message) (int, error) { + if msg == nil { + return 0, nil + } + b, err := e.doEncode(msg) + if err != nil { + return 0, err } b = append(proto.EncodeVarint(uint64(len(b))), b...) return w.Write(b) } -// Decode the bytes from io.Reader to the protobuf message provided. -func (e SszNetworkEncoder) Decode(r io.Reader, to proto.Message) error { +// Decode the bytes to the protobuf message provided. +func (e SszNetworkEncoder) Decode(b []byte, to proto.Message) error { + if e.UseSnappyCompression { + var err error + b, err = snappy.Decode(nil /*dst*/, b) + if err != nil { + return err + } + } + + return ssz.Unmarshal(b, to) +} + +// DecodeWithLength the bytes from io.Reader to the protobuf message provided. +func (e SszNetworkEncoder) DecodeWithLength(r io.Reader, to proto.Message) error { msgLen, err := readVarint(r) if err != nil { return err diff --git a/beacon-chain/p2p/encoder/ssz_test.go b/beacon-chain/p2p/encoder/ssz_test.go index 683ee44df2..7396f2f304 100644 --- a/beacon-chain/p2p/encoder/ssz_test.go +++ b/beacon-chain/p2p/encoder/ssz_test.go @@ -12,11 +12,13 @@ import ( func TestSszNetworkEncoder_RoundTrip(t *testing.T) { e := &encoder.SszNetworkEncoder{UseSnappyCompression: false} testRoundTrip(t, e) + testRoundTripWithLength(t, e) } func TestSszNetworkEncoder_RoundTrip_Snappy(t *testing.T) { e := &encoder.SszNetworkEncoder{UseSnappyCompression: true} testRoundTrip(t, e) + testRoundTripWithLength(t, e) } func testRoundTrip(t *testing.T, e *encoder.SszNetworkEncoder) { @@ -30,7 +32,27 @@ func testRoundTrip(t *testing.T, e *encoder.SszNetworkEncoder) { t.Fatal(err) } decoded := &testpb.TestSimpleMessage{} - if err := e.Decode(buf, decoded); err != nil { + if err := e.Decode(buf.Bytes(), decoded); err != nil { + t.Fatal(err) + } + if !proto.Equal(decoded, msg) { + t.Logf("decoded=%+v\n", decoded) + t.Error("Decoded message is not the same as original") + } +} + +func testRoundTripWithLength(t *testing.T, e *encoder.SszNetworkEncoder) { + buf := new(bytes.Buffer) + msg := &testpb.TestSimpleMessage{ + Foo: []byte("fooooo"), + Bar: 9001, + } + _, err := e.EncodeWithLength(buf, msg) + if err != nil { + t.Fatal(err) + } + decoded := &testpb.TestSimpleMessage{} + if err := e.DecodeWithLength(buf, decoded); err != nil { t.Fatal(err) } if !proto.Equal(decoded, msg) { diff --git a/beacon-chain/p2p/sender.go b/beacon-chain/p2p/sender.go index 30e4875ff7..9941b76e49 100644 --- a/beacon-chain/p2p/sender.go +++ b/beacon-chain/p2p/sender.go @@ -25,7 +25,7 @@ func (s *Service) Send(ctx context.Context, message proto.Message, pid peer.ID) return nil, err } - if _, err := s.Encoding().Encode(stream, message); err != nil { + if _, err := s.Encoding().EncodeWithLength(stream, message); err != nil { return nil, err } diff --git a/beacon-chain/p2p/sender_test.go b/beacon-chain/p2p/sender_test.go index 242280f256..2a962ce9ed 100644 --- a/beacon-chain/p2p/sender_test.go +++ b/beacon-chain/p2p/sender_test.go @@ -38,10 +38,10 @@ func TestService_Send(t *testing.T) { go func() { p2.SetStreamHandler("/testing/1/ssz", func(stream network.Stream) { rcvd := &testpb.TestSimpleMessage{} - if err := svc.Encoding().Decode(stream, rcvd); err != nil { + if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil { t.Fatal(err) } - if _, err := svc.Encoding().Encode(stream, rcvd); err != nil { + if _, err := svc.Encoding().EncodeWithLength(stream, rcvd); err != nil { t.Fatal(err) } if err := stream.Close(); err != nil { @@ -59,7 +59,7 @@ func TestService_Send(t *testing.T) { testutil.WaitTimeout(&wg, 1*time.Second) rcvd := &testpb.TestSimpleMessage{} - if err := svc.Encoding().Decode(stream, rcvd); err != nil { + if err := svc.Encoding().DecodeWithLength(stream, rcvd); err != nil { t.Fatal(err) } if !proto.Equal(rcvd, msg) { diff --git a/beacon-chain/p2p/testing/p2p.go b/beacon-chain/p2p/testing/p2p.go index db3bb4607c..f6d7723b26 100644 --- a/beacon-chain/p2p/testing/p2p.go +++ b/beacon-chain/p2p/testing/p2p.go @@ -81,7 +81,7 @@ func (p *TestP2P) ReceiveRPC(topic string, msg proto.Message) { } defer s.Close() - n, err := p.Encoding().Encode(s, msg) + n, err := p.Encoding().EncodeWithLength(s, msg) if err != nil { p.t.Fatalf("Failed to encode message: %v", err) } @@ -185,7 +185,7 @@ func (p *TestP2P) Send(ctx context.Context, msg proto.Message, pid peer.ID) (net return nil, err } - if _, err := p.Encoding().Encode(stream, msg); err != nil { + if _, err := p.Encoding().EncodeWithLength(stream, msg); err != nil { return nil, err } diff --git a/beacon-chain/sync/error.go b/beacon-chain/sync/error.go index 632e657d76..165f243df2 100644 --- a/beacon-chain/sync/error.go +++ b/beacon-chain/sync/error.go @@ -19,7 +19,7 @@ var responseCodeServerError = byte(0x02) func (r *RegularSync) generateErrorResponse(code byte, reason string) ([]byte, error) { buf := bytes.NewBuffer([]byte{code}) - if _, err := r.p2p.Encoding().Encode(buf, &pb.ErrorMessage{ErrorMessage: reason}); err != nil { + if _, err := r.p2p.Encoding().EncodeWithLength(buf, &pb.ErrorMessage{ErrorMessage: reason}); err != nil { return nil, err } @@ -39,7 +39,7 @@ func ReadStatusCode(stream io.Reader, encoding encoder.NetworkEncoding) (uint8, } msg := &pb.ErrorMessage{} - if err := encoding.Decode(stream, msg); err != nil { + if err := encoding.DecodeWithLength(stream, msg); err != nil { return 0, nil, err } diff --git a/beacon-chain/sync/error_test.go b/beacon-chain/sync/error_test.go index ee249eff0a..ac2dfceb27 100644 --- a/beacon-chain/sync/error_test.go +++ b/beacon-chain/sync/error_test.go @@ -26,7 +26,7 @@ func TestRegularSync_generateErrorResponse(t *testing.T) { t.Errorf("The first byte was not the status code. Got %#x wanted %#x", b, responseCodeServerError) } msg := &pb.ErrorMessage{} - if err := r.p2p.Encoding().Decode(buf, msg); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(buf, msg); err != nil { t.Fatal(err) } if msg.ErrorMessage != "something bad happened" { diff --git a/beacon-chain/sync/initial-sync/service.go b/beacon-chain/sync/initial-sync/service.go index da6faf3653..b78fb66670 100644 --- a/beacon-chain/sync/initial-sync/service.go +++ b/beacon-chain/sync/initial-sync/service.go @@ -128,7 +128,7 @@ func (s *InitialSync) Start() { } resp := &pb.BeaconBlocksResponse{} - if err := s.p2p.Encoding().Decode(strm, resp); err != nil { + if err := s.p2p.Encoding().DecodeWithLength(strm, resp); err != nil { panic(err) } diff --git a/beacon-chain/sync/rpc.go b/beacon-chain/sync/rpc.go index 647c129e71..6b1491d328 100644 --- a/beacon-chain/sync/rpc.go +++ b/beacon-chain/sync/rpc.go @@ -72,7 +72,7 @@ func (r *RegularSync) registerRPC(topic string, base proto.Message, handle rpcHa // Clone the base message type so we have a newly initialized message as the decoding // destination. msg := proto.Clone(base) - if err := r.p2p.Encoding().Decode(stream, msg); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil { log.WithError(err).Error("Failed to decode stream message") return } diff --git a/beacon-chain/sync/rpc_beacon_blocks.go b/beacon-chain/sync/rpc_beacon_blocks.go index 5268c7df9e..0b2f51ae40 100644 --- a/beacon-chain/sync/rpc_beacon_blocks.go +++ b/beacon-chain/sync/rpc_beacon_blocks.go @@ -60,6 +60,6 @@ func (r *RegularSync) beaconBlocksRPCHandler(ctx context.Context, msg proto.Mess if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil { log.WithError(err).Error("Failed to write to stream") } - _, err = r.p2p.Encoding().Encode(stream, ret) + _, err = r.p2p.Encoding().EncodeWithLength(stream, ret) return err } diff --git a/beacon-chain/sync/rpc_beacon_blocks_test.go b/beacon-chain/sync/rpc_beacon_blocks_test.go index 1db465c08a..204bccbcd2 100644 --- a/beacon-chain/sync/rpc_beacon_blocks_test.go +++ b/beacon-chain/sync/rpc_beacon_blocks_test.go @@ -47,7 +47,7 @@ func TestBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) { defer wg.Done() expectSuccess(t, r, stream) res := &pb.BeaconBlocksResponse{} - if err := r.p2p.Encoding().Decode(stream, res); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil { t.Error(err) } if uint64(len(res.Blocks)) != req.Count { diff --git a/beacon-chain/sync/rpc_hello.go b/beacon-chain/sync/rpc_hello.go index fcb41fccc7..fb187e05bd 100644 --- a/beacon-chain/sync/rpc_hello.go +++ b/beacon-chain/sync/rpc_hello.go @@ -50,7 +50,7 @@ func (r *RegularSync) sendRPCHelloRequest(ctx context.Context, id peer.ID) error } msg := &pb.Hello{} - if err := r.p2p.Encoding().Decode(stream, msg); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(stream, msg); err != nil { return err } r.helloTrackerLock.Lock() @@ -121,7 +121,7 @@ func (r *RegularSync) helloRPCHandler(ctx context.Context, msg proto.Message, st if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil { log.WithError(err).Error("Failed to write to stream") } - _, err := r.p2p.Encoding().Encode(stream, resp) + _, err := r.p2p.Encoding().EncodeWithLength(stream, resp) return err } diff --git a/beacon-chain/sync/rpc_hello_test.go b/beacon-chain/sync/rpc_hello_test.go index b59832ee17..da9666c302 100644 --- a/beacon-chain/sync/rpc_hello_test.go +++ b/beacon-chain/sync/rpc_hello_test.go @@ -112,7 +112,7 @@ func TestHelloRPCHandler_ReturnsHelloMessage(t *testing.T) { defer wg.Done() expectSuccess(t, r, stream) out := &pb.Hello{} - if err := r.p2p.Encoding().Decode(stream, out); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil { t.Fatal(err) } expected := &pb.Hello{ @@ -187,7 +187,7 @@ func TestHelloRPCRequest_RequestSent(t *testing.T) { p2.Host.SetStreamHandler(pcl, func(stream network.Stream) { defer wg.Done() out := &pb.Hello{} - if err := r.p2p.Encoding().Decode(stream, out); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(stream, out); err != nil { t.Fatal(err) } expected := &pb.Hello{ diff --git a/beacon-chain/sync/rpc_recent_beacon_blocks.go b/beacon-chain/sync/rpc_recent_beacon_blocks.go index 73beaf9048..c5f7a76327 100644 --- a/beacon-chain/sync/rpc_recent_beacon_blocks.go +++ b/beacon-chain/sync/rpc_recent_beacon_blocks.go @@ -54,6 +54,6 @@ func (r *RegularSync) recentBeaconBlocksRPCHandler(ctx context.Context, msg prot if _, err := stream.Write([]byte{responseCodeSuccess}); err != nil { log.WithError(err).Error("Failed to write to stream") } - _, err := r.p2p.Encoding().Encode(stream, ret) + _, err := r.p2p.Encoding().EncodeWithLength(stream, ret) return err } diff --git a/beacon-chain/sync/rpc_recent_beacon_blocks_test.go b/beacon-chain/sync/rpc_recent_beacon_blocks_test.go index 5ac4773124..4e9fa669dd 100644 --- a/beacon-chain/sync/rpc_recent_beacon_blocks_test.go +++ b/beacon-chain/sync/rpc_recent_beacon_blocks_test.go @@ -54,7 +54,7 @@ func TestRecentBeaconBlocksRPCHandler_ReturnsBlocks(t *testing.T) { defer wg.Done() expectSuccess(t, r, stream) res := &pb.BeaconBlocksResponse{} - if err := r.p2p.Encoding().Decode(stream, res); err != nil { + if err := r.p2p.Encoding().DecodeWithLength(stream, res); err != nil { t.Error(err) } if len(res.Blocks) != len(req.BlockRoots) { diff --git a/beacon-chain/sync/subscriber.go b/beacon-chain/sync/subscriber.go index ec8668f23d..adc6f80037 100644 --- a/beacon-chain/sync/subscriber.go +++ b/beacon-chain/sync/subscriber.go @@ -1,7 +1,6 @@ package sync import ( - "bytes" "context" "errors" "fmt" @@ -105,7 +104,7 @@ func (r *RegularSync) subscribe(topic string, validate validator, handle subHand } msg := proto.Clone(base) - if err := r.p2p.Encoding().Decode(bytes.NewBuffer(data), msg); err != nil { + if err := r.p2p.Encoding().Decode(data, msg); err != nil { log.WithError(err).Warn("Failed to decode pubsub message") return }