mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-07 22:54:06 -05:00
resolve gemini comment
This commit is contained in:
@@ -66,6 +66,9 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
|
||||
if c.AllowedOrigins == nil {
|
||||
c.AllowedOrigins = []string{"*"}
|
||||
}
|
||||
if c.AllowedHosts == nil {
|
||||
c.AllowedHosts = []string{"*"}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
|
||||
@@ -188,16 +188,14 @@ description: >
|
||||
path: tools.yaml
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
To prevent DNS rebinding attack, use the `--allowed-hosts` flag to specify a
|
||||
list of hosts for validation. E.g. `command: [ "toolbox",
|
||||
"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0",
|
||||
"--allowed-hosts", "localhost:5000"]`
|
||||
{{< notice tip >}}
|
||||
To prevent DNS rebinding attack, use the `--allowed-origins` flag to specify a
|
||||
list of origins permitted to access the server. E.g. `args: ["--address",
|
||||
"0.0.0.0", "--allowed-hosts", "foo.bar:5000"]`
|
||||
|
||||
To implement CORs, use the `--allowed-origins` flag to specify a
|
||||
list of origins permitted to access the server. E.g. `command: [ "toolbox",
|
||||
"--tools-file", "/config/tools.yaml", "--address", "0.0.0.0",
|
||||
"--allowed-origins", "https://foo.bar"]`
|
||||
list of origins permitted to access the server. E.g. `args: ["--address",
|
||||
"0.0.0.0", "--allowed-origins", "https://foo.bar"]`
|
||||
{{< /notice >}}
|
||||
|
||||
1. Create the deployment.
|
||||
|
||||
@@ -270,18 +270,12 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
func HostCheck(allowedHosts []string) func(http.Handler) http.Handler {
|
||||
func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
isAllowed := false
|
||||
for _, h := range allowedHosts {
|
||||
if h == "*" || r.Host == h {
|
||||
isAllowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isAllowed {
|
||||
_, hasWildcard := allowedHosts["*"]
|
||||
_, hostIsAllowed := allowedHosts[r.Host]
|
||||
if !hasWildcard && !hostIsAllowed {
|
||||
// Return 400 Bad Request or 403 Forbidden to block the attack
|
||||
http.Error(w, "Invalid Host header", http.StatusBadRequest)
|
||||
return
|
||||
@@ -380,7 +374,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
if slices.Contains(cfg.AllowedHosts, "*") {
|
||||
s.logger.WarnContext(ctx, "wildcard (`*`) allows all hosts to access the resource and is not secure. Use it with cautious for public, non-sensitive data, or during local development. Recommended to use `--allowed-hosts` flag to prevent DNS rebinding attacks")
|
||||
}
|
||||
r.Use(HostCheck(cfg.AllowedHosts))
|
||||
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
|
||||
for _, h := range cfg.AllowedHosts {
|
||||
allowedHostsMap[h] = struct{}{}
|
||||
}
|
||||
r.Use(hostCheck(allowedHostsMap))
|
||||
|
||||
// control plane
|
||||
apiR, err := apiRouter(s)
|
||||
|
||||
@@ -42,9 +42,10 @@ func TestServe(t *testing.T) {
|
||||
|
||||
addr, port := "127.0.0.1", 5000
|
||||
cfg := server.ServerConfig{
|
||||
Version: "0.0.0",
|
||||
Address: addr,
|
||||
Port: port,
|
||||
Version: "0.0.0",
|
||||
Address: addr,
|
||||
Port: port,
|
||||
AllowedHosts: []string{"*"},
|
||||
}
|
||||
|
||||
otelShutdown, err := telemetry.SetupOTel(ctx, "0.0.0", "", false, "toolbox")
|
||||
|
||||
Reference in New Issue
Block a user