mirror of
https://github.com/OffchainLabs/prysm.git
synced 2026-01-09 07:28:06 -05:00
APIs: reusing grpc cors middleware for rest (#13284)
* reusing grpc cors middleware for rest * addressing radek's comments * Update api/server/middleware.go Co-authored-by: Sammy Rosso <15244892+saolyn@users.noreply.github.com> * fixing to recommended name * fixing naming * fixing rename on test --------- Co-authored-by: Sammy Rosso <15244892+saolyn@users.noreply.github.com> Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
@@ -14,11 +14,11 @@ go_library(
|
||||
"//validator:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//api/server:go_default_library",
|
||||
"//runtime:go_default_library",
|
||||
"@com_github_gorilla_mux//:go_default_library",
|
||||
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
|
||||
"@com_github_pkg_errors//:go_default_library",
|
||||
"@com_github_rs_cors//:go_default_library",
|
||||
"@com_github_sirupsen_logrus//:go_default_library",
|
||||
"@org_golang_google_grpc//:go_default_library",
|
||||
"@org_golang_google_grpc//connectivity:go_default_library",
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/prysmaticlabs/prysm/v4/api/server"
|
||||
"github.com/prysmaticlabs/prysm/v4/runtime"
|
||||
"github.com/rs/cors"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -104,7 +104,7 @@ func (g *Gateway) Start() {
|
||||
}
|
||||
}
|
||||
|
||||
corsMux := g.corsMiddleware(g.cfg.router)
|
||||
corsMux := server.CorsHandler(g.cfg.allowedOrigins).Middleware(g.cfg.router)
|
||||
|
||||
if g.cfg.muxHandler != nil {
|
||||
g.cfg.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -158,17 +158,6 @@ func (g *Gateway) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) corsMiddleware(h http.Handler) http.Handler {
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: g.cfg.allowedOrigins,
|
||||
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 600,
|
||||
AllowedHeaders: []string{"*"},
|
||||
})
|
||||
return c.Handler(h)
|
||||
}
|
||||
|
||||
// dial the gRPC server.
|
||||
func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientConn, error) {
|
||||
switch network {
|
||||
|
||||
@@ -8,6 +8,10 @@ go_library(
|
||||
],
|
||||
importpath = "github.com/prysmaticlabs/prysm/v4/api/server",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_github_gorilla_mux//:go_default_library",
|
||||
"@com_github_rs_cors//:go_default_library",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
|
||||
@@ -2,8 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
)
|
||||
|
||||
// NormalizeQueryValuesHandler normalizes an input query of "key=value1,value2,value3" to "key=value1&key=value2&key=value3"
|
||||
func NormalizeQueryValuesHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
@@ -13,3 +17,16 @@ func NormalizeQueryValuesHandler(next http.Handler) http.Handler {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CorsHandler sets the cors settings on api endpoints
|
||||
func CorsHandler(allowOrigins []string) mux.MiddlewareFunc {
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: allowOrigins,
|
||||
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 600,
|
||||
AllowedHeaders: []string{"*"},
|
||||
})
|
||||
|
||||
return c.Handler
|
||||
}
|
||||
|
||||
@@ -271,8 +271,7 @@ func New(cliCtx *cli.Context, cancel context.CancelFunc, opts ...Option) (*Beaco
|
||||
}
|
||||
|
||||
log.Debugln("Registering RPC Service")
|
||||
router := mux.NewRouter()
|
||||
router.Use(server.NormalizeQueryValuesHandler)
|
||||
router := newRouter(cliCtx)
|
||||
if err := beacon.registerRPCService(router); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -306,6 +305,19 @@ func New(cliCtx *cli.Context, cancel context.CancelFunc, opts ...Option) (*Beaco
|
||||
return beacon, nil
|
||||
}
|
||||
|
||||
func newRouter(cliCtx *cli.Context) *mux.Router {
|
||||
var allowedOrigins []string
|
||||
if cliCtx.IsSet(flags.GPRCGatewayCorsDomain.Name) {
|
||||
allowedOrigins = strings.Split(cliCtx.String(flags.GPRCGatewayCorsDomain.Name), ",")
|
||||
} else {
|
||||
allowedOrigins = strings.Split(flags.GPRCGatewayCorsDomain.Value, ",")
|
||||
}
|
||||
r := mux.NewRouter()
|
||||
r.Use(server.NormalizeQueryValuesHandler)
|
||||
r.Use(server.CorsHandler(allowedOrigins))
|
||||
return r
|
||||
}
|
||||
|
||||
// StateFeed implements statefeed.Notifier.
|
||||
func (b *BeaconNode) StateFeed() *event.Feed {
|
||||
return b.stateFeed
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -222,3 +224,50 @@ func Test_hasNetworkFlag(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
// Mock CLI context with a test CORS domain
|
||||
app := cli.App{}
|
||||
set := flag.NewFlagSet("test", 0)
|
||||
set.String(flags.GPRCGatewayCorsDomain.Name, "http://allowed-example.com", "")
|
||||
cliCtx := cli.NewContext(&app, set, nil)
|
||||
require.NoError(t, cliCtx.Set(flags.GPRCGatewayCorsDomain.Name, "http://allowed-example.com"))
|
||||
|
||||
router := newRouter(cliCtx)
|
||||
|
||||
// Ensure a test route exists
|
||||
router.HandleFunc("/some-path", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}).Methods(http.MethodGet)
|
||||
|
||||
// Define test cases
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
expectAllow bool
|
||||
}{
|
||||
{"AllowedOrigin", "http://allowed-example.com", true},
|
||||
{"DisallowedOrigin", "http://disallowed-example.com", false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
// Create a request and response recorder
|
||||
req := httptest.NewRequest("GET", "http://example.com/some-path", nil)
|
||||
req.Header.Set("Origin", tc.origin)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Serve HTTP
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
// Check the CORS headers based on the expected outcome
|
||||
if tc.expectAllow && rr.Header().Get("Access-Control-Allow-Origin") != tc.origin {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin header to be %v, got %v", tc.origin, rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
if !tc.expectAllow && rr.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Errorf("Expected Access-Control-Allow-Origin header to be empty for disallowed origin, got %v", rr.Header().Get("Access-Control-Allow-Origin"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,8 +128,7 @@ func NewValidatorClient(cliCtx *cli.Context) (*ValidatorClient, error) {
|
||||
configureFastSSZHashingAlgorithm()
|
||||
|
||||
// initialize router used for endpoints
|
||||
router := mux.NewRouter()
|
||||
router.Use(server.NormalizeQueryValuesHandler)
|
||||
router := newRouter(cliCtx)
|
||||
// If the --web flag is enabled to administer the validator
|
||||
// client via a web portal, we start the validator client in a different way.
|
||||
// Change Web flag name to enable keymanager API, look at merging initializeFromCLI and initializeForWeb maybe after WebUI DEPRECATED.
|
||||
@@ -151,6 +150,19 @@ func NewValidatorClient(cliCtx *cli.Context) (*ValidatorClient, error) {
|
||||
return validatorClient, nil
|
||||
}
|
||||
|
||||
func newRouter(cliCtx *cli.Context) *mux.Router {
|
||||
var allowedOrigins []string
|
||||
if cliCtx.IsSet(flags.GPRCGatewayCorsDomain.Name) {
|
||||
allowedOrigins = strings.Split(cliCtx.String(flags.GPRCGatewayCorsDomain.Name), ",")
|
||||
} else {
|
||||
allowedOrigins = strings.Split(flags.GPRCGatewayCorsDomain.Value, ",")
|
||||
}
|
||||
r := mux.NewRouter()
|
||||
r.Use(server.NormalizeQueryValuesHandler)
|
||||
r.Use(server.CorsHandler(allowedOrigins))
|
||||
return r
|
||||
}
|
||||
|
||||
// Start every service in the validator client.
|
||||
func (c *ValidatorClient) Start() {
|
||||
c.lock.Lock()
|
||||
|
||||
Reference in New Issue
Block a user