diff --git a/cmd/root_test.go b/cmd/root_test.go index 4e47783385..fcd7fcd3a3 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -66,6 +66,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig { if c.AllowedOrigins == nil { c.AllowedOrigins = []string{"*"} } + if c.AllowedHosts == nil { + c.AllowedHosts = []string{"*"} + } return c } diff --git a/docs/en/how-to/deploy_gke.md b/docs/en/how-to/deploy_gke.md index 4c4bc6cca0..058976ee0a 100644 --- a/docs/en/how-to/deploy_gke.md +++ b/docs/en/how-to/deploy_gke.md @@ -188,16 +188,14 @@ description: > path: tools.yaml ``` - {{< notice tip >}} -To prevent DNS rebinding attack, use the `--allowed-hosts` flag to specify a -list of hosts for validation. E.g. `command: [ "toolbox", -"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0", -"--allowed-hosts", "localhost:5000"]` + {{< notice tip >}} +To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a +list of origins permitted to access the server. E.g. `args: ["--address", +"0.0.0.0", "--allowed-hosts", "foo.bar:5000"]` To implement CORs, use the `--allowed-origins` flag to specify a -list of origins permitted to access the server. E.g. `command: [ "toolbox", -"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0", -"--allowed-origins", "https://foo.bar"]` +list of origins permitted to access the server. E.g. `args: ["--address", +"0.0.0.0", "--allowed-origins", "https://foo.bar"]` {{< /notice >}} 1. Create the deployment. diff --git a/internal/server/server.go b/internal/server/server.go index 87069d18b0..4bbc85adc5 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -270,18 +270,12 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil } -func HostCheck(allowedHosts []string) func(http.Handler) http.Handler { +func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - isAllowed := false - for _, h := range allowedHosts { - if h == "*" || r.Host == h { - isAllowed = true - break - } - } - - if !isAllowed { + _, hasWildcard := allowedHosts["*"] + _, hostIsAllowed := allowedHosts[r.Host] + if !hasWildcard && !hostIsAllowed { // Return 400 Bad Request or 403 Forbidden to block the attack http.Error(w, "Invalid Host header", http.StatusBadRequest) return @@ -380,7 +374,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { if slices.Contains(cfg.AllowedHosts, "*") { s.logger.WarnContext(ctx, "wildcard (`*`) allows all hosts to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-hosts` flag to prevent DNS rebinding attacks") } - r.Use(HostCheck(cfg.AllowedHosts)) + allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts)) + for _, h := range cfg.AllowedHosts { + allowedHostsMap[h] = struct{}{} + } + r.Use(hostCheck(allowedHostsMap)) // control plane apiR, err := apiRouter(s) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 1d11379d9c..9c753e48bc 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -42,9 +42,10 @@ func TestServe(t *testing.T) { addr, port := "127.0.0.1", 5000 cfg := server.ServerConfig{ - Version: "0.0.0", - Address: addr, - Port: port, + Version: "0.0.0", + Address: addr, + Port: port, + AllowedHosts: []string{"*"}, } otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox")