mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-10 06:48:04 -05:00
Merge pull request #1302 from verebes1/feat/add-lmstudio
feat: Add LM Studio compatibility
This commit is contained in:
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/danielmiessler/fabric/plugins/ai/dryrun"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/gemini"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/groq"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/lmstudio"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/mistral"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/ollama"
|
||||
"github.com/danielmiessler/fabric/plugins/ai/openai"
|
||||
@@ -54,7 +55,7 @@ func NewPluginRegistry(db *fsdb.Db) (ret *PluginRegistry, err error) {
|
||||
gemini.NewClient(),
|
||||
//gemini_openai.NewClient(),
|
||||
anthropic.NewClient(), siliconcloud.NewClient(),
|
||||
openrouter.NewClient(), mistral.NewClient(), deepseek.NewClient())
|
||||
openrouter.NewClient(), lmstudio.NewClient(), mistral.NewClient(), deepseek.NewClient())
|
||||
_ = ret.Configure()
|
||||
|
||||
return
|
||||
|
||||
358
plugins/ai/lmstudio/lmstudio.go
Normal file
358
plugins/ai/lmstudio/lmstudio.go
Normal file
@@ -0,0 +1,358 @@
|
||||
package lmstudio
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
goopenai "github.com/sashabaranov/go-openai"
|
||||
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
"github.com/danielmiessler/fabric/plugins"
|
||||
)
|
||||
|
||||
// NewClient creates a new LM Studio client with default configuration.
|
||||
func NewClient() (ret *Client) {
|
||||
return NewClientCompatible("LM Studio", "http://localhost:1234/v1", nil)
|
||||
}
|
||||
|
||||
// NewClientCompatible creates a new LM Studio client with custom configuration.
|
||||
func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCustom func() error) (ret *Client) {
|
||||
ret = &Client{}
|
||||
|
||||
if configureCustom == nil {
|
||||
configureCustom = ret.configure
|
||||
}
|
||||
ret.PluginBase = &plugins.PluginBase{
|
||||
Name: vendorName,
|
||||
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
|
||||
ConfigureCustom: configureCustom,
|
||||
}
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
return
|
||||
}
|
||||
|
||||
// Client represents the LM Studio client.
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
HttpClient *http.Client
|
||||
}
|
||||
|
||||
// configure sets up the HTTP client.
|
||||
func (c *Client) configure() error {
|
||||
c.HttpClient = &http.Client{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Configure sets up the client configuration.
|
||||
func (c *Client) Configure() error {
|
||||
return c.ConfigureCustom()
|
||||
}
|
||||
|
||||
// ListModels returns a list of available models.
|
||||
func (c *Client) ListModels() ([]string, error) {
|
||||
url := fmt.Sprintf("%s/models", c.ApiBaseURL.Value)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
models := make([]string, len(result.Data))
|
||||
for i, model := range result.Data {
|
||||
models[i] = model.ID
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// // SendStream sends a stream of messages (not implemented for LM Studio).
|
||||
// func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
// return fmt.Errorf("streaming is not currently supported for LM Studio")
|
||||
// }
|
||||
|
||||
func (c *Client) SendStream(msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions, channel chan string) error {
|
||||
url := fmt.Sprintf("%s/chat/completions", c.ApiBaseURL.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"messages": msgs,
|
||||
"model": opts.Model,
|
||||
"stream": true, // Enable streaming
|
||||
}
|
||||
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Close channel when function exits
|
||||
defer close(channel)
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("error reading response: %w", err)
|
||||
}
|
||||
|
||||
// Ignore empty lines
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove OpenAI-style prefix
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
line = bytes.TrimPrefix(line, []byte("data: "))
|
||||
}
|
||||
|
||||
// Handle [DONE] signal
|
||||
if string(line) == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse JSON response
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(line, &result); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract content from streaming chunks
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
delta, ok := choices[0].(map[string]interface{})["delta"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
content, _ := delta["content"].(string)
|
||||
|
||||
// Send data to channel
|
||||
channel <- content
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send sends a single message and returns the response.
|
||||
func (c *Client) Send(ctx context.Context, msgs []*goopenai.ChatCompletionMessage, opts *common.ChatOptions) (string, error) {
|
||||
url := fmt.Sprintf("%s/chat/completions", c.ApiBaseURL.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"messages": msgs,
|
||||
"model": opts.Model,
|
||||
// Add other options from opts if supported by LM Studio
|
||||
}
|
||||
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return "", fmt.Errorf("invalid response format: missing or empty choices")
|
||||
}
|
||||
|
||||
message, ok := choices[0].(map[string]interface{})["message"].(map[string]interface{})
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: missing message in first choice")
|
||||
}
|
||||
|
||||
content, ok := message["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: missing or non-string content in message")
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// Complete sends a completion request and returns the response.
|
||||
func (c *Client) Complete(ctx context.Context, prompt string, opts *common.ChatOptions) (string, error) {
|
||||
url := fmt.Sprintf("%s/completions", c.ApiBaseURL.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"prompt": prompt,
|
||||
"model": opts.Model,
|
||||
// Add other options from opts if supported by LM Studio
|
||||
}
|
||||
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
choices, ok := result["choices"].([]interface{})
|
||||
if !ok || len(choices) == 0 {
|
||||
return "", fmt.Errorf("invalid response format: missing or empty choices")
|
||||
}
|
||||
|
||||
text, ok := choices[0].(map[string]interface{})["text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("invalid response format: missing or non-string text in first choice")
|
||||
}
|
||||
|
||||
return text, nil
|
||||
}
|
||||
|
||||
// GetEmbeddings returns embeddings for the given input.
|
||||
func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *common.ChatOptions) ([]float64, error) {
|
||||
url := fmt.Sprintf("%s/embeddings", c.ApiBaseURL.Value)
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"input": input,
|
||||
"model": opts.Model,
|
||||
// Add other options from opts if supported by LM Studio
|
||||
}
|
||||
|
||||
jsonPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HttpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no embeddings returned")
|
||||
}
|
||||
|
||||
return result.Data[0].Embedding, nil
|
||||
}
|
||||
|
||||
// GetName returns the name of the vendor.
|
||||
func (c *Client) GetName() string {
|
||||
return c.Name
|
||||
}
|
||||
|
||||
// IsConfigured checks if the client is configured.
|
||||
func (c *Client) IsConfigured() bool {
|
||||
return c.ApiBaseURL != nil && c.ApiBaseURL.Value != ""
|
||||
}
|
||||
|
||||
// Setup performs any necessary setup for the client.
|
||||
func (c *Client) Setup() error {
|
||||
return c.Configure()
|
||||
}
|
||||
|
||||
// SetupFillEnvFileContent fills the environment file content.
|
||||
func (c *Client) SetupFillEnvFileContent(buffer *bytes.Buffer) {
|
||||
envName := fmt.Sprintf("%s_API_BASE_URL", c.EnvNamePrefix)
|
||||
buffer.WriteString(fmt.Sprintf("%s=%s\n", envName, c.ApiBaseURL.Value))
|
||||
}
|
||||
@@ -65,6 +65,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
"openrouter": os.Getenv("OPENROUTER_API_KEY"),
|
||||
"silicon": os.Getenv("SILICON_API_KEY"),
|
||||
"deepseek": os.Getenv("DEEPSEEK_API_KEY"),
|
||||
"lmstudio": os.Getenv("LM_STUDIO_API_BASE_URL"),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, config)
|
||||
@@ -86,6 +87,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
OpenRouterApiKey string `json:"openrouter_api_key"`
|
||||
SiliconApiKey string `json:"silicon_api_key"`
|
||||
DeepSeekApiKey string `json:"deepseek_api_key"`
|
||||
LMStudioURL string `json:"lm_studio_base_url"`
|
||||
}
|
||||
|
||||
if err := c.BindJSON(&config); err != nil {
|
||||
@@ -94,15 +96,16 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
}
|
||||
|
||||
envVars := map[string]string{
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"OPENAI_API_KEY": config.OpenAIApiKey,
|
||||
"ANTHROPIC_API_KEY": config.AnthropicApiKey,
|
||||
"GROQ_API_KEY": config.GroqApiKey,
|
||||
"MISTRAL_API_KEY": config.MistralApiKey,
|
||||
"GEMINI_API_KEY": config.GeminiApiKey,
|
||||
"OLLAMA_URL": config.OllamaURL,
|
||||
"OPENROUTER_API_KEY": config.OpenRouterApiKey,
|
||||
"SILICON_API_KEY": config.SiliconApiKey,
|
||||
"DEEPSEEK_API_KEY": config.DeepSeekApiKey,
|
||||
"LM_STUDIO_API_BASE_URL": config.LMStudioURL,
|
||||
}
|
||||
|
||||
var envContent strings.Builder
|
||||
|
||||
Reference in New Issue
Block a user