From 481d77bfdea52a077ca5bac1a54f5083f5daf193 Mon Sep 17 00:00:00 2001 From: james-prysm <90280386+james-prysm@users.noreply.github.com> Date: Thu, 7 Dec 2023 22:24:18 -0600 Subject: [PATCH] APIs: reusing grpc cors middleware for rest (#13284) * reusing grpc cors middleware for rest * addressing radek's comments * Update api/server/middleware.go Co-authored-by: Sammy Rosso <15244892+saolyn@users.noreply.github.com> * fixing to recommended name * fixing naming * fixing rename on test --------- Co-authored-by: Sammy Rosso <15244892+saolyn@users.noreply.github.com> Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com> --- api/gateway/BUILD.bazel | 2 +- api/gateway/gateway.go | 15 ++--------- api/server/BUILD.bazel | 4 +++ api/server/middleware.go | 17 ++++++++++++ beacon-chain/node/node.go | 16 +++++++++-- beacon-chain/node/node_test.go | 49 ++++++++++++++++++++++++++++++++++ validator/node/node.go | 16 +++++++++-- 7 files changed, 101 insertions(+), 18 deletions(-) diff --git a/api/gateway/BUILD.bazel b/api/gateway/BUILD.bazel index c58d7d49e6..dcc3db1811 100644 --- a/api/gateway/BUILD.bazel +++ b/api/gateway/BUILD.bazel @@ -14,11 +14,11 @@ go_library( "//validator:__subpackages__", ], deps = [ + "//api/server:go_default_library", "//runtime:go_default_library", "@com_github_gorilla_mux//:go_default_library", "@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library", "@com_github_pkg_errors//:go_default_library", - "@com_github_rs_cors//:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//connectivity:go_default_library", diff --git a/api/gateway/gateway.go b/api/gateway/gateway.go index 32f47dabed..16f6a5ae4c 100644 --- a/api/gateway/gateway.go +++ b/api/gateway/gateway.go @@ -11,8 +11,8 @@ import ( "github.com/gorilla/mux" gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/v4/api/server" "github.com/prysmaticlabs/prysm/v4/runtime" - "github.com/rs/cors" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" @@ -104,7 +104,7 @@ func (g *Gateway) Start() { } } - corsMux := g.corsMiddleware(g.cfg.router) + corsMux := server.CorsHandler(g.cfg.allowedOrigins).Middleware(g.cfg.router) if g.cfg.muxHandler != nil { g.cfg.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -158,17 +158,6 @@ func (g *Gateway) Stop() error { return nil } -func (g *Gateway) corsMiddleware(h http.Handler) http.Handler { - c := cors.New(cors.Options{ - AllowedOrigins: g.cfg.allowedOrigins, - AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions}, - AllowCredentials: true, - MaxAge: 600, - AllowedHeaders: []string{"*"}, - }) - return c.Handler(h) -} - // dial the gRPC server. func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientConn, error) { switch network { diff --git a/api/server/BUILD.bazel b/api/server/BUILD.bazel index 9ecdd5b454..8ab494a524 100644 --- a/api/server/BUILD.bazel +++ b/api/server/BUILD.bazel @@ -8,6 +8,10 @@ go_library( ], importpath = "github.com/prysmaticlabs/prysm/v4/api/server", visibility = ["//visibility:public"], + deps = [ + "@com_github_gorilla_mux//:go_default_library", + "@com_github_rs_cors//:go_default_library", + ], ) go_test( diff --git a/api/server/middleware.go b/api/server/middleware.go index 670dff2be5..13afd71fbc 100644 --- a/api/server/middleware.go +++ b/api/server/middleware.go @@ -2,8 +2,12 @@ package server import ( "net/http" + + "github.com/gorilla/mux" + "github.com/rs/cors" ) +// NormalizeQueryValuesHandler normalizes an input query of "key=value1,value2,value3" to "key=value1&key=value2&key=value3" func NormalizeQueryValuesHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() @@ -13,3 +17,16 @@ func NormalizeQueryValuesHandler(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +// CorsHandler sets the cors settings on api endpoints +func CorsHandler(allowOrigins []string) mux.MiddlewareFunc { + c := cors.New(cors.Options{ + AllowedOrigins: allowOrigins, + AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions}, + AllowCredentials: true, + MaxAge: 600, + AllowedHeaders: []string{"*"}, + }) + + return c.Handler +} diff --git a/beacon-chain/node/node.go b/beacon-chain/node/node.go index c2fc26ea95..7a0fdd692d 100644 --- a/beacon-chain/node/node.go +++ b/beacon-chain/node/node.go @@ -271,8 +271,7 @@ func New(cliCtx *cli.Context, cancel context.CancelFunc, opts ...Option) (*Beaco } log.Debugln("Registering RPC Service") - router := mux.NewRouter() - router.Use(server.NormalizeQueryValuesHandler) + router := newRouter(cliCtx) if err := beacon.registerRPCService(router); err != nil { return nil, err } @@ -306,6 +305,19 @@ func New(cliCtx *cli.Context, cancel context.CancelFunc, opts ...Option) (*Beaco return beacon, nil } +func newRouter(cliCtx *cli.Context) *mux.Router { + var allowedOrigins []string + if cliCtx.IsSet(flags.GPRCGatewayCorsDomain.Name) { + allowedOrigins = strings.Split(cliCtx.String(flags.GPRCGatewayCorsDomain.Name), ",") + } else { + allowedOrigins = strings.Split(flags.GPRCGatewayCorsDomain.Value, ",") + } + r := mux.NewRouter() + r.Use(server.NormalizeQueryValuesHandler) + r.Use(server.CorsHandler(allowedOrigins)) + return r +} + // StateFeed implements statefeed.Notifier. func (b *BeaconNode) StateFeed() *event.Feed { return b.stateFeed diff --git a/beacon-chain/node/node_test.go b/beacon-chain/node/node_test.go index 024522e176..921a1b29d9 100644 --- a/beacon-chain/node/node_test.go +++ b/beacon-chain/node/node_test.go @@ -4,6 +4,8 @@ import ( "context" "flag" "fmt" + "net/http" + "net/http/httptest" "os" "path/filepath" "strconv" @@ -222,3 +224,50 @@ func Test_hasNetworkFlag(t *testing.T) { }) } } + +func TestCORS(t *testing.T) { + // Mock CLI context with a test CORS domain + app := cli.App{} + set := flag.NewFlagSet("test", 0) + set.String(flags.GPRCGatewayCorsDomain.Name, "http://allowed-example.com", "") + cliCtx := cli.NewContext(&app, set, nil) + require.NoError(t, cliCtx.Set(flags.GPRCGatewayCorsDomain.Name, "http://allowed-example.com")) + + router := newRouter(cliCtx) + + // Ensure a test route exists + router.HandleFunc("/some-path", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }).Methods(http.MethodGet) + + // Define test cases + tests := []struct { + name string + origin string + expectAllow bool + }{ + {"AllowedOrigin", "http://allowed-example.com", true}, + {"DisallowedOrigin", "http://disallowed-example.com", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + + // Create a request and response recorder + req := httptest.NewRequest("GET", "http://example.com/some-path", nil) + req.Header.Set("Origin", tc.origin) + rr := httptest.NewRecorder() + + // Serve HTTP + router.ServeHTTP(rr, req) + + // Check the CORS headers based on the expected outcome + if tc.expectAllow && rr.Header().Get("Access-Control-Allow-Origin") != tc.origin { + t.Errorf("Expected Access-Control-Allow-Origin header to be %v, got %v", tc.origin, rr.Header().Get("Access-Control-Allow-Origin")) + } + if !tc.expectAllow && rr.Header().Get("Access-Control-Allow-Origin") != "" { + t.Errorf("Expected Access-Control-Allow-Origin header to be empty for disallowed origin, got %v", rr.Header().Get("Access-Control-Allow-Origin")) + } + }) + } +} diff --git a/validator/node/node.go b/validator/node/node.go index 6f5b43638d..73d9c091c6 100644 --- a/validator/node/node.go +++ b/validator/node/node.go @@ -128,8 +128,7 @@ func NewValidatorClient(cliCtx *cli.Context) (*ValidatorClient, error) { configureFastSSZHashingAlgorithm() // initialize router used for endpoints - router := mux.NewRouter() - router.Use(server.NormalizeQueryValuesHandler) + router := newRouter(cliCtx) // If the --web flag is enabled to administer the validator // client via a web portal, we start the validator client in a different way. // Change Web flag name to enable keymanager API, look at merging initializeFromCLI and initializeForWeb maybe after WebUI DEPRECATED. @@ -151,6 +150,19 @@ func NewValidatorClient(cliCtx *cli.Context) (*ValidatorClient, error) { return validatorClient, nil } +func newRouter(cliCtx *cli.Context) *mux.Router { + var allowedOrigins []string + if cliCtx.IsSet(flags.GPRCGatewayCorsDomain.Name) { + allowedOrigins = strings.Split(cliCtx.String(flags.GPRCGatewayCorsDomain.Name), ",") + } else { + allowedOrigins = strings.Split(flags.GPRCGatewayCorsDomain.Value, ",") + } + r := mux.NewRouter() + r.Use(server.NormalizeQueryValuesHandler) + r.Use(server.CorsHandler(allowedOrigins)) + return r +} + // Start every service in the validator client. func (c *ValidatorClient) Start() { c.lock.Lock()