mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 22:38:10 -05:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
265f2b807e | ||
|
|
dc63e0d1cc | ||
|
|
75842d8610 | ||
|
|
bcd4c6caea | ||
|
|
a6a63698e1 | ||
|
|
0528556b5c | ||
|
|
47cf24e19d | ||
|
|
3f07afbef4 | ||
|
|
38d714dccd | ||
|
|
d0b5c95d61 | ||
|
|
f8f80ca206 | ||
|
|
0af458872f | ||
|
|
24e46a6f37 | ||
|
|
d6a31e68b0 | ||
|
|
b1013ca61b | ||
|
|
6b4ce946a5 | ||
|
|
2d2830e9c8 | ||
|
|
115327fdab | ||
|
|
e672f9b73f | ||
|
|
ef4364a1aa |
@@ -66,17 +66,35 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (s
|
||||
message := ""
|
||||
|
||||
if o.Stream {
|
||||
channel := make(chan string)
|
||||
responseChan := make(chan string)
|
||||
errChan := make(chan error, 1)
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
if streamErr := o.vendor.SendStream(session.GetVendorMessages(), opts, channel); streamErr != nil {
|
||||
channel <- streamErr.Error()
|
||||
defer close(done)
|
||||
if streamErr := o.vendor.SendStream(session.GetVendorMessages(), opts, responseChan); streamErr != nil {
|
||||
errChan <- streamErr
|
||||
}
|
||||
}()
|
||||
|
||||
for response := range channel {
|
||||
for response := range responseChan {
|
||||
message += response
|
||||
fmt.Print(response)
|
||||
}
|
||||
|
||||
// Wait for goroutine to finish
|
||||
<-done
|
||||
|
||||
// Check for errors in errChan
|
||||
select {
|
||||
case streamErr := <-errChan:
|
||||
if streamErr != nil {
|
||||
err = streamErr
|
||||
return
|
||||
}
|
||||
default:
|
||||
// No errors, continue
|
||||
}
|
||||
} else {
|
||||
if message, err = o.vendor.Send(context.Background(), session.GetVendorMessages(), opts); err != nil {
|
||||
return
|
||||
|
||||
181
core/chatter_test.go
Normal file
181
core/chatter_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins/db/fsdb"
|
||||
)
|
||||
|
||||
// mockVendor implements the ai.Vendor interface for testing
|
||||
type mockVendor struct {
|
||||
sendStreamError error
|
||||
streamChunks []string
|
||||
}
|
||||
|
||||
func (m *mockVendor) GetName() string {
|
||||
return "mock"
|
||||
}
|
||||
|
||||
func (m *mockVendor) GetSetupDescription() string {
|
||||
return "mock vendor"
|
||||
}
|
||||
|
||||
func (m *mockVendor) IsConfigured() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockVendor) Configure() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVendor) Setup() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockVendor) SetupFillEnvFileContent(*bytes.Buffer) {
|
||||
}
|
||||
|
||||
func (m *mockVendor) ListModels() ([]string, error) {
|
||||
return []string{"test-model"}, nil
|
||||
}
|
||||
|
||||
func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *common.ChatOptions, responseChan chan string) error {
|
||||
// Send chunks if provided (for successful streaming test)
|
||||
if m.streamChunks != nil {
|
||||
for _, chunk := range m.streamChunks {
|
||||
responseChan <- chunk
|
||||
}
|
||||
}
|
||||
// Close the channel like real vendors do
|
||||
close(responseChan)
|
||||
return m.sendStreamError
|
||||
}
|
||||
|
||||
func (m *mockVendor) Send(ctx context.Context, messages []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
return "test response", nil
|
||||
}
|
||||
|
||||
func (m *mockVendor) NeedsRawMode(modelName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestChatter_Send_StreamingErrorPropagation(t *testing.T) {
|
||||
// Create a temporary database for testing
|
||||
tempDir := t.TempDir()
|
||||
db := fsdb.NewDb(tempDir)
|
||||
|
||||
// Create a mock vendor that will return an error from SendStream
|
||||
expectedError := errors.New("streaming error")
|
||||
mockVendor := &mockVendor{
|
||||
sendStreamError: expectedError,
|
||||
}
|
||||
|
||||
// Create chatter with streaming enabled
|
||||
chatter := &Chatter{
|
||||
db: db,
|
||||
Stream: true, // Enable streaming to trigger SendStream path
|
||||
vendor: mockVendor,
|
||||
model: "test-model",
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
request := &common.ChatRequest{
|
||||
Message: &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
Content: "test message",
|
||||
},
|
||||
}
|
||||
|
||||
// Create test options
|
||||
opts := &common.ChatOptions{
|
||||
Model: "test-model",
|
||||
}
|
||||
|
||||
// Call Send and expect it to return the streaming error
|
||||
session, err := chatter.Send(request, opts)
|
||||
|
||||
// Verify that the error from SendStream is propagated
|
||||
if err == nil {
|
||||
t.Fatal("Expected error to be returned, but got nil")
|
||||
}
|
||||
|
||||
if !errors.Is(err, expectedError) {
|
||||
t.Errorf("Expected error %q, but got %q", expectedError, err)
|
||||
}
|
||||
|
||||
// Session should still be returned (it was built successfully before the streaming error)
|
||||
if session == nil {
|
||||
t.Error("Expected session to be returned even when streaming error occurs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) {
|
||||
// Create a temporary database for testing
|
||||
tempDir := t.TempDir()
|
||||
db := fsdb.NewDb(tempDir)
|
||||
|
||||
// Create test chunks that should be aggregated
|
||||
testChunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."}
|
||||
expectedMessage := "Hello world! This is a test."
|
||||
|
||||
// Create a mock vendor that will send chunks successfully
|
||||
mockVendor := &mockVendor{
|
||||
sendStreamError: nil, // No error for successful streaming
|
||||
streamChunks: testChunks,
|
||||
}
|
||||
|
||||
// Create chatter with streaming enabled
|
||||
chatter := &Chatter{
|
||||
db: db,
|
||||
Stream: true, // Enable streaming to trigger SendStream path
|
||||
vendor: mockVendor,
|
||||
model: "test-model",
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
request := &common.ChatRequest{
|
||||
Message: &chat.ChatCompletionMessage{
|
||||
Role: chat.ChatMessageRoleUser,
|
||||
Content: "test message",
|
||||
},
|
||||
}
|
||||
|
||||
// Create test options
|
||||
opts := &common.ChatOptions{
|
||||
Model: "test-model",
|
||||
}
|
||||
|
||||
// Call Send and expect successful aggregation
|
||||
session, err := chatter.Send(request, opts)
|
||||
|
||||
// Verify no error occurred
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, but got: %v", err)
|
||||
}
|
||||
|
||||
// Verify session was returned
|
||||
if session == nil {
|
||||
t.Fatal("Expected session to be returned")
|
||||
}
|
||||
|
||||
// Verify the message was aggregated correctly
|
||||
messages := session.GetVendorMessages()
|
||||
if len(messages) != 2 { // user message + assistant response
|
||||
t.Fatalf("Expected 2 messages, got %d", len(messages))
|
||||
}
|
||||
|
||||
// Check the assistant's response (last message)
|
||||
assistantMessage := messages[len(messages)-1]
|
||||
if assistantMessage.Role != chat.ChatMessageRoleAssistant {
|
||||
t.Errorf("Expected assistant role, got %s", assistantMessage.Role)
|
||||
}
|
||||
|
||||
if assistantMessage.Content != expectedMessage {
|
||||
t.Errorf("Expected aggregated message %q, got %q", expectedMessage, assistantMessage.Content)
|
||||
}
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
"1.4.238"
|
||||
"1.4.240"
|
||||
|
||||
@@ -19,7 +19,7 @@ const webSearchToolName = "web_search"
|
||||
const webSearchToolType = "web_search_20250305"
|
||||
const sourcesHeader = "## Sources"
|
||||
|
||||
const vendorTokenIdentifier = "claude"
|
||||
const authTokenIdentifier = "claude"
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Anthropic"
|
||||
@@ -65,15 +65,15 @@ func (an *Client) IsConfigured() bool {
|
||||
}
|
||||
|
||||
// If no valid token exists, automatically run OAuth flow
|
||||
if !storage.HasValidToken(vendorTokenIdentifier, 5) {
|
||||
if !storage.HasValidToken(authTokenIdentifier, 5) {
|
||||
fmt.Println("OAuth enabled but no valid token found. Starting authentication...")
|
||||
_, err := RunOAuthFlow()
|
||||
_, err := RunOAuthFlow(authTokenIdentifier)
|
||||
if err != nil {
|
||||
fmt.Printf("OAuth authentication failed: %v\n", err)
|
||||
return false
|
||||
}
|
||||
// After successful OAuth flow, check again
|
||||
return storage.HasValidToken("claude", 5)
|
||||
return storage.HasValidToken(authTokenIdentifier, 5)
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -107,9 +107,9 @@ func (an *Client) Setup() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if !storage.HasValidToken("claude", 5) {
|
||||
if !storage.HasValidToken(authTokenIdentifier, 5) {
|
||||
// No valid token, run OAuth flow
|
||||
if _, err = RunOAuthFlow(); err != nil {
|
||||
if _, err = RunOAuthFlow(authTokenIdentifier); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,7 +37,7 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := req.Clone(req.Context())
|
||||
|
||||
// Get current token (may refresh if needed)
|
||||
token, err := t.getValidToken()
|
||||
token, err := t.getValidToken(authTokenIdentifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
|
||||
}
|
||||
@@ -58,21 +58,21 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
|
||||
// getValidToken returns a valid access token, refreshing if necessary
|
||||
func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
func (t *OAuthTransport) getValidToken(tokenIdentifier string) (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load stored token
|
||||
token, err := storage.LoadToken("claude")
|
||||
token, err := storage.LoadToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
// If no token exists, run OAuth flow
|
||||
if token == nil {
|
||||
fmt.Println("No OAuth token found, initiating authentication...")
|
||||
newAccessToken, err := RunOAuthFlow()
|
||||
newAccessToken, err := RunOAuthFlow(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
@@ -82,11 +82,11 @@ func (t *OAuthTransport) getValidToken() (string, error) {
|
||||
// Check if token needs refresh (5 minute buffer)
|
||||
if token.IsExpired(5) {
|
||||
fmt.Println("OAuth token expired, refreshing...")
|
||||
newAccessToken, err := RefreshToken()
|
||||
newAccessToken, err := RefreshToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
// If refresh fails, try re-authentication
|
||||
fmt.Println("Token refresh failed, re-authenticating...")
|
||||
newAccessToken, err = RunOAuthFlow()
|
||||
newAccessToken, err = RunOAuthFlow(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
|
||||
}
|
||||
@@ -129,7 +129,28 @@ func openBrowser(url string) {
|
||||
}
|
||||
|
||||
// RunOAuthFlow executes the complete OAuth authorization flow
|
||||
func RunOAuthFlow() (token string, err error) {
|
||||
func RunOAuthFlow(tokenIdentifier string) (token string, err error) {
|
||||
// First check if we have an existing token that can be refreshed
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err == nil {
|
||||
existingToken, err := storage.LoadToken(tokenIdentifier)
|
||||
if err == nil && existingToken != nil {
|
||||
// If token exists but is expired, try refreshing first
|
||||
if existingToken.IsExpired(5) {
|
||||
fmt.Println("Found expired OAuth token, attempting refresh...")
|
||||
refreshedToken, refreshErr := RefreshToken(tokenIdentifier)
|
||||
if refreshErr == nil {
|
||||
fmt.Println("Token refresh successful")
|
||||
return refreshedToken, nil
|
||||
}
|
||||
fmt.Printf("Token refresh failed (%v), proceeding with full OAuth flow...\n", refreshErr)
|
||||
} else {
|
||||
// Token exists and is still valid
|
||||
return existingToken.AccessToken, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifier, challenge, err := generatePKCE()
|
||||
if err != nil {
|
||||
return
|
||||
@@ -171,12 +192,12 @@ func RunOAuthFlow() (token string, err error) {
|
||||
"code_verifier": verifier,
|
||||
}
|
||||
|
||||
token, err = exchangeToken(tokenReq)
|
||||
token, err = exchangeToken(tokenIdentifier, tokenReq)
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeToken exchanges authorization code for access token
|
||||
func exchangeToken(params map[string]string) (token string, err error) {
|
||||
func exchangeToken(tokenIdentifier string, params map[string]string) (token string, err error) {
|
||||
reqBody, err := json.Marshal(params)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -219,7 +240,7 @@ func exchangeToken(params map[string]string) (token string, err error) {
|
||||
Scope: result.Scope,
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", oauthToken); err != nil {
|
||||
if err = storage.SaveToken(tokenIdentifier, oauthToken); err != nil {
|
||||
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
|
||||
}
|
||||
|
||||
@@ -228,14 +249,14 @@ func exchangeToken(params map[string]string) (token string, err error) {
|
||||
}
|
||||
|
||||
// RefreshToken refreshes an expired OAuth token using the refresh token
|
||||
func RefreshToken() (string, error) {
|
||||
func RefreshToken(tokenIdentifier string) (string, error) {
|
||||
storage, err := common.NewOAuthStorage()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
|
||||
}
|
||||
|
||||
// Load existing token
|
||||
token, err := storage.LoadToken("claude")
|
||||
token, err := storage.LoadToken(tokenIdentifier)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to load stored token: %w", err)
|
||||
}
|
||||
@@ -292,7 +313,7 @@ func RefreshToken() (string, error) {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
if err = storage.SaveToken("claude", newToken); err != nil {
|
||||
if err = storage.SaveToken(tokenIdentifier, newToken); err != nil {
|
||||
return "", fmt.Errorf("failed to save refreshed token: %w", err)
|
||||
}
|
||||
|
||||
|
||||
434
plugins/ai/anthropic/oauth_test.go
Normal file
434
plugins/ai/anthropic/oauth_test.go
Normal file
@@ -0,0 +1,434 @@
|
||||
package anthropic
|
||||
|
||||
// OAuth Testing Strategy:
|
||||
//
|
||||
// This test suite covers OAuth functionality while avoiding real external calls.
|
||||
// Key principles:
|
||||
// 1. Never trigger real OAuth flows that would open browsers or call external APIs
|
||||
// 2. Use temporary directories and mock tokens for isolated testing
|
||||
// 3. Skip integration tests that would require real OAuth servers
|
||||
// 4. Test error paths and edge cases safely
|
||||
//
|
||||
// Tests are categorized as:
|
||||
// - Unit tests: Test individual functions with mocked data (SAFE)
|
||||
// - Integration tests: Would require real OAuth servers (SKIPPED)
|
||||
// - Error path tests: Test failure scenarios safely (SAFE)
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
// createTestToken creates a test OAuth token
|
||||
func createTestToken(accessToken, refreshToken string, expiresIn int64) *common.OAuthToken {
|
||||
return &common.OAuthToken{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: time.Now().Unix() + expiresIn,
|
||||
TokenType: "Bearer",
|
||||
Scope: "org:create_api_key user:profile user:inference",
|
||||
}
|
||||
}
|
||||
|
||||
// createExpiredToken creates an expired test token
|
||||
func createExpiredToken(accessToken, refreshToken string) *common.OAuthToken {
|
||||
return &common.OAuthToken{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: refreshToken,
|
||||
ExpiresAt: time.Now().Unix() - 3600, // Expired 1 hour ago
|
||||
TokenType: "Bearer",
|
||||
Scope: "org:create_api_key user:profile user:inference",
|
||||
}
|
||||
}
|
||||
|
||||
// mockTokenServer creates a mock OAuth token server for testing
|
||||
func mockTokenServer(_ *testing.T, responses map[string]interface{}) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/oauth/token" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var req map[string]string
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := req["grant_type"]
|
||||
response, exists := responses[grantType]
|
||||
if !exists {
|
||||
http.Error(w, "Unsupported grant type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if errorResp, ok := response.(map[string]interface{}); ok && errorResp["error"] != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
}
|
||||
|
||||
func TestGeneratePKCE(t *testing.T) {
|
||||
verifier, challenge, err := generatePKCE()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if verifier == "" {
|
||||
t.Error("Expected non-empty verifier")
|
||||
}
|
||||
|
||||
if challenge == "" {
|
||||
t.Error("Expected non-empty challenge")
|
||||
}
|
||||
|
||||
if len(verifier) < 43 { // Base64 encoded 32 bytes should be at least 43 chars
|
||||
t.Errorf("Verifier too short: %d chars", len(verifier))
|
||||
}
|
||||
|
||||
if len(challenge) < 43 { // SHA256 hash should be at least 43 chars when base64 encoded
|
||||
t.Errorf("Challenge too short: %d chars", len(challenge))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeToken_Success(t *testing.T) {
|
||||
// Create mock server
|
||||
server := mockTokenServer(t, map[string]interface{}{
|
||||
"authorization_code": map[string]interface{}{
|
||||
"access_token": "test_access_token",
|
||||
"refresh_token": "test_refresh_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "org:create_api_key user:profile user:inference",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// Create a temporary directory for token storage
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Mock the storage creation to use our temp directory
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
|
||||
// Set up a fake home directory
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
os.MkdirAll(filepath.Join(fakeHome, ".config", "fabric"), 0755)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// This test would need the actual exchangeToken function to be modified to accept a custom URL
|
||||
// For now, we'll test the logic without the actual HTTP call
|
||||
t.Skip("Skipping integration test - would need URL injection for proper testing")
|
||||
}
|
||||
func TestRefreshToken_Success(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create an expired token
|
||||
expiredToken := createExpiredToken("old_access_token", "valid_refresh_token")
|
||||
|
||||
// Save the expired token
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(expiredToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Create mock server for refresh
|
||||
server := mockTokenServer(t, map[string]interface{}{
|
||||
"refresh_token": map[string]interface{}{
|
||||
"access_token": "new_access_token",
|
||||
"refresh_token": "new_refresh_token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "org:create_api_key user:profile user:inference",
|
||||
},
|
||||
})
|
||||
defer server.Close()
|
||||
|
||||
// This test would need the RefreshToken function to accept a custom URL
|
||||
t.Skip("Skipping integration test - would need URL injection for proper testing")
|
||||
}
|
||||
|
||||
func TestRefreshToken_NoRefreshToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a token without refresh token
|
||||
tokenWithoutRefresh := &common.OAuthToken{
|
||||
AccessToken: "access_token",
|
||||
RefreshToken: "", // No refresh token
|
||||
ExpiresAt: time.Now().Unix() - 3600,
|
||||
TokenType: "Bearer",
|
||||
Scope: "org:create_api_key user:profile user:inference",
|
||||
}
|
||||
|
||||
// Save the token
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(tokenWithoutRefresh, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Test RefreshToken
|
||||
_, err := RefreshToken("test")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when no refresh token available")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no refresh token available") {
|
||||
t.Errorf("Expected 'no refresh token available' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshToken_NoStoredToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Don't create any token file
|
||||
|
||||
// Test RefreshToken
|
||||
_, err := RefreshToken("nonexistent")
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when no token stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthTransport_RoundTrip(t *testing.T) {
|
||||
// Create a mock client
|
||||
client := &Client{}
|
||||
|
||||
// Create the transport
|
||||
transport := NewOAuthTransport(client, http.DefaultTransport)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "https://api.anthropic.com/v1/messages", nil)
|
||||
req.Header.Set("x-api-key", "should-be-removed")
|
||||
|
||||
// Create temporary directory and set up fake home with valid token
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a valid token
|
||||
validToken := createTestToken("valid_access_token", "refresh_token", 3600)
|
||||
tokenPath := filepath.Join(configDir, fmt.Sprintf(".%s_oauth", authTokenIdentifier))
|
||||
data, _ := json.MarshalIndent(validToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Create a mock server to handle the request
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that OAuth headers are set correctly
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != "Bearer valid_access_token" {
|
||||
t.Errorf("Expected 'Bearer valid_access_token', got '%s'", auth)
|
||||
}
|
||||
|
||||
beta := r.Header.Get("anthropic-beta")
|
||||
if beta != "oauth-2025-04-20" {
|
||||
t.Errorf("Expected 'oauth-2025-04-20', got '%s'", beta)
|
||||
}
|
||||
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
if userAgent != "ai-sdk/anthropic" {
|
||||
t.Errorf("Expected 'ai-sdk/anthropic', got '%s'", userAgent)
|
||||
}
|
||||
|
||||
// Check that x-api-key header is removed
|
||||
if r.Header.Get("x-api-key") != "" {
|
||||
t.Error("Expected x-api-key header to be removed")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Update the request URL to point to our mock server
|
||||
req.URL.Host = strings.TrimPrefix(server.URL, "http://")
|
||||
req.URL.Scheme = "http"
|
||||
|
||||
// Execute the request
|
||||
resp, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("RoundTrip failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunOAuthFlow_ExistingValidToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a valid token
|
||||
validToken := createTestToken("existing_valid_token", "refresh_token", 3600)
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(validToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Test RunOAuthFlow - should return existing token without starting OAuth flow
|
||||
token, err := RunOAuthFlow("test")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if token != "existing_valid_token" {
|
||||
t.Errorf("Expected 'existing_valid_token', got '%s'", token)
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func TestCreateTestToken(t *testing.T) {
|
||||
token := createTestToken("access", "refresh", 3600)
|
||||
|
||||
if token.AccessToken != "access" {
|
||||
t.Errorf("Expected access token 'access', got '%s'", token.AccessToken)
|
||||
}
|
||||
|
||||
if token.RefreshToken != "refresh" {
|
||||
t.Errorf("Expected refresh token 'refresh', got '%s'", token.RefreshToken)
|
||||
}
|
||||
|
||||
if token.IsExpired(5) {
|
||||
t.Error("Expected token to not be expired")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateExpiredToken(t *testing.T) {
|
||||
token := createExpiredToken("access", "refresh")
|
||||
|
||||
if !token.IsExpired(5) {
|
||||
t.Error("Expected token to be expired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenExpirationLogic tests the token expiration detection without OAuth flows
|
||||
func TestTokenExpirationLogic(t *testing.T) {
|
||||
// Test valid token
|
||||
validToken := createTestToken("access", "refresh", 3600)
|
||||
if validToken.IsExpired(5) {
|
||||
t.Error("Valid token should not be expired")
|
||||
}
|
||||
|
||||
// Test expired token
|
||||
expiredToken := createExpiredToken("access", "refresh")
|
||||
if !expiredToken.IsExpired(5) {
|
||||
t.Error("Expired token should be expired")
|
||||
}
|
||||
|
||||
// Test token expiring soon (within buffer)
|
||||
soonExpiredToken := createTestToken("access", "refresh", 240) // 4 minutes
|
||||
if !soonExpiredToken.IsExpired(5) { // 5 minute buffer
|
||||
t.Error("Token expiring within buffer should be considered expired")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetValidTokenWithValidToken tests the getValidToken method with a valid token
|
||||
func TestGetValidTokenWithValidToken(t *testing.T) {
|
||||
// Create temporary directory and set up fake home
|
||||
tempDir := t.TempDir()
|
||||
fakeHome := filepath.Join(tempDir, "home")
|
||||
configDir := filepath.Join(fakeHome, ".config", "fabric")
|
||||
os.MkdirAll(configDir, 0755)
|
||||
|
||||
originalHome := os.Getenv("HOME")
|
||||
defer os.Setenv("HOME", originalHome)
|
||||
os.Setenv("HOME", fakeHome)
|
||||
|
||||
// Create a valid token
|
||||
validToken := createTestToken("valid_access_token", "refresh_token", 3600)
|
||||
tokenPath := filepath.Join(configDir, ".test_oauth")
|
||||
data, _ := json.MarshalIndent(validToken, "", " ")
|
||||
os.WriteFile(tokenPath, data, 0600)
|
||||
|
||||
// Create transport
|
||||
client := &Client{}
|
||||
transport := NewOAuthTransport(client, http.DefaultTransport)
|
||||
|
||||
// Test getValidToken - this should return the valid token without any OAuth flow
|
||||
token, err := transport.getValidToken("test")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error with valid token, got: %v", err)
|
||||
}
|
||||
|
||||
if token != "valid_access_token" {
|
||||
t.Errorf("Expected 'valid_access_token', got '%s'", token)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGeneratePKCE(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, err := generatePKCE()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkTokenIsExpired(b *testing.B) {
|
||||
token := createTestToken("access", "refresh", 3600)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
token.IsExpired(5)
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,3 @@
|
||||
package main
|
||||
|
||||
var version = "v1.4.238"
|
||||
var version = "v1.4.240"
|
||||
|
||||
Reference in New Issue
Block a user