mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-10 14:58:02 -05:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d081fd269c | ||
|
|
369a0a850d | ||
|
|
8dc5343ee6 | ||
|
|
eda552dac5 | ||
|
|
f13a56685b | ||
|
|
2f9afe0247 | ||
|
|
1ec525ad97 | ||
|
|
b7dc6748e0 | ||
|
|
f1b612d828 | ||
|
|
eac5a104f2 | ||
|
|
4bff88fae3 |
124
common/oauth_storage.go
Normal file
124
common/oauth_storage.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthToken represents stored OAuth token information
|
||||
type OAuthToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// IsExpired checks if the token is expired or will expire within the buffer time
|
||||
func (t *OAuthToken) IsExpired(bufferMinutes int) bool {
|
||||
if t.ExpiresAt == 0 {
|
||||
return true
|
||||
}
|
||||
bufferTime := time.Duration(bufferMinutes) * time.Minute
|
||||
return time.Now().Add(bufferTime).Unix() >= t.ExpiresAt
|
||||
}
|
||||
|
||||
// OAuthStorage handles persistent storage of OAuth tokens
|
||||
type OAuthStorage struct {
|
||||
configDir string
|
||||
}
|
||||
|
||||
// NewOAuthStorage creates a new OAuth storage instance
|
||||
func NewOAuthStorage() (*OAuthStorage, error) {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user home directory: %w", err)
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".config", "fabric")
|
||||
|
||||
// Ensure config directory exists
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
return &OAuthStorage{configDir: configDir}, nil
|
||||
}
|
||||
|
||||
// GetTokenPath returns the file path for a provider's OAuth token
|
||||
func (s *OAuthStorage) GetTokenPath(provider string) string {
|
||||
return filepath.Join(s.configDir, fmt.Sprintf(".%s_oauth", provider))
|
||||
}
|
||||
|
||||
// SaveToken saves an OAuth token to disk with proper permissions
|
||||
func (s *OAuthStorage) SaveToken(provider string, token *OAuthToken) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Marshal token to JSON
|
||||
data, err := json.MarshalIndent(token, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal token: %w", err)
|
||||
}
|
||||
|
||||
// Write to temporary file first for atomic operation
|
||||
tempPath := tokenPath + ".tmp"
|
||||
if err := os.WriteFile(tempPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write token file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, tokenPath); err != nil {
|
||||
os.Remove(tempPath) // Clean up temp file
|
||||
return fmt.Errorf("failed to save token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadToken loads an OAuth token from disk
|
||||
func (s *OAuthStorage) LoadToken(provider string) (*OAuthToken, error) {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
// Check if file exists
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
return nil, nil // No token stored
|
||||
}
|
||||
|
||||
// Read token file
|
||||
data, err := os.ReadFile(tokenPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read token file: %w", err)
|
||||
}
|
||||
|
||||
// Unmarshal token
|
||||
var token OAuthToken
|
||||
if err := json.Unmarshal(data, &token); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token file: %w", err)
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
// DeleteToken removes a stored OAuth token
|
||||
func (s *OAuthStorage) DeleteToken(provider string) error {
|
||||
tokenPath := s.GetTokenPath(provider)
|
||||
|
||||
if err := os.Remove(tokenPath); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("failed to delete token file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasValidToken checks if a valid (non-expired) token exists for a provider
|
||||
func (s *OAuthStorage) HasValidToken(provider string, bufferMinutes int) bool {
|
||||
token, err := s.LoadToken(provider)
|
||||
if err != nil || token == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return !token.IsExpired(bufferMinutes)
|
||||
}
|
||||
232
common/oauth_storage_test.go
Normal file
232
common/oauth_storage_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOAuthToken_IsExpired(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresAt int64
|
||||
bufferMinutes int
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "token not expired",
|
||||
expiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
bufferMinutes: 5,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "token expired",
|
||||
expiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "token expires within buffer",
|
||||
expiresAt: time.Now().Unix() + 120, // 2 minutes from now
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "zero expiry time",
|
||||
expiresAt: 0,
|
||||
bufferMinutes: 5,
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := &OAuthToken{ExpiresAt: tt.expiresAt}
|
||||
if got := token.IsExpired(tt.bufferMinutes); got != tt.expected {
|
||||
t.Errorf("IsExpired() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_SaveAndLoadToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
// Create storage with custom config dir
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_access_token",
|
||||
RefreshToken: "test_refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "test_scope",
|
||||
}
|
||||
|
||||
// Test saving token
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists and has correct permissions
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
info, err := os.Stat(tokenPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Token file not created: %v", err)
|
||||
}
|
||||
if info.Mode().Perm() != 0600 {
|
||||
t.Errorf("Token file has wrong permissions: %v, want 0600", info.Mode().Perm())
|
||||
}
|
||||
|
||||
// Test loading token
|
||||
loadedToken, err := storage.LoadToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load token: %v", err)
|
||||
}
|
||||
if loadedToken == nil {
|
||||
t.Fatal("Loaded token is nil")
|
||||
}
|
||||
|
||||
// Verify token data
|
||||
if loadedToken.AccessToken != token.AccessToken {
|
||||
t.Errorf("AccessToken mismatch: got %v, want %v", loadedToken.AccessToken, token.AccessToken)
|
||||
}
|
||||
if loadedToken.RefreshToken != token.RefreshToken {
|
||||
t.Errorf("RefreshToken mismatch: got %v, want %v", loadedToken.RefreshToken, token.RefreshToken)
|
||||
}
|
||||
if loadedToken.ExpiresAt != token.ExpiresAt {
|
||||
t.Errorf("ExpiresAt mismatch: got %v, want %v", loadedToken.ExpiresAt, token.ExpiresAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_LoadNonExistentToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Try to load non-existent token
|
||||
token, err := storage.LoadToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error loading non-existent token: %v", err)
|
||||
}
|
||||
if token != nil {
|
||||
t.Error("Expected nil token for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_DeleteToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Create and save a token
|
||||
token := &OAuthToken{
|
||||
AccessToken: "test_token",
|
||||
RefreshToken: "test_refresh",
|
||||
ExpiresAt: time.Now().Unix() + 3600,
|
||||
}
|
||||
err = storage.SaveToken("test_provider", token)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token exists
|
||||
tokenPath := storage.GetTokenPath("test_provider")
|
||||
if _, err := os.Stat(tokenPath); os.IsNotExist(err) {
|
||||
t.Fatal("Token file should exist before deletion")
|
||||
}
|
||||
|
||||
// Delete token
|
||||
err = storage.DeleteToken("test_provider")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to delete token: %v", err)
|
||||
}
|
||||
|
||||
// Verify token is deleted
|
||||
if _, err := os.Stat(tokenPath); !os.IsNotExist(err) {
|
||||
t.Error("Token file should not exist after deletion")
|
||||
}
|
||||
|
||||
// Test deleting non-existent token (should not error)
|
||||
err = storage.DeleteToken("nonexistent")
|
||||
if err != nil {
|
||||
t.Errorf("Deleting non-existent token should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_HasValidToken(t *testing.T) {
|
||||
// Create temporary directory for testing
|
||||
tempDir, err := os.MkdirTemp("", "fabric_oauth_test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
storage := &OAuthStorage{configDir: tempDir}
|
||||
|
||||
// Test with no token
|
||||
if storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return false when no token exists")
|
||||
}
|
||||
|
||||
// Save valid token
|
||||
validToken := &OAuthToken{
|
||||
AccessToken: "valid_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() + 3600, // 1 hour from now
|
||||
}
|
||||
err = storage.SaveToken("test_provider", validToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save valid token: %v", err)
|
||||
}
|
||||
|
||||
// Test with valid token
|
||||
if !storage.HasValidToken("test_provider", 5) {
|
||||
t.Error("Should return true for valid token")
|
||||
}
|
||||
|
||||
// Save expired token
|
||||
expiredToken := &OAuthToken{
|
||||
AccessToken: "expired_token",
|
||||
RefreshToken: "refresh_token",
|
||||
ExpiresAt: time.Now().Unix() - 3600, // 1 hour ago
|
||||
}
|
||||
err = storage.SaveToken("expired_provider", expiredToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save expired token: %v", err)
|
||||
}
|
||||
|
||||
// Test with expired token
|
||||
if storage.HasValidToken("expired_provider", 5) {
|
||||
t.Error("Should return false for expired token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthStorage_GetTokenPath(t *testing.T) {
|
||||
storage := &OAuthStorage{configDir: "/test/config"}
|
||||
|
||||
expected := filepath.Join("/test/config", ".test_provider_oauth")
|
||||
actual := storage.GetTokenPath("test_provider")
|
||||
|
||||
if actual != expected {
|
||||
t.Errorf("GetTokenPath() = %v, want %v", actual, expected)
|
||||
}
|
||||
}
|
||||
2
go.mod
2
go.mod
@@ -25,6 +25,7 @@ require (
|
||||
github.com/samber/lo v1.50.0
|
||||
github.com/sgaunet/perplexity-go/v2 v2.8.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/text v0.26.0
|
||||
google.golang.org/api v0.236.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -108,7 +109,6 @@ require (
|
||||
golang.org/x/arch v0.18.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/time v0.12.0 // indirect
|
||||
|
||||
@@ -1 +1 @@
|
||||
"1.4.230"
|
||||
"1.4.231"
|
||||
|
||||
@@ -3,6 +3,7 @@ package anthropic
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
@@ -30,7 +31,12 @@ func NewClient() (ret *Client) {
|
||||
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
ret.UseOAuth = ret.AddSetupQuestionBool("Use OAuth login", false)
|
||||
if plugins.ParseBoolElseFalse(ret.UseOAuth.Value) {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
||||
} else {
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
}
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -50,6 +56,7 @@ type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseOAuth *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -58,24 +65,50 @@ type Client struct {
|
||||
client anthropic.Client
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
baseURL := an.ApiBaseURL.Value
|
||||
func (an *Client) Setup() (err error) {
|
||||
if err = an.PluginBase.Ask(an.Name); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// As of 2.0beta1, using v2 API endpoint.
|
||||
// https://github.com/anthropics/anthropic-sdk-go/blob/main/CHANGELOG.md#020-beta1-2025-03-25
|
||||
if strings.Contains(baseURL, "-") && !strings.HasSuffix(baseURL, "/v2") {
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
baseURL = baseURL + "/v2"
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// Check if we have a valid stored token
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(
|
||||
option.WithAPIKey(an.ApiKey.Value),
|
||||
option.WithBaseURL(baseURL),
|
||||
)
|
||||
} else {
|
||||
an.client = anthropic.NewClient(option.WithAPIKey(an.ApiKey.Value))
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = an.configure()
|
||||
return
|
||||
}
|
||||
|
||||
func (an *Client) configure() (err error) {
|
||||
opts := []option.RequestOption{}
|
||||
|
||||
if an.ApiBaseURL.Value != "" {
|
||||
opts = append(opts, option.WithBaseURL(an.ApiBaseURL.Value))
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
// For OAuth, use Bearer token with custom headers
|
||||
// Create custom HTTP client that adds OAuth Bearer token and beta header
|
||||
baseTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: NewOAuthTransport(an, baseTransport),
|
||||
}
|
||||
opts = append(opts, option.WithHTTPClient(httpClient))
|
||||
} else {
|
||||
opts = append(opts, option.WithAPIKey(an.ApiKey.Value))
|
||||
}
|
||||
|
||||
an.client = anthropic.NewClient(opts...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -124,6 +157,17 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
// Add Claude Code spoofing system message for OAuth authentication
|
||||
if plugins.ParseBoolElseFalse(an.UseOAuth.Value) {
|
||||
params.System = []anthropic.TextBlockParam{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
@@ -207,6 +251,9 @@ func (an *Client) toMessages(msgs []*chat.ChatCompletionMessage) (ret []anthropi
|
||||
|
||||
var anthropicMessages []anthropic.MessageParam
|
||||
var systemContent string
|
||||
|
||||
// Note: Claude Code spoofing is now handled in buildMessageParams
|
||||
|
||||
isFirstUserMessage := true
|
||||
lastRoleWasUser := false
|
||||
|
||||
|
||||
300
plugins/ai/anthropic/oauth.go
Normal file
300
plugins/ai/anthropic/oauth.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth configuration constants
|
||||
const (
|
||||
oauthClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauthAuthURL = "https://claude.ai/oauth/authorize"
|
||||
oauthTokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
oauthRedirectURL = "https://console.anthropic.com/oauth/code/callback"
|
||||
)
|
||||
|
||||
// OAuthTransport is a custom HTTP transport that adds OAuth Bearer token and beta header
|
||||
type OAuthTransport struct {
|
||||
client *Client
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Clone the request to avoid modifying the original
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
|
||||
// Add OAuth Bearer token
|
||||
newReq.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
// Add the anthropic-beta header for OAuth
|
||||
newReq.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// Set User-Agent to match AI SDK exactly
|
||||
newReq.Header.Set("User-Agent", "ai-sdk/anthropic")
|
||||
|
||||
// Remove x-api-key header if present (OAuth doesn't use it)
|
||||
newReq.Header.Del("x-api-key")
|
||||
|
||||
return t.base.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
}
|
||||
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// NewOAuthTransport creates a new OAuth transport for the given client
|
||||
func NewOAuthTransport(client *Client, base http.RoundTripper) *OAuthTransport {
|
||||
return &OAuthTransport{
|
||||
client: client,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
|
||||
// generatePKCE generates PKCE code verifier and challenge
|
||||
func generatePKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err = rand.Read(b); err != nil {
|
||||
return
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum := sha256.Sum256([]byte(verifier))
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum[:])
|
||||
return
|
||||
}
|
||||
|
||||
// openBrowser attempts to open the given URL in the default browser
|
||||
func openBrowser(url string) {
|
||||
commands := [][]string{{"xdg-open", url}, {"open", url}, {"cmd", "/c", "start", url}}
|
||||
for _, cmd := range commands {
|
||||
if exec.Command(cmd[0], cmd[1:]...).Start() == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
Endpoint: oauth2.Endpoint{AuthURL: oauthAuthURL, TokenURL: oauthTokenURL},
|
||||
RedirectURL: oauthRedirectURL,
|
||||
Scopes: []string{"org:create_api_key", "user:profile", "user:inference"},
|
||||
}
|
||||
|
||||
authURL := cfg.AuthCodeURL(verifier,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("code", "true"),
|
||||
oauth2.SetAuthURLParam("state", verifier),
|
||||
)
|
||||
|
||||
fmt.Println("Open the following URL in your browser. Fabric would like to authorize:")
|
||||
fmt.Println(authURL)
|
||||
openBrowser(authURL)
|
||||
fmt.Print("Paste the authorization code here: ")
|
||||
var code string
|
||||
fmt.Scanln(&code)
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
state := verifier
|
||||
if len(parts) == 2 {
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
// Manual token exchange to match opencode implementation
|
||||
tokenReq := map[string]string{
|
||||
"code": parts[0],
|
||||
"state": state,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauthClientID,
|
||||
"redirect_uri": oauthRedirectURL,
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
err = fmt.Errorf("token exchange failed: %s - %s", resp.Status, string(body))
|
||||
return
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save the complete token information
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
oauthToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
token = result.AccessToken
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
if token == nil || token.RefreshToken == "" {
|
||||
return "", fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Prepare refresh request
|
||||
refreshReq := map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token.RefreshToken,
|
||||
"client_id": oauthClientID,
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(refreshReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal refresh request: %w", err)
|
||||
}
|
||||
|
||||
// Make refresh request
|
||||
resp, err := http.Post(oauthTokenURL, "application/json", bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("refresh request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token refresh failed: %s - %s", resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
// Update stored token
|
||||
newToken := &common.OAuthToken{
|
||||
AccessToken: result.AccessToken,
|
||||
RefreshToken: result.RefreshToken,
|
||||
ExpiresAt: time.Now().Unix() + int64(result.ExpiresIn),
|
||||
TokenType: result.TokenType,
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
// Use existing refresh token if new one not provided
|
||||
if newToken.RefreshToken == "" {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
return result.AccessToken, nil
|
||||
}
|
||||
@@ -57,17 +57,18 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
config := map[string]string{
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
"openai": os.Getenv("OPENAI_API_KEY"),
|
||||
"anthropic": os.Getenv("ANTHROPIC_API_KEY"),
|
||||
"anthropic_use_oauth_login": os.Getenv("ANTHROPIC_USE_OAUTH_LOGIN"),
|
||||
"groq": os.Getenv("GROQ_API_KEY"),
|
||||
"mistral": os.Getenv("MISTRAL_API_KEY"),
|
||||
"gemini": os.Getenv("GEMINI_API_KEY"),
|
||||
"ollama": os.Getenv("OLLAMA_URL"),
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"grokai": os.Getenv("GROKAI_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -80,17 +81,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
var config struct {
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
OpenAIApiKey string `json:"openai_api_key"`
|
||||
AnthropicApiKey string `json:"anthropic_api_key"`
|
||||
AnthropicUseAuthToken string `json:"anthropic_use_auth_token"`
|
||||
GroqApiKey string `json:"groq_api_key"`
|
||||
MistralApiKey string `json:"mistral_api_key"`
|
||||
GeminiApiKey string `json:"gemini_api_key"`
|
||||
OllamaURL string `json:"ollama_url"`
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
GrokaiApiKey string `json:"grokai_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&config); err != nil {
|
||||
@@ -99,17 +101,18 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"ANTHROPIC_USE_OAUTH_LOGIN": config.AnthropicUseAuthToken,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"GROKAI_API_KEY": config.GrokaiApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package main
|
||||
|
||||
var version = "v1.4.230"
|
||||
var version = "v1.4.231"
|
||||
|
||||
Reference in New Issue
Block a user