mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
updated auth
This commit is contained in:
48
rnd/gosrv/.gitignore
vendored
Normal file
48
rnd/gosrv/.gitignore
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
*.exe~
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
|
||||
# Output of the go coverage tool, specifically when used with LiteIDE
|
||||
*.out
|
||||
|
||||
# Dependency directories (remove the comment below to include it)
|
||||
# vendor/
|
||||
|
||||
# Go workspace file
|
||||
go.work
|
||||
|
||||
# IDE-specific files
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# OS-specific files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Binary output directory
|
||||
/bin/
|
||||
|
||||
# Log files
|
||||
*.log
|
||||
|
||||
# Environment variables file
|
||||
.env
|
||||
|
||||
# Air temporary files (if using Air for live reloading)
|
||||
tmp/
|
||||
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
|
||||
# Debug files
|
||||
debug
|
||||
|
||||
# Project-specific build outputs
|
||||
/gosrv
|
||||
@@ -19,6 +19,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -a -installsuffix cgo -o main .
|
||||
FROM alpine:latest
|
||||
|
||||
RUN apk --no-cache add ca-certificates
|
||||
ENV GIN_MODE=release
|
||||
|
||||
WORKDIR /root/
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
type Config struct {
|
||||
ServerAddress string
|
||||
DatabaseURL string
|
||||
AuthEnabled bool
|
||||
JWTSecret string
|
||||
JWTAlgorithm string
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v4 v4.5.0
|
||||
github.com/jackc/pgx/v4 v4.18.3
|
||||
github.com/spf13/viper v1.19.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
go.uber.org/zap v1.27.0
|
||||
)
|
||||
|
||||
@@ -17,6 +18,7 @@ require (
|
||||
github.com/bytedance/sonic/loader v0.2.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.5 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
@@ -42,6 +44,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/sagikazarmark/locafero v0.6.0 // indirect
|
||||
github.com/sagikazarmark/slog-shim v0.1.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.0 // indirect
|
||||
|
||||
14
rnd/gosrv/handlers/user_handlers.go
Normal file
14
rnd/gosrv/handlers/user_handlers.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/swiftyos/market/models"
|
||||
)
|
||||
|
||||
func GetUserFromContext(c *gin.Context) (models.User, bool) {
|
||||
user, exists := c.Get("user")
|
||||
if !exists {
|
||||
return models.User{}, false
|
||||
}
|
||||
return user.(models.User), true
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func main() {
|
||||
{
|
||||
agents := api.Group("/agents")
|
||||
{
|
||||
agents.POST("/submit", middleware.Auth(), handlers.SubmitAgent(db, logger))
|
||||
agents.POST("/submit", middleware.Auth(cfg), handlers.SubmitAgent(db, logger))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,10 +8,22 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/swiftyos/market/config"
|
||||
"github.com/swiftyos/market/models"
|
||||
)
|
||||
|
||||
func Auth() gin.HandlerFunc {
|
||||
func Auth(cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !cfg.AuthEnabled {
|
||||
// This handles the case when authentication is disabled
|
||||
defaultUser := models.User{
|
||||
UserID: "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
Role: "admin",
|
||||
}
|
||||
c.Set("user", defaultUser)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Authorization header is missing"})
|
||||
@@ -19,7 +31,7 @@ func Auth() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token, err := parseJWTToken(tokenString)
|
||||
token, err := parseJWTToken(tokenString, cfg.JWTSecret)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -31,22 +43,36 @@ func Auth() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user", claims)
|
||||
user, err := verifyUser(claims, false) // Pass 'true' for admin-only routes
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user", user)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func parseJWTToken(tokenString string) (*jwt.Token, error) {
|
||||
cfg, err := config.Load()
|
||||
func verifyUser(payload jwt.MapClaims, adminOnly bool) (models.User, error) {
|
||||
user, err := models.NewUserFromPayload(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return models.User{}, err
|
||||
}
|
||||
|
||||
if adminOnly && user.Role != "admin" {
|
||||
return models.User{}, errors.New("Admin access required")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func parseJWTToken(tokenString string, secret string) (*jwt.Token, error) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("unexpected signing method")
|
||||
}
|
||||
return []byte(cfg.JWTSecret), nil
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
||||
213
rnd/gosrv/middleware/auth_test.go
Normal file
213
rnd/gosrv/middleware/auth_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/swiftyos/market/config"
|
||||
"github.com/swiftyos/market/models"
|
||||
)
|
||||
|
||||
func TestVerifyUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload jwt.MapClaims
|
||||
adminOnly bool
|
||||
wantUser models.User
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid user",
|
||||
payload: jwt.MapClaims{
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"role": "user",
|
||||
},
|
||||
adminOnly: false,
|
||||
wantUser: models.User{
|
||||
UserID: "test-user",
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid admin",
|
||||
payload: jwt.MapClaims{
|
||||
"sub": "admin-user",
|
||||
"email": "admin@example.com",
|
||||
"role": "admin",
|
||||
},
|
||||
adminOnly: true,
|
||||
wantUser: models.User{
|
||||
UserID: "admin-user",
|
||||
Email: "admin@example.com",
|
||||
Role: "admin",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Non-admin accessing admin-only route",
|
||||
payload: jwt.MapClaims{
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"role": "user",
|
||||
},
|
||||
adminOnly: true,
|
||||
wantUser: models.User{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Missing sub claim",
|
||||
payload: jwt.MapClaims{},
|
||||
adminOnly: false,
|
||||
wantUser: models.User{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotUser, err := verifyUser(tt.payload, tt.adminOnly)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantUser, gotUser)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJWTToken(t *testing.T) {
|
||||
secret := "test-secret"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tokenString string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Valid token",
|
||||
tokenString: createValidToken(secret),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
tokenString: "invalid.token.string",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Empty token",
|
||||
tokenString: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := parseJWTToken(tt.tokenString, secret)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, token)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, token)
|
||||
assert.True(t, token.Valid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createValidToken(secret string) string {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"role": "user",
|
||||
})
|
||||
tokenString, _ := token.SignedString([]byte(secret))
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func TestAuth(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
JWTSecret: "test-secret",
|
||||
AuthEnabled: true,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
expectedUser models.User
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid token",
|
||||
authHeader: "Bearer " + createValidToken(cfg.JWTSecret),
|
||||
expectedUser: models.User{
|
||||
UserID: "test-user",
|
||||
Email: "test@example.com",
|
||||
Role: "user",
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid token",
|
||||
authHeader: "Bearer invalid.token.string",
|
||||
expectedUser: models.User{},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing auth header",
|
||||
authHeader: "",
|
||||
expectedUser: models.User{},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a mock gin.Context
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
c.Request.Header.Set("Authorization", tt.authHeader)
|
||||
|
||||
// Call the Auth middleware
|
||||
Auth(cfg)(c)
|
||||
|
||||
// Check the results
|
||||
if tt.expectedError {
|
||||
assert.True(t, c.IsAborted())
|
||||
} else {
|
||||
assert.False(t, c.IsAborted())
|
||||
user, exists := c.Get("user")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, tt.expectedUser, user.(models.User))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthDisabled(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
JWTSecret: "test-secret",
|
||||
AuthEnabled: false,
|
||||
}
|
||||
|
||||
// Create a mock gin.Context
|
||||
c, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||
c.Request = httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
Auth(cfg)(c)
|
||||
|
||||
assert.False(t, c.IsAborted())
|
||||
user, exists := c.Get("user")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, models.User{
|
||||
UserID: "3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
Role: "admin",
|
||||
}, user.(models.User))
|
||||
}
|
||||
28
rnd/gosrv/models/user.go
Normal file
28
rnd/gosrv/models/user.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func NewUserFromPayload(claims jwt.MapClaims) (User, error) {
|
||||
userID, ok := claims["sub"].(string)
|
||||
if !ok {
|
||||
return User{}, fmt.Errorf("invalid or missing 'sub' claim")
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
role, _ := claims["role"].(string)
|
||||
|
||||
return User{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
Role: role,
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user