feat: implement OAuth token refresh and persistent storage for Claude authentication

## CHANGES

- Add automatic OAuth token refresh when expired
- Implement persistent token storage using common OAuth storage
- Remove deprecated AuthToken setting from client configuration
- Add token validation with 5-minute expiration buffer
- Create refreshToken function for seamless token renewal
- Update OAuth flow to save complete token information
- Enhance error handling for OAuth authentication failures
- Simplify client configuration by removing manual token management
This commit is contained in:
Kayvan Sylvan
2025-07-05 09:17:50 -07:00
parent 4bff88fae3
commit eac5a104f2
3 changed files with 506 additions and 26 deletions

View File

@@ -10,9 +10,9 @@ import (
"fmt"
"io"
"net/http"
"os"
"os/exec"
"strings"
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
@@ -44,9 +44,8 @@ func NewClient() (ret *Client) {
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
ret.ApiBaseURL.Value = defaultBaseUrl
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login Fabric", false)
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
ret.AuthToken = ret.AddSetting("Auth Token", false)
ret.maxTokens = 4096
ret.defaultRequiredUserMessage = "Hi"
@@ -67,7 +66,6 @@ type Client struct {
ApiBaseURL *plugins.SetupQuestion
ApiKey *plugins.SetupQuestion
UseOAuth *plugins.SetupQuestion
AuthToken *plugins.Setting
maxTokens int
defaultRequiredUserMessage string
@@ -87,7 +85,10 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := req.Clone(req.Context())
// Get current token (may refresh if needed)
token := t.client.AuthToken.Value
token, err := t.getValidToken()
if err != nil {
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
}
// Add OAuth Bearer token
newReq.Header.Set("Authorization", "Bearer "+token)
@@ -104,19 +105,52 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return t.base.RoundTrip(newReq)
}
// getValidToken returns a valid access token, refreshing if necessary
func (t *oauthTransport) getValidToken() (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load stored token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
if token == nil {
return "", fmt.Errorf("no OAuth token stored, please re-authenticate")
}
// Check if token needs refresh (5 minute buffer)
if token.IsExpired(5) {
fmt.Println("OAuth token expired, refreshing...")
newAccessToken, err := refreshToken()
if err != nil {
return "", fmt.Errorf("failed to refresh token: %w", err)
}
return newAccessToken, nil
}
return token.AccessToken, nil
}
func (an *Client) Setup() (err error) {
if err = an.PluginBase.Ask(an.Name); err != nil {
return
}
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) && an.AuthToken.Value == "" {
var token string
if token, err = runOAuthFlow(); err != nil {
return
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
// Check if we have a valid stored token
storage, err := common.NewOAuthStorage()
if err != nil {
return err
}
an.AuthToken.Value = token
if an.AuthToken.EnvVariable != "" {
_ = os.Setenv(an.AuthToken.EnvVariable, token)
if !storage.HasValidToken("claude", 5) {
// No valid token, run OAuth flow
if _, err = runOAuthFlow(); err != nil {
return err
}
}
}
@@ -149,21 +183,17 @@ func (an *Client) configure() (err error) {
opts = append(opts, option.WithBaseURL(baseURL))
}
if an.AuthToken.Value != "" {
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
// For OAuth, use Bearer token with custom headers
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
// Create custom HTTP client that adds OAuth Bearer token and beta header
baseTransport := &http.Transport{}
httpClient := &http.Client{
Transport: &oauthTransport{
client: an,
base: baseTransport,
},
}
opts = append(opts, option.WithHTTPClient(httpClient))
} else {
opts = append(opts, option.WithAuthToken(an.AuthToken.Value))
// Create custom HTTP client that adds OAuth Bearer token and beta header
baseTransport := &http.Transport{}
httpClient := &http.Client{
Transport: &oauthTransport{
client: an,
base: baseTransport,
},
}
opts = append(opts, option.WithHTTPClient(httpClient))
} else {
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
}
@@ -461,12 +491,106 @@ func exchangeToken(params map[string]string) (token string, err error) {
}
var result struct {
AccessToken string `json:"access_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return
}
// Save the complete token information
storage, err := common.NewOAuthStorage()
if err != nil {
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
}
oauthToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
if err = storage.SaveToken("claude", oauthToken); err != nil {
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
}
token = result.AccessToken
return
}
// refreshToken refreshes an expired OAuth token using the refresh token
func refreshToken() (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load existing token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
if token == nil || token.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available")
}
// Prepare refresh request
refreshReq := map[string]string{
"grant_type": "refresh_token",
"refresh_token": token.RefreshToken,
"client_id": oauthClientID,
}
reqBody, err := json.Marshal(refreshReq)
if err != nil {
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
}
// Make refresh request
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
}
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse refresh response: %w", err)
}
// Update stored token
newToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
// Use existing refresh token if new one not provided
if newToken.RefreshToken == "" {
newToken.RefreshToken = token.RefreshToken
}
if err = storage.SaveToken("claude", newToken); err != nil {
return "", fmt.Errorf("failed to save refreshed token: %w", err)
}
return result.AccessToken, nil
}