refactor: rewrite Ollama chat handler to support proper streaming responses

- Add `json:"-"` tag to exclude UpdateChan from JSON serialization
- Extract URL building logic into dedicated `buildFabricChatURL` helper function
- Replace single-read body parsing with streaming `bufio.Scanner` approach
- Add proper SSE data prefix parsing for fabric response format
- Implement real-time streaming with `writeOllamaResponse` helper function
- Add `writeOllamaResponseStruct` for consistent JSON response writing
- Handle both streaming and non-streaming response modes separately
- Add proper error handling for fabric error response types
- Ensure response body is properly closed with defer statement
This commit is contained in:
Kayvan Sylvan
2026-01-17 00:52:29 -08:00
parent e3c2723988
commit e318a939aa
2 changed files with 127 additions and 48 deletions

View File

@@ -53,7 +53,7 @@ type ChatOptions struct {
NotificationCommand string
ShowMetadata bool
Quiet bool
UpdateChan chan StreamUpdate
UpdateChan chan StreamUpdate `json:"-"`
}
// NormalizeMessages remove empty messages and ensure messages order user-assist-user

View File

@@ -1,6 +1,7 @@
package restapi
import (
"bufio"
"bytes"
"context"
"encoding/json"
@@ -8,6 +9,7 @@ import (
"io"
"log"
"net/http"
"net/url"
"strings"
"time"
@@ -192,11 +194,13 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
}
ctx := context.Background()
var req *http.Request
if strings.Contains(*f.addr, "http") {
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
} else {
req, err = http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
baseURL, err := buildFabricChatURL(*f.addr)
if err != nil {
log.Printf("Error building /chat URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", baseURL), bytes.NewBuffer(fabricChatReq))
if err != nil {
log.Fatal(err)
}
@@ -209,62 +213,137 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
body, err = io.ReadAll(fabricRes.Body)
if err != nil {
log.Printf("Error reading body: %v", err)
defer fabricRes.Body.Close()
var contentBuilder strings.Builder
scanner := bufio.NewScanner(fabricRes.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
payload := strings.TrimPrefix(line, "data: ")
var fabricResponse FabricResponseFormat
if err := json.Unmarshal([]byte(payload), &fabricResponse); err != nil {
log.Printf("Error unmarshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
if fabricResponse.Type == "error" {
c.JSON(http.StatusInternalServerError, gin.H{"error": fabricResponse.Content})
return
}
if fabricResponse.Type != "content" {
continue
}
contentBuilder.WriteString(fabricResponse.Content)
if prompt.Stream {
if err := writeOllamaResponse(c, prompt.Model, fabricResponse.Content, false); err != nil {
log.Printf("Error writing response: %v", err)
return
}
}
}
if err := scanner.Err(); err != nil {
log.Printf("Error scanning body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
var forwardedResponse OllamaResponse
var forwardedResponses []OllamaResponse
var fabricResponse FabricResponseFormat
err = json.Unmarshal([]byte(strings.Split(strings.Split(string(body), "\n")[0], "data: ")[1]), &fabricResponse)
if err != nil {
log.Printf("Error unmarshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
return
}
for word := range strings.SplitSeq(fabricResponse.Content, " ") {
forwardedResponse = OllamaResponse{
Model: "",
CreatedAt: "",
if !prompt.Stream {
response := OllamaResponse{
Model: prompt.Model,
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}(struct {
Role string
Content string
}{Content: fmt.Sprintf("%s ", word), Role: "assistant"}),
Done: false,
}{Content: contentBuilder.String(), Role: "assistant"}),
DoneReason: "stop",
Done: true,
TotalDuration: time.Since(now).Nanoseconds(),
LoadDuration: int(time.Since(now).Nanoseconds()),
PromptEvalCount: 42,
PromptEvalDuration: int(time.Since(now).Nanoseconds()),
EvalCount: 420,
EvalDuration: time.Since(now).Nanoseconds(),
}
forwardedResponses = append(forwardedResponses, forwardedResponse)
c.JSON(200, response)
return
}
forwardedResponse.Model = prompt.Model
forwardedResponse.CreatedAt = time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z")
forwardedResponse.Message.Role = "assistant"
forwardedResponse.Message.Content = ""
forwardedResponse.DoneReason = "stop"
forwardedResponse.Done = true
forwardedResponse.TotalDuration = time.Since(now).Nanoseconds()
forwardedResponse.LoadDuration = int(time.Since(now).Nanoseconds())
forwardedResponse.PromptEvalCount = 42
forwardedResponse.PromptEvalDuration = int(time.Since(now).Nanoseconds())
forwardedResponse.EvalCount = 420
forwardedResponse.EvalDuration = time.Since(now).Nanoseconds()
forwardedResponses = append(forwardedResponses, forwardedResponse)
var res []byte
for _, response := range forwardedResponses {
marshalled, err := json.Marshal(response)
if err != nil {
log.Printf("Error marshalling body: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
res = append(res, marshalled...)
res = append(res, '\n')
finalResponse := OllamaResponse{
Model: prompt.Model,
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}(struct {
Role string
Content string
}{Content: "", Role: "assistant"}),
DoneReason: "stop",
Done: true,
TotalDuration: time.Since(now).Nanoseconds(),
LoadDuration: int(time.Since(now).Nanoseconds()),
PromptEvalCount: 42,
PromptEvalDuration: int(time.Since(now).Nanoseconds()),
EvalCount: 420,
EvalDuration: time.Since(now).Nanoseconds(),
}
if err := writeOllamaResponseStruct(c, finalResponse); err != nil {
log.Printf("Error writing response: %v", err)
}
c.Data(200, "application/json", res)
//c.JSON(200, forwardedResponse)
}
func buildFabricChatURL(addr string) (string, error) {
if addr == "" {
return "", fmt.Errorf("empty address")
}
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
parsed, err := url.Parse(addr)
if err != nil {
return "", fmt.Errorf("invalid address: %w", err)
}
return strings.TrimRight(parsed.String(), "/"), nil
}
if strings.HasPrefix(addr, ":") {
return fmt.Sprintf("http://127.0.0.1%s", addr), nil
}
return fmt.Sprintf("http://%s", addr), nil
}
func writeOllamaResponse(c *gin.Context, model string, content string, done bool) error {
response := OllamaResponse{
Model: model,
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
Message: struct {
Role string `json:"role"`
Content string `json:"content"`
}(struct {
Role string
Content string
}{Content: content, Role: "assistant"}),
Done: done,
}
return writeOllamaResponseStruct(c, response)
}
func writeOllamaResponseStruct(c *gin.Context, response OllamaResponse) error {
marshalled, err := json.Marshal(response)
if err != nil {
return err
}
if _, err := c.Writer.Write(marshalled); err != nil {
return err
}
if _, err := c.Writer.Write([]byte("\n")); err != nil {
return err
}
c.Writer.Flush()
return nil
}