feat: enhance OAuth authentication flow with automatic re-authentication and timeout handling

## CHANGES

- Add automatic OAuth flow initiation when no token exists
- Implement fallback re-authentication when token refresh fails
- Add timeout contexts for OAuth and refresh operations
- Create context-aware OAuth flow and token exchange functions
- Enhance error handling with graceful authentication recovery
- Add user input timeout protection for authorization codes
- Preserve refresh tokens during token exchange operations
This commit is contained in:
Kayvan Sylvan
2025-07-05 09:59:27 -07:00
parent f1b612d828
commit b7dc6748e0

View File

@@ -117,8 +117,14 @@ func (t *oauthTransport) getValidToken() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
// If no token exists, run OAuth flow
if token == nil {
return "", fmt.Errorf("no OAuth token stored, please re-authenticate")
fmt.Println("No OAuth token found, initiating authentication...")
newAccessToken, err := runOAuthFlow()
if err != nil {
return "", fmt.Errorf("failed to authenticate: %w", err)
}
return newAccessToken, nil
}
// Check if token needs refresh (5 minute buffer)
@@ -126,7 +132,12 @@ func (t *oauthTransport) getValidToken() (string, error) {
fmt.Println("OAuth token expired, refreshing...")
newAccessToken, err := refreshToken()
if err != nil {
return "", fmt.Errorf("failed to refresh token: %w", err)
// If refresh fails, try re-authentication
fmt.Println("Token refresh failed, re-authenticating...")
newAccessToken, err = runOAuthFlow()
if err != nil {
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
}
}
return newAccessToken, nil
@@ -584,3 +595,182 @@ func refreshToken() (string, error) {
return result.AccessToken, nil
}
// runOAuthFlowWithTimeout runs the OAuth flow with an extended timeout for user interaction
func runOAuthFlowWithTimeout() (string, error) {
// Use a longer timeout context for OAuth flows (10 minutes)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// Store the context for use in HTTP requests
return runOAuthFlowWithContext(ctx)
}
// refreshTokenWithTimeout refreshes a token with extended timeout
func refreshTokenWithTimeout() (string, error) {
// Use a longer timeout context for token refresh (2 minutes)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
return refreshTokenWithContext(ctx)
}
// runOAuthFlowWithContext runs OAuth flow with the provided context
func runOAuthFlowWithContext(ctx context.Context) (string, error) {
verifier, challenge, err := generatePKCE()
if err != nil {
return "", err
}
cfg := oauth2.Config{
ClientID: oauthClientID,
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
RedirectURL: oauthRedirectURL,
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
}
authURL := cfg.AuthCodeURL(verifier,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code", "true"),
oauth2.SetAuthURLParam("state", verifier),
)
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
fmt.Println(authURL)
openBrowser(authURL)
fmt.Print("Paste the authorization code here: ")
// Create a channel to receive the user input
codeChan := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
var code string
_, err := fmt.Scanln(&code)
if err != nil {
errChan <- err
return
}
codeChan <- code
}()
// Wait for either user input or context timeout
var code string
select {
case code = <-codeChan:
// User provided input
case err := <-errChan:
return "", fmt.Errorf("failed to read authorization code: %w", err)
case <-ctx.Done():
return "", fmt.Errorf("authentication timed out - please try again")
}
parts := strings.SplitN(code, "#", 2)
state := verifier
if len(parts) == 2 {
state = parts[1]
}
// Manual token exchange to match opencode implementation
tokenReq := map[string]string{
"code": parts[0],
"state": state,
"grant_type": "authorization_code",
"client_id": oauthClientID,
"redirect_uri": oauthRedirectURL,
"code_verifier": verifier,
}
return exchangeTokenWithContext(ctx, tokenReq)
}
// refreshTokenWithContext refreshes a token with the provided context
func refreshTokenWithContext(ctx context.Context) (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load existing token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
if token == nil || token.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available")
}
// Prepare refresh request
refreshReq := map[string]string{
"grant_type": "refresh_token",
"refresh_token": token.RefreshToken,
"client_id": oauthClientID,
}
return exchangeTokenWithContext(ctx, refreshReq)
}
// exchangeTokenWithContext exchanges tokens with the provided context
func exchangeTokenWithContext(ctx context.Context, params map[string]string) (string, error) {
reqBody, err := json.Marshal(params)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, "POST", oauthTokenURL, bytes.NewBuffer(reqBody))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
}
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse token response: %w", err)
}
// Save the complete token information
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
oauthToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
// Use existing refresh token if new one not provided (for refresh operations)
if oauthToken.RefreshToken == "" && params["grant_type"] == "refresh_token" {
if existingToken, err := storage.LoadToken("claude"); err == nil && existingToken != nil {
oauthToken.RefreshToken = existingToken.RefreshToken
}
}
if err = storage.SaveToken("claude", oauthToken); err != nil {
return "", fmt.Errorf("failed to save OAuth token: %w", err)
}
return result.AccessToken, nil
}