mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-04-24 03:00:15 -04:00
refactor: rewrite Ollama chat handler to support proper streaming responses
- Add `json:"-"` tag to exclude UpdateChan from JSON serialization - Extract URL building logic into dedicated `buildFabricChatURL` helper function - Replace single-read body parsing with streaming `bufio.Scanner` approach - Add proper SSE data prefix parsing for fabric response format - Implement real-time streaming with `writeOllamaResponse` helper function - Add `writeOllamaResponseStruct` for consistent JSON response writing - Handle both streaming and non-streaming response modes separately - Add proper error handling for fabric error response types - Ensure response body is properly closed with defer statement
This commit is contained in:
@@ -53,7 +53,7 @@ type ChatOptions struct {
|
||||
NotificationCommand string
|
||||
ShowMetadata bool
|
||||
Quiet bool
|
||||
UpdateChan chan StreamUpdate
|
||||
UpdateChan chan StreamUpdate `json:"-"`
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package restapi
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -192,11 +194,13 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
|
||||
}
|
||||
ctx := context.Background()
|
||||
var req *http.Request
|
||||
if strings.Contains(*f.addr, "http") {
|
||||
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
|
||||
} else {
|
||||
req, err = http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
|
||||
baseURL, err := buildFabricChatURL(*f.addr)
|
||||
if err != nil {
|
||||
log.Printf("Error building /chat URL: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", baseURL), bytes.NewBuffer(fabricChatReq))
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -209,62 +213,137 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
body, err = io.ReadAll(fabricRes.Body)
|
||||
if err != nil {
|
||||
log.Printf("Error reading body: %v", err)
|
||||
defer fabricRes.Body.Close()
|
||||
var contentBuilder strings.Builder
|
||||
scanner := bufio.NewScanner(fabricRes.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
var fabricResponse FabricResponseFormat
|
||||
if err := json.Unmarshal([]byte(payload), &fabricResponse); err != nil {
|
||||
log.Printf("Error unmarshalling body: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
|
||||
return
|
||||
}
|
||||
if fabricResponse.Type == "error" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fabricResponse.Content})
|
||||
return
|
||||
}
|
||||
if fabricResponse.Type != "content" {
|
||||
continue
|
||||
}
|
||||
contentBuilder.WriteString(fabricResponse.Content)
|
||||
if prompt.Stream {
|
||||
if err := writeOllamaResponse(c, prompt.Model, fabricResponse.Content, false); err != nil {
|
||||
log.Printf("Error writing response: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Printf("Error scanning body: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
|
||||
return
|
||||
}
|
||||
var forwardedResponse OllamaResponse
|
||||
var forwardedResponses []OllamaResponse
|
||||
var fabricResponse FabricResponseFormat
|
||||
err = json.Unmarshal([]byte(strings.Split(strings.Split(string(body), "\n")[0], "data: ")[1]), &fabricResponse)
|
||||
if err != nil {
|
||||
log.Printf("Error unmarshalling body: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
|
||||
return
|
||||
}
|
||||
for word := range strings.SplitSeq(fabricResponse.Content, " ") {
|
||||
forwardedResponse = OllamaResponse{
|
||||
Model: "",
|
||||
CreatedAt: "",
|
||||
|
||||
if !prompt.Stream {
|
||||
response := OllamaResponse{
|
||||
Model: prompt.Model,
|
||||
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
|
||||
Message: struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}(struct {
|
||||
Role string
|
||||
Content string
|
||||
}{Content: fmt.Sprintf("%s ", word), Role: "assistant"}),
|
||||
Done: false,
|
||||
}{Content: contentBuilder.String(), Role: "assistant"}),
|
||||
DoneReason: "stop",
|
||||
Done: true,
|
||||
TotalDuration: time.Since(now).Nanoseconds(),
|
||||
LoadDuration: int(time.Since(now).Nanoseconds()),
|
||||
PromptEvalCount: 42,
|
||||
PromptEvalDuration: int(time.Since(now).Nanoseconds()),
|
||||
EvalCount: 420,
|
||||
EvalDuration: time.Since(now).Nanoseconds(),
|
||||
}
|
||||
forwardedResponses = append(forwardedResponses, forwardedResponse)
|
||||
c.JSON(200, response)
|
||||
return
|
||||
}
|
||||
forwardedResponse.Model = prompt.Model
|
||||
forwardedResponse.CreatedAt = time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z")
|
||||
forwardedResponse.Message.Role = "assistant"
|
||||
forwardedResponse.Message.Content = ""
|
||||
forwardedResponse.DoneReason = "stop"
|
||||
forwardedResponse.Done = true
|
||||
forwardedResponse.TotalDuration = time.Since(now).Nanoseconds()
|
||||
forwardedResponse.LoadDuration = int(time.Since(now).Nanoseconds())
|
||||
forwardedResponse.PromptEvalCount = 42
|
||||
forwardedResponse.PromptEvalDuration = int(time.Since(now).Nanoseconds())
|
||||
forwardedResponse.EvalCount = 420
|
||||
forwardedResponse.EvalDuration = time.Since(now).Nanoseconds()
|
||||
forwardedResponses = append(forwardedResponses, forwardedResponse)
|
||||
|
||||
var res []byte
|
||||
for _, response := range forwardedResponses {
|
||||
marshalled, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
log.Printf("Error marshalling body: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
res = append(res, marshalled...)
|
||||
res = append(res, '\n')
|
||||
finalResponse := OllamaResponse{
|
||||
Model: prompt.Model,
|
||||
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
|
||||
Message: struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}(struct {
|
||||
Role string
|
||||
Content string
|
||||
}{Content: "", Role: "assistant"}),
|
||||
DoneReason: "stop",
|
||||
Done: true,
|
||||
TotalDuration: time.Since(now).Nanoseconds(),
|
||||
LoadDuration: int(time.Since(now).Nanoseconds()),
|
||||
PromptEvalCount: 42,
|
||||
PromptEvalDuration: int(time.Since(now).Nanoseconds()),
|
||||
EvalCount: 420,
|
||||
EvalDuration: time.Since(now).Nanoseconds(),
|
||||
}
|
||||
if err := writeOllamaResponseStruct(c, finalResponse); err != nil {
|
||||
log.Printf("Error writing response: %v", err)
|
||||
}
|
||||
c.Data(200, "application/json", res)
|
||||
|
||||
//c.JSON(200, forwardedResponse)
|
||||
}
|
||||
|
||||
func buildFabricChatURL(addr string) (string, error) {
|
||||
if addr == "" {
|
||||
return "", fmt.Errorf("empty address")
|
||||
}
|
||||
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
||||
parsed, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid address: %w", err)
|
||||
}
|
||||
return strings.TrimRight(parsed.String(), "/"), nil
|
||||
}
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
return fmt.Sprintf("http://127.0.0.1%s", addr), nil
|
||||
}
|
||||
return fmt.Sprintf("http://%s", addr), nil
|
||||
}
|
||||
|
||||
func writeOllamaResponse(c *gin.Context, model string, content string, done bool) error {
|
||||
response := OllamaResponse{
|
||||
Model: model,
|
||||
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
|
||||
Message: struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}(struct {
|
||||
Role string
|
||||
Content string
|
||||
}{Content: content, Role: "assistant"}),
|
||||
Done: done,
|
||||
}
|
||||
return writeOllamaResponseStruct(c, response)
|
||||
}
|
||||
|
||||
func writeOllamaResponseStruct(c *gin.Context, response OllamaResponse) error {
|
||||
marshalled, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.Writer.Write(marshalled); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := c.Writer.Write([]byte("\n")); err != nil {
|
||||
return err
|
||||
}
|
||||
c.Writer.Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user