mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-02-13 23:45:08 -05:00
feat: implement OAuth token refresh and persistent storage for Claude authentication
## CHANGES - Add automatic OAuth token refresh when expired - Implement persistent token storage using common OAuth storage - Remove deprecated AuthToken setting from client configuration - Add token validation with 5-minute expiration buffer - Create refreshToken function for seamless token renewal - Update OAuth flow to save complete token information - Enhance error handling for OAuth authentication failures - Simplify client configuration by removing manual token management
This commit is contained in:
@@ -10,9 +10,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
@@ -44,9 +44,8 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login Fabric", false)
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
ret.AuthToken = ret.AddSetting("Auth Token", false)
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -67,7 +66,6 @@ type Client struct {
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
AuthToken *plugins.Setting
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -87,7 +85,10 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token := t.client.AuthToken.Value
|
||||
token, err := t.getValidToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
@@ -104,19 +105,52 @@ func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *oauthTransport) getValidToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil {
|
||||
return "", fmt.Errorf("no OAuth token stored, please re-authenticate")
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := refreshToken()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
func (an *Client) Setup() (err error) {
|
||||
if err = an.PluginBase.Ask(an.Name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) && an.AuthToken.Value == "" {
|
||||
var token string
|
||||
if token, err = runOAuthFlow(); err != nil {
|
||||
return
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
an.AuthToken.Value = token
|
||||
if an.AuthToken.EnvVariable != "" {
|
||||
_ = os.Setenv(an.AuthToken.EnvVariable, token)
|
||||
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = runOAuthFlow(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,21 +183,17 @@ func (an *Client) configure() (err error) {
|
||||
opts = append(opts, option.WithBaseURL(baseURL))
|
||||
}
|
||||
|
||||
if an.AuthToken.Value != "" {
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: &oauthTransport{
|
||||
client: an,
|
||||
base: baseTransport,
|
||||
},
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAuthToken(an.AuthToken.Value))
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: &oauthTransport{
|
||||
client: an,
|
||||
base: baseTransport,
|
||||
},
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
@@ -461,12 +491,106 @@ func exchangeToken(params map[string]string) (token string, err error) {
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// refreshToken refreshes an expired OAuth token using the refresh token
|
||||
func refreshToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user