Compare commits

...

20 Commits

Author SHA1 Message Date
github-actions[bot]
265f2b807e Update version to v1.4.240 and commit 2025-07-07 21:25:55 +00:00
Kayvan Sylvan
dc63e0d1cc Merge pull request #1593 from ksylvan/0707-claude-oauth-improvement
Refactor: Generalize OAuth flow for improved token handling.
2025-07-07 14:24:22 -07:00
Kayvan Sylvan
75842d8610 chore: refactor token path to use authTokenIdentifier 2025-07-07 13:59:13 -07:00
Kayvan Sylvan
bcd4c6caea test: add comprehensive OAuth testing suite for Anthropic plugin
## CHANGES

- Add OAuth test file with 434 lines coverage
- Create mock token server for safe testing
- Implement PKCE generation and validation tests
- Add token expiration logic verification tests
- Create OAuth transport round-trip testing
- Add benchmark tests for performance validation
- Implement helper functions for test token creation
- Add comprehensive error path testing scenarios
2025-07-07 13:50:57 -07:00
Kayvan Sylvan
a6a63698e1 fix: update RefreshToken to use tokenIdentifier parameter 2025-07-07 13:31:08 -07:00
Kayvan Sylvan
0528556b5c refactor: replace hardcoded "claude" with configurable authTokenIdentifier parameter
## CHANGES

- Replace hardcoded "claude" string with `authTokenIdentifier` constant
- Update `RunOAuthFlow` to accept token identifier parameter
- Modify `RefreshToken` to use configurable token identifier
- Update `exchangeToken` to accept token identifier parameter
- Enhance `getValidToken` to use parameterized token identifier
- Add token refresh attempt before full OAuth flow
- Improve OAuth flow with existing token validation
2025-07-07 13:19:00 -07:00
github-actions[bot]
47cf24e19d Update version to v1.4.239 and commit 2025-07-07 18:58:38 +00:00
Kayvan Sylvan
3f07afbef4 Merge pull request #1592 from ksylvan/0707-possible-go-routine-race-condition-fix
Fix Streaming Error Handling in Chatter
2025-07-07 11:57:11 -07:00
Kayvan Sylvan
38d714dccd chore: improve error comparison in TestChatter_Send_StreamingErrorPropagation 2025-07-07 11:20:01 -07:00
Kayvan Sylvan
d0b5c95d61 chore: remove redundant channel closure in Send method
### CHANGES

- Remove redundant `close(responseChan)` in `Send` method
- Update `SendStream` to close `responseChan` properly
- Modify test to reflect channel closure logic
2025-07-07 11:02:04 -07:00
Kayvan Sylvan
f8f80ca206 chore: rename doneChan to done and add streaming aggregation test
## CHANGES

- Rename `doneChan` variable to `done` for consistency
- Add `streamChunks` field to mock vendor struct
- Implement chunk sending logic in mock SendStream method
- Add comprehensive streaming success aggregation test case
- Verify message aggregation from multiple stream chunks
- Test assistant response role and content validation
- Ensure proper session handling in streaming scenarios
2025-07-07 10:49:29 -07:00
Kayvan Sylvan
0af458872f feat: add test for Chatter's Send method error propagation
### CHANGES

- Implement mockVendor for testing ai.Vendor interface
- Add TestChatter_Send_StreamingErrorPropagation test case
- Verify error propagation in Chatter's Send method
- Ensure session returns even on streaming error
- Create temporary database for testing Chatter functionality
2025-07-07 10:36:40 -07:00
Kayvan Sylvan
24e46a6f37 chore: rename channels for clarity in Send method
### CHANGES

- Rename `done` to `doneChan` for clarity
- Adjust channel closure for `doneChan`
- Update channel listening logic to use `doneChan`
2025-07-07 10:28:54 -07:00
Kayvan Sylvan
d6a31e68b0 refactor: rename channel variable to responseChan for better clarity in streaming logic
## CHANGES

- Rename `channel` variable to `responseChan` for clarity
- Update channel references in goroutine defer statements
- Pass renamed channel to `SendStream` method call
- Maintain consistent naming throughout streaming flow
2025-07-07 10:23:42 -07:00
Kayvan Sylvan
b1013ca61b chore: close channel after sending stream in Send
### CHANGES

- Add `channel` closure after sending stream
- Ensure resource cleanup in `Send` method
2025-07-07 10:09:24 -07:00
Kayvan Sylvan
6b4ce946a5 chore: refactor error handling and response aggregation in Send
### CHANGES

- Simplify response aggregation loop in `Send`
- Remove redundant select case for closed channel
- Streamline error checking from `errChan`
- Ensure goroutine completion before returning
2025-07-07 09:39:58 -07:00
Kayvan Sylvan
2d2830e9c8 chore: enhance Chatter.Send method with proper goroutine synchronization
### CHANGES
- Add `done` channel to track goroutine completion.
- Replace `errChan` closure with `done` channel closure.
- Ensure main loop waits for goroutine on channel close.
- Synchronize error handling with `done` channel wait.
2025-07-07 09:09:04 -07:00
Kayvan Sylvan
115327fdab refactor: use select to handle stream and error channels concurrently
### CHANGES

- Replace for-range loop with a non-blocking select statement.
- Process message and error channels concurrently for better handling.
- Improve the robustness of streaming error detection.
- Exit loop cleanly when the message channel closes.
2025-07-07 08:37:31 -07:00
Kayvan Sylvan
e672f9b73f chore: simplify error handling in streaming chat response by removing unnecessary select statement 2025-07-07 08:15:24 -07:00
Kayvan Sylvan
ef4364a1aa fix: improve error handling in streaming chat functionality
## CHANGES

- Add dedicated error channel for stream operations
- Separate error handling from message streaming logic
- Check for streaming errors after channel closure
- Close error channel properly in goroutine cleanup
- Remove error messages from message stream channel
- Add proper error propagation for stream failures
2025-07-07 03:31:58 -07:00
7 changed files with 679 additions and 25 deletions

View File

@@ -66,17 +66,35 @@ func (o *Chatter) Send(request *common.ChatRequest, opts *common.ChatOptions) (s
message := ""
if o.Stream {
channel := make(chan string)
responseChan := make(chan string)
errChan := make(chan error, 1)
done := make(chan struct{})
go func() {
if streamErr := o.vendor.SendStream(session.GetVendorMessages(), opts, channel); streamErr != nil {
channel <- streamErr.Error()
defer close(done)
if streamErr := o.vendor.SendStream(session.GetVendorMessages(), opts, responseChan); streamErr != nil {
errChan <- streamErr
}
}()
for response := range channel {
for response := range responseChan {
message += response
fmt.Print(response)
}
// Wait for goroutine to finish
<-done
// Check for errors in errChan
select {
case streamErr := <-errChan:
if streamErr != nil {
err = streamErr
return
}
default:
// No errors, continue
}
} else {
if message, err = o.vendor.Send(context.Background(), session.GetVendorMessages(), opts); err != nil {
return

181
core/chatter_test.go Normal file
View File

@@ -0,0 +1,181 @@
package core
import (
"bytes"
"context"
"errors"
"testing"
"github.com/danielmiessler/fabric/chat"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins/db/fsdb"
)
// mockVendor implements the ai.Vendor interface for testing
type mockVendor struct {
sendStreamError error
streamChunks []string
}
func (m *mockVendor) GetName() string {
return "mock"
}
func (m *mockVendor) GetSetupDescription() string {
return "mock vendor"
}
func (m *mockVendor) IsConfigured() bool {
return true
}
func (m *mockVendor) Configure() error {
return nil
}
func (m *mockVendor) Setup() error {
return nil
}
func (m *mockVendor) SetupFillEnvFileContent(*bytes.Buffer) {
}
func (m *mockVendor) ListModels() ([]string, error) {
return []string{"test-model"}, nil
}
func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *common.ChatOptions, responseChan chan string) error {
// Send chunks if provided (for successful streaming test)
if m.streamChunks != nil {
for _, chunk := range m.streamChunks {
responseChan <- chunk
}
}
// Close the channel like real vendors do
close(responseChan)
return m.sendStreamError
}
func (m *mockVendor) Send(ctx context.Context, messages []*chat.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
return "test response", nil
}
func (m *mockVendor) NeedsRawMode(modelName string) bool {
return false
}
func TestChatter_Send_StreamingErrorPropagation(t *testing.T) {
// Create a temporary database for testing
tempDir := t.TempDir()
db := fsdb.NewDb(tempDir)
// Create a mock vendor that will return an error from SendStream
expectedError := errors.New("streaming error")
mockVendor := &mockVendor{
sendStreamError: expectedError,
}
// Create chatter with streaming enabled
chatter := &Chatter{
db: db,
Stream: true, // Enable streaming to trigger SendStream path
vendor: mockVendor,
model: "test-model",
}
// Create a test request
request := &common.ChatRequest{
Message: &chat.ChatCompletionMessage{
Role: chat.ChatMessageRoleUser,
Content: "test message",
},
}
// Create test options
opts := &common.ChatOptions{
Model: "test-model",
}
// Call Send and expect it to return the streaming error
session, err := chatter.Send(request, opts)
// Verify that the error from SendStream is propagated
if err == nil {
t.Fatal("Expected error to be returned, but got nil")
}
if !errors.Is(err, expectedError) {
t.Errorf("Expected error %q, but got %q", expectedError, err)
}
// Session should still be returned (it was built successfully before the streaming error)
if session == nil {
t.Error("Expected session to be returned even when streaming error occurs")
}
}
func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) {
// Create a temporary database for testing
tempDir := t.TempDir()
db := fsdb.NewDb(tempDir)
// Create test chunks that should be aggregated
testChunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."}
expectedMessage := "Hello world! This is a test."
// Create a mock vendor that will send chunks successfully
mockVendor := &mockVendor{
sendStreamError: nil, // No error for successful streaming
streamChunks: testChunks,
}
// Create chatter with streaming enabled
chatter := &Chatter{
db: db,
Stream: true, // Enable streaming to trigger SendStream path
vendor: mockVendor,
model: "test-model",
}
// Create a test request
request := &common.ChatRequest{
Message: &chat.ChatCompletionMessage{
Role: chat.ChatMessageRoleUser,
Content: "test message",
},
}
// Create test options
opts := &common.ChatOptions{
Model: "test-model",
}
// Call Send and expect successful aggregation
session, err := chatter.Send(request, opts)
// Verify no error occurred
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}
// Verify session was returned
if session == nil {
t.Fatal("Expected session to be returned")
}
// Verify the message was aggregated correctly
messages := session.GetVendorMessages()
if len(messages) != 2 { // user message + assistant response
t.Fatalf("Expected 2 messages, got %d", len(messages))
}
// Check the assistant's response (last message)
assistantMessage := messages[len(messages)-1]
if assistantMessage.Role != chat.ChatMessageRoleAssistant {
t.Errorf("Expected assistant role, got %s", assistantMessage.Role)
}
if assistantMessage.Content != expectedMessage {
t.Errorf("Expected aggregated message %q, got %q", expectedMessage, assistantMessage.Content)
}
}

View File

@@ -1 +1 @@
"1.4.238"
"1.4.240"

View File

@@ -19,7 +19,7 @@ const webSearchToolName = "web_search"
const webSearchToolType = "web_search_20250305"
const sourcesHeader = "## Sources"
const vendorTokenIdentifier = "claude"
const authTokenIdentifier = "claude"
func NewClient() (ret *Client) {
vendorName := "Anthropic"
@@ -65,15 +65,15 @@ func (an *Client) IsConfigured() bool {
}
// If no valid token exists, automatically run OAuth flow
if !storage.HasValidToken(vendorTokenIdentifier, 5) {
if !storage.HasValidToken(authTokenIdentifier, 5) {
fmt.Println("OAuth enabled but no valid token found. Starting authentication...")
_, err := RunOAuthFlow()
_, err := RunOAuthFlow(authTokenIdentifier)
if err != nil {
fmt.Printf("OAuth authentication failed: %v\n", err)
return false
}
// After successful OAuth flow, check again
return storage.HasValidToken("claude", 5)
return storage.HasValidToken(authTokenIdentifier, 5)
}
return true
@@ -107,9 +107,9 @@ func (an *Client) Setup() (err error) {
return err
}
if !storage.HasValidToken("claude", 5) {
if !storage.HasValidToken(authTokenIdentifier, 5) {
// No valid token, run OAuth flow
if _, err = RunOAuthFlow(); err != nil {
if _, err = RunOAuthFlow(authTokenIdentifier); err != nil {
return err
}
}

View File

@@ -37,7 +37,7 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
newReq := req.Clone(req.Context())
// Get current token (may refresh if needed)
token, err := t.getValidToken()
token, err := t.getValidToken(authTokenIdentifier)
if err != nil {
return nil, fmt.Errorf("failed to get valid OAuth token: %w", err)
}
@@ -58,21 +58,21 @@ func (t *OAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
}
// getValidToken returns a valid access token, refreshing if necessary
func (t *OAuthTransport) getValidToken() (string, error) {
func (t *OAuthTransport) getValidToken(tokenIdentifier string) (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load stored token
token, err := storage.LoadToken("claude")
token, err := storage.LoadToken(tokenIdentifier)
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
// If no token exists, run OAuth flow
if token == nil {
fmt.Println("No OAuth token found, initiating authentication...")
newAccessToken, err := RunOAuthFlow()
newAccessToken, err := RunOAuthFlow(tokenIdentifier)
if err != nil {
return "", fmt.Errorf("failed to authenticate: %w", err)
}
@@ -82,11 +82,11 @@ func (t *OAuthTransport) getValidToken() (string, error) {
// Check if token needs refresh (5 minute buffer)
if token.IsExpired(5) {
fmt.Println("OAuth token expired, refreshing...")
newAccessToken, err := RefreshToken()
newAccessToken, err := RefreshToken(tokenIdentifier)
if err != nil {
// If refresh fails, try re-authentication
fmt.Println("Token refresh failed, re-authenticating...")
newAccessToken, err = RunOAuthFlow()
newAccessToken, err = RunOAuthFlow(tokenIdentifier)
if err != nil {
return "", fmt.Errorf("failed to refresh or re-authenticate: %w", err)
}
@@ -129,7 +129,28 @@ func openBrowser(url string) {
}
// RunOAuthFlow executes the complete OAuth authorization flow
func RunOAuthFlow() (token string, err error) {
func RunOAuthFlow(tokenIdentifier string) (token string, err error) {
// First check if we have an existing token that can be refreshed
storage, err := common.NewOAuthStorage()
if err == nil {
existingToken, err := storage.LoadToken(tokenIdentifier)
if err == nil && existingToken != nil {
// If token exists but is expired, try refreshing first
if existingToken.IsExpired(5) {
fmt.Println("Found expired OAuth token, attempting refresh...")
refreshedToken, refreshErr := RefreshToken(tokenIdentifier)
if refreshErr == nil {
fmt.Println("Token refresh successful")
return refreshedToken, nil
}
fmt.Printf("Token refresh failed (%v), proceeding with full OAuth flow...\n", refreshErr)
} else {
// Token exists and is still valid
return existingToken.AccessToken, nil
}
}
}
verifier, challenge, err := generatePKCE()
if err != nil {
return
@@ -171,12 +192,12 @@ func RunOAuthFlow() (token string, err error) {
"code_verifier": verifier,
}
token, err = exchangeToken(tokenReq)
token, err = exchangeToken(tokenIdentifier, tokenReq)
return
}
// exchangeToken exchanges authorization code for access token
func exchangeToken(params map[string]string) (token string, err error) {
func exchangeToken(tokenIdentifier string, params map[string]string) (token string, err error) {
reqBody, err := json.Marshal(params)
if err != nil {
return
@@ -219,7 +240,7 @@ func exchangeToken(params map[string]string) (token string, err error) {
Scope: result.Scope,
}
if err = storage.SaveToken("claude", oauthToken); err != nil {
if err = storage.SaveToken(tokenIdentifier, oauthToken); err != nil {
return result.AccessToken, fmt.Errorf("failed to save OAuth token: %w", err)
}
@@ -228,14 +249,14 @@ func exchangeToken(params map[string]string) (token string, err error) {
}
// RefreshToken refreshes an expired OAuth token using the refresh token
func RefreshToken() (string, error) {
func RefreshToken(tokenIdentifier string) (string, error) {
storage, err := common.NewOAuthStorage()
if err != nil {
return "", fmt.Errorf("failed to create OAuth storage: %w", err)
}
// Load existing token
token, err := storage.LoadToken("claude")
token, err := storage.LoadToken(tokenIdentifier)
if err != nil {
return "", fmt.Errorf("failed to load stored token: %w", err)
}
@@ -292,7 +313,7 @@ func RefreshToken() (string, error) {
newToken.RefreshToken = token.RefreshToken
}
if err = storage.SaveToken("claude", newToken); err != nil {
if err = storage.SaveToken(tokenIdentifier, newToken); err != nil {
return "", fmt.Errorf("failed to save refreshed token: %w", err)
}

View File

@@ -0,0 +1,434 @@
package anthropic
// OAuth Testing Strategy:
//
// This test suite covers OAuth functionality while avoiding real external calls.
// Key principles:
// 1. Never trigger real OAuth flows that would open browsers or call external APIs
// 2. Use temporary directories and mock tokens for isolated testing
// 3. Skip integration tests that would require real OAuth servers
// 4. Test error paths and edge cases safely
//
// Tests are categorized as:
// - Unit tests: Test individual functions with mocked data (SAFE)
// - Integration tests: Would require real OAuth servers (SKIPPED)
// - Error path tests: Test failure scenarios safely (SAFE)
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/danielmiessler/fabric/common"
)
// createTestToken creates a test OAuth token
func createTestToken(accessToken, refreshToken string, expiresIn int64) *common.OAuthToken {
return &common.OAuthToken{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: time.Now().Unix() + expiresIn,
TokenType: "Bearer",
Scope: "org:create_api_key user:profile user:inference",
}
}
// createExpiredToken creates an expired test token
func createExpiredToken(accessToken, refreshToken string) *common.OAuthToken {
return &common.OAuthToken{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresAt: time.Now().Unix() - 3600, // Expired 1 hour ago
TokenType: "Bearer",
Scope: "org:create_api_key user:profile user:inference",
}
}
// mockTokenServer creates a mock OAuth token server for testing
func mockTokenServer(_ *testing.T, responses map[string]interface{}) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/oauth/token" {
http.NotFound(w, r)
return
}
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read body", http.StatusBadRequest)
return
}
var req map[string]string
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
grantType := req["grant_type"]
response, exists := responses[grantType]
if !exists {
http.Error(w, "Unsupported grant type", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
if errorResp, ok := response.(map[string]interface{}); ok && errorResp["error"] != nil {
w.WriteHeader(http.StatusBadRequest)
}
json.NewEncoder(w).Encode(response)
}))
}
func TestGeneratePKCE(t *testing.T) {
verifier, challenge, err := generatePKCE()
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if verifier == "" {
t.Error("Expected non-empty verifier")
}
if challenge == "" {
t.Error("Expected non-empty challenge")
}
if len(verifier) < 43 { // Base64 encoded 32 bytes should be at least 43 chars
t.Errorf("Verifier too short: %d chars", len(verifier))
}
if len(challenge) < 43 { // SHA256 hash should be at least 43 chars when base64 encoded
t.Errorf("Challenge too short: %d chars", len(challenge))
}
}
func TestExchangeToken_Success(t *testing.T) {
// Create mock server
server := mockTokenServer(t, map[string]interface{}{
"authorization_code": map[string]interface{}{
"access_token": "test_access_token",
"refresh_token": "test_refresh_token",
"expires_in": 3600,
"token_type": "Bearer",
"scope": "org:create_api_key user:profile user:inference",
},
})
defer server.Close()
// Create a temporary directory for token storage
tempDir := t.TempDir()
// Mock the storage creation to use our temp directory
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
// Set up a fake home directory
fakeHome := filepath.Join(tempDir, "home")
os.MkdirAll(filepath.Join(fakeHome, ".config", "fabric"), 0755)
os.Setenv("HOME", fakeHome)
// This test would need the actual exchangeToken function to be modified to accept a custom URL
// For now, we'll test the logic without the actual HTTP call
t.Skip("Skipping integration test - would need URL injection for proper testing")
}
func TestRefreshToken_Success(t *testing.T) {
// Create temporary directory and set up fake home
tempDir := t.TempDir()
fakeHome := filepath.Join(tempDir, "home")
configDir := filepath.Join(fakeHome, ".config", "fabric")
os.MkdirAll(configDir, 0755)
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
os.Setenv("HOME", fakeHome)
// Create an expired token
expiredToken := createExpiredToken("old_access_token", "valid_refresh_token")
// Save the expired token
tokenPath := filepath.Join(configDir, ".test_oauth")
data, _ := json.MarshalIndent(expiredToken, "", " ")
os.WriteFile(tokenPath, data, 0600)
// Create mock server for refresh
server := mockTokenServer(t, map[string]interface{}{
"refresh_token": map[string]interface{}{
"access_token": "new_access_token",
"refresh_token": "new_refresh_token",
"expires_in": 3600,
"token_type": "Bearer",
"scope": "org:create_api_key user:profile user:inference",
},
})
defer server.Close()
// This test would need the RefreshToken function to accept a custom URL
t.Skip("Skipping integration test - would need URL injection for proper testing")
}
func TestRefreshToken_NoRefreshToken(t *testing.T) {
// Create temporary directory and set up fake home
tempDir := t.TempDir()
fakeHome := filepath.Join(tempDir, "home")
configDir := filepath.Join(fakeHome, ".config", "fabric")
os.MkdirAll(configDir, 0755)
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
os.Setenv("HOME", fakeHome)
// Create a token without refresh token
tokenWithoutRefresh := &common.OAuthToken{
AccessToken: "access_token",
RefreshToken: "", // No refresh token
ExpiresAt: time.Now().Unix() - 3600,
TokenType: "Bearer",
Scope: "org:create_api_key user:profile user:inference",
}
// Save the token
tokenPath := filepath.Join(configDir, ".test_oauth")
data, _ := json.MarshalIndent(tokenWithoutRefresh, "", " ")
os.WriteFile(tokenPath, data, 0600)
// Test RefreshToken
_, err := RefreshToken("test")
if err == nil {
t.Error("Expected error when no refresh token available")
}
if !strings.Contains(err.Error(), "no refresh token available") {
t.Errorf("Expected 'no refresh token available' error, got: %v", err)
}
}
func TestRefreshToken_NoStoredToken(t *testing.T) {
// Create temporary directory and set up fake home
tempDir := t.TempDir()
fakeHome := filepath.Join(tempDir, "home")
configDir := filepath.Join(fakeHome, ".config", "fabric")
os.MkdirAll(configDir, 0755)
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
os.Setenv("HOME", fakeHome)
// Don't create any token file
// Test RefreshToken
_, err := RefreshToken("nonexistent")
if err == nil {
t.Error("Expected error when no token stored")
}
}
func TestOAuthTransport_RoundTrip(t *testing.T) {
// Create a mock client
client := &Client{}
// Create the transport
transport := NewOAuthTransport(client, http.DefaultTransport)
// Create a test request
req := httptest.NewRequest("GET", "https://api.anthropic.com/v1/messages", nil)
req.Header.Set("x-api-key", "should-be-removed")
// Create temporary directory and set up fake home with valid token
tempDir := t.TempDir()
fakeHome := filepath.Join(tempDir, "home")
configDir := filepath.Join(fakeHome, ".config", "fabric")
os.MkdirAll(configDir, 0755)
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
os.Setenv("HOME", fakeHome)
// Create a valid token
validToken := createTestToken("valid_access_token", "refresh_token", 3600)
tokenPath := filepath.Join(configDir, fmt.Sprintf(".%s_oauth", authTokenIdentifier))
data, _ := json.MarshalIndent(validToken, "", " ")
os.WriteFile(tokenPath, data, 0600)
// Create a mock server to handle the request
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check that OAuth headers are set correctly
auth := r.Header.Get("Authorization")
if auth != "Bearer valid_access_token" {
t.Errorf("Expected 'Bearer valid_access_token', got '%s'", auth)
}
beta := r.Header.Get("anthropic-beta")
if beta != "oauth-2025-04-20" {
t.Errorf("Expected 'oauth-2025-04-20', got '%s'", beta)
}
userAgent := r.Header.Get("User-Agent")
if userAgent != "ai-sdk/anthropic" {
t.Errorf("Expected 'ai-sdk/anthropic', got '%s'", userAgent)
}
// Check that x-api-key header is removed
if r.Header.Get("x-api-key") != "" {
t.Error("Expected x-api-key header to be removed")
}
w.WriteHeader(http.StatusOK)
w.Write([]byte("success"))
}))
defer server.Close()
// Update the request URL to point to our mock server
req.URL.Host = strings.TrimPrefix(server.URL, "http://")
req.URL.Scheme = "http"
// Execute the request
resp, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}
}
func TestRunOAuthFlow_ExistingValidToken(t *testing.T) {
// Create temporary directory and set up fake home
tempDir := t.TempDir()
fakeHome := filepath.Join(tempDir, "home")
configDir := filepath.Join(fakeHome, ".config", "fabric")
os.MkdirAll(configDir, 0755)
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
os.Setenv("HOME", fakeHome)
// Create a valid token
validToken := createTestToken("existing_valid_token", "refresh_token", 3600)
tokenPath := filepath.Join(configDir, ".test_oauth")
data, _ := json.MarshalIndent(validToken, "", " ")
os.WriteFile(tokenPath, data, 0600)
// Test RunOAuthFlow - should return existing token without starting OAuth flow
token, err := RunOAuthFlow("test")
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if token != "existing_valid_token" {
t.Errorf("Expected 'existing_valid_token', got '%s'", token)
}
}
// Test helper functions
func TestCreateTestToken(t *testing.T) {
token := createTestToken("access", "refresh", 3600)
if token.AccessToken != "access" {
t.Errorf("Expected access token 'access', got '%s'", token.AccessToken)
}
if token.RefreshToken != "refresh" {
t.Errorf("Expected refresh token 'refresh', got '%s'", token.RefreshToken)
}
if token.IsExpired(5) {
t.Error("Expected token to not be expired")
}
}
func TestCreateExpiredToken(t *testing.T) {
token := createExpiredToken("access", "refresh")
if !token.IsExpired(5) {
t.Error("Expected token to be expired")
}
}
// TestTokenExpirationLogic tests the token expiration detection without OAuth flows
func TestTokenExpirationLogic(t *testing.T) {
// Test valid token
validToken := createTestToken("access", "refresh", 3600)
if validToken.IsExpired(5) {
t.Error("Valid token should not be expired")
}
// Test expired token
expiredToken := createExpiredToken("access", "refresh")
if !expiredToken.IsExpired(5) {
t.Error("Expired token should be expired")
}
// Test token expiring soon (within buffer)
soonExpiredToken := createTestToken("access", "refresh", 240) // 4 minutes
if !soonExpiredToken.IsExpired(5) { // 5 minute buffer
t.Error("Token expiring within buffer should be considered expired")
}
}
// TestGetValidTokenWithValidToken tests the getValidToken method with a valid token
func TestGetValidTokenWithValidToken(t *testing.T) {
// Create temporary directory and set up fake home
tempDir := t.TempDir()
fakeHome := filepath.Join(tempDir, "home")
configDir := filepath.Join(fakeHome, ".config", "fabric")
os.MkdirAll(configDir, 0755)
originalHome := os.Getenv("HOME")
defer os.Setenv("HOME", originalHome)
os.Setenv("HOME", fakeHome)
// Create a valid token
validToken := createTestToken("valid_access_token", "refresh_token", 3600)
tokenPath := filepath.Join(configDir, ".test_oauth")
data, _ := json.MarshalIndent(validToken, "", " ")
os.WriteFile(tokenPath, data, 0600)
// Create transport
client := &Client{}
transport := NewOAuthTransport(client, http.DefaultTransport)
// Test getValidToken - this should return the valid token without any OAuth flow
token, err := transport.getValidToken("test")
if err != nil {
t.Fatalf("Expected no error with valid token, got: %v", err)
}
if token != "valid_access_token" {
t.Errorf("Expected 'valid_access_token', got '%s'", token)
}
}
// Benchmark tests
func BenchmarkGeneratePKCE(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, err := generatePKCE()
if err != nil {
b.Fatal(err)
}
}
}
func BenchmarkTokenIsExpired(b *testing.B) {
token := createTestToken("access", "refresh", 3600)
b.ResetTimer()
for i := 0; i < b.N; i++ {
token.IsExpired(5)
}
}

View File

@@ -1,3 +1,3 @@
package main
var version = "v1.4.238"
var version = "v1.4.240"