diff --git a/plugins/ai/anthropic/anthropic.go b/plugins/ai/anthropic/anthropic.go index a755c8c7..7f0dfe24 100644 --- a/plugins/ai/anthropic/anthropic.go +++ b/plugins/ai/anthropic/anthropic.go @@ -117,8 +117,14 @@ func (t *oauthTransport) getValidToken() (string, error) { if err != nil { return "", fmt.Errorf("failed to load stored token: %w", err) } + // If no token exists, run OAuth flow if token == nil { - return "", fmt.Errorf("no OAuth token stored, please re-authenticate") + fmt.Println("No OAuth token found, initiating authentication...") + newAccessToken, err := runOAuthFlow() + if err != nil { + return "", fmt.Errorf("failed to authenticate: %w", err) + } + return newAccessToken, nil } // Check if token needs refresh (5 minute buffer) @@ -126,7 +132,12 @@ func (t *oauthTransport) getValidToken() (string, error) { fmt.Println("OAuth token expired, refreshing...") newAccessToken, err := refreshToken() if err != nil { - return "", fmt.Errorf("failed to refresh token: %w", err) + // If refresh fails, try re-authentication + fmt.Println("Token refresh failed, re-authenticating...") + newAccessToken, err = runOAuthFlow() + if err != nil { + return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err) + } } return newAccessToken, nil @@ -584,3 +595,182 @@ func refreshToken() (string, error) { return result.AccessToken, nil } + +// runOAuthFlowWithTimeout runs the OAuth flow with an extended timeout for user interaction +func runOAuthFlowWithTimeout() (string, error) { + // Use a longer timeout context for OAuth flows (10 minutes) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + // Store the context for use in HTTP requests + return runOAuthFlowWithContext(ctx) +} + +// refreshTokenWithTimeout refreshes a token with extended timeout +func refreshTokenWithTimeout() (string, error) { + // Use a longer timeout context for token refresh (2 minutes) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + return refreshTokenWithContext(ctx) +} + +// runOAuthFlowWithContext runs OAuth flow with the provided context +func runOAuthFlowWithContext(ctx context.Context) (string, error) { + verifier, challenge, err := generatePKCE() + if err != nil { + return "", err + } + + cfg := oauth2.Config{ + ClientID: oauthClientID, + Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL}, + RedirectURL: oauthRedirectURL, + Scopes: []string{"org:create_api_key", "user:profile", "user:inference"}, + } + + authURL := cfg.AuthCodeURL(verifier, + oauth2.SetAuthURLParam("code_challenge", challenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + oauth2.SetAuthURLParam("code", "true"), + oauth2.SetAuthURLParam("state", verifier), + ) + + fmt.Println("Open the following URL in your browser. Fabric would like to authorize:") + fmt.Println(authURL) + openBrowser(authURL) + fmt.Print("Paste the authorization code here: ") + + // Create a channel to receive the user input + codeChan := make(chan string, 1) + errChan := make(chan error, 1) + + go func() { + var code string + _, err := fmt.Scanln(&code) + if err != nil { + errChan <- err + return + } + codeChan <- code + }() + + // Wait for either user input or context timeout + var code string + select { + case code = <-codeChan: + // User provided input + case err := <-errChan: + return "", fmt.Errorf("failed to read authorization code: %w", err) + case <-ctx.Done(): + return "", fmt.Errorf("authentication timed out - please try again") + } + + parts := strings.SplitN(code, "#", 2) + state := verifier + if len(parts) == 2 { + state = parts[1] + } + + // Manual token exchange to match opencode implementation + tokenReq := map[string]string{ + "code": parts[0], + "state": state, + "grant_type": "authorization_code", + "client_id": oauthClientID, + "redirect_uri": oauthRedirectURL, + "code_verifier": verifier, + } + + return exchangeTokenWithContext(ctx, tokenReq) +} + +// refreshTokenWithContext refreshes a token with the provided context +func refreshTokenWithContext(ctx context.Context) (string, error) { + storage, err := common.NewOAuthStorage() + if err != nil { + return "", fmt.Errorf("failed to create OAuth storage: %w", err) + } + + // Load existing token + token, err := storage.LoadToken("claude") + if err != nil { + return "", fmt.Errorf("failed to load stored token: %w", err) + } + if token == nil || token.RefreshToken == "" { + return "", fmt.Errorf("no refresh token available") + } + + // Prepare refresh request + refreshReq := map[string]string{ + "grant_type": "refresh_token", + "refresh_token": token.RefreshToken, + "client_id": oauthClientID, + } + + return exchangeTokenWithContext(ctx, refreshReq) +} + +// exchangeTokenWithContext exchanges tokens with the provided context +func exchangeTokenWithContext(ctx context.Context, params map[string]string) (string, error) { + reqBody, err := json.Marshal(params) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, "POST", oauthTokenURL, bytes.NewBuffer(reqBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("token exchange request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body)) + } + + var result struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + Scope string `json:"scope"` + } + if err = json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to parse token response: %w", err) + } + + // Save the complete token information + storage, err := common.NewOAuthStorage() + if err != nil { + return "", fmt.Errorf("failed to create OAuth storage: %w", err) + } + + oauthToken := &common.OAuthToken{ + AccessToken: result.AccessToken, + RefreshToken: result.RefreshToken, + ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn), + TokenType: result.TokenType, + Scope: result.Scope, + } + + // Use existing refresh token if new one not provided (for refresh operations) + if oauthToken.RefreshToken == "" && params["grant_type"] == "refresh_token" { + if existingToken, err := storage.LoadToken("claude"); err == nil && existingToken != nil { + oauthToken.RefreshToken = existingToken.RefreshToken + } + } + + if err = storage.SaveToken("claude", oauthToken); err != nil { + return "", fmt.Errorf("failed to save OAuth token: %w", err) + } + + return result.AccessToken, nil +}