diff --git a/cli/cli.go b/cli/cli.go index 5dd21281..011b2bcf 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -57,7 +57,7 @@ func Cli(version string) (err error) { if currentFlags.Serve { registry.ConfigureVendors() - err = restapi.Serve(registry, currentFlags.ServeAddress) + err = restapi.Serve(registry, currentFlags.ServeAddress, currentFlags.ServeApiKey) return } diff --git a/cli/flags.go b/cli/flags.go index fcd42522..0210c96c 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -64,6 +64,7 @@ type Flags struct { Serve bool `long:"serve" description:"Serve the Fabric Rest API"` ServeOllama bool `long:"serveOllama" description:"Serve the Fabric Rest API with ollama endpoints"` ServeAddress string `long:"address" description:"The address to bind the REST API" default:":8080"` + ServeApiKey string `long:"apikey" description:"API key used to secure server routes" default:""` Config string `long:"config" description:"Path to YAML config file"` Version bool `long:"version" description:"Print current version"` ListExtensions bool `long:"listextensions" description:"List all registered extensions"` diff --git a/restapi/auth.go b/restapi/auth.go new file mode 100644 index 00000000..c034bee9 --- /dev/null +++ b/restapi/auth.go @@ -0,0 +1,21 @@ +package restapi + +import ( + "net/http" + "fmt" + + "github.com/gin-gonic/gin" +) + +func ApiKeyMiddleware(apiKey string) gin.HandlerFunc { + return func(c *gin.Context) { + headerApiKey := c.GetHeader("X-API-Key") + + if headerApiKey != apiKey { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("Wrong or missing API Key")}) + return + } + + c.Next() + } +} \ No newline at end of file diff --git a/restapi/serve.go b/restapi/serve.go index 51f6cd2b..c9b364a9 100644 --- a/restapi/serve.go +++ b/restapi/serve.go @@ -5,13 +5,17 @@ import ( "github.com/gin-gonic/gin" ) -func Serve(registry *core.PluginRegistry, address string) (err error) { +func Serve(registry *core.PluginRegistry, address string, apiKey string) (err error) { r := gin.New() // Middleware r.Use(gin.Logger()) r.Use(gin.Recovery()) + if apiKey != "" { + r.Use(ApiKeyMiddleware(apiKey)) + } + // Register routes fabricDb := registry.Db NewPatternsHandler(r, fabricDb.Patterns)