Merge pull request #1302 from verebes1/feat/add-lmstudio

feat: Add LM Studio compatibility
This commit is contained in:
Eugen Eisler
2025-02-24 20:58:36 +01:00
committed by GitHub
3 changed files with 372 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/danielmiessler/fabric/plugins/ai/dryrun"
"github.com/danielmiessler/fabric/plugins/ai/gemini"
"github.com/danielmiessler/fabric/plugins/ai/groq"
"github.com/danielmiessler/fabric/plugins/ai/lmstudio"
"github.com/danielmiessler/fabric/plugins/ai/mistral"
"github.com/danielmiessler/fabric/plugins/ai/ollama"
"github.com/danielmiessler/fabric/plugins/ai/openai"
@@ -54,7 +55,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) {
gemini.NewClient(),
//gemini_openai.NewClient(),
anthropic.NewClient(), siliconcloud.NewClient(),
openrouter.NewClient(), mistral.NewClient(), deepseek.NewClient())
openrouter.NewClient(), lmstudio.NewClient(), mistral.NewClient(), deepseek.NewClient())
_ = ret.Configure()
return

View File

@@ -0,0 +1,358 @@
package lmstudio
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
goopenai "github.com/sashabaranov/go-openai"
"github.com/danielmiessler/fabric/common"
"github.com/danielmiessler/fabric/plugins"
)
// NewClient creates a new LM Studio client with default configuration.
func NewClient() (ret *Client) {
return NewClientCompatible("LM Studio", "http://localhost:1234/v1", nil)
}
// NewClientCompatible creates a new LM Studio client with custom configuration.
func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCustom func() error) (ret *Client) {
ret = &Client{}
if configureCustom == nil {
configureCustom = ret.configure
}
ret.PluginBase = &plugins.PluginBase{
Name: vendorName,
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
ConfigureCustom: configureCustom,
}
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
ret.ApiBaseURL.Value = defaultBaseUrl
return
}
// Client represents the LM Studio client.
type Client struct {
*plugins.PluginBase
ApiBaseURL *plugins.SetupQuestion
HttpClient *http.Client
}
// configure sets up the HTTP client.
func (c *Client) configure() error {
c.HttpClient = &http.Client{}
return nil
}
// Configure sets up the client configuration.
func (c *Client) Configure() error {
return c.ConfigureCustom()
}
// ListModels returns a list of available models.
func (c *Client) ListModels() ([]string, error) {
url := fmt.Sprintf("%s/models", c.ApiBaseURL.Value)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
models := make([]string, len(result.Data))
for i, model := range result.Data {
models[i] = model.ID
}
return models, nil
}
// // SendStream sends a stream of messages (not implemented for LM Studio).
// func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
// return fmt.Errorf("streaming is not currently supported for LM Studio")
// }
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
url := fmt.Sprintf("%s/chat/completions", c.ApiBaseURL.Value)
payload := map[string]interface{}{
"messages": msgs,
"model": opts.Model,
"stream": true, // Enable streaming
}
jsonPayload, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.HttpClient.Do(req)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
// Close channel when function exits
defer close(channel)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("error reading response: %w", err)
}
// Ignore empty lines
if len(line) == 0 {
continue
}
// Remove OpenAI-style prefix
if bytes.HasPrefix(line, []byte("data: ")) {
line = bytes.TrimPrefix(line, []byte("data: "))
}
// Handle [DONE] signal
if string(line) == "[DONE]" {
break
}
// Parse JSON response
var result map[string]interface{}
if err := json.Unmarshal(line, &result); err != nil {
continue
}
// Extract content from streaming chunks
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
delta, ok := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
if !ok {
continue
}
content, _ := delta["content"].(string)
// Send data to channel
channel <- content
}
return nil
}
// Send sends a single message and returns the response.
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
url := fmt.Sprintf("%s/chat/completions", c.ApiBaseURL.Value)
payload := map[string]interface{}{
"messages": msgs,
"model": opts.Model,
// Add other options from opts if supported by LM Studio
}
jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.HttpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return "", fmt.Errorf("invalid response format: missing or empty choices")
}
message, ok := choices[0].(map[string]interface{})["message"].(map[string]interface{})
if !ok {
return "", fmt.Errorf("invalid response format: missing message in first choice")
}
content, ok := message["content"].(string)
if !ok {
return "", fmt.Errorf("invalid response format: missing or non-string content in message")
}
return content, nil
}
// Complete sends a completion request and returns the response.
func (c *Client) Complete(ctx context.Context, prompt string, opts *common.ChatOptions) (string, error) {
url := fmt.Sprintf("%s/completions", c.ApiBaseURL.Value)
payload := map[string]interface{}{
"prompt": prompt,
"model": opts.Model,
// Add other options from opts if supported by LM Studio
}
jsonPayload, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.HttpClient.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("failed to decode response: %w", err)
}
choices, ok := result["choices"].([]interface{})
if !ok || len(choices) == 0 {
return "", fmt.Errorf("invalid response format: missing or empty choices")
}
text, ok := choices[0].(map[string]interface{})["text"].(string)
if !ok {
return "", fmt.Errorf("invalid response format: missing or non-string text in first choice")
}
return text, nil
}
// GetEmbeddings returns embeddings for the given input.
func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *common.ChatOptions) ([]float64, error) {
url := fmt.Sprintf("%s/embeddings", c.ApiBaseURL.Value)
payload := map[string]interface{}{
"input": input,
"model": opts.Model,
// Add other options from opts if supported by LM Studio
}
jsonPayload, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.HttpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result struct {
Data []struct {
Embedding []float64 `json:"embedding"`
} `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no embeddings returned")
}
return result.Data[0].Embedding, nil
}
// GetName returns the name of the vendor.
func (c *Client) GetName() string {
return c.Name
}
// IsConfigured checks if the client is configured.
func (c *Client) IsConfigured() bool {
return c.ApiBaseURL != nil && c.ApiBaseURL.Value != ""
}
// Setup performs any necessary setup for the client.
func (c *Client) Setup() error {
return c.Configure()
}
// SetupFillEnvFileContent fills the environment file content.
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
envName := fmt.Sprintf("%s_API_BASE_URL", c.EnvNamePrefix)
buffer.WriteString(fmt.Sprintf("%s=%s\n", envName, c.ApiBaseURL.Value))
}

View File

@@ -65,6 +65,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
"silicon": os.Getenv("SILICON_API_KEY"),
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
}
c.JSON(http.StatusOK, config)
@@ -86,6 +87,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
OpenRouterApiKey string `json:"openrouter_api_key"`
SiliconApiKey string `json:"silicon_api_key"`
DeepSeekApiKey string `json:"deepseek_api_key"`
LMStudioURL string `json:"lm_studio_base_url"`
}
if err := c.BindJSON(&config); err != nil {
@@ -94,15 +96,16 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
}
envVars := map[string]string{
"OPENAI_API_KEY": config.OpenAIApiKey,
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
"GROQ_API_KEY": config.GroqApiKey,
"MISTRAL_API_KEY": config.MistralApiKey,
"GEMINI_API_KEY": config.GeminiApiKey,
"OLLAMA_URL": config.OllamaURL,
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
"SILICON_API_KEY": config.SiliconApiKey,
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
"OPENAI_API_KEY": config.OpenAIApiKey,
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
"GROQ_API_KEY": config.GroqApiKey,
"MISTRAL_API_KEY": config.MistralApiKey,
"GEMINI_API_KEY": config.GeminiApiKey,
"OLLAMA_URL": config.OllamaURL,
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
"SILICON_API_KEY": config.SiliconApiKey,
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
}
var envContent strings.Builder