mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-02-19 10:14:21 -05:00
feat: add optional API key authentication to LM Studio client
- Add optional API key setup question to client configuration - Add `ApiKey` field to the LM Studio `Client` struct - Create `addAuthorizationHeader` helper to attach Bearer token to requests - Apply authorization header to all outgoing HTTP requests - Skip authorization header when API key is empty or unset
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
|
||||
@@ -31,6 +32,7 @@ func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCust
|
||||
ret.PluginBase = plugins.NewVendorPluginBase(vendorName, configureCustom)
|
||||
ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true,
|
||||
fmt.Sprintf(i18n.T("lmstudio_api_url_question"), vendorName, defaultBaseUrl))
|
||||
ret.ApiKey = ret.AddSetupQuestion("API key", false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -38,6 +40,7 @@ func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCust
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiUrl *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
HttpClient *http.Client
|
||||
}
|
||||
|
||||
@@ -55,6 +58,7 @@ func (c *Client) ListModels() ([]string, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(i18n.T("lmstudio_failed_create_request"), err)
|
||||
}
|
||||
c.addAuthorizationHeader(req)
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -109,6 +113,7 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.addAuthorizationHeader(req)
|
||||
|
||||
var resp *http.Response
|
||||
if resp, err = c.HttpClient.Do(req); err != nil {
|
||||
@@ -216,6 +221,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.addAuthorizationHeader(req)
|
||||
|
||||
var resp *http.Response
|
||||
if resp, err = c.HttpClient.Do(req); err != nil {
|
||||
@@ -278,6 +284,7 @@ func (c *Client) Complete(ctx context.Context, prompt string, opts *domain.ChatO
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.addAuthorizationHeader(req)
|
||||
|
||||
var resp *http.Response
|
||||
if resp, err = c.HttpClient.Do(req); err != nil {
|
||||
@@ -334,6 +341,7 @@ func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *domain.C
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.addAuthorizationHeader(req)
|
||||
|
||||
var resp *http.Response
|
||||
if resp, err = c.HttpClient.Do(req); err != nil {
|
||||
@@ -370,3 +378,14 @@ func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *domain.C
|
||||
func (c *Client) NeedsRawMode(modelName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) addAuthorizationHeader(req *http.Request) {
|
||||
if c.ApiKey == nil {
|
||||
return
|
||||
}
|
||||
apiKey := strings.TrimSpace(c.ApiKey.Value)
|
||||
if apiKey == "" {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
92
internal/plugins/ai/lmstudio/lmstudio_test.go
Normal file
92
internal/plugins/ai/lmstudio/lmstudio_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package lmstudio
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListModelsUsesBearerTokenWhenConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "/models", r.URL.Path)
|
||||
require.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":[{"id":"model-1"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client.ApiUrl.Value = server.URL
|
||||
client.ApiKey.Value = "secret"
|
||||
client.HttpClient = server.Client()
|
||||
|
||||
models, err := client.ListModels()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"model-1"}, models)
|
||||
}
|
||||
|
||||
func TestSendEndpointsUseBearerTokenWhenConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, "Bearer secret", r.Header.Get("Authorization"))
|
||||
switch r.URL.Path {
|
||||
case "/chat/completions":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"}}]}`))
|
||||
case "/completions":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"choices":[{"text":"ok"}]}`))
|
||||
case "/embeddings":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":[{"embedding":[1,2]}]}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client.ApiUrl.Value = server.URL
|
||||
client.ApiKey.Value = "secret"
|
||||
client.HttpClient = server.Client()
|
||||
|
||||
msgs := []*chat.ChatCompletionMessage{{Role: chat.ChatMessageRoleUser, Content: "hello"}}
|
||||
opts := &domain.ChatOptions{Model: "test-model"}
|
||||
|
||||
_, err := client.Send(context.Background(), msgs, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.Complete(context.Background(), "hello", opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.GetEmbeddings(context.Background(), "hello", opts)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestListModelsDoesNotSendBearerForWhitespaceOnlyKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Empty(t, r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"data":[{"id":"model-1"}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient()
|
||||
client.ApiUrl.Value = server.URL
|
||||
client.ApiKey.Value = " "
|
||||
client.HttpClient = server.Client()
|
||||
|
||||
models, err := client.ListModels()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []string{"model-1"}, models)
|
||||
}
|
||||
Reference in New Issue
Block a user