package ollama import ( "context" "encoding/base64" "fmt" "io" "net/http" "net/url" "os" "strings" "time" "github.com/danielmiessler/fabric/internal/chat" "github.com/danielmiessler/fabric/internal/domain" debuglog "github.com/danielmiessler/fabric/internal/log" "github.com/danielmiessler/fabric/internal/plugins" ollamaapi "github.com/ollama/ollama/api" ) const defaultBaseUrl = "http://localhost:11434" func NewClient() (ret *Client) { vendorName := "Ollama" ret = &Client{} ret.PluginBase = &plugins.PluginBase{ Name: vendorName, EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName), ConfigureCustom: ret.configure, } ret.ApiUrl = ret.AddSetupQuestionCustom("API URL", true, "Enter your Ollama URL (as a reminder, it is usually http://localhost:11434')") ret.ApiUrl.Value = defaultBaseUrl ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", false) ret.ApiKey.Value = "" ret.ApiHttpTimeout = ret.AddSetupQuestionCustom("HTTP Timeout", true, "Specify HTTP timeout duration for Ollama requests (e.g. 30s, 5m, 1h)") ret.ApiHttpTimeout.Value = "20m" return } type Client struct { *plugins.PluginBase ApiUrl *plugins.SetupQuestion ApiKey *plugins.SetupQuestion apiUrl *url.URL client *ollamaapi.Client ApiHttpTimeout *plugins.SetupQuestion httpClient *http.Client } type transport_sec struct { underlyingTransport http.RoundTripper ApiKey *plugins.SetupQuestion } func (t *transport_sec) RoundTrip(req *http.Request) (*http.Response, error) { if t.ApiKey.Value != "" { req.Header.Add("Authorization", "Bearer "+t.ApiKey.Value) } return t.underlyingTransport.RoundTrip(req) } // IsConfigured returns true only if OLLAMA_API_URL environment variable is explicitly set func (o *Client) IsConfigured() bool { return os.Getenv("OLLAMA_API_URL") != "" } func (o *Client) configure() (err error) { if o.apiUrl, err = url.Parse(o.ApiUrl.Value); err != nil { fmt.Printf("cannot parse URL: %s: %v\n", o.ApiUrl.Value, err) return } timeout := 20 * time.Minute // Default timeout if o.ApiHttpTimeout != nil { parsed, err := time.ParseDuration(o.ApiHttpTimeout.Value) if err == nil && o.ApiHttpTimeout.Value != "" { timeout = parsed } else if o.ApiHttpTimeout.Value != "" { fmt.Printf("Invalid HTTP timeout format (%q), using default (20m): %v\n", o.ApiHttpTimeout.Value, err) } } o.httpClient = &http.Client{Timeout: timeout, Transport: &transport_sec{underlyingTransport: http.DefaultTransport, ApiKey: o.ApiKey}} o.client = ollamaapi.NewClient(o.apiUrl, o.httpClient) return } func (o *Client) ListModels() (ret []string, err error) { ctx := context.Background() var listResp *ollamaapi.ListResponse if listResp, err = o.client.List(ctx); err != nil { return } for _, mod := range listResp.Models { ret = append(ret, mod.Model) } return } func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { ctx := context.Background() var req ollamaapi.ChatRequest if req, err = o.createChatRequest(ctx, msgs, opts); err != nil { return } respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) { channel <- domain.StreamUpdate{ Type: domain.StreamTypeContent, Content: resp.Message.Content, } if resp.Done { channel <- domain.StreamUpdate{ Type: domain.StreamTypeUsage, Usage: &domain.UsageMetadata{ InputTokens: resp.PromptEvalCount, OutputTokens: resp.EvalCount, TotalTokens: resp.PromptEvalCount + resp.EvalCount, }, } } return } if err = o.client.Chat(ctx, &req, respFunc); err != nil { return } close(channel) return } func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) { bf := false var req ollamaapi.ChatRequest if req, err = o.createChatRequest(ctx, msgs, opts); err != nil { return } req.Stream = &bf respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) { ret = resp.Message.Content return } if err = o.client.Chat(ctx, &req, respFunc); err != nil { debuglog.Debug(debuglog.Basic, "Ollama chat request failed: %v\n", err) } return } func (o *Client) createChatRequest(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret ollamaapi.ChatRequest, err error) { messages := make([]ollamaapi.Message, len(msgs)) for i, message := range msgs { if messages[i], err = o.convertMessage(ctx, message); err != nil { return } } options := map[string]any{ "temperature": opts.Temperature, "presence_penalty": opts.PresencePenalty, "frequency_penalty": opts.FrequencyPenalty, "top_p": opts.TopP, } if opts.ModelContextLength != 0 { options["num_ctx"] = opts.ModelContextLength } ret = ollamaapi.ChatRequest{ Model: opts.Model, Messages: messages, Options: options, } return } func (o *Client) convertMessage(ctx context.Context, message *chat.ChatCompletionMessage) (ret ollamaapi.Message, err error) { ret = ollamaapi.Message{Role: message.Role, Content: message.Content} if len(message.MultiContent) == 0 { return } // Pre-allocate with capacity hint textParts := make([]string, 0, len(message.MultiContent)) if strings.TrimSpace(ret.Content) != "" { textParts = append(textParts, strings.TrimSpace(ret.Content)) } for _, part := range message.MultiContent { switch part.Type { case chat.ChatMessagePartTypeText: if trimmed := strings.TrimSpace(part.Text); trimmed != "" { textParts = append(textParts, trimmed) } case chat.ChatMessagePartTypeImageURL: // Nil guard if part.ImageURL == nil || part.ImageURL.URL == "" { continue } var img []byte if img, err = o.loadImageBytes(ctx, part.ImageURL.URL); err != nil { return } ret.Images = append(ret.Images, ollamaapi.ImageData(img)) } } ret.Content = strings.Join(textParts, "\n") return } func (o *Client) loadImageBytes(ctx context.Context, imageURL string) (ret []byte, err error) { // Handle data URLs (base64 encoded) if strings.HasPrefix(imageURL, "data:") { parts := strings.SplitN(imageURL, ",", 2) if len(parts) != 2 { err = fmt.Errorf("invalid data URL format") return } if ret, err = base64.StdEncoding.DecodeString(parts[1]); err != nil { err = fmt.Errorf("failed to decode data URL: %w", err) } return } // Handle HTTP URLs with context var req *http.Request if req, err = http.NewRequestWithContext(ctx, http.MethodGet, imageURL, nil); err != nil { return } var resp *http.Response if resp, err = o.httpClient.Do(req); err != nil { return } defer resp.Body.Close() if resp.StatusCode >= http.StatusBadRequest { err = fmt.Errorf("failed to fetch image %s: %s", imageURL, resp.Status) return } ret, err = io.ReadAll(resp.Body) return } func (o *Client) NeedsRawMode(modelName string) bool { ollamaSearchStrings := []string{ "llama3", "llama2", "mistral", } for _, searchString := range ollamaSearchStrings { if strings.Contains(modelName, searchString) { return true } } return false }