|
|
|
|
@@ -1,13 +1,14 @@
|
|
|
|
|
package restapi
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bufio"
|
|
|
|
|
"bytes"
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"log"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/url"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
@@ -43,11 +44,11 @@ type APIConvert struct {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type OllamaRequestBody struct {
|
|
|
|
|
Messages []OllamaMessage `json:"messages"`
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
Options struct {
|
|
|
|
|
} `json:"options"`
|
|
|
|
|
Stream bool `json:"stream"`
|
|
|
|
|
Messages []OllamaMessage `json:"messages"`
|
|
|
|
|
Model string `json:"model"`
|
|
|
|
|
Options map[string]any `json:"options,omitempty"`
|
|
|
|
|
Stream bool `json:"stream"`
|
|
|
|
|
Variables map[string]string `json:"variables,omitempty"` // Fabric-specific: pattern variables (direct)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type OllamaMessage struct {
|
|
|
|
|
@@ -65,10 +66,10 @@ type OllamaResponse struct {
|
|
|
|
|
DoneReason string `json:"done_reason,omitempty"`
|
|
|
|
|
Done bool `json:"done"`
|
|
|
|
|
TotalDuration int64 `json:"total_duration,omitempty"`
|
|
|
|
|
LoadDuration int `json:"load_duration,omitempty"`
|
|
|
|
|
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
|
|
|
|
|
PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
|
|
|
|
|
EvalCount int `json:"eval_count,omitempty"`
|
|
|
|
|
LoadDuration int64 `json:"load_duration,omitempty"`
|
|
|
|
|
PromptEvalCount int64 `json:"prompt_eval_count,omitempty"`
|
|
|
|
|
PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"`
|
|
|
|
|
EvalCount int64 `json:"eval_count,omitempty"`
|
|
|
|
|
EvalDuration int64 `json:"eval_duration,omitempty"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -163,6 +164,29 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
|
|
|
|
|
now := time.Now()
|
|
|
|
|
var chat ChatRequest
|
|
|
|
|
|
|
|
|
|
// Extract variables from either top-level Variables field or Options.variables
|
|
|
|
|
variables := prompt.Variables
|
|
|
|
|
if variables == nil && prompt.Options != nil {
|
|
|
|
|
if optVars, ok := prompt.Options["variables"]; ok {
|
|
|
|
|
// Options.variables can be either a JSON string or a map
|
|
|
|
|
switch v := optVars.(type) {
|
|
|
|
|
case string:
|
|
|
|
|
// Parse JSON string into map
|
|
|
|
|
if err := json.Unmarshal([]byte(v), &variables); err != nil {
|
|
|
|
|
log.Printf("Warning: failed to parse options.variables as JSON: %v", err)
|
|
|
|
|
}
|
|
|
|
|
case map[string]any:
|
|
|
|
|
// Convert map[string]any to map[string]string
|
|
|
|
|
variables = make(map[string]string)
|
|
|
|
|
for k, val := range v {
|
|
|
|
|
if s, ok := val.(string); ok {
|
|
|
|
|
variables[k] = s
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(prompt.Messages) == 1 {
|
|
|
|
|
chat.Prompts = []PromptRequest{{
|
|
|
|
|
UserInput: prompt.Messages[0].Content,
|
|
|
|
|
@@ -170,6 +194,7 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
|
|
|
|
|
Model: "",
|
|
|
|
|
ContextName: "",
|
|
|
|
|
PatternName: strings.Split(prompt.Model, ":")[0],
|
|
|
|
|
Variables: variables,
|
|
|
|
|
}}
|
|
|
|
|
} else if len(prompt.Messages) > 1 {
|
|
|
|
|
var content string
|
|
|
|
|
@@ -182,89 +207,242 @@ func (f APIConvert) ollamaChat(c *gin.Context) {
|
|
|
|
|
Model: "",
|
|
|
|
|
ContextName: "",
|
|
|
|
|
PatternName: strings.Split(prompt.Model, ":")[0],
|
|
|
|
|
Variables: variables,
|
|
|
|
|
}}
|
|
|
|
|
}
|
|
|
|
|
fabricChatReq, err := json.Marshal(chat)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Error marshalling body: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
ctx := context.Background()
|
|
|
|
|
var req *http.Request
|
|
|
|
|
if strings.Contains(*f.addr, "http") {
|
|
|
|
|
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
|
|
|
|
|
} else {
|
|
|
|
|
req, err = http.NewRequest("POST", fmt.Sprintf("http://127.0.0.1%s/chat", *f.addr), bytes.NewBuffer(fabricChatReq))
|
|
|
|
|
}
|
|
|
|
|
baseURL, err := buildFabricChatURL(*f.addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
log.Printf("Error building /chat URL: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
req, err = http.NewRequest("POST", fmt.Sprintf("%s/chat", baseURL), bytes.NewBuffer(fabricChatReq))
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Error creating /chat request: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
req = req.WithContext(ctx)
|
|
|
|
|
req = req.WithContext(c.Request.Context())
|
|
|
|
|
|
|
|
|
|
fabricRes, err := http.DefaultClient.Do(req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Error getting /chat body: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
body, err = io.ReadAll(fabricRes.Body)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Error reading body: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
var forwardedResponse OllamaResponse
|
|
|
|
|
var forwardedResponses []OllamaResponse
|
|
|
|
|
var fabricResponse FabricResponseFormat
|
|
|
|
|
err = json.Unmarshal([]byte(strings.Split(strings.Split(string(body), "\n")[0], "data: ")[1]), &fabricResponse)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Error unmarshalling body: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "testing endpoint"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
for word := range strings.SplitSeq(fabricResponse.Content, " ") {
|
|
|
|
|
forwardedResponse = OllamaResponse{
|
|
|
|
|
Model: "",
|
|
|
|
|
CreatedAt: "",
|
|
|
|
|
Message: struct {
|
|
|
|
|
Role string `json:"role"`
|
|
|
|
|
Content string `json:"content"`
|
|
|
|
|
}(struct {
|
|
|
|
|
Role string
|
|
|
|
|
Content string
|
|
|
|
|
}{Content: fmt.Sprintf("%s ", word), Role: "assistant"}),
|
|
|
|
|
Done: false,
|
|
|
|
|
}
|
|
|
|
|
forwardedResponses = append(forwardedResponses, forwardedResponse)
|
|
|
|
|
}
|
|
|
|
|
forwardedResponse.Model = prompt.Model
|
|
|
|
|
forwardedResponse.CreatedAt = time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z")
|
|
|
|
|
forwardedResponse.Message.Role = "assistant"
|
|
|
|
|
forwardedResponse.Message.Content = ""
|
|
|
|
|
forwardedResponse.DoneReason = "stop"
|
|
|
|
|
forwardedResponse.Done = true
|
|
|
|
|
forwardedResponse.TotalDuration = time.Since(now).Nanoseconds()
|
|
|
|
|
forwardedResponse.LoadDuration = int(time.Since(now).Nanoseconds())
|
|
|
|
|
forwardedResponse.PromptEvalCount = 42
|
|
|
|
|
forwardedResponse.PromptEvalDuration = int(time.Since(now).Nanoseconds())
|
|
|
|
|
forwardedResponse.EvalCount = 420
|
|
|
|
|
forwardedResponse.EvalDuration = time.Since(now).Nanoseconds()
|
|
|
|
|
forwardedResponses = append(forwardedResponses, forwardedResponse)
|
|
|
|
|
defer fabricRes.Body.Close()
|
|
|
|
|
|
|
|
|
|
var res []byte
|
|
|
|
|
for _, response := range forwardedResponses {
|
|
|
|
|
marshalled, err := json.Marshal(response)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("Error marshalling body: %v", err)
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err})
|
|
|
|
|
if fabricRes.StatusCode < http.StatusOK || fabricRes.StatusCode >= http.StatusMultipleChoices {
|
|
|
|
|
bodyBytes, readErr := io.ReadAll(fabricRes.Body)
|
|
|
|
|
if readErr != nil {
|
|
|
|
|
log.Printf("Upstream Fabric server returned non-2xx status %d and body could not be read: %v", fabricRes.StatusCode, readErr)
|
|
|
|
|
} else {
|
|
|
|
|
log.Printf("Upstream Fabric server returned non-2xx status %d: %s", fabricRes.StatusCode, string(bodyBytes))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
errorMessage := fmt.Sprintf("upstream Fabric server returned status %d", fabricRes.StatusCode)
|
|
|
|
|
if prompt.Stream {
|
|
|
|
|
_ = writeOllamaResponse(c, prompt.Model, fmt.Sprintf("Error: %s", errorMessage), true)
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(fabricRes.StatusCode, gin.H{"error": errorMessage})
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if prompt.Stream {
|
|
|
|
|
c.Header("Content-Type", "application/x-ndjson")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var contentBuilder strings.Builder
|
|
|
|
|
scanner := bufio.NewScanner(fabricRes.Body)
|
|
|
|
|
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
|
|
|
|
|
for scanner.Scan() {
|
|
|
|
|
line := scanner.Text()
|
|
|
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
payload := strings.TrimPrefix(line, "data: ")
|
|
|
|
|
var fabricResponse FabricResponseFormat
|
|
|
|
|
if err := json.Unmarshal([]byte(payload), &fabricResponse); err != nil {
|
|
|
|
|
log.Printf("Error unmarshalling body: %v", err)
|
|
|
|
|
if prompt.Stream {
|
|
|
|
|
// In streaming mode, send the error in the same streaming format
|
|
|
|
|
_ = writeOllamaResponse(c, prompt.Model, "Error: failed to parse upstream response", true)
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal Fabric response"})
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
res = append(res, marshalled...)
|
|
|
|
|
res = append(res, '\n')
|
|
|
|
|
if fabricResponse.Type == "error" {
|
|
|
|
|
if prompt.Stream {
|
|
|
|
|
// In streaming mode, propagate the upstream error via a final streaming chunk
|
|
|
|
|
_ = writeOllamaResponse(c, prompt.Model, fmt.Sprintf("Error: %s", fabricResponse.Content), true)
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fabricResponse.Content})
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if fabricResponse.Type != "content" {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
contentBuilder.WriteString(fabricResponse.Content)
|
|
|
|
|
if prompt.Stream {
|
|
|
|
|
if err := writeOllamaResponse(c, prompt.Model, fabricResponse.Content, false); err != nil {
|
|
|
|
|
log.Printf("Error writing response: %v", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
|
|
|
log.Printf("Error scanning body: %v", err)
|
|
|
|
|
errorMsg := fmt.Sprintf("failed to scan SSE response stream: %v", err)
|
|
|
|
|
// Check for buffer size exceeded error
|
|
|
|
|
if strings.Contains(err.Error(), "token too long") {
|
|
|
|
|
errorMsg = "SSE line exceeds 1MB buffer limit - data line too large"
|
|
|
|
|
}
|
|
|
|
|
if prompt.Stream {
|
|
|
|
|
// In streaming mode, send the error in the same streaming format
|
|
|
|
|
_ = writeOllamaResponse(c, prompt.Model, fmt.Sprintf("Error: %s", errorMsg), true)
|
|
|
|
|
} else {
|
|
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
|
|
|
|
|
}
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.Data(200, "application/json", res)
|
|
|
|
|
|
|
|
|
|
//c.JSON(200, forwardedResponse)
|
|
|
|
|
// Capture duration once for consistent timing values
|
|
|
|
|
duration := time.Since(now).Nanoseconds()
|
|
|
|
|
|
|
|
|
|
// Check if we received any content from upstream
|
|
|
|
|
if contentBuilder.Len() == 0 {
|
|
|
|
|
log.Printf("Warning: no content received from upstream Fabric server")
|
|
|
|
|
// In non-streaming mode, treat absence of content as an error
|
|
|
|
|
if !prompt.Stream {
|
|
|
|
|
c.JSON(http.StatusBadGateway, gin.H{"error": "no content received from upstream Fabric server"})
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !prompt.Stream {
|
|
|
|
|
response := buildFinalOllamaResponse(prompt.Model, contentBuilder.String(), duration)
|
|
|
|
|
c.JSON(200, response)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
finalResponse := buildFinalOllamaResponse(prompt.Model, "", duration)
|
|
|
|
|
if err := writeOllamaResponseStruct(c, finalResponse); err != nil {
|
|
|
|
|
log.Printf("Error writing response: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// buildFinalOllamaResponse constructs the final OllamaResponse with timing metrics
|
|
|
|
|
// and the complete message content. Used for both streaming and non-streaming final responses.
|
|
|
|
|
func buildFinalOllamaResponse(model string, content string, duration int64) OllamaResponse {
|
|
|
|
|
return OllamaResponse{
|
|
|
|
|
Model: model,
|
|
|
|
|
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
|
|
|
|
|
Message: struct {
|
|
|
|
|
Role string `json:"role"`
|
|
|
|
|
Content string `json:"content"`
|
|
|
|
|
}(struct {
|
|
|
|
|
Role string
|
|
|
|
|
Content string
|
|
|
|
|
}{Content: content, Role: "assistant"}),
|
|
|
|
|
DoneReason: "stop",
|
|
|
|
|
Done: true,
|
|
|
|
|
TotalDuration: duration,
|
|
|
|
|
LoadDuration: duration,
|
|
|
|
|
PromptEvalDuration: duration,
|
|
|
|
|
EvalDuration: duration,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// buildFabricChatURL constructs a valid HTTP/HTTPS base URL from various address
|
|
|
|
|
// formats. It accepts fully-qualified URLs (http:// or https://), :port shorthand
|
|
|
|
|
// which is resolved to http://127.0.0.1:port, and bare host[:port] addresses. It
|
|
|
|
|
// returns a normalized URL string without a trailing slash, or an error if the
|
|
|
|
|
// address is empty, invalid, missing a host/hostname, or (for bare addresses)
|
|
|
|
|
// contains a path component.
|
|
|
|
|
func buildFabricChatURL(addr string) (string, error) {
|
|
|
|
|
if addr == "" {
|
|
|
|
|
return "", fmt.Errorf("empty address")
|
|
|
|
|
}
|
|
|
|
|
if strings.HasPrefix(addr, "http://") || strings.HasPrefix(addr, "https://") {
|
|
|
|
|
parsed, err := url.Parse(addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("invalid address: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if parsed.Host == "" {
|
|
|
|
|
return "", fmt.Errorf("invalid address: missing host")
|
|
|
|
|
}
|
|
|
|
|
if strings.HasPrefix(parsed.Host, ":") {
|
|
|
|
|
return "", fmt.Errorf("invalid address: missing hostname")
|
|
|
|
|
}
|
|
|
|
|
return strings.TrimRight(parsed.String(), "/"), nil
|
|
|
|
|
}
|
|
|
|
|
if strings.HasPrefix(addr, ":") {
|
|
|
|
|
return fmt.Sprintf("http://127.0.0.1%s", addr), nil
|
|
|
|
|
}
|
|
|
|
|
// Validate bare addresses (without http/https prefix)
|
|
|
|
|
parsed, err := url.Parse("http://" + addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", fmt.Errorf("invalid address: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if parsed.Host == "" {
|
|
|
|
|
return "", fmt.Errorf("invalid address: missing host")
|
|
|
|
|
}
|
|
|
|
|
if strings.HasPrefix(parsed.Host, ":") {
|
|
|
|
|
return "", fmt.Errorf("invalid address: missing hostname")
|
|
|
|
|
}
|
|
|
|
|
// Bare addresses should be host[:port] only - reject path components
|
|
|
|
|
if parsed.Path != "" && parsed.Path != "/" {
|
|
|
|
|
return "", fmt.Errorf("invalid address: path component not allowed in bare address")
|
|
|
|
|
}
|
|
|
|
|
return strings.TrimRight(parsed.String(), "/"), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// writeOllamaResponse constructs an Ollama-formatted response chunk and writes it
|
|
|
|
|
// to the streaming output associated with the provided Gin context. The model
|
|
|
|
|
// parameter identifies the model, content is the assistant message text, and
|
|
|
|
|
// done indicates whether this is the final chunk in the stream.
|
|
|
|
|
func writeOllamaResponse(c *gin.Context, model string, content string, done bool) error {
|
|
|
|
|
response := OllamaResponse{
|
|
|
|
|
Model: model,
|
|
|
|
|
CreatedAt: time.Now().UTC().Format("2006-01-02T15:04:05.999999999Z"),
|
|
|
|
|
Message: struct {
|
|
|
|
|
Role string `json:"role"`
|
|
|
|
|
Content string `json:"content"`
|
|
|
|
|
}(struct {
|
|
|
|
|
Role string
|
|
|
|
|
Content string
|
|
|
|
|
}{Content: content, Role: "assistant"}),
|
|
|
|
|
Done: done,
|
|
|
|
|
}
|
|
|
|
|
return writeOllamaResponseStruct(c, response)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// writeOllamaResponseStruct marshals the provided OllamaResponse and writes it
|
|
|
|
|
// as newline-delimited JSON to the HTTP response stream.
|
|
|
|
|
func writeOllamaResponseStruct(c *gin.Context, response OllamaResponse) error {
|
|
|
|
|
marshalled, err := json.Marshal(response)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if _, err := c.Writer.Write(marshalled); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if _, err := c.Writer.Write([]byte("\n")); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if flusher, ok := c.Writer.(http.Flusher); ok {
|
|
|
|
|
flusher.Flush()
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|