add universal auth to agent

This commit is contained in:
Maidul Islam
2023-12-12 19:36:48 -05:00
parent fbe0cf006f
commit c92c0f7288
5 changed files with 315 additions and 110 deletions

15
cli/agent-config.yaml Normal file
View File

@@ -0,0 +1,15 @@
infisical:
address: "http://localhost:8080"
auth:
type: "universal-auth"
config:
client-id: "./client-id"
client-secret: "./client-secret"
remove_client_secret_on_read: false
sinks:
- type: "file"
config:
path: "access-token"
templates:
- source-path: my-dot-ev-secret-template
destination-path: my-dot-env.env

View File

@@ -1,17 +0,0 @@
infisical:
address: "http://localhost:8080"
auth:
type: "token"
config:
token-path: "./role-id"
sinks:
- type: "file"
config:
path: "/Users/maidulislam/Desktop/test/infisical-token"
- type: "file"
config:
path: "access-token"
- type: "file"
config:
path: "maiduls-access-token"
templates:

View File

@@ -425,24 +425,44 @@ func CallCreateServiceToken(httpClient *resty.Client, request CreateServiceToken
return createServiceTokenResponse, nil return createServiceTokenResponse, nil
} }
func CallServiceTokenV3Refresh(httpClient *resty.Client, request ServiceTokenV3RefreshTokenRequest) (ServiceTokenV3RefreshTokenResponse, error) { func CallUniversalAuthLogin(httpClient *resty.Client, request UniversalAuthLoginRequest) (UniversalAuthLoginResponse, error) {
var serviceTokenV3RefreshTokenResponse ServiceTokenV3RefreshTokenResponse var universalAuthLoginResponse UniversalAuthLoginResponse
response, err := httpClient. response, err := httpClient.
R(). R().
SetResult(&serviceTokenV3RefreshTokenResponse). SetResult(&universalAuthLoginResponse).
SetHeader("User-Agent", USER_AGENT). SetHeader("User-Agent", USER_AGENT).
SetBody(request). SetBody(request).
Post(fmt.Sprintf("%v/v3/service-token/me/token", config.INFISICAL_URL)) Post(fmt.Sprintf("%v/v1/auth/universal-auth/login/", config.INFISICAL_URL))
if err != nil { if err != nil {
return ServiceTokenV3RefreshTokenResponse{}, fmt.Errorf("CallServiceTokenV3Refresh: Unable to complete api request [err=%s]", err) return UniversalAuthLoginResponse{}, fmt.Errorf("CallUniversalAuthLogin: Unable to complete api request [err=%s]", err)
} }
if response.IsError() { if response.IsError() {
return ServiceTokenV3RefreshTokenResponse{}, fmt.Errorf("CallServiceTokenV3Refresh: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String()) return UniversalAuthLoginResponse{}, fmt.Errorf("CallUniversalAuthLogin: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
} }
return serviceTokenV3RefreshTokenResponse, nil return universalAuthLoginResponse, nil
}
func CallUniversalAuthRefreshAccessToken(httpClient *resty.Client, request UniversalAuthRefreshRequest) (UniversalAuthRefreshResponse, error) {
var universalAuthRefreshResponse UniversalAuthRefreshResponse
response, err := httpClient.
R().
SetResult(&universalAuthRefreshResponse).
SetHeader("User-Agent", USER_AGENT).
SetBody(request).
Post(fmt.Sprintf("%v/v1/auth/token/renew", config.INFISICAL_URL))
if err != nil {
return UniversalAuthRefreshResponse{}, fmt.Errorf("CallUniversalAuthRefreshAccessToken: Unable to complete api request [err=%s]", err)
}
if response.IsError() {
return UniversalAuthRefreshResponse{}, fmt.Errorf("CallUniversalAuthRefreshAccessToken: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
}
return universalAuthRefreshResponse, nil
} }
func CallGetRawSecretsV3(httpClient *resty.Client, request GetRawSecretsV3Request) (GetRawSecretsV3Response, error) { func CallGetRawSecretsV3(httpClient *resty.Client, request GetRawSecretsV3Request) (GetRawSecretsV3Response, error) {
@@ -466,7 +486,7 @@ func CallGetRawSecretsV3(httpClient *resty.Client, request GetRawSecretsV3Reques
} }
if response.IsError() { if response.IsError() {
return GetRawSecretsV3Response{}, fmt.Errorf("CallGetRawSecretsV3: Unsuccessful response [%v %v] [status-code=%v]", response.Request.Method, response.Request.URL, response.StatusCode()) return GetRawSecretsV3Response{}, fmt.Errorf("CallUniversalAuthLogin: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String())
} }
return getRawSecretsV3Response, nil return getRawSecretsV3Response, nil

View File

@@ -463,14 +463,27 @@ type CreateServiceTokenResponse struct {
ServiceTokenData ServiceTokenData `json:"serviceTokenData"` ServiceTokenData ServiceTokenData `json:"serviceTokenData"`
} }
type ServiceTokenV3RefreshTokenRequest struct { type UniversalAuthLoginRequest struct {
RefreshToken string `json:"refresh_token"` ClientSecret string `json:"clientSecret"`
ClientId string `json:"clientId"`
} }
type ServiceTokenV3RefreshTokenResponse struct {
RefreshToken string `json:"refresh_token"` type UniversalAuthLoginResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"accessToken"`
ExpiresIn int `json:"expires_in"` AccessTokenTTL int `json:"expiresIn"`
TokenType string `json:"token_type"` TokenType string `json:"tokenType"`
AccessTokenMaxTTL int `json:"accessTokenMaxTTL"`
}
type UniversalAuthRefreshRequest struct {
AccessToken string `json:"accessToken"`
}
type UniversalAuthRefreshResponse struct {
AccessToken string `json:"accessToken"`
AccessTokenTTL int `json:"expiresIn"`
TokenType string `json:"tokenType"`
AccessTokenMaxTTL int `json:"accessTokenMaxTTL"`
} }
type GetRawSecretsV3Request struct { type GetRawSecretsV3Request struct {

View File

@@ -5,12 +5,12 @@ package cmd
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"os/signal" "os/signal"
"strings" "strings"
"sync"
"syscall" "syscall"
"text/template" "text/template"
"time" "time"
@@ -44,8 +44,10 @@ type AuthConfig struct {
Config interface{} `yaml:"config"` Config interface{} `yaml:"config"`
} }
type TokenAuthConfig struct { type UniversalAuth struct {
TokenPath string `yaml:"token-path"` ClientIDPath string `yaml:"client-id"`
ClientSecretPath string `yaml:"client-secret"`
RemoveClientSecretOnRead bool `yaml:"remove_client_secret_on_read"`
} }
type OAuthConfig struct { type OAuthConfig struct {
@@ -149,11 +151,12 @@ func ParseAgentConfig(filePath string) (*Config, error) {
} }
switch rawConfig.Auth.Type { switch rawConfig.Auth.Type {
case "token": case "universal-auth":
var tokenConfig TokenAuthConfig var tokenConfig UniversalAuth
if err := yaml.Unmarshal(configBytes, &tokenConfig); err != nil { if err := yaml.Unmarshal(configBytes, &tokenConfig); err != nil {
return nil, err return nil, err
} }
config.Auth.Config = tokenConfig config.Auth.Config = tokenConfig
case "oauth": // aws, gcp, k8s service account, etc case "oauth": // aws, gcp, k8s service account, etc
var oauthConfig OAuthConfig var oauthConfig OAuthConfig
@@ -199,59 +202,235 @@ func ProcessTemplate(templatePath string, data interface{}, accessToken string)
return &buf, nil return &buf, nil
} }
func refreshTokenAndProcessTemplate(refreshToken string, config *Config, errChan chan error) { type TokenManager struct {
for { accessToken string
httpClient := resty.New() accessTokenTTL time.Duration
httpClient.SetRetryCount(10000). accessTokenMaxTTL time.Duration
SetRetryMaxWaitTime(20 * time.Second). accessTokenFetchedTime time.Time
SetRetryWaitTime(5 * time.Second) accessTokenRefreshedTime time.Time
mutex sync.Mutex
filePaths []Sink // Store file paths if needed
templates []Template
clientIdPath string
clientSecretPath string
newAccessTokenNotificationChan chan bool
removeClientSecretOnRead bool
cachedClientSecret string
}
tokenResponse, err := api.CallServiceTokenV3Refresh(httpClient, api.ServiceTokenV3RefreshTokenRequest{RefreshToken: refreshToken}) func NewTokenManager(fileDeposits []Sink, templates []Template, clientIdPath string, clientSecretPath string, newAccessTokenNotificationChan chan bool, removeClientSecretOnRead bool) *TokenManager {
if err != nil { return &TokenManager{filePaths: fileDeposits, templates: templates, clientIdPath: clientIdPath, clientSecretPath: clientSecretPath, newAccessTokenNotificationChan: newAccessTokenNotificationChan, removeClientSecretOnRead: removeClientSecretOnRead}
errChan <- fmt.Errorf("unable to complete renewal because [%s]", err) }
}
for _, sinkFile := range config.Sinks { func (tm *TokenManager) SetToken(token string, accessTokenTTL time.Duration, accessTokenMaxTTL time.Duration) {
if sinkFile.Type == "file" { tm.mutex.Lock()
err = ioutil.WriteFile(sinkFile.Config.Path, []byte(tokenResponse.AccessToken), 0644) defer tm.mutex.Unlock()
if err != nil {
errChan <- err
return
}
} else {
errChan <- errors.New("unsupported sink type. Only 'file' type is supported")
return
}
}
refreshToken = tokenResponse.RefreshToken tm.accessToken = token
nextRefreshCycle := time.Duration(tokenResponse.ExpiresIn-5) * time.Second // when the next access refresh will happen tm.accessTokenTTL = accessTokenTTL
tm.accessTokenMaxTTL = accessTokenMaxTTL
d, err := time.ParseDuration(nextRefreshCycle.String()) tm.newAccessTokenNotificationChan <- true
if err != nil { }
errChan <- fmt.Errorf("unable to parse refresh time because %s", err)
return
}
log.Info().Msgf("token refreshed and saved to selected path; next cycle will occur in %s", d.String()) func (tm *TokenManager) GetToken() string {
tm.mutex.Lock()
defer tm.mutex.Unlock()
for _, secretTemplate := range config.Templates { return tm.accessToken
processedTemplate, err := ProcessTemplate(secretTemplate.SourcePath, nil, tokenResponse.AccessToken) }
if err != nil {
errChan <- err
return
}
if err := WriteBytesToFile(processedTemplate, secretTemplate.DestinationPath); err != nil { // Fetches a new access token using client credentials
errChan <- err func (tm *TokenManager) FetchNewAccessToken() error {
return clientIDAsByte, err := ReadFile(tm.clientIdPath)
} if err != nil {
return fmt.Errorf("unable to read client id from file path '%s' due to error: %v", tm.clientIdPath, err)
log.Info().Msgf("secret template at path %s has been rendered and saved to path %s", secretTemplate.SourcePath, secretTemplate.DestinationPath)
}
time.Sleep(nextRefreshCycle)
} }
clientSecretAsByte, err := ReadFile(tm.clientSecretPath)
if err != nil {
if len(tm.cachedClientSecret) == 0 {
return fmt.Errorf("unable to read client secret from file and no cached client secret found: %v", err)
} else {
clientSecretAsByte = []byte(tm.cachedClientSecret)
}
}
// remove client secret after first read
if tm.removeClientSecretOnRead {
os.Remove(tm.clientSecretPath)
}
clientId := string(clientIDAsByte)
clientSecret := string(clientSecretAsByte)
// save as cache in memory
tm.cachedClientSecret = clientSecret
err, loginResponse := universalAuthLogin(clientId, clientSecret)
if err != nil {
return err
}
accessTokenTTL := time.Duration(loginResponse.AccessTokenTTL * int(time.Second))
accessTokenMaxTTL := time.Duration(loginResponse.AccessTokenMaxTTL * int(time.Second))
if accessTokenTTL <= time.Duration(5)*time.Second {
util.PrintErrorMessageAndExit("At this this, agent does not support refresh of tokens with 5 seconds or less ttl. Please increase access token ttl and try again")
}
tm.accessTokenFetchedTime = time.Now()
tm.SetToken(loginResponse.AccessToken, accessTokenTTL, accessTokenMaxTTL)
return nil
}
// Refreshes the existing access token
func (tm *TokenManager) RefreshAccessToken() error {
httpClient := resty.New()
httpClient.SetRetryCount(10000).
SetRetryMaxWaitTime(20 * time.Second).
SetRetryWaitTime(5 * time.Second)
accessToken := tm.GetToken()
response, err := api.CallUniversalAuthRefreshAccessToken(httpClient, api.UniversalAuthRefreshRequest{AccessToken: accessToken})
if err != nil {
return err
}
accessTokenTTL := time.Duration(response.AccessTokenTTL * int(time.Second))
accessTokenMaxTTL := time.Duration(response.AccessTokenMaxTTL * int(time.Second))
tm.accessTokenRefreshedTime = time.Now()
tm.SetToken(response.AccessToken, accessTokenTTL, accessTokenMaxTTL)
return nil
}
func (tm *TokenManager) ManageTokenLifecycle() {
for {
accessTokenMaxTTLExpiresInTime := tm.accessTokenFetchedTime.Add(tm.accessTokenMaxTTL - (5 * time.Second))
accessTokenRefreshedTime := tm.accessTokenRefreshedTime
if accessTokenRefreshedTime.IsZero() {
accessTokenRefreshedTime = tm.accessTokenFetchedTime
}
nextAccessTokenExpiresInTime := accessTokenRefreshedTime.Add(tm.accessTokenTTL - (5 * time.Second))
if tm.accessTokenFetchedTime.IsZero() && tm.accessTokenRefreshedTime.IsZero() {
// case: init login to get access token
log.Info().Msg("attempting to authenticate...")
err := tm.FetchNewAccessToken()
if err != nil {
log.Error().Msgf("unable to authenticate because %v. Will retry in 30 seconds", err)
// wait a bit before trying again
time.Sleep((30 * time.Second))
continue
}
} else if time.Now().After(accessTokenMaxTTLExpiresInTime) {
log.Info().Msgf("token has reached max ttl, attempting to re authenticate...")
err := tm.FetchNewAccessToken()
if err != nil {
log.Error().Msgf("unable to authenticate because %v. Will retry in 30 seconds", err)
// wait a bit before trying again
time.Sleep((30 * time.Second))
continue
}
} else {
log.Info().Msgf("attempting to refresh existing token...")
err := tm.RefreshAccessToken()
if err != nil {
log.Error().Msgf("unable to refresh token because %v. Will retry in 30 seconds", err)
// wait a bit before trying again
time.Sleep((30 * time.Second))
continue
}
}
if accessTokenRefreshedTime.IsZero() {
accessTokenRefreshedTime = tm.accessTokenFetchedTime
} else {
accessTokenRefreshedTime = tm.accessTokenRefreshedTime
}
nextAccessTokenExpiresInTime = accessTokenRefreshedTime.Add(tm.accessTokenTTL - (5 * time.Second))
accessTokenMaxTTLExpiresInTime = tm.accessTokenFetchedTime.Add(tm.accessTokenMaxTTL - (5 * time.Second))
if nextAccessTokenExpiresInTime.After(accessTokenMaxTTLExpiresInTime) {
// case: Refreshed so close that the next refresh would occur beyond max ttl (this is because currently, token renew tries to add +access-token-ttl amount of time)
// example: access token ttl is 11 sec and max ttl is 30 sec. So it will start with 11 seconds, then 22 seconds but the next time you call refresh it would try to extend it to 33 but max ttl only allows 30, so the token will be valid until 30 before we need to reauth
time.Sleep(tm.accessTokenTTL - nextAccessTokenExpiresInTime.Sub(accessTokenMaxTTLExpiresInTime))
} else {
time.Sleep(tm.accessTokenTTL - (5 * time.Second))
}
}
}
func (tm *TokenManager) WriteTokenToFiles() {
token := tm.GetToken()
for _, sinkFile := range tm.filePaths {
if sinkFile.Type == "file" {
err := ioutil.WriteFile(sinkFile.Config.Path, []byte(token), 0644)
if err != nil {
log.Error().Msgf("unable to write file sink to path '%s' because %v", sinkFile.Config.Path, err)
}
log.Info().Msgf("new access token saved to file at path '%s'", sinkFile.Config.Path)
} else {
log.Error().Msg("unsupported sink type. Only 'file' type is supported")
}
}
}
func (tm *TokenManager) FetchSecrets() {
log.Info().Msgf("template engine started...")
for {
token := tm.GetToken()
if token != "" {
for _, secretTemplate := range tm.templates {
processedTemplate, err := ProcessTemplate(secretTemplate.SourcePath, nil, token)
if err != nil {
log.Error().Msgf("template engine: unable to render secrets because %s. Will try again in 30 seconds", err)
// wait a bit before trying again
time.Sleep((30 * time.Second))
continue
}
if err := WriteBytesToFile(processedTemplate, secretTemplate.DestinationPath); err != nil {
log.Error().Msgf("template engine: unable to write secrets to path because %s. Will try again in 30 seconds", err)
// wait a bit before trying again
time.Sleep((30 * time.Second))
continue
}
log.Info().Msgf("template engine: secret template at path %s has been rendered and saved to path %s", secretTemplate.SourcePath, secretTemplate.DestinationPath)
}
// fetch new secrets every 5 minutes (TODO: add PubSub in the future )
time.Sleep(5 * time.Minute)
}
}
}
func universalAuthLogin(clientId string, clientSecret string) (error, api.UniversalAuthLoginResponse) {
httpClient := resty.New()
httpClient.SetRetryCount(10000).
SetRetryMaxWaitTime(20 * time.Second).
SetRetryWaitTime(5 * time.Second)
tokenResponse, err := api.CallUniversalAuthLogin(httpClient, api.UniversalAuthLoginRequest{ClientId: clientId, ClientSecret: clientSecret})
if err != nil {
return err, api.UniversalAuthLoginResponse{}
}
return nil, tokenResponse
} }
// runCmd represents the run command // runCmd represents the run command
@@ -282,36 +461,31 @@ var agentCmd = &cobra.Command{
return return
} }
errChan := make(chan error) if agentConfig.Auth.Type != "universal-auth" {
sigChan := make(chan os.Signal, 1) util.PrintErrorMessageAndExit("Only auth type of 'universal-auth' is supported at this time")
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
switch configAuthType := agentConfig.Auth.Config.(type) {
case TokenAuthConfig:
content, err := ReadFile(configAuthType.TokenPath)
if err != nil {
log.Error().Msgf("unable to read initial token from file path %s because %v", configAuthType.TokenPath, err)
return
}
refreshToken := string(content)
go refreshTokenAndProcessTemplate(refreshToken, agentConfig, errChan)
case OAuthConfig:
// future auth types
default:
log.Error().Msgf("unknown auth config type. Only 'file' type is supported")
return
} }
select { configUniversalAuthType := agentConfig.Auth.Config.(UniversalAuth)
case err := <-errChan:
log.Fatal().Msgf("agent stopped due to error: %v", err) tokenRefreshNotifier := make(chan bool)
os.Exit(1) sigChan := make(chan os.Signal, 1)
case <-sigChan: signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
log.Info().Msg("agent is gracefully shutting...")
os.Exit(1) filePaths := agentConfig.Sinks
tm := NewTokenManager(filePaths, agentConfig.Templates, configUniversalAuthType.ClientIDPath, configUniversalAuthType.ClientSecretPath, tokenRefreshNotifier, configUniversalAuthType.RemoveClientSecretOnRead)
go tm.ManageTokenLifecycle()
go tm.FetchSecrets()
for {
select {
case <-tokenRefreshNotifier:
go tm.WriteTokenToFiles()
case <-sigChan:
log.Info().Msg("agent is gracefully shutting...")
// TODO: check if we are in the middle of writing files to disk
os.Exit(1)
}
} }
}, },