Add CORS preflight support (#5177)

* Add CORS preflight support

* lint

* clarify description
This commit is contained in:
Preston Van Loon
2020-03-23 11:17:17 -07:00
committed by GitHub
parent b0128ad894
commit 5241582ece
6 changed files with 51 additions and 18 deletions

View File

@@ -57,6 +57,12 @@ var (
Name: "grpc-gateway-port",
Usage: "Enable gRPC gateway for JSON requests",
}
// GPRCGatewayCorsDomain serves preflight requests when serving gRPC JSON gateway.
GPRCGatewayCorsDomain = &cli.StringFlag{
Name: "grpc-gateway-corsdomain",
Usage: "Comma separated list of domains from which to accept cross origin requests " +
"(browser enforced). This flag has no effect if not used with --grpc-gateway-port.",
}
// MinSyncPeers specifies the required number of successful peer handshakes in order
// to start syncing with external peers.
MinSyncPeers = &cli.IntFlag{

View File

@@ -4,6 +4,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"cors.go",
"gateway.go",
"handlers.go",
"log.go",
@@ -16,6 +17,7 @@ go_library(
deps = [
"//shared:go_default_library",
"@com_github_prysmaticlabs_ethereumapis//eth/v1alpha1:go_grpc_gateway_library",
"@com_github_rs_cors//:go_default_library",
"@com_github_sirupsen_logrus//:go_default_library",
"@grpc_ecosystem_grpc_gateway//runtime:go_default_library",
"@org_golang_google_grpc//:go_default_library",

View File

@@ -0,0 +1,20 @@
package gateway
import (
"net/http"
"github.com/rs/cors"
)
func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler {
if len(allowedOrigins) == 0 {
return srv
}
c := cors.New(cors.Options{
AllowedOrigins: allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet},
MaxAge: 600,
AllowedHeaders: []string{"*"},
})
return c.Handler(srv)
}

View File

@@ -19,13 +19,14 @@ var _ = shared.Service(&Gateway{})
// Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward
// it to the beacon-chain gRPC server.
type Gateway struct {
conn *grpc.ClientConn
ctx context.Context
cancel context.CancelFunc
gatewayAddr string
remoteAddr string
server *http.Server
mux *http.ServeMux
conn *grpc.ClientConn
ctx context.Context
cancel context.CancelFunc
gatewayAddr string
remoteAddr string
server *http.Server
mux *http.ServeMux
allowedOrigins []string
startFailure error
}
@@ -64,7 +65,7 @@ func (g *Gateway) Start() {
g.server = &http.Server{
Addr: g.gatewayAddr,
Handler: g.mux,
Handler: newCorsHandler(g.mux, g.allowedOrigins),
}
go func() {
if err := g.server.ListenAndServe(); err != http.ErrServerClosed {
@@ -105,16 +106,17 @@ func (g *Gateway) Stop() error {
// New returns a new gateway server which translates HTTP into gRPC.
// Accepts a context and optional http.ServeMux.
func New(ctx context.Context, remoteAddress, gatewayAddress string, mux *http.ServeMux) *Gateway {
func New(ctx context.Context, remoteAddress, gatewayAddress string, mux *http.ServeMux, allowedOrigins []string) *Gateway {
if mux == nil {
mux = http.NewServeMux()
}
return &Gateway{
remoteAddr: remoteAddress,
gatewayAddr: gatewayAddress,
ctx: ctx,
mux: mux,
remoteAddr: remoteAddress,
gatewayAddr: gatewayAddress,
ctx: ctx,
mux: mux,
allowedOrigins: allowedOrigins,
}
}

View File

@@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"net/http"
"strings"
joonix "github.com/joonix/log"
"github.com/prysmaticlabs/prysm/beacon-chain/gateway"
@@ -13,9 +14,10 @@ import (
)
var (
beaconRPC = flag.String("beacon-rpc", "localhost:4000", "Beacon chain gRPC endpoint")
port = flag.Int("port", 8000, "Port to serve on")
debug = flag.Bool("debug", false, "Enable debug logging")
beaconRPC = flag.String("beacon-rpc", "localhost:4000", "Beacon chain gRPC endpoint")
port = flag.Int("port", 8000, "Port to serve on")
debug = flag.Bool("debug", false, "Enable debug logging")
allowedOrigins = flag.String("corsdomain", "", "A comma separated list of CORS domains to allow.")
)
func init() {
@@ -31,7 +33,7 @@ func main() {
}
mux := http.NewServeMux()
gw := gateway.New(context.Background(), *beaconRPC, fmt.Sprintf("0.0.0.0:%d", *port), mux)
gw := gateway.New(context.Background(), *beaconRPC, fmt.Sprintf("0.0.0.0:%d", *port), mux, strings.Split(*allowedOrigins, ","))
mux.HandleFunc("/swagger/", gateway.SwaggerServer())
mux.HandleFunc("/healthz", healthzServer(gw))
gw.Start()

View File

@@ -588,7 +588,8 @@ func (b *BeaconNode) registerGRPCGateway(ctx *cli.Context) error {
if gatewayPort > 0 {
selfAddress := fmt.Sprintf("127.0.0.1:%d", ctx.Int(flags.RPCPort.Name))
gatewayAddress := fmt.Sprintf("0.0.0.0:%d", gatewayPort)
return b.services.RegisterService(gateway.New(context.Background(), selfAddress, gatewayAddress, nil /*optional mux*/))
allowedOrigins := strings.Split(ctx.String(flags.GPRCGatewayCorsDomain.Name), ",")
return b.services.RegisterService(gateway.New(context.Background(), selfAddress, gatewayAddress, nil /*optional mux*/, allowedOrigins))
}
return nil
}