mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-08 22:08:03 -05:00
275 lines
7.0 KiB
Go
275 lines
7.0 KiB
Go
package ollama
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/danielmiessler/fabric/internal/chat"
|
|
"github.com/danielmiessler/fabric/internal/domain"
|
|
debuglog "github.com/danielmiessler/fabric/internal/log"
|
|
"github.com/danielmiessler/fabric/internal/plugins"
|
|
ollamaapi "github.com/ollama/ollama/api"
|
|
)
|
|
|
|
const defaultBaseUrl = "http://localhost:11434"
|
|
|
|
func NewClient() (ret *Client) {
|
|
vendorName := "Ollama"
|
|
ret = &Client{}
|
|
|
|
ret.PluginBase = &plugins.PluginBase{
|
|
Name: vendorName,
|
|
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
|
|
ConfigureCustom: ret.configure,
|
|
}
|
|
|
|
ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true,
|
|
"Enter your Ollama URL (as a reminder, it is usually http://localhost:11434')")
|
|
ret.ApiUrl.Value = defaultBaseUrl
|
|
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false)
|
|
ret.ApiKey.Value = ""
|
|
ret.ApiHttpTimeout = ret.AddSetupQuestionCustom("HTTP Timeout", true,
|
|
"Specify HTTP timeout duration for Ollama requests (e.g. 30s, 5m, 1h)")
|
|
ret.ApiHttpTimeout.Value = "20m"
|
|
|
|
return
|
|
}
|
|
|
|
type Client struct {
|
|
*plugins.PluginBase
|
|
ApiUrl *plugins.SetupQuestion
|
|
ApiKey *plugins.SetupQuestion
|
|
apiUrl *url.URL
|
|
client *ollamaapi.Client
|
|
ApiHttpTimeout *plugins.SetupQuestion
|
|
httpClient *http.Client
|
|
}
|
|
|
|
type transport_sec struct {
|
|
underlyingTransport http.RoundTripper
|
|
ApiKey *plugins.SetupQuestion
|
|
}
|
|
|
|
func (t *transport_sec) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if t.ApiKey.Value != "" {
|
|
req.Header.Add("Authorization", "Bearer "+t.ApiKey.Value)
|
|
}
|
|
return t.underlyingTransport.RoundTrip(req)
|
|
}
|
|
|
|
// IsConfigured returns true only if OLLAMA_API_URL environment variable is explicitly set
|
|
func (o *Client) IsConfigured() bool {
|
|
return os.Getenv("OLLAMA_API_URL") != ""
|
|
}
|
|
|
|
func (o *Client) configure() (err error) {
|
|
if o.apiUrl, err = url.Parse(o.ApiUrl.Value); err != nil {
|
|
fmt.Printf("cannot parse URL: %s: %v\n", o.ApiUrl.Value, err)
|
|
return
|
|
}
|
|
|
|
timeout := 20 * time.Minute // Default timeout
|
|
|
|
if o.ApiHttpTimeout != nil {
|
|
parsed, err := time.ParseDuration(o.ApiHttpTimeout.Value)
|
|
if err == nil && o.ApiHttpTimeout.Value != "" {
|
|
timeout = parsed
|
|
} else if o.ApiHttpTimeout.Value != "" {
|
|
fmt.Printf("Invalid HTTP timeout format (%q), using default (20m): %v\n", o.ApiHttpTimeout.Value, err)
|
|
}
|
|
}
|
|
|
|
o.httpClient = &http.Client{Timeout: timeout, Transport: &transport_sec{underlyingTransport: http.DefaultTransport, ApiKey: o.ApiKey}}
|
|
o.client = ollamaapi.NewClient(o.apiUrl, o.httpClient)
|
|
|
|
return
|
|
}
|
|
|
|
func (o *Client) ListModels() (ret []string, err error) {
|
|
ctx := context.Background()
|
|
|
|
var listResp *ollamaapi.ListResponse
|
|
if listResp, err = o.client.List(ctx); err != nil {
|
|
return
|
|
}
|
|
|
|
for _, mod := range listResp.Models {
|
|
ret = append(ret, mod.Model)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
|
|
ctx := context.Background()
|
|
|
|
var req ollamaapi.ChatRequest
|
|
if req, err = o.createChatRequest(ctx, msgs, opts); err != nil {
|
|
return
|
|
}
|
|
|
|
respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) {
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeContent,
|
|
Content: resp.Message.Content,
|
|
}
|
|
|
|
if resp.Done {
|
|
channel <- domain.StreamUpdate{
|
|
Type: domain.StreamTypeUsage,
|
|
Usage: &domain.UsageMetadata{
|
|
InputTokens: resp.PromptEvalCount,
|
|
OutputTokens: resp.EvalCount,
|
|
TotalTokens: resp.PromptEvalCount + resp.EvalCount,
|
|
},
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
|
return
|
|
}
|
|
|
|
close(channel)
|
|
return
|
|
}
|
|
|
|
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
|
|
bf := false
|
|
|
|
var req ollamaapi.ChatRequest
|
|
if req, err = o.createChatRequest(ctx, msgs, opts); err != nil {
|
|
return
|
|
}
|
|
req.Stream = &bf
|
|
|
|
respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) {
|
|
ret = resp.Message.Content
|
|
return
|
|
}
|
|
|
|
if err = o.client.Chat(ctx, &req, respFunc); err != nil {
|
|
debuglog.Debug(debuglog.Basic, "Ollama chat request failed: %v\n", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
func (o *Client) createChatRequest(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret ollamaapi.ChatRequest, err error) {
|
|
messages := make([]ollamaapi.Message, len(msgs))
|
|
for i, message := range msgs {
|
|
if messages[i], err = o.convertMessage(ctx, message); err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
options := map[string]any{
|
|
"temperature": opts.Temperature,
|
|
"presence_penalty": opts.PresencePenalty,
|
|
"frequency_penalty": opts.FrequencyPenalty,
|
|
"top_p": opts.TopP,
|
|
}
|
|
|
|
if opts.ModelContextLength != 0 {
|
|
options["num_ctx"] = opts.ModelContextLength
|
|
}
|
|
|
|
ret = ollamaapi.ChatRequest{
|
|
Model: opts.Model,
|
|
Messages: messages,
|
|
Options: options,
|
|
}
|
|
return
|
|
}
|
|
|
|
func (o *Client) convertMessage(ctx context.Context, message *chat.ChatCompletionMessage) (ret ollamaapi.Message, err error) {
|
|
ret = ollamaapi.Message{Role: message.Role, Content: message.Content}
|
|
|
|
if len(message.MultiContent) == 0 {
|
|
return
|
|
}
|
|
|
|
// Pre-allocate with capacity hint
|
|
textParts := make([]string, 0, len(message.MultiContent))
|
|
if strings.TrimSpace(ret.Content) != "" {
|
|
textParts = append(textParts, strings.TrimSpace(ret.Content))
|
|
}
|
|
|
|
for _, part := range message.MultiContent {
|
|
switch part.Type {
|
|
case chat.ChatMessagePartTypeText:
|
|
if trimmed := strings.TrimSpace(part.Text); trimmed != "" {
|
|
textParts = append(textParts, trimmed)
|
|
}
|
|
case chat.ChatMessagePartTypeImageURL:
|
|
// Nil guard
|
|
if part.ImageURL == nil || part.ImageURL.URL == "" {
|
|
continue
|
|
}
|
|
var img []byte
|
|
if img, err = o.loadImageBytes(ctx, part.ImageURL.URL); err != nil {
|
|
return
|
|
}
|
|
ret.Images = append(ret.Images, ollamaapi.ImageData(img))
|
|
}
|
|
}
|
|
|
|
ret.Content = strings.Join(textParts, "\n")
|
|
return
|
|
}
|
|
|
|
func (o *Client) loadImageBytes(ctx context.Context, imageURL string) (ret []byte, err error) {
|
|
// Handle data URLs (base64 encoded)
|
|
if strings.HasPrefix(imageURL, "data:") {
|
|
parts := strings.SplitN(imageURL, ",", 2)
|
|
if len(parts) != 2 {
|
|
err = fmt.Errorf("invalid data URL format")
|
|
return
|
|
}
|
|
if ret, err = base64.StdEncoding.DecodeString(parts[1]); err != nil {
|
|
err = fmt.Errorf("failed to decode data URL: %w", err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Handle HTTP URLs with context
|
|
var req *http.Request
|
|
if req, err = http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil); err != nil {
|
|
return
|
|
}
|
|
|
|
var resp *http.Response
|
|
if resp, err = o.httpClient.Do(req); err != nil {
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode >= http.StatusBadRequest {
|
|
err = fmt.Errorf("failed to fetch image %s: %s", imageURL, resp.Status)
|
|
return
|
|
}
|
|
|
|
ret, err = io.ReadAll(resp.Body)
|
|
return
|
|
}
|
|
|
|
func (o *Client) NeedsRawMode(modelName string) bool {
|
|
ollamaSearchStrings := []string{
|
|
"llama3",
|
|
"llama2",
|
|
"mistral",
|
|
}
|
|
for _, searchString := range ollamaSearchStrings {
|
|
if strings.Contains(modelName, searchString) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|