refactor: extract OAuth functionality from anthropic client to separate module

## CHANGES

- Remove OAuth transport implementation from main client
- Extract OAuth flow functions to separate module
- Remove unused imports and constants from client
- Replace inline OAuth transport with NewOAuthTransport call
- Update runOAuthFlow to exported RunOAuthFlow function
- Clean up token management and refresh logic
- Simplify client configuration by removing OAuth internals
This commit is contained in:
Kayvan Sylvan
2025-07-05 14:59:38 -07:00
parent f13a56685b
commit eda552dac5
2 changed files with 302 additions and 278 deletions

View File

@@ -1,32 +1,19 @@
package anthropic
import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os/exec"
"strings"
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
"golang.org/x/oauth2"
)
const defaultBaseUrl = "https://api.anthropic.com/"
const oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
const oauthAuthURL = "https://claude.ai/oauth/authorize"
const oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
const oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
const webSearchToolName = "web_search"
const webSearchToolType = "web_search_20250305"
@@ -79,77 +66,6 @@ type Client struct {
client anthropic.Client
}
// oauthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
type oauthTransport struct {
client *Client
base http.RoundTripper
}
func (t *oauthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
newReq := req.Clone(req.Context())
// Get current token (may refresh if needed)
token, err := t.getValidToken()
if err != nil {
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
}
// Add OAuth Bearer token
newReq.Header.Set("Authorization", "Bearer "+token)
// Add the anthropic-beta header for OAuth
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
// Set User-Agent to match AI SDK exactly
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
// Remove x-api-key header if present (OAuth doesn't use it)
newReq.Header.Del("x-api-key")
return t.base.RoundTrip(newReq)
}
// getValidToken returns a valid access token, refreshing if necessary
func (t *oauthTransport) getValidToken() (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load stored token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
// If no token exists, run OAuth flow
if token == nil {
fmt.Println("No OAuth token found, initiating authentication...")
newAccessToken, err := runOAuthFlow()
if err != nil {
return "", fmt.Errorf("failed to authenticate: %w", err)
}
return newAccessToken, nil
}
// Check if token needs refresh (5 minute buffer)
if token.IsExpired(5) {
fmt.Println("OAuth token expired, refreshing...")
newAccessToken, err := refreshToken()
if err != nil {
// If refresh fails, try re-authentication
fmt.Println("Token refresh failed, re-authenticating...")
newAccessToken, err = runOAuthFlow()
if err != nil {
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
}
}
return newAccessToken, nil
}
return token.AccessToken, nil
}
func (an *Client) Setup() (err error) {
if err = an.PluginBase.Ask(an.Name); err != nil {
return
@@ -164,7 +80,7 @@ func (an *Client) Setup() (err error) {
if !storage.HasValidToken("claude", 5) {
// No valid token, run OAuth flow
if _, err = runOAuthFlow(); err != nil {
if _, err = RunOAuthFlow(); err != nil {
return err
}
}
@@ -186,10 +102,7 @@ func (an *Client) configure() (err error) {
// Create custom HTTP client that adds OAuth Bearer token and beta header
baseTransport := &http.Transport{}
httpClient := &http.Client{
Transport: &oauthTransport{
client: an,
base: baseTransport,
},
Transport: NewOAuthTransport(an, baseTransport),
}
opts = append(opts, option.WithHTTPClient(httpClient))
} else {
@@ -403,192 +316,3 @@ func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropi
func (an *Client) NeedsRawMode(modelName string) bool {
return false
}
func generatePKCE() (verifier, challenge string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return
}
verifier = base64.RawURLEncoding.EncodeToString(b)
sum := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
return
}
func openBrowser(url string) {
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
for _, cmd := range commands {
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
return
}
}
}
func runOAuthFlow() (token string, err error) {
verifier, challenge, err := generatePKCE()
if err != nil {
return
}
cfg := oauth2.Config{
ClientID: oauthClientID,
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
RedirectURL: oauthRedirectURL,
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
}
authURL := cfg.AuthCodeURL(verifier,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code", "true"),
oauth2.SetAuthURLParam("state", verifier),
)
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
fmt.Println(authURL)
openBrowser(authURL)
fmt.Print("Paste the authorization code here: ")
var code string
fmt.Scanln(&code)
parts := strings.SplitN(code, "#", 2)
state := verifier
if len(parts) == 2 {
state = parts[1]
}
// Manual token exchange to match opencode implementation
tokenReq := map[string]string{
"code": parts[0],
"state": state,
"grant_type": "authorization_code",
"client_id": oauthClientID,
"redirect_uri": oauthRedirectURL,
"code_verifier": verifier,
}
token, err = exchangeToken(tokenReq)
return
}
func exchangeToken(params map[string]string) (token string, err error) {
reqBody, err := json.Marshal(params)
if err != nil {
return
}
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
return
}
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return
}
// Save the complete token information
storage, err := common.NewOAuthStorage()
if err != nil {
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
}
oauthToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
if err = storage.SaveToken("claude", oauthToken); err != nil {
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
}
token = result.AccessToken
return
}
// refreshToken refreshes an expired OAuth token using the refresh token
func refreshToken() (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load existing token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
if token == nil || token.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available")
}
// Prepare refresh request
refreshReq := map[string]string{
"grant_type": "refresh_token",
"refresh_token": token.RefreshToken,
"client_id": oauthClientID,
}
reqBody, err := json.Marshal(refreshReq)
if err != nil {
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
}
// Make refresh request
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
}
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse refresh response: %w", err)
}
// Update stored token
newToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
// Use existing refresh token if new one not provided
if newToken.RefreshToken == "" {
newToken.RefreshToken = token.RefreshToken
}
if err = storage.SaveToken("claude", newToken); err != nil {
return "", fmt.Errorf("failed to save refreshed token: %w", err)
}
return result.AccessToken, nil
}

View File

@@ -0,0 +1,300 @@
package anthropic
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"os/exec"
"strings"
"time"
"github.com/danielmiessler/fabric/common"
"golang.org/x/oauth2"
)
// OAuth configuration constants
const (
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauthAuthURL = "https://claude.ai/oauth/authorize"
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
)
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
type OAuthTransport struct {
client *Client
base http.RoundTripper
}
// RoundTrip implements the http.RoundTripper interface
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Clone the request to avoid modifying the original
newReq := req.Clone(req.Context())
// Get current token (may refresh if needed)
token, err := t.getValidToken()
if err != nil {
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
}
// Add OAuth Bearer token
newReq.Header.Set("Authorization", "Bearer "+token)
// Add the anthropic-beta header for OAuth
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
// Set User-Agent to match AI SDK exactly
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
// Remove x-api-key header if present (OAuth doesn't use it)
newReq.Header.Del("x-api-key")
return t.base.RoundTrip(newReq)
}
// getValidToken returns a valid access token, refreshing if necessary
func (t *OAuthTransport) getValidToken() (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load stored token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
// If no token exists, run OAuth flow
if token == nil {
fmt.Println("No OAuth token found, initiating authentication...")
newAccessToken, err := RunOAuthFlow()
if err != nil {
return "", fmt.Errorf("failed to authenticate: %w", err)
}
return newAccessToken, nil
}
// Check if token needs refresh (5 minute buffer)
if token.IsExpired(5) {
fmt.Println("OAuth token expired, refreshing...")
newAccessToken, err := RefreshToken()
if err != nil {
// If refresh fails, try re-authentication
fmt.Println("Token refresh failed, re-authenticating...")
newAccessToken, err = RunOAuthFlow()
if err != nil {
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
}
}
return newAccessToken, nil
}
return token.AccessToken, nil
}
// NewOAuthTransport creates a new OAuth transport for the given client
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
return &OAuthTransport{
client: client,
base: base,
}
}
// generatePKCE generates PKCE code verifier and challenge
func generatePKCE() (verifier, challenge string, err error) {
b := make([]byte, 32)
if _, err = rand.Read(b); err != nil {
return
}
verifier = base64.RawURLEncoding.EncodeToString(b)
sum := sha256.Sum256([]byte(verifier))
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
return
}
// openBrowser attempts to open the given URL in the default browser
func openBrowser(url string) {
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
for _, cmd := range commands {
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
return
}
}
}
// RunOAuthFlow executes the complete OAuth authorization flow
func RunOAuthFlow() (token string, err error) {
verifier, challenge, err := generatePKCE()
if err != nil {
return
}
cfg := oauth2.Config{
ClientID: oauthClientID,
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
RedirectURL: oauthRedirectURL,
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
}
authURL := cfg.AuthCodeURL(verifier,
oauth2.SetAuthURLParam("code_challenge", challenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code", "true"),
oauth2.SetAuthURLParam("state", verifier),
)
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
fmt.Println(authURL)
openBrowser(authURL)
fmt.Print("Paste the authorization code here: ")
var code string
fmt.Scanln(&code)
parts := strings.SplitN(code, "#", 2)
state := verifier
if len(parts) == 2 {
state = parts[1]
}
// Manual token exchange to match opencode implementation
tokenReq := map[string]string{
"code": parts[0],
"state": state,
"grant_type": "authorization_code",
"client_id": oauthClientID,
"redirect_uri": oauthRedirectURL,
"code_verifier": verifier,
}
token, err = exchangeToken(tokenReq)
return
}
// exchangeToken exchanges authorization code for access token
func exchangeToken(params map[string]string) (token string, err error) {
reqBody, err := json.Marshal(params)
if err != nil {
return
}
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
return
}
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return
}
// Save the complete token information
storage, err := common.NewOAuthStorage()
if err != nil {
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
}
oauthToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
if err = storage.SaveToken("claude", oauthToken); err != nil {
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
}
token = result.AccessToken
return
}
// RefreshToken refreshes an expired OAuth token using the refresh token
func RefreshToken() (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load existing token
token, err := storage.LoadToken("claude")
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
if token == nil || token.RefreshToken == "" {
return "", fmt.Errorf("no refresh token available")
}
// Prepare refresh request
refreshReq := map[string]string{
"grant_type": "refresh_token",
"refresh_token": token.RefreshToken,
"client_id": oauthClientID,
}
reqBody, err := json.Marshal(refreshReq)
if err != nil {
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
}
// Make refresh request
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
if err != nil {
return "", fmt.Errorf("refresh request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
}
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to parse refresh response: %w", err)
}
// Update stored token
newToken := &common.OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
TokenType: result.TokenType,
Scope: result.Scope,
}
// Use existing refresh token if new one not provided
if newToken.RefreshToken == "" {
newToken.RefreshToken = token.RefreshToken
}
if err = storage.SaveToken("claude", newToken); err != nil {
return "", fmt.Errorf("failed to save refreshed token: %w", err)
}
return result.AccessToken, nil
}