Add flags for disabling selected API (#9606)

* Add flags for disabling selected API

* tests

* build file

* Use comma-separated modules

* test fix

* fix gateway tests

* fix import in flag tests
This commit is contained in:
Radosław Kapka
2021-09-24 11:25:42 +02:00
committed by GitHub
parent 7dd99de69f
commit 12480e12b2
15 changed files with 207 additions and 75 deletions

View File

@@ -39,7 +39,7 @@ type MuxHandler func(http.Handler, http.ResponseWriter, *http.Request)
// Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server.
type Gateway struct {
conn *grpc.ClientConn
pbHandlers []PbMux
pbHandlers []*PbMux
muxHandler MuxHandler
maxCallRecvMsgSize uint64
router *mux.Router
@@ -57,7 +57,7 @@ type Gateway struct {
// New returns a new instance of the Gateway.
func New(
ctx context.Context,
pbHandlers []PbMux,
pbHandlers []*PbMux,
muxHandler MuxHandler,
remoteAddr,
gatewayAddress string,

View File

@@ -42,7 +42,7 @@ func TestGateway_Customized(t *testing.T) {
g := New(
context.Background(),
[]PbMux{},
[]*PbMux{},
func(handler http.Handler, writer http.ResponseWriter, request *http.Request) {
},
@@ -77,7 +77,7 @@ func TestGateway_StartStop(t *testing.T) {
g := New(
ctx.Context,
[]PbMux{},
[]*PbMux{},
func(handler http.Handler, writer http.ResponseWriter, request *http.Request) {
},
@@ -108,7 +108,7 @@ func TestGateway_NilHandler_NotFoundHandlerRegistered(t *testing.T) {
g := New(
ctx.Context,
[]PbMux{},
[]*PbMux{},
/* muxHandler */ nil,
selfAddress,
gatewayAddress,

View File

@@ -7,6 +7,7 @@ go_library(
visibility = ["//beacon-chain:__subpackages__"],
deps = [
"//api/gateway:go_default_library",
"//cmd/beacon-chain/flags:go_default_library",
"//proto/eth/service:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
@@ -19,6 +20,7 @@ go_test(
srcs = ["helpers_test.go"],
embed = [":go_default_library"],
deps = [
"//api/gateway:go_default_library",
"//testing/assert:go_default_library",
"//testing/require:go_default_library",
],

View File

@@ -3,6 +3,7 @@ package gateway
import (
gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/prysmaticlabs/prysm/api/gateway"
"github.com/prysmaticlabs/prysm/cmd/beacon-chain/flags"
ethpbservice "github.com/prysmaticlabs/prysm/proto/eth/service"
ethpbalpha "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
"google.golang.org/protobuf/encoding/protojson"
@@ -11,66 +12,74 @@ import (
// MuxConfig contains configuration that should be used when registering the beacon node in the gateway.
type MuxConfig struct {
Handler gateway.MuxHandler
EthPbMux gateway.PbMux
V1Alpha1PbMux gateway.PbMux
EthPbMux *gateway.PbMux
V1Alpha1PbMux *gateway.PbMux
}
// DefaultConfig returns a fully configured MuxConfig with standard gateway behavior.
func DefaultConfig(enableDebugRPCEndpoints bool) MuxConfig {
v1Alpha1Registrations := []gateway.PbHandlerRegistration{
ethpbalpha.RegisterNodeHandler,
ethpbalpha.RegisterBeaconChainHandler,
ethpbalpha.RegisterBeaconNodeValidatorHandler,
ethpbalpha.RegisterHealthHandler,
}
ethRegistrations := []gateway.PbHandlerRegistration{
ethpbservice.RegisterBeaconNodeHandler,
ethpbservice.RegisterBeaconChainHandler,
ethpbservice.RegisterBeaconValidatorHandler,
ethpbservice.RegisterEventsHandler,
}
if enableDebugRPCEndpoints {
v1Alpha1Registrations = append(v1Alpha1Registrations, ethpbalpha.RegisterDebugHandler)
ethRegistrations = append(ethRegistrations, ethpbservice.RegisterBeaconDebugHandler)
func DefaultConfig(enableDebugRPCEndpoints bool, httpModules string) MuxConfig {
var v1Alpha1PbHandler, ethPbHandler *gateway.PbMux
if flags.EnableHTTPPrysmAPI(httpModules) {
v1Alpha1Registrations := []gateway.PbHandlerRegistration{
ethpbalpha.RegisterNodeHandler,
ethpbalpha.RegisterBeaconChainHandler,
ethpbalpha.RegisterBeaconNodeValidatorHandler,
ethpbalpha.RegisterHealthHandler,
}
if enableDebugRPCEndpoints {
v1Alpha1Registrations = append(v1Alpha1Registrations, ethpbalpha.RegisterDebugHandler)
}
v1Alpha1Mux := gwruntime.NewServeMux(
gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{
Marshaler: &gwruntime.JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}),
gwruntime.WithMarshalerOption(
"text/event-stream", &gwruntime.EventSourceJSONPb{},
),
)
v1Alpha1PbHandler = &gateway.PbMux{
Registrations: v1Alpha1Registrations,
Patterns: []string{"/eth/v1alpha1/"},
Mux: v1Alpha1Mux,
}
}
v1Alpha1Mux := gwruntime.NewServeMux(
gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{
Marshaler: &gwruntime.JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
if flags.EnableHTTPEthAPI(httpModules) {
ethRegistrations := []gateway.PbHandlerRegistration{
ethpbservice.RegisterBeaconNodeHandler,
ethpbservice.RegisterBeaconChainHandler,
ethpbservice.RegisterBeaconValidatorHandler,
ethpbservice.RegisterEventsHandler,
}
if enableDebugRPCEndpoints {
ethRegistrations = append(ethRegistrations, ethpbservice.RegisterBeaconDebugHandler)
}
ethMux := gwruntime.NewServeMux(
gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{
Marshaler: &gwruntime.JSONPb{
MarshalOptions: protojson.MarshalOptions{
UseProtoNames: true,
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}),
gwruntime.WithMarshalerOption(
"text/event-stream", &gwruntime.EventSourceJSONPb{},
),
)
ethMux := gwruntime.NewServeMux(
gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{
Marshaler: &gwruntime.JSONPb{
MarshalOptions: protojson.MarshalOptions{
UseProtoNames: true,
EmitUnpopulated: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
},
}),
)
v1Alpha1PbHandler := gateway.PbMux{
Registrations: v1Alpha1Registrations,
Patterns: []string{"/eth/v1alpha1/"},
Mux: v1Alpha1Mux,
}
ethPbHandler := gateway.PbMux{
Registrations: ethRegistrations,
Patterns: []string{"/internal/eth/v1/", "/internal/eth/v2/"},
Mux: ethMux,
}),
)
ethPbHandler = &gateway.PbMux{
Registrations: ethRegistrations,
Patterns: []string{"/internal/eth/v1/", "/internal/eth/v2/"},
Mux: ethMux,
}
}
return MuxConfig{

View File

@@ -3,13 +3,14 @@ package gateway
import (
"testing"
"github.com/prysmaticlabs/prysm/api/gateway"
"github.com/prysmaticlabs/prysm/testing/assert"
"github.com/prysmaticlabs/prysm/testing/require"
)
func TestDefaultConfig(t *testing.T) {
t.Run("Without debug endpoints", func(t *testing.T) {
cfg := DefaultConfig(false)
cfg := DefaultConfig(false, "eth,prysm")
assert.NotNil(t, cfg.EthPbMux.Mux)
require.Equal(t, 2, len(cfg.EthPbMux.Patterns))
assert.Equal(t, "/internal/eth/v1/", cfg.EthPbMux.Patterns[0])
@@ -22,7 +23,7 @@ func TestDefaultConfig(t *testing.T) {
})
t.Run("With debug endpoints", func(t *testing.T) {
cfg := DefaultConfig(true)
cfg := DefaultConfig(true, "eth,prysm")
assert.NotNil(t, cfg.EthPbMux.Mux)
require.Equal(t, 2, len(cfg.EthPbMux.Patterns))
assert.Equal(t, "/internal/eth/v1/", cfg.EthPbMux.Patterns[0])
@@ -33,4 +34,20 @@ func TestDefaultConfig(t *testing.T) {
assert.Equal(t, "/eth/v1alpha1/", cfg.V1Alpha1PbMux.Patterns[0])
assert.Equal(t, 5, len(cfg.V1Alpha1PbMux.Registrations))
})
t.Run("Without Prysm API", func(t *testing.T) {
cfg := DefaultConfig(true, "eth")
assert.NotNil(t, cfg.EthPbMux.Mux)
require.Equal(t, 2, len(cfg.EthPbMux.Patterns))
assert.Equal(t, "/internal/eth/v1/", cfg.EthPbMux.Patterns[0])
assert.Equal(t, 5, len(cfg.EthPbMux.Registrations))
assert.Equal(t, (*gateway.PbMux)(nil), cfg.V1Alpha1PbMux)
})
t.Run("Without Eth API", func(t *testing.T) {
cfg := DefaultConfig(true, "prysm")
assert.Equal(t, (*gateway.PbMux)(nil), cfg.EthPbMux)
assert.NotNil(t, cfg.V1Alpha1PbMux.Mux)
require.Equal(t, 1, len(cfg.V1Alpha1PbMux.Patterns))
assert.Equal(t, "/eth/v1alpha1/", cfg.V1Alpha1PbMux.Patterns[0])
assert.Equal(t, 5, len(cfg.V1Alpha1PbMux.Registrations))
})
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/api/gateway"
apigateway "github.com/prysmaticlabs/prysm/api/gateway"
"github.com/prysmaticlabs/prysm/async/event"
"github.com/prysmaticlabs/prysm/beacon-chain/blockchain"
"github.com/prysmaticlabs/prysm/beacon-chain/cache/depositcache"
@@ -26,7 +26,7 @@ import (
"github.com/prysmaticlabs/prysm/beacon-chain/db/kv"
"github.com/prysmaticlabs/prysm/beacon-chain/forkchoice"
"github.com/prysmaticlabs/prysm/beacon-chain/forkchoice/protoarray"
gateway2 "github.com/prysmaticlabs/prysm/beacon-chain/gateway"
"github.com/prysmaticlabs/prysm/beacon-chain/gateway"
interopcoldstart "github.com/prysmaticlabs/prysm/beacon-chain/interop-cold-start"
"github.com/prysmaticlabs/prysm/beacon-chain/node/registration"
"github.com/prysmaticlabs/prysm/beacon-chain/operations/attestations"
@@ -670,22 +670,32 @@ func (b *BeaconNode) registerGRPCGateway() error {
enableDebugRPCEndpoints := b.cliCtx.Bool(flags.EnableDebugRPCEndpoints.Name)
selfCert := b.cliCtx.String(flags.CertFlag.Name)
maxCallSize := b.cliCtx.Uint64(cmd.GrpcMaxCallRecvMsgSizeFlag.Name)
httpModules := b.cliCtx.String(flags.HTTPModules.Name)
if enableDebugRPCEndpoints {
maxCallSize = uint64(math.Max(float64(maxCallSize), debugGrpcMaxMsgSize))
}
gatewayConfig := gateway2.DefaultConfig(enableDebugRPCEndpoints)
gatewayConfig := gateway.DefaultConfig(enableDebugRPCEndpoints, httpModules)
muxs := make([]*apigateway.PbMux, 0)
if gatewayConfig.V1Alpha1PbMux != nil {
muxs = append(muxs, gatewayConfig.V1Alpha1PbMux)
}
if gatewayConfig.EthPbMux != nil {
muxs = append(muxs, gatewayConfig.EthPbMux)
}
g := gateway.New(
g := apigateway.New(
b.ctx,
[]gateway.PbMux{gatewayConfig.V1Alpha1PbMux, gatewayConfig.EthPbMux},
muxs,
gatewayConfig.Handler,
selfAddress,
gatewayAddress,
).WithAllowedOrigins(allowedOrigins).
WithRemoteCert(selfCert).
WithMaxCallRecvMsgSize(maxCallSize).
WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})
WithMaxCallRecvMsgSize(maxCallSize)
if flags.EnableHTTPEthAPI(httpModules) {
g.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})
}
return b.services.RegisterService(g)
}

View File

@@ -13,6 +13,7 @@ go_library(
"//api/gateway:go_default_library",
"//beacon-chain/gateway:go_default_library",
"//beacon-chain/rpc/apimiddleware:go_default_library",
"//cmd/beacon-chain/flags:go_default_library",
"//runtime/maxprocs:go_default_library",
"@com_github_gorilla_mux//:go_default_library",
"@com_github_joonix_log//:go_default_library",

View File

@@ -14,6 +14,7 @@ import (
"github.com/prysmaticlabs/prysm/api/gateway"
beaconGateway "github.com/prysmaticlabs/prysm/beacon-chain/gateway"
"github.com/prysmaticlabs/prysm/beacon-chain/rpc/apimiddleware"
"github.com/prysmaticlabs/prysm/cmd/beacon-chain/flags"
_ "github.com/prysmaticlabs/prysm/runtime/maxprocs"
"github.com/sirupsen/logrus"
)
@@ -26,6 +27,11 @@ var (
allowedOrigins = flag.String("corsdomain", "localhost:4242", "A comma separated list of CORS domains to allow")
enableDebugRPCEndpoints = flag.Bool("enable-debug-rpc-endpoints", false, "Enable debug rpc endpoints such as /eth/v1alpha1/beacon/state")
grpcMaxMsgSize = flag.Int("grpc-max-msg-size", 1<<22, "Integer to define max recieve message call size")
httpModules = flag.String(
"http-modules",
strings.Join([]string{flags.PrysmAPIModule, flags.EthAPIModule}, ","),
"Comma-separated list of API module names. Possible values: `"+flags.PrysmAPIModule+`,`+flags.EthAPIModule+"`.",
)
)
func init() {
@@ -38,17 +44,26 @@ func main() {
log.SetLevel(logrus.DebugLevel)
}
gatewayConfig := beaconGateway.DefaultConfig(*enableDebugRPCEndpoints)
gatewayConfig := beaconGateway.DefaultConfig(*enableDebugRPCEndpoints, *httpModules)
muxs := make([]*gateway.PbMux, 0)
if gatewayConfig.V1Alpha1PbMux != nil {
muxs = append(muxs, gatewayConfig.V1Alpha1PbMux)
}
if gatewayConfig.EthPbMux != nil {
muxs = append(muxs, gatewayConfig.EthPbMux)
}
gw := gateway.New(
context.Background(),
[]gateway.PbMux{gatewayConfig.V1Alpha1PbMux, gatewayConfig.EthPbMux},
muxs,
gatewayConfig.Handler,
*beaconRPC,
fmt.Sprintf("%s:%d", *host, *port),
).WithAllowedOrigins(strings.Split(*allowedOrigins, ",")).
WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize)).
WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})
WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize))
if flags.EnableHTTPEthAPI(*httpModules) {
gw.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})
}
r := mux.NewRouter()
r.HandleFunc("/swagger/", gateway.SwaggerServer())

View File

@@ -1,8 +1,9 @@
load("@prysm//tools/go:def.bzl", "go_library")
load("@prysm//tools/go:def.bzl", "go_library", "go_test")
go_library(
name = "go_default_library",
srcs = [
"api_module.go",
"base.go",
"config.go",
"interop.go",
@@ -22,3 +23,10 @@ go_library(
"@com_github_urfave_cli_v2//:go_default_library",
],
)
go_test(
name = "go_default_test",
srcs = ["api_module_test.go"],
embed = [":go_default_library"],
deps = ["//testing/assert:go_default_library"],
)

View File

@@ -0,0 +1,23 @@
package flags
import "strings"
const PrysmAPIModule string = "prysm"
const EthAPIModule string = "eth"
func EnableHTTPPrysmAPI(httpModules string) bool {
return enableAPI(httpModules, PrysmAPIModule)
}
func EnableHTTPEthAPI(httpModules string) bool {
return enableAPI(httpModules, EthAPIModule)
}
func enableAPI(httpModules string, api string) bool {
for _, m := range strings.Split(httpModules, ",") {
if strings.EqualFold(m, api) {
return true
}
}
return false
}

View File

@@ -0,0 +1,37 @@
package flags
import (
"testing"
"github.com/prysmaticlabs/prysm/testing/assert"
)
func TestEnableHTTPPrysmAPI(t *testing.T) {
assert.Equal(t, true, EnableHTTPPrysmAPI("prysm"))
assert.Equal(t, true, EnableHTTPPrysmAPI("prysm,foo"))
assert.Equal(t, true, EnableHTTPPrysmAPI("foo,prysm"))
assert.Equal(t, true, EnableHTTPPrysmAPI("prysm,prysm"))
assert.Equal(t, true, EnableHTTPPrysmAPI("PrYsM"))
assert.Equal(t, false, EnableHTTPPrysmAPI("foo"))
assert.Equal(t, false, EnableHTTPPrysmAPI(""))
}
func TestEnableHTTPEthAPI(t *testing.T) {
assert.Equal(t, true, EnableHTTPEthAPI("eth"))
assert.Equal(t, true, EnableHTTPEthAPI("eth,foo"))
assert.Equal(t, true, EnableHTTPEthAPI("foo,eth"))
assert.Equal(t, true, EnableHTTPEthAPI("eth,eth"))
assert.Equal(t, true, EnableHTTPEthAPI("EtH"))
assert.Equal(t, false, EnableHTTPEthAPI("foo"))
assert.Equal(t, false, EnableHTTPEthAPI(""))
}
func TestEnableApi(t *testing.T) {
assert.Equal(t, true, enableAPI("foo", "foo"))
assert.Equal(t, true, enableAPI("foo,bar", "foo"))
assert.Equal(t, true, enableAPI("bar,foo", "foo"))
assert.Equal(t, true, enableAPI("foo,foo", "foo"))
assert.Equal(t, true, enableAPI("FoO", "foo"))
assert.Equal(t, false, enableAPI("bar", "foo"))
assert.Equal(t, false, enableAPI("", "foo"))
}

View File

@@ -3,6 +3,8 @@
package flags
import (
"strings"
"github.com/prysmaticlabs/prysm/config/params"
"github.com/urfave/cli/v2"
)
@@ -53,6 +55,12 @@ var (
Name: "tls-key",
Usage: "Key for secure gRPC. Pass this and the tls-cert flag in order to use gRPC securely.",
}
// HTTPModules define the set of enabled HTTP APIs.
HTTPModules = &cli.StringFlag{
Name: "http-modules",
Usage: "Comma-separated list of API module names. Possible values: `" + PrysmAPIModule + `,` + EthAPIModule + "`.",
Value: strings.Join([]string{PrysmAPIModule, EthAPIModule}, ","),
}
// DisableGRPCGateway for JSON-HTTP requests to the beacon node.
DisableGRPCGateway = &cli.BoolFlag{
Name: "disable-grpc-gateway",

View File

@@ -36,6 +36,7 @@ var appFlags = []cli.Flag{
flags.RPCPort,
flags.CertFlag,
flags.KeyFlag,
flags.HTTPModules,
flags.DisableGRPCGateway,
flags.GRPCGatewayHost,
flags.GRPCGatewayPort,

View File

@@ -99,6 +99,7 @@ var appHelpFlagGroups = []flagGroup{
flags.RPCPort,
flags.CertFlag,
flags.KeyFlag,
flags.HTTPModules,
flags.DisableGRPCGateway,
flags.GRPCGatewayHost,
flags.GRPCGatewayPort,

View File

@@ -552,7 +552,7 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error {
}
}
pbHandler := gateway.PbMux{
pbHandler := &gateway.PbMux{
Registrations: registrations,
Patterns: []string{"/accounts/", "/v2/"},
Mux: mux,
@@ -560,7 +560,7 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error {
gw := gateway.New(
cliCtx.Context,
[]gateway.PbMux{pbHandler},
[]*gateway.PbMux{pbHandler},
muxHandler,
rpcAddr,
gatewayAddress,