APIs: reusing grpc cors middleware for rest (#13284)

* reusing grpc cors middleware for rest

* addressing radek's comments

* Update api/server/middleware.go

Co-authored-by: Sammy Rosso <15244892+saolyn@users.noreply.github.com>

* fixing to recommended name

* fixing naming

* fixing rename on test

---------

Co-authored-by: Sammy Rosso <15244892+saolyn@users.noreply.github.com>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
james-prysm
2023-12-07 22:24:18 -06:00
committed by GitHub
parent 590317553c
commit 481d77bfde
7 changed files with 101 additions and 18 deletions

View File

@@ -14,11 +14,11 @@ go_library(
"//validator:__subpackages__",
],
deps = [
"//api/server:go_default_library",
"//runtime:go_default_library",
"@com_github_gorilla_mux//:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
"@com_github_pkg_errors//:go_default_library",
"@com_github_rs_cors//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//connectivity:go_default_library",

View File

@@ -11,8 +11,8 @@ import (
"github.com/gorilla/mux"
gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/v4/api/server"
"github.com/prysmaticlabs/prysm/v4/runtime"
"github.com/rs/cors"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
@@ -104,7 +104,7 @@ func (g *Gateway) Start() {
}
}
corsMux := g.corsMiddleware(g.cfg.router)
corsMux := server.CorsHandler(g.cfg.allowedOrigins).Middleware(g.cfg.router)
if g.cfg.muxHandler != nil {
g.cfg.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -158,17 +158,6 @@ func (g *Gateway) Stop() error {
return nil
}
func (g *Gateway) corsMiddleware(h http.Handler) http.Handler {
c := cors.New(cors.Options{
AllowedOrigins: g.cfg.allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions},
AllowCredentials: true,
MaxAge: 600,
AllowedHeaders: []string{"*"},
})
return c.Handler(h)
}
// dial the gRPC server.
func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientConn, error) {
switch network {

View File

@@ -8,6 +8,10 @@ go_library(
],
importpath = "github.com/prysmaticlabs/prysm/v4/api/server",
visibility = ["//visibility:public"],
deps = [
"@com_github_gorilla_mux//:go_default_library",
"@com_github_rs_cors//:go_default_library",
],
)
go_test(

View File

@@ -2,8 +2,12 @@ package server
import (
"net/http"
"github.com/gorilla/mux"
"github.com/rs/cors"
)
// NormalizeQueryValuesHandler normalizes an input query of "key=value1,value2,value3" to "key=value1&key=value2&key=value3"
func NormalizeQueryValuesHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
@@ -13,3 +17,16 @@ func NormalizeQueryValuesHandler(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
// CorsHandler sets the cors settings on api endpoints
func CorsHandler(allowOrigins []string) mux.MiddlewareFunc {
c := cors.New(cors.Options{
AllowedOrigins: allowOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions},
AllowCredentials: true,
MaxAge: 600,
AllowedHeaders: []string{"*"},
})
return c.Handler
}

View File

@@ -271,8 +271,7 @@ func New(cliCtx *cli.Context, cancel context.CancelFunc, opts ...Option) (*Beaco
}
log.Debugln("Registering RPC Service")
router := mux.NewRouter()
router.Use(server.NormalizeQueryValuesHandler)
router := newRouter(cliCtx)
if err := beacon.registerRPCService(router); err != nil {
return nil, err
}
@@ -306,6 +305,19 @@ func New(cliCtx *cli.Context, cancel context.CancelFunc, opts ...Option) (*Beaco
return beacon, nil
}
func newRouter(cliCtx *cli.Context) *mux.Router {
var allowedOrigins []string
if cliCtx.IsSet(flags.GPRCGatewayCorsDomain.Name) {
allowedOrigins = strings.Split(cliCtx.String(flags.GPRCGatewayCorsDomain.Name), ",")
} else {
allowedOrigins = strings.Split(flags.GPRCGatewayCorsDomain.Value, ",")
}
r := mux.NewRouter()
r.Use(server.NormalizeQueryValuesHandler)
r.Use(server.CorsHandler(allowedOrigins))
return r
}
// StateFeed implements statefeed.Notifier.
func (b *BeaconNode) StateFeed() *event.Feed {
return b.stateFeed

View File

@@ -4,6 +4,8 @@ import (
"context"
"flag"
"fmt"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
@@ -222,3 +224,50 @@ func Test_hasNetworkFlag(t *testing.T) {
})
}
}
func TestCORS(t *testing.T) {
// Mock CLI context with a test CORS domain
app := cli.App{}
set := flag.NewFlagSet("test", 0)
set.String(flags.GPRCGatewayCorsDomain.Name, "http://allowed-example.com", "")
cliCtx := cli.NewContext(&app, set, nil)
require.NoError(t, cliCtx.Set(flags.GPRCGatewayCorsDomain.Name, "http://allowed-example.com"))
router := newRouter(cliCtx)
// Ensure a test route exists
router.HandleFunc("/some-path", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}).Methods(http.MethodGet)
// Define test cases
tests := []struct {
name string
origin string
expectAllow bool
}{
{"AllowedOrigin", "http://allowed-example.com", true},
{"DisallowedOrigin", "http://disallowed-example.com", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a request and response recorder
req := httptest.NewRequest("GET", "http://example.com/some-path", nil)
req.Header.Set("Origin", tc.origin)
rr := httptest.NewRecorder()
// Serve HTTP
router.ServeHTTP(rr, req)
// Check the CORS headers based on the expected outcome
if tc.expectAllow && rr.Header().Get("Access-Control-Allow-Origin") != tc.origin {
t.Errorf("Expected Access-Control-Allow-Origin header to be %v, got %v", tc.origin, rr.Header().Get("Access-Control-Allow-Origin"))
}
if !tc.expectAllow && rr.Header().Get("Access-Control-Allow-Origin") != "" {
t.Errorf("Expected Access-Control-Allow-Origin header to be empty for disallowed origin, got %v", rr.Header().Get("Access-Control-Allow-Origin"))
}
})
}
}

View File

@@ -128,8 +128,7 @@ func NewValidatorClient(cliCtx *cli.Context) (*ValidatorClient, error) {
configureFastSSZHashingAlgorithm()
// initialize router used for endpoints
router := mux.NewRouter()
router.Use(server.NormalizeQueryValuesHandler)
router := newRouter(cliCtx)
// If the --web flag is enabled to administer the validator
// client via a web portal, we start the validator client in a different way.
// Change Web flag name to enable keymanager API, look at merging initializeFromCLI and initializeForWeb maybe after WebUI DEPRECATED.
@@ -151,6 +150,19 @@ func NewValidatorClient(cliCtx *cli.Context) (*ValidatorClient, error) {
return validatorClient, nil
}
func newRouter(cliCtx *cli.Context) *mux.Router {
var allowedOrigins []string
if cliCtx.IsSet(flags.GPRCGatewayCorsDomain.Name) {
allowedOrigins = strings.Split(cliCtx.String(flags.GPRCGatewayCorsDomain.Name), ",")
} else {
allowedOrigins = strings.Split(flags.GPRCGatewayCorsDomain.Value, ",")
}
r := mux.NewRouter()
r.Use(server.NormalizeQueryValuesHandler)
r.Use(server.CorsHandler(allowedOrigins))
return r
}
// Start every service in the validator client.
func (c *ValidatorClient) Start() {
c.lock.Lock()