From c92c0f72889b3a0d4c4f77f6dc1b4503b5824b87 Mon Sep 17 00:00:00 2001 From: Maidul Islam Date: Tue, 12 Dec 2023 19:36:48 -0500 Subject: [PATCH] add universal auth to agent --- cli/agent-config.yaml | 15 ++ cli/example-agent-config.yaml | 17 -- cli/packages/api/api.go | 36 +++- cli/packages/api/model.go | 27 ++- cli/packages/cmd/agent.go | 330 ++++++++++++++++++++++++++-------- 5 files changed, 315 insertions(+), 110 deletions(-) create mode 100644 cli/agent-config.yaml delete mode 100644 cli/example-agent-config.yaml diff --git a/cli/agent-config.yaml b/cli/agent-config.yaml new file mode 100644 index 0000000000..ae130d3a8e --- /dev/null +++ b/cli/agent-config.yaml @@ -0,0 +1,15 @@ +infisical: + address: "http://localhost:8080" +auth: + type: "universal-auth" + config: + client-id: "./client-id" + client-secret: "./client-secret" + remove_client_secret_on_read: false +sinks: + - type: "file" + config: + path: "access-token" +templates: + - source-path: my-dot-ev-secret-template + destination-path: my-dot-env.env diff --git a/cli/example-agent-config.yaml b/cli/example-agent-config.yaml deleted file mode 100644 index 0afbe66f01..0000000000 --- a/cli/example-agent-config.yaml +++ /dev/null @@ -1,17 +0,0 @@ -infisical: - address: "http://localhost:8080" -auth: - type: "token" - config: - token-path: "./role-id" -sinks: - - type: "file" - config: - path: "/Users/maidulislam/Desktop/test/infisical-token" - - type: "file" - config: - path: "access-token" - - type: "file" - config: - path: "maiduls-access-token" -templates: diff --git a/cli/packages/api/api.go b/cli/packages/api/api.go index c57cec59f3..305df6e461 100644 --- a/cli/packages/api/api.go +++ b/cli/packages/api/api.go @@ -425,24 +425,44 @@ func CallCreateServiceToken(httpClient *resty.Client, request CreateServiceToken return createServiceTokenResponse, nil } -func CallServiceTokenV3Refresh(httpClient *resty.Client, request ServiceTokenV3RefreshTokenRequest) (ServiceTokenV3RefreshTokenResponse, error) { - var serviceTokenV3RefreshTokenResponse ServiceTokenV3RefreshTokenResponse +func CallUniversalAuthLogin(httpClient *resty.Client, request UniversalAuthLoginRequest) (UniversalAuthLoginResponse, error) { + var universalAuthLoginResponse UniversalAuthLoginResponse response, err := httpClient. R(). - SetResult(&serviceTokenV3RefreshTokenResponse). + SetResult(&universalAuthLoginResponse). SetHeader("User-Agent", USER_AGENT). SetBody(request). - Post(fmt.Sprintf("%v/v3/service-token/me/token", config.INFISICAL_URL)) + Post(fmt.Sprintf("%v/v1/auth/universal-auth/login/", config.INFISICAL_URL)) if err != nil { - return ServiceTokenV3RefreshTokenResponse{}, fmt.Errorf("CallServiceTokenV3Refresh: Unable to complete api request [err=%s]", err) + return UniversalAuthLoginResponse{}, fmt.Errorf("CallUniversalAuthLogin: Unable to complete api request [err=%s]", err) } if response.IsError() { - return ServiceTokenV3RefreshTokenResponse{}, fmt.Errorf("CallServiceTokenV3Refresh: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String()) + return UniversalAuthLoginResponse{}, fmt.Errorf("CallUniversalAuthLogin: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String()) } - return serviceTokenV3RefreshTokenResponse, nil + return universalAuthLoginResponse, nil +} + +func CallUniversalAuthRefreshAccessToken(httpClient *resty.Client, request UniversalAuthRefreshRequest) (UniversalAuthRefreshResponse, error) { + var universalAuthRefreshResponse UniversalAuthRefreshResponse + response, err := httpClient. + R(). + SetResult(&universalAuthRefreshResponse). + SetHeader("User-Agent", USER_AGENT). + SetBody(request). + Post(fmt.Sprintf("%v/v1/auth/token/renew", config.INFISICAL_URL)) + + if err != nil { + return UniversalAuthRefreshResponse{}, fmt.Errorf("CallUniversalAuthRefreshAccessToken: Unable to complete api request [err=%s]", err) + } + + if response.IsError() { + return UniversalAuthRefreshResponse{}, fmt.Errorf("CallUniversalAuthRefreshAccessToken: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String()) + } + + return universalAuthRefreshResponse, nil } func CallGetRawSecretsV3(httpClient *resty.Client, request GetRawSecretsV3Request) (GetRawSecretsV3Response, error) { @@ -466,7 +486,7 @@ func CallGetRawSecretsV3(httpClient *resty.Client, request GetRawSecretsV3Reques } if response.IsError() { - return GetRawSecretsV3Response{}, fmt.Errorf("CallGetRawSecretsV3: Unsuccessful response [%v %v] [status-code=%v]", response.Request.Method, response.Request.URL, response.StatusCode()) + return GetRawSecretsV3Response{}, fmt.Errorf("CallUniversalAuthLogin: Unsuccessful response [%v %v] [status-code=%v] [response=%v]", response.Request.Method, response.Request.URL, response.StatusCode(), response.String()) } return getRawSecretsV3Response, nil diff --git a/cli/packages/api/model.go b/cli/packages/api/model.go index ceff88e71f..3c6466382e 100644 --- a/cli/packages/api/model.go +++ b/cli/packages/api/model.go @@ -463,14 +463,27 @@ type CreateServiceTokenResponse struct { ServiceTokenData ServiceTokenData `json:"serviceTokenData"` } -type ServiceTokenV3RefreshTokenRequest struct { - RefreshToken string `json:"refresh_token"` +type UniversalAuthLoginRequest struct { + ClientSecret string `json:"clientSecret"` + ClientId string `json:"clientId"` } -type ServiceTokenV3RefreshTokenResponse struct { - RefreshToken string `json:"refresh_token"` - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` - TokenType string `json:"token_type"` + +type UniversalAuthLoginResponse struct { + AccessToken string `json:"accessToken"` + AccessTokenTTL int `json:"expiresIn"` + TokenType string `json:"tokenType"` + AccessTokenMaxTTL int `json:"accessTokenMaxTTL"` +} + +type UniversalAuthRefreshRequest struct { + AccessToken string `json:"accessToken"` +} + +type UniversalAuthRefreshResponse struct { + AccessToken string `json:"accessToken"` + AccessTokenTTL int `json:"expiresIn"` + TokenType string `json:"tokenType"` + AccessTokenMaxTTL int `json:"accessTokenMaxTTL"` } type GetRawSecretsV3Request struct { diff --git a/cli/packages/cmd/agent.go b/cli/packages/cmd/agent.go index 83986021de..21f9613f7c 100644 --- a/cli/packages/cmd/agent.go +++ b/cli/packages/cmd/agent.go @@ -5,12 +5,12 @@ package cmd import ( "bytes" - "errors" "fmt" "io/ioutil" "os" "os/signal" "strings" + "sync" "syscall" "text/template" "time" @@ -44,8 +44,10 @@ type AuthConfig struct { Config interface{} `yaml:"config"` } -type TokenAuthConfig struct { - TokenPath string `yaml:"token-path"` +type UniversalAuth struct { + ClientIDPath string `yaml:"client-id"` + ClientSecretPath string `yaml:"client-secret"` + RemoveClientSecretOnRead bool `yaml:"remove_client_secret_on_read"` } type OAuthConfig struct { @@ -149,11 +151,12 @@ func ParseAgentConfig(filePath string) (*Config, error) { } switch rawConfig.Auth.Type { - case "token": - var tokenConfig TokenAuthConfig + case "universal-auth": + var tokenConfig UniversalAuth if err := yaml.Unmarshal(configBytes, &tokenConfig); err != nil { return nil, err } + config.Auth.Config = tokenConfig case "oauth": // aws, gcp, k8s service account, etc var oauthConfig OAuthConfig @@ -199,59 +202,235 @@ func ProcessTemplate(templatePath string, data interface{}, accessToken string) return &buf, nil } -func refreshTokenAndProcessTemplate(refreshToken string, config *Config, errChan chan error) { - for { - httpClient := resty.New() - httpClient.SetRetryCount(10000). - SetRetryMaxWaitTime(20 * time.Second). - SetRetryWaitTime(5 * time.Second) +type TokenManager struct { + accessToken string + accessTokenTTL time.Duration + accessTokenMaxTTL time.Duration + accessTokenFetchedTime time.Time + accessTokenRefreshedTime time.Time + mutex sync.Mutex + filePaths []Sink // Store file paths if needed + templates []Template + clientIdPath string + clientSecretPath string + newAccessTokenNotificationChan chan bool + removeClientSecretOnRead bool + cachedClientSecret string +} - tokenResponse, err := api.CallServiceTokenV3Refresh(httpClient, api.ServiceTokenV3RefreshTokenRequest{RefreshToken: refreshToken}) - if err != nil { - errChan <- fmt.Errorf("unable to complete renewal because [%s]", err) - } +func NewTokenManager(fileDeposits []Sink, templates []Template, clientIdPath string, clientSecretPath string, newAccessTokenNotificationChan chan bool, removeClientSecretOnRead bool) *TokenManager { + return &TokenManager{filePaths: fileDeposits, templates: templates, clientIdPath: clientIdPath, clientSecretPath: clientSecretPath, newAccessTokenNotificationChan: newAccessTokenNotificationChan, removeClientSecretOnRead: removeClientSecretOnRead} +} - for _, sinkFile := range config.Sinks { - if sinkFile.Type == "file" { - err = ioutil.WriteFile(sinkFile.Config.Path, []byte(tokenResponse.AccessToken), 0644) - if err != nil { - errChan <- err - return - } - } else { - errChan <- errors.New("unsupported sink type. Only 'file' type is supported") - return - } - } +func (tm *TokenManager) SetToken(token string, accessTokenTTL time.Duration, accessTokenMaxTTL time.Duration) { + tm.mutex.Lock() + defer tm.mutex.Unlock() - refreshToken = tokenResponse.RefreshToken - nextRefreshCycle := time.Duration(tokenResponse.ExpiresIn-5) * time.Second // when the next access refresh will happen + tm.accessToken = token + tm.accessTokenTTL = accessTokenTTL + tm.accessTokenMaxTTL = accessTokenMaxTTL - d, err := time.ParseDuration(nextRefreshCycle.String()) - if err != nil { - errChan <- fmt.Errorf("unable to parse refresh time because %s", err) - return - } + tm.newAccessTokenNotificationChan <- true +} - log.Info().Msgf("token refreshed and saved to selected path; next cycle will occur in %s", d.String()) +func (tm *TokenManager) GetToken() string { + tm.mutex.Lock() + defer tm.mutex.Unlock() - for _, secretTemplate := range config.Templates { - processedTemplate, err := ProcessTemplate(secretTemplate.SourcePath, nil, tokenResponse.AccessToken) - if err != nil { - errChan <- err - return - } + return tm.accessToken +} - if err := WriteBytesToFile(processedTemplate, secretTemplate.DestinationPath); err != nil { - errChan <- err - return - } - - log.Info().Msgf("secret template at path %s has been rendered and saved to path %s", secretTemplate.SourcePath, secretTemplate.DestinationPath) - } - - time.Sleep(nextRefreshCycle) +// Fetches a new access token using client credentials +func (tm *TokenManager) FetchNewAccessToken() error { + clientIDAsByte, err := ReadFile(tm.clientIdPath) + if err != nil { + return fmt.Errorf("unable to read client id from file path '%s' due to error: %v", tm.clientIdPath, err) } + + clientSecretAsByte, err := ReadFile(tm.clientSecretPath) + if err != nil { + if len(tm.cachedClientSecret) == 0 { + return fmt.Errorf("unable to read client secret from file and no cached client secret found: %v", err) + } else { + clientSecretAsByte = []byte(tm.cachedClientSecret) + } + } + + // remove client secret after first read + if tm.removeClientSecretOnRead { + os.Remove(tm.clientSecretPath) + } + + clientId := string(clientIDAsByte) + clientSecret := string(clientSecretAsByte) + + // save as cache in memory + tm.cachedClientSecret = clientSecret + + err, loginResponse := universalAuthLogin(clientId, clientSecret) + if err != nil { + return err + } + + accessTokenTTL := time.Duration(loginResponse.AccessTokenTTL * int(time.Second)) + accessTokenMaxTTL := time.Duration(loginResponse.AccessTokenMaxTTL * int(time.Second)) + + if accessTokenTTL <= time.Duration(5)*time.Second { + util.PrintErrorMessageAndExit("At this this, agent does not support refresh of tokens with 5 seconds or less ttl. Please increase access token ttl and try again") + } + + tm.accessTokenFetchedTime = time.Now() + tm.SetToken(loginResponse.AccessToken, accessTokenTTL, accessTokenMaxTTL) + + return nil +} + +// Refreshes the existing access token +func (tm *TokenManager) RefreshAccessToken() error { + httpClient := resty.New() + httpClient.SetRetryCount(10000). + SetRetryMaxWaitTime(20 * time.Second). + SetRetryWaitTime(5 * time.Second) + + accessToken := tm.GetToken() + response, err := api.CallUniversalAuthRefreshAccessToken(httpClient, api.UniversalAuthRefreshRequest{AccessToken: accessToken}) + if err != nil { + return err + } + + accessTokenTTL := time.Duration(response.AccessTokenTTL * int(time.Second)) + accessTokenMaxTTL := time.Duration(response.AccessTokenMaxTTL * int(time.Second)) + tm.accessTokenRefreshedTime = time.Now() + + tm.SetToken(response.AccessToken, accessTokenTTL, accessTokenMaxTTL) + + return nil +} + +func (tm *TokenManager) ManageTokenLifecycle() { + for { + accessTokenMaxTTLExpiresInTime := tm.accessTokenFetchedTime.Add(tm.accessTokenMaxTTL - (5 * time.Second)) + accessTokenRefreshedTime := tm.accessTokenRefreshedTime + + if accessTokenRefreshedTime.IsZero() { + accessTokenRefreshedTime = tm.accessTokenFetchedTime + } + + nextAccessTokenExpiresInTime := accessTokenRefreshedTime.Add(tm.accessTokenTTL - (5 * time.Second)) + + if tm.accessTokenFetchedTime.IsZero() && tm.accessTokenRefreshedTime.IsZero() { + // case: init login to get access token + log.Info().Msg("attempting to authenticate...") + err := tm.FetchNewAccessToken() + if err != nil { + log.Error().Msgf("unable to authenticate because %v. Will retry in 30 seconds", err) + + // wait a bit before trying again + time.Sleep((30 * time.Second)) + continue + } + } else if time.Now().After(accessTokenMaxTTLExpiresInTime) { + log.Info().Msgf("token has reached max ttl, attempting to re authenticate...") + err := tm.FetchNewAccessToken() + if err != nil { + log.Error().Msgf("unable to authenticate because %v. Will retry in 30 seconds", err) + + // wait a bit before trying again + time.Sleep((30 * time.Second)) + continue + } + } else { + log.Info().Msgf("attempting to refresh existing token...") + err := tm.RefreshAccessToken() + if err != nil { + log.Error().Msgf("unable to refresh token because %v. Will retry in 30 seconds", err) + + // wait a bit before trying again + time.Sleep((30 * time.Second)) + continue + } + } + + if accessTokenRefreshedTime.IsZero() { + accessTokenRefreshedTime = tm.accessTokenFetchedTime + } else { + accessTokenRefreshedTime = tm.accessTokenRefreshedTime + } + + nextAccessTokenExpiresInTime = accessTokenRefreshedTime.Add(tm.accessTokenTTL - (5 * time.Second)) + accessTokenMaxTTLExpiresInTime = tm.accessTokenFetchedTime.Add(tm.accessTokenMaxTTL - (5 * time.Second)) + + if nextAccessTokenExpiresInTime.After(accessTokenMaxTTLExpiresInTime) { + // case: Refreshed so close that the next refresh would occur beyond max ttl (this is because currently, token renew tries to add +access-token-ttl amount of time) + // example: access token ttl is 11 sec and max ttl is 30 sec. So it will start with 11 seconds, then 22 seconds but the next time you call refresh it would try to extend it to 33 but max ttl only allows 30, so the token will be valid until 30 before we need to reauth + time.Sleep(tm.accessTokenTTL - nextAccessTokenExpiresInTime.Sub(accessTokenMaxTTLExpiresInTime)) + } else { + time.Sleep(tm.accessTokenTTL - (5 * time.Second)) + } + } +} + +func (tm *TokenManager) WriteTokenToFiles() { + token := tm.GetToken() + for _, sinkFile := range tm.filePaths { + if sinkFile.Type == "file" { + err := ioutil.WriteFile(sinkFile.Config.Path, []byte(token), 0644) + if err != nil { + log.Error().Msgf("unable to write file sink to path '%s' because %v", sinkFile.Config.Path, err) + } + + log.Info().Msgf("new access token saved to file at path '%s'", sinkFile.Config.Path) + + } else { + log.Error().Msg("unsupported sink type. Only 'file' type is supported") + } + } +} + +func (tm *TokenManager) FetchSecrets() { + log.Info().Msgf("template engine started...") + for { + token := tm.GetToken() + if token != "" { + for _, secretTemplate := range tm.templates { + processedTemplate, err := ProcessTemplate(secretTemplate.SourcePath, nil, token) + if err != nil { + log.Error().Msgf("template engine: unable to render secrets because %s. Will try again in 30 seconds", err) + + // wait a bit before trying again + time.Sleep((30 * time.Second)) + continue + } + + if err := WriteBytesToFile(processedTemplate, secretTemplate.DestinationPath); err != nil { + log.Error().Msgf("template engine: unable to write secrets to path because %s. Will try again in 30 seconds", err) + + // wait a bit before trying again + time.Sleep((30 * time.Second)) + continue + } + + log.Info().Msgf("template engine: secret template at path %s has been rendered and saved to path %s", secretTemplate.SourcePath, secretTemplate.DestinationPath) + } + + // fetch new secrets every 5 minutes (TODO: add PubSub in the future ) + time.Sleep(5 * time.Minute) + } + } +} + +func universalAuthLogin(clientId string, clientSecret string) (error, api.UniversalAuthLoginResponse) { + httpClient := resty.New() + httpClient.SetRetryCount(10000). + SetRetryMaxWaitTime(20 * time.Second). + SetRetryWaitTime(5 * time.Second) + + tokenResponse, err := api.CallUniversalAuthLogin(httpClient, api.UniversalAuthLoginRequest{ClientId: clientId, ClientSecret: clientSecret}) + if err != nil { + return err, api.UniversalAuthLoginResponse{} + } + + return nil, tokenResponse } // runCmd represents the run command @@ -282,36 +461,31 @@ var agentCmd = &cobra.Command{ return } - errChan := make(chan error) - sigChan := make(chan os.Signal, 1) - - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - switch configAuthType := agentConfig.Auth.Config.(type) { - case TokenAuthConfig: - content, err := ReadFile(configAuthType.TokenPath) - if err != nil { - log.Error().Msgf("unable to read initial token from file path %s because %v", configAuthType.TokenPath, err) - return - } - - refreshToken := string(content) - go refreshTokenAndProcessTemplate(refreshToken, agentConfig, errChan) - - case OAuthConfig: - // future auth types - default: - log.Error().Msgf("unknown auth config type. Only 'file' type is supported") - return + if agentConfig.Auth.Type != "universal-auth" { + util.PrintErrorMessageAndExit("Only auth type of 'universal-auth' is supported at this time") } - select { - case err := <-errChan: - log.Fatal().Msgf("agent stopped due to error: %v", err) - os.Exit(1) - case <-sigChan: - log.Info().Msg("agent is gracefully shutting...") - os.Exit(1) + configUniversalAuthType := agentConfig.Auth.Config.(UniversalAuth) + + tokenRefreshNotifier := make(chan bool) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + filePaths := agentConfig.Sinks + tm := NewTokenManager(filePaths, agentConfig.Templates, configUniversalAuthType.ClientIDPath, configUniversalAuthType.ClientSecretPath, tokenRefreshNotifier, configUniversalAuthType.RemoveClientSecretOnRead) + + go tm.ManageTokenLifecycle() + go tm.FetchSecrets() + + for { + select { + case <-tokenRefreshNotifier: + go tm.WriteTokenToFiles() + case <-sigChan: + log.Info().Msg("agent is gracefully shutting...") + // TODO: check if we are in the middle of writing files to disk + os.Exit(1) + } } },