diff --git a/internal/plugins/ai/lmstudio/lmstudio.go b/internal/plugins/ai/lmstudio/lmstudio.go index 169f7672..0b73e7c8 100644 --- a/internal/plugins/ai/lmstudio/lmstudio.go +++ b/internal/plugins/ai/lmstudio/lmstudio.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/danielmiessler/fabric/internal/chat" @@ -31,6 +32,7 @@ func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCust ret.PluginBase = plugins.NewVendorPluginBase(vendorName, configureCustom) ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true, fmt.Sprintf(i18n.T("lmstudio_api_url_question"), vendorName, defaultBaseUrl)) + ret.ApiKey = ret.AddSetupQuestion("API key", false) return } @@ -38,6 +40,7 @@ func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCust type Client struct { *plugins.PluginBase ApiUrl *plugins.SetupQuestion + ApiKey *plugins.SetupQuestion HttpClient *http.Client } @@ -55,6 +58,7 @@ func (c *Client) ListModels() ([]string, error) { if err != nil { return nil, fmt.Errorf(i18n.T("lmstudio_failed_create_request"), err) } + c.addAuthorizationHeader(req) resp, err := c.HttpClient.Do(req) if err != nil { @@ -109,6 +113,7 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha } req.Header.Set("Content-Type", "application/json") + c.addAuthorizationHeader(req) var resp *http.Response if resp, err = c.HttpClient.Do(req); err != nil { @@ -216,6 +221,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o } req.Header.Set("Content-Type", "application/json") + c.addAuthorizationHeader(req) var resp *http.Response if resp, err = c.HttpClient.Do(req); err != nil { @@ -278,6 +284,7 @@ func (c *Client) Complete(ctx context.Context, prompt string, opts *domain.ChatO } req.Header.Set("Content-Type", "application/json") + c.addAuthorizationHeader(req) var resp *http.Response if resp, err = c.HttpClient.Do(req); err != nil { @@ -334,6 +341,7 @@ func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *domain.C } req.Header.Set("Content-Type", "application/json") + c.addAuthorizationHeader(req) var resp *http.Response if resp, err = c.HttpClient.Do(req); err != nil { @@ -370,3 +378,14 @@ func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *domain.C func (c *Client) NeedsRawMode(modelName string) bool { return false } + +func (c *Client) addAuthorizationHeader(req *http.Request) { + if c.ApiKey == nil { + return + } + apiKey := strings.TrimSpace(c.ApiKey.Value) + if apiKey == "" { + return + } + req.Header.Set("Authorization", "Bearer "+apiKey) +} diff --git a/internal/plugins/ai/lmstudio/lmstudio_test.go b/internal/plugins/ai/lmstudio/lmstudio_test.go new file mode 100644 index 00000000..2bb1d96b --- /dev/null +++ b/internal/plugins/ai/lmstudio/lmstudio_test.go @@ -0,0 +1,92 @@ +package lmstudio + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/danielmiessler/fabric/internal/chat" + "github.com/danielmiessler/fabric/internal/domain" + "github.com/stretchr/testify/require" +) + +func TestListModelsUsesBearerTokenWhenConfigured(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/models", r.URL.Path) + require.Equal(t, "Bearer secret", r.Header.Get("Authorization")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"id":"model-1"}]}`)) + })) + defer server.Close() + + client := NewClient() + client.ApiUrl.Value = server.URL + client.ApiKey.Value = "secret" + client.HttpClient = server.Client() + + models, err := client.ListModels() + require.NoError(t, err) + require.Equal(t, []string{"model-1"}, models) +} + +func TestSendEndpointsUseBearerTokenWhenConfigured(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "Bearer secret", r.Header.Get("Authorization")) + switch r.URL.Path { + case "/chat/completions": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"}}]}`)) + case "/completions": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"choices":[{"text":"ok"}]}`)) + case "/embeddings": + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"embedding":[1,2]}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := NewClient() + client.ApiUrl.Value = server.URL + client.ApiKey.Value = "secret" + client.HttpClient = server.Client() + + msgs := []*chat.ChatCompletionMessage{{Role: chat.ChatMessageRoleUser, Content: "hello"}} + opts := &domain.ChatOptions{Model: "test-model"} + + _, err := client.Send(context.Background(), msgs, opts) + require.NoError(t, err) + + _, err = client.Complete(context.Background(), "hello", opts) + require.NoError(t, err) + + _, err = client.GetEmbeddings(context.Background(), "hello", opts) + require.NoError(t, err) +} + +func TestListModelsDoesNotSendBearerForWhitespaceOnlyKey(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Empty(t, r.Header.Get("Authorization")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[{"id":"model-1"}]}`)) + })) + defer server.Close() + + client := NewClient() + client.ApiUrl.Value = server.URL + client.ApiKey.Value = " " + client.HttpClient = server.Client() + + models, err := client.ListModels() + require.NoError(t, err) + require.Equal(t, []string{"model-1"}, models) +}