mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-10 06:48:04 -05:00
refactor: replace hardcoded "claude" with configurable authTokenIdentifier parameter
## CHANGES - Replace hardcoded "claude" string with `authTokenIdentifier` constant - Update `RunOAuthFlow` to accept token identifier parameter - Modify `RefreshToken` to use configurable token identifier - Update `exchangeToken` to accept token identifier parameter - Enhance `getValidToken` to use parameterized token identifier - Add token refresh attempt before full OAuth flow - Improve OAuth flow with existing token validation
This commit is contained in:
@@ -19,7 +19,7 @@ const webSearchToolName = "web_search"
|
||||
const webSearchToolType = "web_search_20250305"
|
||||
const sourcesHeader = "## Sources"
|
||||
|
||||
const vendorTokenIdentifier = "claude"
|
||||
const authTokenIdentifier = "claude"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Anthropic"
|
||||
@@ -65,15 +65,15 @@ func (an *Client) IsConfigured() bool {
|
||||
}
|
||||
|
||||
// If no valid token exists, automatically run OAuth flow
|
||||
if !storage.HasValidToken(vendorTokenIdentifier, 5) {
|
||||
if !storage.HasValidToken(authTokenIdentifier, 5) {
|
||||
fmt.Println("OAuth enabled but no valid token found. Starting authentication...")
|
||||
_, err := RunOAuthFlow()
|
||||
_, err := RunOAuthFlow(authTokenIdentifier)
|
||||
if err != nil {
|
||||
fmt.Printf("OAuth authentication failed: %v\n", err)
|
||||
return false
|
||||
}
|
||||
// After successful OAuth flow, check again
|
||||
return storage.HasValidToken("claude", 5)
|
||||
return storage.HasValidToken(authTokenIdentifier, 5)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -107,9 +107,9 @@ func (an *Client) Setup() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
if !storage.HasValidToken(authTokenIdentifier, 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
if _, err = RunOAuthFlow(authTokenIdentifier); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
token, err := t.getValidToken(authTokenIdentifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
@@ -58,21 +58,21 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
func (t *OAuthTransport) getValidToken(tokenIdentifier string) (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
token, err := storage.LoadToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
newAccessToken, err := RunOAuthFlow(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
@@ -82,11 +82,11 @@ func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
newAccessToken, err := RefreshToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
newAccessToken, err = RunOAuthFlow(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
@@ -129,7 +129,28 @@ func openBrowser(url string) {
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
func RunOAuthFlow(tokenIdentifier string) (token string, err error) {
|
||||
// First check if we have an existing token that can be refreshed
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err == nil {
|
||||
existingToken, err := storage.LoadToken(tokenIdentifier)
|
||||
if err == nil && existingToken != nil {
|
||||
// If token exists but is expired, try refreshing first
|
||||
if existingToken.IsExpired(5) {
|
||||
fmt.Println("Found expired OAuth token, attempting refresh...")
|
||||
refreshedToken, refreshErr := RefreshToken(token)
|
||||
if refreshErr == nil {
|
||||
fmt.Println("Token refresh successful")
|
||||
return refreshedToken, nil
|
||||
}
|
||||
fmt.Printf("Token refresh failed (%v), proceeding with full OAuth flow...\n", refreshErr)
|
||||
} else {
|
||||
// Token exists and is still valid
|
||||
return existingToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -171,12 +192,12 @@ func RunOAuthFlow() (token string, err error) {
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
token, err = exchangeToken(tokenIdentifier, tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
func exchangeToken(tokenIdentifier string, params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -219,7 +240,7 @@ func exchangeToken(params map[string]string) (token string, err error) {
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
if err = storage.SaveToken(tokenIdentifier, oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
@@ -228,14 +249,14 @@ func exchangeToken(params map[string]string) (token string, err error) {
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
func RefreshToken(tokenIdentifier string) (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
token, err := storage.LoadToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
@@ -292,7 +313,7 @@ func RefreshToken() (string, error) {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
if err = storage.SaveToken(tokenIdentifier, newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user