mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-02-13 15:34:59 -05:00
feat: enhance OAuth authentication flow with automatic re-authentication and timeout handling
## CHANGES - Add automatic OAuth flow initiation when no token exists - Implement fallback re-authentication when token refresh fails - Add timeout contexts for OAuth and refresh operations - Create context-aware OAuth flow and token exchange functions - Enhance error handling with graceful authentication recovery - Add user input timeout protection for authorization codes - Preserve refresh tokens during token exchange operations
This commit is contained in:
@@ -117,8 +117,14 @@ func (t *oauthTransport) getValidToken() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
return "", fmt.Errorf("no OAuth token stored, please re-authenticate")
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := runOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
@@ -126,7 +132,12 @@ func (t *oauthTransport) getValidToken() (string, error) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := refreshToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %w", err)
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = runOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
@@ -584,3 +595,182 @@ func refreshToken() (string, error) {
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
|
||||
// runOAuthFlowWithTimeout runs the OAuth flow with an extended timeout for user interaction
|
||||
func runOAuthFlowWithTimeout() (string, error) {
|
||||
// Use a longer timeout context for OAuth flows (10 minutes)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// Store the context for use in HTTP requests
|
||||
return runOAuthFlowWithContext(ctx)
|
||||
}
|
||||
|
||||
// refreshTokenWithTimeout refreshes a token with extended timeout
|
||||
func refreshTokenWithTimeout() (string, error) {
|
||||
// Use a longer timeout context for token refresh (2 minutes)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
return refreshTokenWithContext(ctx)
|
||||
}
|
||||
|
||||
// runOAuthFlowWithContext runs OAuth flow with the provided context
|
||||
func runOAuthFlowWithContext(ctx context.Context) (string, error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
|
||||
fmt.Println(authURL)
|
||||
openBrowser(authURL)
|
||||
fmt.Print("Paste the authorization code here: ")
|
||||
|
||||
// Create a channel to receive the user input
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
var code string
|
||||
_, err := fmt.Scanln(&code)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
codeChan <- code
|
||||
}()
|
||||
|
||||
// Wait for either user input or context timeout
|
||||
var code string
|
||||
select {
|
||||
case code = <-codeChan:
|
||||
// User provided input
|
||||
case err := <-errChan:
|
||||
return "", fmt.Errorf("failed to read authorization code: %w", err)
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("authentication timed out - please try again")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
return exchangeTokenWithContext(ctx, tokenReq)
|
||||
}
|
||||
|
||||
// refreshTokenWithContext refreshes a token with the provided context
|
||||
func refreshTokenWithContext(ctx context.Context) (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
return exchangeTokenWithContext(ctx, refreshReq)
|
||||
}
|
||||
|
||||
// exchangeTokenWithContext exchanges tokens with the provided context
|
||||
func exchangeTokenWithContext(ctx context.Context, params map[string]string) (string, error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", oauthTokenURL, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token exchange request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided (for refresh operations)
|
||||
if oauthToken.RefreshToken == "" && params["grant_type"] == "refresh_token" {
|
||||
if existingToken, err := storage.LoadToken("claude"); err == nil && existingToken != nil {
|
||||
oauthToken.RefreshToken = existingToken.RefreshToken
|
||||
}
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user