Compare commits

..

1 Commits

Author SHA1 Message Date
Yuan Teoh
347a480d0b chore: update host validation error to 403 2026-01-13 16:16:45 -08:00
5 changed files with 15 additions and 25 deletions

View File

@@ -384,7 +384,6 @@ func NewCommand(opts ...Option) *Command {
// TODO: Insecure by default. Might consider updating this for v1.0.0
flags.StringSliceVar(&cmd.cfg.AllowedOrigins, "allowed-origins", []string{"*"}, "Specifies a list of origins permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.AllowedHosts, "allowed-hosts", []string{"*"}, "Specifies a list of hosts permitted to access this server. Defaults to '*'.")
flags.StringSliceVar(&cmd.cfg.UserAgentExtra, "user-agent-extra", []string{}, "Appends additional metadata to the User-Agent.")
// wrap RunE command so that we have access to original Command object
cmd.RunE = func(*cobra.Command, []string) error { return run(cmd) }

View File

@@ -70,9 +70,6 @@ func withDefaults(c server.ServerConfig) server.ServerConfig {
if c.AllowedHosts == nil {
c.AllowedHosts = []string{"*"}
}
if c.UserAgentExtra == nil {
c.UserAgentExtra = []string{}
}
return c
}
@@ -233,13 +230,6 @@ func TestServerConfigFlags(t *testing.T) {
AllowedHosts: []string{"http://foo.com", "http://bar.com"},
}),
},
{
desc: "user agent extra",
args: []string{"--user-agent-extra", "foo,bar"},
want: withDefaults(server.ServerConfig{
UserAgentExtra: []string{"foo", "bar"},
}),
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {

View File

@@ -27,7 +27,6 @@ description: >
| | `--ui` | Launches the Toolbox UI web server. | |
| | `--allowed-origins` | Specifies a list of origins permitted to access this server for CORs access. | `*` |
| | `--allowed-hosts` | Specifies a list of hosts permitted to access this server to prevent DNS rebinding attacks. | `*` |
| | `--user-agent-extra` | Appends additional metadata to the User-Agent. | |
| `-v` | `--version` | version for toolbox | |
## Examples

View File

@@ -64,14 +64,12 @@ type ServerConfig struct {
Stdio bool
// DisableReload indicates if the user has disabled dynamic reloading for Toolbox.
DisableReload bool
// UI indicates if Toolbox UI endpoints (/ui) are available.
// UI indicates if Toolbox UI endpoints (/ui) are available
UI bool
// Specifies a list of origins permitted to access this server.
AllowedOrigins []string
// Specifies a list of hosts permitted to access this server.
// Specifies a list of hosts permitted to access this server
AllowedHosts []string
// UserAgentExtra specifies additional metadata to append to the User-Agent string.
UserAgentExtra []string
}
type logFormat string

View File

@@ -64,11 +64,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
map[string]prompts.Promptset,
error,
) {
metadataStr := cfg.Version
if len(cfg.UserAgentExtra) > 0 {
metadataStr += "+" + strings.Join(cfg.UserAgentExtra, ".")
}
ctx = util.WithUserAgent(ctx, metadataStr)
ctx = util.WithUserAgent(ctx, cfg.Version)
instrumentation, err := util.InstrumentationFromContext(ctx)
if err != nil {
panic(err)
@@ -308,10 +304,14 @@ func hostCheck(allowedHosts map[string]struct{}) func(http.Handler) http.Handler
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, hasWildcard := allowedHosts["*"]
_, hostIsAllowed := allowedHosts[r.Host]
hostname := r.Host
if host, _, err := net.SplitHostPort(r.Host); err == nil {
hostname = host
}
_, hostIsAllowed := allowedHosts[hostname]
if !hasWildcard && !hostIsAllowed {
// Return 400 Bad Request or 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusBadRequest)
// Return 403 Forbidden to block the attack
http.Error(w, "Invalid Host header", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
@@ -410,7 +410,11 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
}
allowedHostsMap := make(map[string]struct{}, len(cfg.AllowedHosts))
for _, h := range cfg.AllowedHosts {
allowedHostsMap[h] = struct{}{}
hostname := h
if host, _, err := net.SplitHostPort(h); err == nil {
hostname = host
}
allowedHostsMap[hostname] = struct{}{}
}
r.Use(hostCheck(allowedHostsMap))