mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-08 22:08:03 -05:00
376 lines
9.3 KiB
Go
376 lines
9.3 KiB
Go
package lmstudio
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
|
|
"github.com/danielmiessler/fabric/internal/chat"
|
|
|
|
"github.com/danielmiessler/fabric/internal/domain"
|
|
"github.com/danielmiessler/fabric/internal/plugins"
|
|
)
|
|
|
|
// NewClient creates a new LM Studio client with default configuration.
|
|
func NewClient() (ret *Client) {
|
|
return NewClientCompatible("LM Studio", "http://localhost:1234/v1", nil)
|
|
}
|
|
|
|
// NewClientCompatible creates a new LM Studio client with custom configuration.
|
|
func NewClientCompatible(vendorName string, defaultBaseUrl string, configureCustom func() error) (ret *Client) {
|
|
ret = &Client{}
|
|
|
|
if configureCustom == nil {
|
|
configureCustom = ret.configure
|
|
}
|
|
ret.PluginBase = &plugins.PluginBase{
|
|
Name: vendorName,
|
|
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
|
|
ConfigureCustom: configureCustom,
|
|
}
|
|
ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true,
|
|
fmt.Sprintf("Enter your %v URL (as a reminder, it is usually %v')", vendorName, defaultBaseUrl))
|
|
return
|
|
}
|
|
|
|
// Client represents the LM Studio client.
|
|
type Client struct {
|
|
*plugins.PluginBase
|
|
ApiUrl *plugins.SetupQuestion
|
|
HttpClient *http.Client
|
|
}
|
|
|
|
// configure sets up the HTTP client.
|
|
func (c *Client) configure() error {
|
|
c.HttpClient = &http.Client{}
|
|
return nil
|
|
}
|
|
|
|
// ListModels returns a list of available models.
|
|
func (c *Client) ListModels() ([]string, error) {
|
|
url := fmt.Sprintf("%s/models", c.ApiUrl.Value)
|
|
|
|
req, err := http.NewRequest("GET", url, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
|
|
resp, err := c.HttpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
models := make([]string, len(result.Data))
|
|
for i, model := range result.Data {
|
|
models[i] = model.ID
|
|
}
|
|
|
|
return models, nil
|
|
}
|
|
|
|
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
|
|
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
|
|
|
|
payload := map[string]any{
|
|
"messages": msgs,
|
|
"model": opts.Model,
|
|
"stream": true, // Enable streaming
|
|
"stream_options": map[string]any{
|
|
"include_usage": true,
|
|
},
|
|
}
|
|
|
|
var jsonPayload []byte
|
|
if jsonPayload, err = json.Marshal(payload); err != nil {
|
|
err = fmt.Errorf("failed to marshal payload: %w", err)
|
|
return
|
|
}
|
|
|
|
var req *http.Request
|
|
if req, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonPayload)); err != nil {
|
|
err = fmt.Errorf("failed to create request: %w", err)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
var resp *http.Response
|
|
if resp, err = c.HttpClient.Do(req); err != nil {
|
|
err = fmt.Errorf("failed to send request: %w", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
err = fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
defer close(channel)
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
for {
|
|
var line []byte
|
|
if line, err = reader.ReadBytes('\n'); err != nil {
|
|
if err == io.EOF {
|
|
err = nil
|
|
break
|
|
}
|
|
err = fmt.Errorf("error reading response: %w", err)
|
|
return
|
|
}
|
|
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
if after, ok := bytes.CutPrefix(line, []byte("data: ")); ok {
|
|
line = after
|
|
}
|
|
|
|
if string(bytes.TrimSpace(line)) == "[DONE]" {
|
|
break
|
|
}
|
|
|
|
var result map[string]any
|
|
if err = json.Unmarshal(line, &result); err != nil {
|
|
continue
|
|
}
|
|
|
|
// Handle Usage
|
|
if usage, ok := result["usage"].(map[string]any); ok {
|
|
var metadata domain.UsageMetadata
|
|
if val, ok := usage["prompt_tokens"].(float64); ok {
|
|
metadata.InputTokens = int(val)
|
|
}
|
|
if val, ok := usage["completion_tokens"].(float64); ok {
|
|
metadata.OutputTokens = int(val)
|
|
}
|
|
if val, ok := usage["total_tokens"].(float64); ok {
|
|
metadata.TotalTokens = int(val)
|
|
}
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeUsage,
|
|
Usage: &metadata,
|
|
}
|
|
}
|
|
|
|
var choices []any
|
|
var ok bool
|
|
if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 {
|
|
continue
|
|
}
|
|
|
|
var delta map[string]any
|
|
if delta, ok = choices[0].(map[string]any)["delta"].(map[string]any); !ok {
|
|
continue
|
|
}
|
|
|
|
var content string
|
|
if content, _ = delta["content"].(string); content != "" {
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeContent,
|
|
Content: content,
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (content string, err error) {
|
|
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
|
|
|
|
payload := map[string]any{
|
|
"messages": msgs,
|
|
"model": opts.Model,
|
|
// Add other options from opts if supported by LM Studio
|
|
}
|
|
|
|
var jsonPayload []byte
|
|
if jsonPayload, err = json.Marshal(payload); err != nil {
|
|
err = fmt.Errorf("failed to marshal payload: %w", err)
|
|
return
|
|
}
|
|
|
|
var req *http.Request
|
|
if req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload)); err != nil {
|
|
err = fmt.Errorf("failed to create request: %w", err)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
var resp *http.Response
|
|
if resp, err = c.HttpClient.Do(req); err != nil {
|
|
err = fmt.Errorf("failed to send request: %w", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
err = fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
var result map[string]any
|
|
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
err = fmt.Errorf("failed to decode response: %w", err)
|
|
return
|
|
}
|
|
|
|
var choices []any
|
|
var ok bool
|
|
if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 {
|
|
err = fmt.Errorf("invalid response format: missing or empty choices")
|
|
return
|
|
}
|
|
|
|
var message map[string]any
|
|
if message, ok = choices[0].(map[string]any)["message"].(map[string]any); !ok {
|
|
err = fmt.Errorf("invalid response format: missing message in first choice")
|
|
return
|
|
}
|
|
|
|
if content, ok = message["content"].(string); !ok {
|
|
err = fmt.Errorf("invalid response format: missing or non-string content in message")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Client) Complete(ctx context.Context, prompt string, opts *domain.ChatOptions) (text string, err error) {
|
|
url := fmt.Sprintf("%s/completions", c.ApiUrl.Value)
|
|
|
|
payload := map[string]any{
|
|
"prompt": prompt,
|
|
"model": opts.Model,
|
|
// Add other options from opts if supported by LM Studio
|
|
}
|
|
|
|
var jsonPayload []byte
|
|
if jsonPayload, err = json.Marshal(payload); err != nil {
|
|
err = fmt.Errorf("failed to marshal payload: %w", err)
|
|
return
|
|
}
|
|
|
|
var req *http.Request
|
|
if req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload)); err != nil {
|
|
err = fmt.Errorf("failed to create request: %w", err)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
var resp *http.Response
|
|
if resp, err = c.HttpClient.Do(req); err != nil {
|
|
err = fmt.Errorf("failed to send request: %w", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
err = fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
var result map[string]any
|
|
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
err = fmt.Errorf("failed to decode response: %w", err)
|
|
return
|
|
}
|
|
|
|
var choices []any
|
|
var ok bool
|
|
if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 {
|
|
err = fmt.Errorf("invalid response format: missing or empty choices")
|
|
return
|
|
}
|
|
|
|
if text, ok = choices[0].(map[string]any)["text"].(string); !ok {
|
|
err = fmt.Errorf("invalid response format: missing or non-string text in first choice")
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *Client) GetEmbeddings(ctx context.Context, input string, opts *domain.ChatOptions) (embeddings []float64, err error) {
|
|
url := fmt.Sprintf("%s/embeddings", c.ApiUrl.Value)
|
|
|
|
payload := map[string]any{
|
|
"input": input,
|
|
"model": opts.Model,
|
|
// Add other options from opts if supported by LM Studio
|
|
}
|
|
|
|
var jsonPayload []byte
|
|
if jsonPayload, err = json.Marshal(payload); err != nil {
|
|
err = fmt.Errorf("failed to marshal payload: %w", err)
|
|
return
|
|
}
|
|
|
|
var req *http.Request
|
|
if req, err = http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonPayload)); err != nil {
|
|
err = fmt.Errorf("failed to create request: %w", err)
|
|
return
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
var resp *http.Response
|
|
if resp, err = c.HttpClient.Do(req); err != nil {
|
|
err = fmt.Errorf("failed to send request: %w", err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
err = fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
Embedding []float64 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err = json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
err = fmt.Errorf("failed to decode response: %w", err)
|
|
return
|
|
}
|
|
|
|
if len(result.Data) == 0 {
|
|
err = fmt.Errorf("no embeddings returned")
|
|
return
|
|
}
|
|
|
|
embeddings = result.Data[0].Embedding
|
|
return
|
|
}
|
|
|
|
func (c *Client) NeedsRawMode(modelName string) bool {
|
|
return false
|
|
}
|