mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-31 00:57:59 -05:00
Merge pull request #1978 from ksylvan/kayvan/no-anthropic-oauth
chore: remove OAuth support from Anthropic client
This commit is contained in:
@@ -114,7 +114,6 @@ Below are the **new features and capabilities** we've added (newest first):
|
||||
- [v1.4.246](https://github.com/danielmiessler/fabric/releases/tag/v1.4.246) (Jul 14, 2025) — **Automatic ChangeLog Updates**: Add AI-powered changelog generation with high-performance Go tool and comprehensive caching
|
||||
- [v1.4.245](https://github.com/danielmiessler/fabric/releases/tag/v1.4.245) (Jul 11, 2025) — **Together AI**: Together AI Support with OpenAI Fallback Mechanism Added
|
||||
- [v1.4.232](https://github.com/danielmiessler/fabric/releases/tag/v1.4.232) (Jul 6, 2025) — **Add Custom**: Add Custom Patterns Directory Support
|
||||
- [v1.4.231](https://github.com/danielmiessler/fabric/releases/tag/v1.4.231) (Jul 5, 2025) — **OAuth Auto-Auth**: OAuth Authentication Support for Anthropic (Use your Max Subscription)
|
||||
- [v1.4.230](https://github.com/danielmiessler/fabric/releases/tag/v1.4.230) (Jul 5, 2025) — **Model Management**: Add advanced image generation parameters for OpenAI models with four new CLI flags
|
||||
- [v1.4.227](https://github.com/danielmiessler/fabric/releases/tag/v1.4.227) (Jul 4, 2025) — **Add Image**: Add Image Generation Support to Fabric
|
||||
- [v1.4.226](https://github.com/danielmiessler/fabric/releases/tag/v1.4.226) (Jul 4, 2025) — **Web Search**: OpenAI Plugin Now Supports Web Search Functionality
|
||||
|
||||
7
cmd/generate_changelog/incoming/1978.txt
Normal file
7
cmd/generate_changelog/incoming/1978.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
### PR [#1978](https://github.com/danielmiessler/Fabric/pull/1978) by [ksylvan](https://github.com/ksylvan): chore: remove OAuth support from Anthropic client
|
||||
|
||||
- Remove OAuth support from Anthropic client and delete related OAuth files
|
||||
- Simplify configuration handling to check only API key instead of OAuth credentials
|
||||
- Clean up imports and unused variables in anthropic.go
|
||||
- Update server configuration methods to remove OAuth references
|
||||
- Remove OAuth-related environment variables from configuration
|
||||
@@ -3,7 +3,6 @@ package anthropic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
neturl "net/url"
|
||||
"os"
|
||||
"path"
|
||||
@@ -16,7 +15,6 @@ import (
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
debuglog "github.com/danielmiessler/fabric/internal/log"
|
||||
"github.com/danielmiessler/fabric/internal/plugins"
|
||||
"github.com/danielmiessler/fabric/internal/util"
|
||||
)
|
||||
|
||||
const defaultBaseUrl = "https://api.anthropic.com/"
|
||||
@@ -25,8 +23,6 @@ const webSearchToolName = "web_search"
|
||||
const webSearchToolType = "web_search_20250305"
|
||||
const sourcesHeader = "## Sources"
|
||||
|
||||
const authTokenIdentifier = "claude"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Anthropic"
|
||||
ret = &Client{}
|
||||
@@ -35,7 +31,6 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
|
||||
ret.maxTokens = 4096
|
||||
@@ -64,35 +59,13 @@ func NewClient() (ret *Client) {
|
||||
return
|
||||
}
|
||||
|
||||
// IsConfigured returns true if either the API key or OAuth is configured
|
||||
// IsConfigured returns true if the API key is configured
|
||||
func (an *Client) IsConfigured() bool {
|
||||
// Check if API key is configured
|
||||
if an.ApiKey.Value != "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if OAuth is enabled and has a valid token
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
storage, err := util.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// If no valid token exists, automatically run OAuth flow
|
||||
if !storage.HasValidToken(authTokenIdentifier, 5) {
|
||||
fmt.Println("OAuth enabled but no valid token found. Starting authentication...")
|
||||
_, err := RunOAuthFlow(authTokenIdentifier)
|
||||
if err != nil {
|
||||
fmt.Printf("OAuth authentication failed: %v\n", err)
|
||||
return false
|
||||
}
|
||||
// After successful OAuth flow, check again
|
||||
return storage.HasValidToken(authTokenIdentifier, 5)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -100,7 +73,6 @@ type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -115,21 +87,6 @@ func (an *Client) Setup() (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := util.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !storage.HasValidToken(authTokenIdentifier, 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(authTokenIdentifier); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = an.configure()
|
||||
return
|
||||
}
|
||||
@@ -141,17 +98,7 @@ func (an *Client) configure() (err error) {
|
||||
opts = append(opts, option.WithBaseURL(an.ApiBaseURL.Value))
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: NewOAuthTransport(an, baseTransport),
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
|
||||
an.client = anthropic.NewClient(opts...)
|
||||
return
|
||||
@@ -264,17 +211,6 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *domain
|
||||
params.Temperature = anthropic.Opt(opts.Temperature)
|
||||
}
|
||||
|
||||
// Add Claude Code spoofing system message for OAuth authentication
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
params.System = []anthropic.TextBlockParam{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
debuglog "github.com/danielmiessler/fabric/internal/log"
|
||||
"github.com/danielmiessler/fabric/internal/util"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth configuration constants
|
||||
const (
|
||||
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauthAuthURL = "https://claude.ai/oauth/authorize"
|
||||
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
|
||||
type OAuthTransport struct {
|
||||
client *Client
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken(authTokenIdentifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// Add the anthropic-beta header for OAuth, preserving existing betas
|
||||
existing := newReq.Header.Get("anthropic-beta")
|
||||
beta := "oauth-2025-04-20"
|
||||
if existing != "" {
|
||||
beta = existing + "," + beta
|
||||
}
|
||||
newReq.Header.Set("anthropic-beta", beta)
|
||||
|
||||
// Set User-Agent to match AI SDK exactly
|
||||
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
|
||||
|
||||
// Remove x-api-key header if present (OAuth doesn't use it)
|
||||
newReq.Header.Del("x-api-key")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken(tokenIdentifier string) (string, error) {
|
||||
storage, err := util.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
debuglog.Log("No OAuth token found, initiating authentication...\n")
|
||||
newAccessToken, err := RunOAuthFlow(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
debuglog.Log("OAuth token expired, refreshing...\n")
|
||||
newAccessToken, err := RefreshToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
debuglog.Log("Token refresh failed, re-authenticating...\n")
|
||||
newAccessToken, err = RunOAuthFlow(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// NewOAuthTransport creates a new OAuth transport for the given client
|
||||
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
|
||||
return &OAuthTransport{
|
||||
client: client,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return
|
||||
}
|
||||
|
||||
// openBrowser attempts to open the given URL in the default browser
|
||||
func openBrowser(url string) {
|
||||
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
|
||||
for _, cmd := range commands {
|
||||
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow(tokenIdentifier string) (token string, err error) {
|
||||
// First check if we have an existing token that can be refreshed
|
||||
storage, err := util.NewOAuthStorage()
|
||||
if err == nil {
|
||||
existingToken, err := storage.LoadToken(tokenIdentifier)
|
||||
if err == nil && existingToken != nil {
|
||||
// If token exists but is expired, try refreshing first
|
||||
if existingToken.IsExpired(5) {
|
||||
debuglog.Log("Found expired OAuth token, attempting refresh...\n")
|
||||
refreshedToken, refreshErr := RefreshToken(tokenIdentifier)
|
||||
if refreshErr == nil {
|
||||
debuglog.Log("Token refresh successful\n")
|
||||
return refreshedToken, nil
|
||||
}
|
||||
debuglog.Log("Token refresh failed (%v), proceeding with full OAuth flow...\n", refreshErr)
|
||||
} else {
|
||||
// Token exists and is still valid
|
||||
return existingToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
debuglog.Log("Open the following URL in your browser. Fabric would like to authorize:\n")
|
||||
debuglog.Log("%s\n", authURL)
|
||||
openBrowser(authURL)
|
||||
debuglog.Log("Paste the authorization code here: ")
|
||||
var code string
|
||||
fmt.Scanln(&code)
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenIdentifier, tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(tokenIdentifier string, params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := util.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &util.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken(tokenIdentifier, oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken(tokenIdentifier string) (string, error) {
|
||||
storage, err := util.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &util.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken(tokenIdentifier, newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
@@ -1,433 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
// OAuth Testing Strategy:
|
||||
//
|
||||
// This test suite covers OAuth functionality while avoiding real external calls.
|
||||
// Key principles:
|
||||
// 1. Never trigger real OAuth flows that would open browsers or call external APIs
|
||||
// 2. Use temporary directories and mock tokens for isolated testing
|
||||
// 3. Skip integration tests that would require real OAuth servers
|
||||
// 4. Test error paths and edge cases safely
|
||||
//
|
||||
// Tests are categorized as:
|
||||
// - Unit tests: Test individual functions with mocked data (SAFE)
|
||||
// - Integration tests: Would require real OAuth servers (SKIPPED)
|
||||
// - Error path tests: Test failure scenarios safely (SAFE)
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/util"
|
||||
)
|
||||
|
||||
// createTestToken creates a test OAuth token
|
||||
func createTestToken(accessToken, refreshToken string, expiresIn int64) *util.OAuthToken {
|
||||
return &util.OAuthToken{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: time.Now().Unix() + expiresIn,
|
||||
TokenType: "Bearer",
|
||||
Scope: "org:create_api_key user:profile user:inference",
|
||||
}
|
||||
}
|
||||
|
||||
// createExpiredToken creates an expired test token
|
||||
func createExpiredToken(accessToken, refreshToken string) *util.OAuthToken {
|
||||
return &util.OAuthToken{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: time.Now().Unix() - 3600, // Expired 1 hour ago
|
||||
TokenType: "Bearer",
|
||||
Scope: "org:create_api_key user:profile user:inference",
|
||||
}
|
||||
}
|
||||
|
||||
// mockTokenServer creates a mock OAuth token server for testing
|
||||
func mockTokenServer(_ *testing.T, responses map[string]any) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/oauth/token" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]string
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := req["grant_type"]
|
||||
response, exists := responses[grantType]
|
||||
if !exists {
|
||||
http.Error(w, "Unsupported grant type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if errorResp, ok := response.(map[string]any); ok && errorResp["error"] != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestGeneratePKCE(t *testing.T) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if verifier == "" {
|
||||
t.Error("Expected non-empty verifier")
|
||||
}
|
||||
|
||||
if challenge == "" {
|
||||
t.Error("Expected non-empty challenge")
|
||||
}
|
||||
|
||||
if len(verifier) < 43 { // Base64 encoded 32 bytes should be at least 43 chars
|
||||
t.Errorf("Verifier too short: %d chars", len(verifier))
|
||||
}
|
||||
|
||||
if len(challenge) < 43 { // SHA256 hash should be at least 43 chars when base64 encoded
|
||||
t.Errorf("Challenge too short: %d chars", len(challenge))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeToken_Success(t *testing.T) {
|
||||
// Create mock server
|
||||
server := mockTokenServer(t, map[string]any{
|
||||
"authorization_code": map[string]any{
|
||||
"access_token": "test_access_token",
|
||||
"refresh_token": "test_refresh_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "org:create_api_key user:profile user:inference",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary directory for token storage
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Mock the storage creation to use our temp directory
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
|
||||
// Set up a fake home directory
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
os.MkdirAll(filepath.Join(fakeHome, ".config", "fabric"), 0755)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// This test would need the actual exchangeToken function to be modified to accept a custom URL
|
||||
// For now, we'll test the logic without the actual HTTP call
|
||||
t.Skip("Skipping integration test - would need URL injection for proper testing")
|
||||
}
|
||||
func TestRefreshToken_Success(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create an expired token
|
||||
expiredToken := createExpiredToken("old_access_token", "valid_refresh_token")
|
||||
|
||||
// Save the expired token
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(expiredToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Create mock server for refresh
|
||||
server := mockTokenServer(t, map[string]any{
|
||||
"refresh_token": map[string]any{
|
||||
"access_token": "new_access_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "org:create_api_key user:profile user:inference",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// This test would need the RefreshToken function to accept a custom URL
|
||||
t.Skip("Skipping integration test - would need URL injection for proper testing")
|
||||
}
|
||||
|
||||
func TestRefreshToken_NoRefreshToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a token without refresh token
|
||||
tokenWithoutRefresh := &util.OAuthToken{
|
||||
AccessToken: "access_token",
|
||||
RefreshToken: "", // No refresh token
|
||||
ExpiresAt: time.Now().Unix() - 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "org:create_api_key user:profile user:inference",
|
||||
}
|
||||
|
||||
// Save the token
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(tokenWithoutRefresh, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Test RefreshToken
|
||||
_, err := RefreshToken("test")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when no refresh token available")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no refresh token available") {
|
||||
t.Errorf("Expected 'no refresh token available' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_NoStoredToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Don't create any token file
|
||||
|
||||
// Test RefreshToken
|
||||
_, err := RefreshToken("nonexistent")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when no token stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthTransport_RoundTrip(t *testing.T) {
|
||||
// Create a mock client
|
||||
client := &Client{}
|
||||
|
||||
// Create the transport
|
||||
transport := NewOAuthTransport(client, http.DefaultTransport)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "https://api.anthropic.com/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "should-be-removed")
|
||||
|
||||
// Create temporary directory and set up fake home with valid token
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a valid token
|
||||
validToken := createTestToken("valid_access_token", "refresh_token", 3600)
|
||||
tokenPath := filepath.Join(configDir, fmt.Sprintf(".%s_oauth", authTokenIdentifier))
|
||||
data, _ := json.MarshalIndent(validToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Create a mock server to handle the request
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that OAuth headers are set correctly
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "Bearer valid_access_token" {
|
||||
t.Errorf("Expected 'Bearer valid_access_token', got '%s'", auth)
|
||||
}
|
||||
|
||||
beta := r.Header.Get("anthropic-beta")
|
||||
if beta != "oauth-2025-04-20" {
|
||||
t.Errorf("Expected 'oauth-2025-04-20', got '%s'", beta)
|
||||
}
|
||||
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
if userAgent != "ai-sdk/anthropic" {
|
||||
t.Errorf("Expected 'ai-sdk/anthropic', got '%s'", userAgent)
|
||||
}
|
||||
|
||||
// Check that x-api-key header is removed
|
||||
if r.Header.Get("x-api-key") != "" {
|
||||
t.Error("Expected x-api-key header to be removed")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Update the request URL to point to our mock server
|
||||
req.URL.Host = strings.TrimPrefix(server.URL, "http://")
|
||||
req.URL.Scheme = "http"
|
||||
|
||||
// Execute the request
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("RoundTrip failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOAuthFlow_ExistingValidToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a valid token
|
||||
validToken := createTestToken("existing_valid_token", "refresh_token", 3600)
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(validToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Test RunOAuthFlow - should return existing token without starting OAuth flow
|
||||
token, err := RunOAuthFlow("test")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if token != "existing_valid_token" {
|
||||
t.Errorf("Expected 'existing_valid_token', got '%s'", token)
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestCreateTestToken(t *testing.T) {
|
||||
token := createTestToken("access", "refresh", 3600)
|
||||
|
||||
if token.AccessToken != "access" {
|
||||
t.Errorf("Expected access token 'access', got '%s'", token.AccessToken)
|
||||
}
|
||||
|
||||
if token.RefreshToken != "refresh" {
|
||||
t.Errorf("Expected refresh token 'refresh', got '%s'", token.RefreshToken)
|
||||
}
|
||||
|
||||
if token.IsExpired(5) {
|
||||
t.Error("Expected token to not be expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateExpiredToken(t *testing.T) {
|
||||
token := createExpiredToken("access", "refresh")
|
||||
|
||||
if !token.IsExpired(5) {
|
||||
t.Error("Expected token to be expired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenExpirationLogic tests the token expiration detection without OAuth flows
|
||||
func TestTokenExpirationLogic(t *testing.T) {
|
||||
// Test valid token
|
||||
validToken := createTestToken("access", "refresh", 3600)
|
||||
if validToken.IsExpired(5) {
|
||||
t.Error("Valid token should not be expired")
|
||||
}
|
||||
|
||||
// Test expired token
|
||||
expiredToken := createExpiredToken("access", "refresh")
|
||||
if !expiredToken.IsExpired(5) {
|
||||
t.Error("Expired token should be expired")
|
||||
}
|
||||
|
||||
// Test token expiring soon (within buffer)
|
||||
soonExpiredToken := createTestToken("access", "refresh", 240) // 4 minutes
|
||||
if !soonExpiredToken.IsExpired(5) { // 5 minute buffer
|
||||
t.Error("Token expiring within buffer should be considered expired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetValidTokenWithValidToken tests the getValidToken method with a valid token
|
||||
func TestGetValidTokenWithValidToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a valid token
|
||||
validToken := createTestToken("valid_access_token", "refresh_token", 3600)
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(validToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Create transport
|
||||
client := &Client{}
|
||||
transport := NewOAuthTransport(client, http.DefaultTransport)
|
||||
|
||||
// Test getValidToken - this should return the valid token without any OAuth flow
|
||||
token, err := transport.getValidToken("test")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error with valid token, got: %v", err)
|
||||
}
|
||||
|
||||
if token != "valid_access_token" {
|
||||
t.Errorf("Expected 'valid_access_token', got '%s'", token)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGeneratePKCE(b *testing.B) {
|
||||
for b.Loop() {
|
||||
_, _, err := generatePKCE()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenIsExpired(b *testing.B) {
|
||||
token := createTestToken("access", "refresh", 3600)
|
||||
|
||||
for b.Loop() {
|
||||
token.IsExpired(5)
|
||||
}
|
||||
}
|
||||
@@ -57,18 +57,17 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
config := map[string]string{
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic_use_oauth_login": os.Getenv("ANTHROPIC_USE_OAUTH_LOGIN"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -81,18 +80,17 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
var config struct {
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
AnthropicUseAuthToken string `json:"anthropic_use_auth_token"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&config); err != nil {
|
||||
@@ -101,18 +99,17 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"ANTHROPIC_USE_OAUTH_LOGIN": config.AnthropicUseAuthToken,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
Reference in New Issue
Block a user