Merge pull request #1912 from berniegreen/feature/metadata-refactor

refactor: implement structured streaming and metadata support
This commit is contained in:
Kayvan Sylvan
2026-01-03 14:50:15 -08:00
committed by GitHub
25 changed files with 480 additions and 85 deletions

View File

@@ -705,6 +705,7 @@ Application Options:
--yt-dlp-args= Additional arguments to pass to yt-dlp (e.g. '--cookies-from-browser brave')
--thinking= Set reasoning/thinking level (e.g., off, low, medium, high, or
numeric tokens for Anthropic or Google Gemini)
--show-metadata Print metadata (input/output tokens) to stderr
--debug= Set debug level (0: off, 1: basic, 2: detailed, 3: trace)
Help Options:
-h, --help Show this help message

View File

@@ -0,0 +1,7 @@
### PR [#1912](https://github.com/danielmiessler/Fabric/pull/1912) by [berniegreen](https://github.com/berniegreen): refactor: implement structured streaming and metadata support
- Feat: add domain types for structured streaming (Phase 1)
- Refactor: update Vendor interface and Chatter for structured streaming (Phase 2)
- Refactor: implement structured streaming in all AI vendors (Phase 3)
- Feat: implement CLI support for metadata display (Phase 4)
- Feat: implement REST API support for metadata streaming (Phase 5)

View File

@@ -289,6 +289,20 @@ const docTemplate = `{
"ThinkingHigh"
]
},
"domain.UsageMetadata": {
"type": "object",
"properties": {
"input_tokens": {
"type": "integer"
},
"output_tokens": {
"type": "integer"
},
"total_tokens": {
"type": "integer"
}
}
},
"fsdb.Pattern": {
"type": "object",
"properties": {
@@ -360,6 +374,9 @@ const docTemplate = `{
"$ref": "#/definitions/restapi.PromptRequest"
}
},
"quiet": {
"type": "boolean"
},
"raw": {
"type": "boolean"
},
@@ -372,6 +389,9 @@ const docTemplate = `{
"seed": {
"type": "integer"
},
"showMetadata": {
"type": "boolean"
},
"suppressThink": {
"type": "boolean"
},
@@ -392,6 +412,9 @@ const docTemplate = `{
"type": "number",
"format": "float64"
},
"updateChan": {
"type": "object"
},
"voice": {
"type": "string"
}
@@ -423,6 +446,10 @@ const docTemplate = `{
"patternName": {
"type": "string"
},
"sessionName": {
"description": "Session name for multi-turn conversations",
"type": "string"
},
"strategyName": {
"description": "Optional strategy name",
"type": "string"
@@ -446,7 +473,6 @@ const docTemplate = `{
"type": "object",
"properties": {
"content": {
"description": "The actual content",
"type": "string"
},
"format": {
@@ -454,8 +480,11 @@ const docTemplate = `{
"type": "string"
},
"type": {
"description": "\"content\", \"error\", \"complete\"",
"description": "\"content\", \"usage\", \"error\", \"complete\"",
"type": "string"
},
"usage": {
"$ref": "#/definitions/domain.UsageMetadata"
}
}
},

View File

@@ -283,6 +283,20 @@
"ThinkingHigh"
]
},
"domain.UsageMetadata": {
"type": "object",
"properties": {
"input_tokens": {
"type": "integer"
},
"output_tokens": {
"type": "integer"
},
"total_tokens": {
"type": "integer"
}
}
},
"fsdb.Pattern": {
"type": "object",
"properties": {
@@ -354,6 +368,9 @@
"$ref": "#/definitions/restapi.PromptRequest"
}
},
"quiet": {
"type": "boolean"
},
"raw": {
"type": "boolean"
},
@@ -366,6 +383,9 @@
"seed": {
"type": "integer"
},
"showMetadata": {
"type": "boolean"
},
"suppressThink": {
"type": "boolean"
},
@@ -386,6 +406,9 @@
"type": "number",
"format": "float64"
},
"updateChan": {
"type": "object"
},
"voice": {
"type": "string"
}
@@ -417,6 +440,10 @@
"patternName": {
"type": "string"
},
"sessionName": {
"description": "Session name for multi-turn conversations",
"type": "string"
},
"strategyName": {
"description": "Optional strategy name",
"type": "string"
@@ -440,7 +467,6 @@
"type": "object",
"properties": {
"content": {
"description": "The actual content",
"type": "string"
},
"format": {
@@ -448,8 +474,11 @@
"type": "string"
},
"type": {
"description": "\"content\", \"error\", \"complete\"",
"description": "\"content\", \"usage\", \"error\", \"complete\"",
"type": "string"
},
"usage": {
"$ref": "#/definitions/domain.UsageMetadata"
}
}
},

View File

@@ -12,6 +12,15 @@ definitions:
- ThinkingLow
- ThinkingMedium
- ThinkingHigh
domain.UsageMetadata:
properties:
input_tokens:
type: integer
output_tokens:
type: integer
total_tokens:
type: integer
type: object
fsdb.Pattern:
properties:
description:
@@ -60,6 +69,8 @@ definitions:
items:
$ref: '#/definitions/restapi.PromptRequest'
type: array
quiet:
type: boolean
raw:
type: boolean
search:
@@ -68,6 +79,8 @@ definitions:
type: string
seed:
type: integer
showMetadata:
type: boolean
suppressThink:
type: boolean
temperature:
@@ -82,6 +95,8 @@ definitions:
topP:
format: float64
type: number
updateChan:
type: object
voice:
type: string
type: object
@@ -102,6 +117,9 @@ definitions:
type: string
patternName:
type: string
sessionName:
description: Session name for multi-turn conversations
type: string
strategyName:
description: Optional strategy name
type: string
@@ -118,14 +136,15 @@ definitions:
restapi.StreamResponse:
properties:
content:
description: The actual content
type: string
format:
description: '"markdown", "mermaid", "plain"'
type: string
type:
description: '"content", "error", "complete"'
description: '"content", "usage", "error", "complete"'
type: string
usage:
$ref: '#/definitions/domain.UsageMetadata'
type: object
restapi.YouTubeRequest:
properties:

View File

@@ -104,6 +104,7 @@ type Flags struct {
Notification bool `long:"notification" yaml:"notification" description:"Send desktop notification when command completes"`
NotificationCommand string `long:"notification-command" yaml:"notificationCommand" description:"Custom command to run for notifications (overrides built-in notifications)"`
Thinking domain.ThinkingLevel `long:"thinking" yaml:"thinking" description:"Set reasoning/thinking level (e.g., off, low, medium, high, or numeric tokens for Anthropic or Google Gemini)"`
ShowMetadata bool `long:"show-metadata" description:"Print metadata to stderr"`
Debug int `long:"debug" description:"Set debug level (0=off, 1=basic, 2=detailed, 3=trace)" default:"0"`
}
@@ -459,6 +460,7 @@ func (o *Flags) BuildChatOptions() (ret *domain.ChatOptions, err error) {
Voice: o.Voice,
Notification: o.Notification || o.NotificationCommand != "",
NotificationCommand: o.NotificationCommand,
ShowMetadata: o.ShowMetadata,
}
return
}

View File

@@ -64,7 +64,7 @@ func (o *Chatter) Send(request *domain.ChatRequest, opts *domain.ChatOptions) (s
message := ""
if o.Stream {
responseChan := make(chan string)
responseChan := make(chan domain.StreamUpdate)
errChan := make(chan error, 1)
done := make(chan struct{})
printedStream := false
@@ -76,15 +76,31 @@ func (o *Chatter) Send(request *domain.ChatRequest, opts *domain.ChatOptions) (s
}
}()
for response := range responseChan {
message += response
if !opts.SuppressThink {
fmt.Print(response)
printedStream = true
for update := range responseChan {
if opts.UpdateChan != nil {
opts.UpdateChan <- update
}
switch update.Type {
case domain.StreamTypeContent:
message += update.Content
if !opts.SuppressThink && !opts.Quiet {
fmt.Print(update.Content)
printedStream = true
}
case domain.StreamTypeUsage:
if opts.ShowMetadata && update.Usage != nil && !opts.Quiet {
fmt.Fprintf(os.Stderr, "\n[Metadata] Input: %d | Output: %d | Total: %d\n",
update.Usage.InputTokens, update.Usage.OutputTokens, update.Usage.TotalTokens)
}
case domain.StreamTypeError:
if !opts.Quiet {
fmt.Fprintf(os.Stderr, "Error: %s\n", update.Content)
}
errChan <- errors.New(update.Content)
}
}
if printedStream && !opts.SuppressThink && !strings.HasSuffix(message, "\n") {
if printedStream && !opts.SuppressThink && !strings.HasSuffix(message, "\n") && !opts.Quiet {
fmt.Println()
}

View File

@@ -14,7 +14,7 @@ import (
// mockVendor implements the ai.Vendor interface for testing
type mockVendor struct {
sendStreamError error
streamChunks []string
streamChunks []domain.StreamUpdate
sendFunc func(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error)
}
@@ -45,7 +45,7 @@ func (m *mockVendor) ListModels() ([]string, error) {
return []string{"test-model"}, nil
}
func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *domain.ChatOptions, responseChan chan string) error {
func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *domain.ChatOptions, responseChan chan domain.StreamUpdate) error {
// Send chunks if provided (for successful streaming test)
if m.streamChunks != nil {
for _, chunk := range m.streamChunks {
@@ -169,7 +169,11 @@ func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) {
db := fsdb.NewDb(tempDir)
// Create test chunks that should be aggregated
testChunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."}
chunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."}
testChunks := make([]domain.StreamUpdate, len(chunks))
for i, c := range chunks {
testChunks[i] = domain.StreamUpdate{Type: domain.StreamTypeContent, Content: c}
}
expectedMessage := "Hello world! This is a test."
// Create a mock vendor that will send chunks successfully
@@ -228,3 +232,83 @@ func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) {
t.Errorf("Expected aggregated message %q, got %q", expectedMessage, assistantMessage.Content)
}
}
func TestChatter_Send_StreamingMetadataPropagation(t *testing.T) {
// Create a temporary database for testing
tempDir := t.TempDir()
db := fsdb.NewDb(tempDir)
// Create test chunks: one content, one usage metadata
testChunks := []domain.StreamUpdate{
{
Type: domain.StreamTypeContent,
Content: "Test content",
},
{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: 10,
OutputTokens: 5,
TotalTokens: 15,
},
},
}
// Create a mock vendor
mockVendor := &mockVendor{
sendStreamError: nil,
streamChunks: testChunks,
}
// Create chatter with streaming enabled
chatter := &Chatter{
db: db,
Stream: true,
vendor: mockVendor,
model: "test-model",
}
// Create a test request
request := &domain.ChatRequest{
Message: &chat.ChatCompletionMessage{
Role: chat.ChatMessageRoleUser,
Content: "test message",
},
}
// Create an update channel to capture stream events
updateChan := make(chan domain.StreamUpdate, 10)
// Create test options with UpdateChan
opts := &domain.ChatOptions{
Model: "test-model",
UpdateChan: updateChan,
Quiet: true, // Suppress stdout/stderr
}
// Call Send
_, err := chatter.Send(request, opts)
if err != nil {
t.Fatalf("Expected no error, but got: %v", err)
}
close(updateChan)
// Verify we received the metadata event
var usageReceived bool
for update := range updateChan {
if update.Type == domain.StreamTypeUsage {
usageReceived = true
if update.Usage == nil {
t.Error("Expected usage metadata to be non-nil")
} else {
if update.Usage.TotalTokens != 15 {
t.Errorf("Expected 15 total tokens, got %d", update.Usage.TotalTokens)
}
}
}
}
if !usageReceived {
t.Error("Expected to receive a usage metadata update, but didn't")
}
}

View File

@@ -43,7 +43,7 @@ func (m *testVendor) Configure() error { return nil }
func (m *testVendor) Setup() error { return nil }
func (m *testVendor) SetupFillEnvFileContent(*bytes.Buffer) {}
func (m *testVendor) ListModels() ([]string, error) { return m.models, nil }
func (m *testVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error {
func (m *testVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error {
return nil
}
func (m *testVendor) Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) {

View File

@@ -51,6 +51,9 @@ type ChatOptions struct {
Voice string
Notification bool
NotificationCommand string
ShowMetadata bool
Quiet bool
UpdateChan chan StreamUpdate
}
// NormalizeMessages remove empty messages and ensure messages order user-assist-user

24
internal/domain/stream.go Normal file
View File

@@ -0,0 +1,24 @@
package domain
// StreamType distinguishes between partial text content and metadata events.
type StreamType string
const (
StreamTypeContent StreamType = "content"
StreamTypeUsage StreamType = "usage"
StreamTypeError StreamType = "error"
)
// StreamUpdate is the unified payload sent through the internal channels.
type StreamUpdate struct {
Type StreamType `json:"type"`
Content string `json:"content,omitempty"` // For text deltas
Usage *UsageMetadata `json:"usage,omitempty"` // For token counts
}
// UsageMetadata normalizes token counts across different providers.
type UsageMetadata struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}

View File

@@ -184,7 +184,7 @@ func parseThinking(level domain.ThinkingLevel) (anthropic.ThinkingConfigParamUni
}
func (an *Client) SendStream(
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string,
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate,
) (err error) {
messages := an.toMessages(msgs)
if len(messages) == 0 {
@@ -210,9 +210,33 @@ func (an *Client) SendStream(
for stream.Next() {
event := stream.Current()
// directly send any non-empty delta text
// Handle Content
if event.Delta.Text != "" {
channel <- event.Delta.Text
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: event.Delta.Text,
}
}
// Handle Usage
if event.Message.Usage.InputTokens != 0 || event.Message.Usage.OutputTokens != 0 {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(event.Message.Usage.InputTokens),
OutputTokens: int(event.Message.Usage.OutputTokens),
TotalTokens: int(event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens),
},
}
} else if event.Usage.InputTokens != 0 || event.Usage.OutputTokens != 0 {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(event.Usage.InputTokens),
OutputTokens: int(event.Usage.OutputTokens),
TotalTokens: int(event.Usage.InputTokens + event.Usage.OutputTokens),
},
}
}
}

View File

@@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) {
}
// SendStream sends the messages to the Bedrock ConverseStream API
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) {
func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
// Ensure channel is closed on all exit paths to prevent goroutine leaks
defer func() {
if r := recover(); r != nil {
@@ -186,18 +186,35 @@ func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *dom
case *types.ConverseStreamOutputMemberContentBlockDelta:
text, ok := v.Value.Delta.(*types.ContentBlockDeltaMemberText)
if ok {
channel <- text.Value
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: text.Value,
}
}
case *types.ConverseStreamOutputMemberMessageStop:
channel <- "\n"
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: "\n",
}
return nil // Let defer handle the close
case *types.ConverseStreamOutputMemberMetadata:
if v.Value.Usage != nil {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(*v.Value.Usage.InputTokens),
OutputTokens: int(*v.Value.Usage.OutputTokens),
TotalTokens: int(*v.Value.Usage.TotalTokens),
},
}
}
// Unused Events
case *types.ConverseStreamOutputMemberMessageStart,
*types.ConverseStreamOutputMemberContentBlockStart,
*types.ConverseStreamOutputMemberContentBlockStop,
*types.ConverseStreamOutputMemberMetadata:
*types.ConverseStreamOutputMemberContentBlockStop:
default:
return fmt.Errorf("unknown stream event type: %T", v)

View File

@@ -108,12 +108,30 @@ func (c *Client) constructRequest(msgs []*chat.ChatCompletionMessage, opts *doma
return builder.String()
}
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
defer close(channel)
request := c.constructRequest(msgs, opts)
channel <- request
channel <- "\n"
channel <- DryRunResponse
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: request,
}
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: "\n",
}
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: DryRunResponse,
}
// Simulated usage
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: 100,
OutputTokens: 50,
TotalTokens: 150,
},
}
return nil
}

View File

@@ -39,7 +39,7 @@ func TestSendStream_SendsMessages(t *testing.T) {
opts := &domain.ChatOptions{
Model: "dry-run-model",
}
channel := make(chan string)
channel := make(chan domain.StreamUpdate)
go func() {
err := client.SendStream(msgs, opts, channel)
if err != nil {
@@ -48,7 +48,7 @@ func TestSendStream_SendsMessages(t *testing.T) {
}()
var receivedMessages []string
for msg := range channel {
receivedMessages = append(receivedMessages, msg)
receivedMessages = append(receivedMessages, msg.Content)
}
if len(receivedMessages) == 0 {
t.Errorf("Expected to receive messages, but got none")

View File

@@ -129,7 +129,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
return
}
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) {
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
ctx := context.Background()
defer close(channel)
@@ -154,13 +154,30 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
for response, err := range stream {
if err != nil {
channel <- fmt.Sprintf("Error: %v\n", err)
channel <- domain.StreamUpdate{
Type: domain.StreamTypeError,
Content: fmt.Sprintf("Error: %v", err),
}
return err
}
text := o.extractTextFromResponse(response)
if text != "" {
channel <- text
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: text,
}
}
if response.UsageMetadata != nil {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(response.UsageMetadata.PromptTokenCount),
OutputTokens: int(response.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(response.UsageMetadata.TotalTokenCount),
},
}
}
}

View File

@@ -87,13 +87,16 @@ func (c *Client) ListModels() ([]string, error) {
return models, nil
}
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value)
payload := map[string]any{
"messages": msgs,
"model": opts.Model,
"stream": true, // Enable streaming
"stream_options": map[string]any{
"include_usage": true,
},
}
var jsonPayload []byte
@@ -144,7 +147,7 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
line = after
}
if string(line) == "[DONE]" {
if string(bytes.TrimSpace(line)) == "[DONE]" {
break
}
@@ -153,6 +156,24 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
continue
}
// Handle Usage
if usage, ok := result["usage"].(map[string]any); ok {
var metadata domain.UsageMetadata
if val, ok := usage["prompt_tokens"].(float64); ok {
metadata.InputTokens = int(val)
}
if val, ok := usage["completion_tokens"].(float64); ok {
metadata.OutputTokens = int(val)
}
if val, ok := usage["total_tokens"].(float64); ok {
metadata.TotalTokens = int(val)
}
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &metadata,
}
}
var choices []any
var ok bool
if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 {
@@ -166,7 +187,10 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
var content string
if content, _ = delta["content"].(string); content != "" {
channel <- content
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: content,
}
}
}

View File

@@ -106,7 +106,7 @@ func (o *Client) ListModels() (ret []string, err error) {
return
}
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) {
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) {
ctx := context.Background()
var req ollamaapi.ChatRequest
@@ -115,7 +115,21 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
}
respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) {
channel <- resp.Message.Content
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: resp.Message.Content,
}
if resp.Done {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: resp.PromptEvalCount,
OutputTokens: resp.EvalCount,
TotalTokens: resp.PromptEvalCount + resp.EvalCount,
},
}
}
return
}

View File

@@ -30,7 +30,7 @@ func (o *Client) sendChatCompletions(ctx context.Context, msgs []*chat.ChatCompl
// sendStreamChatCompletions sends a streaming request using the Chat Completions API
func (o *Client) sendStreamChatCompletions(
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string,
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate,
) (err error) {
defer close(channel)
@@ -39,11 +39,28 @@ func (o *Client) sendStreamChatCompletions(
for stream.Next() {
chunk := stream.Current()
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
channel <- chunk.Choices[0].Delta.Content
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: chunk.Choices[0].Delta.Content,
}
}
if chunk.Usage.TotalTokens > 0 {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(chunk.Usage.PromptTokens),
OutputTokens: int(chunk.Usage.CompletionTokens),
TotalTokens: int(chunk.Usage.TotalTokens),
},
}
}
}
if stream.Err() == nil {
channel <- "\n"
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: "\n",
}
}
return stream.Err()
}
@@ -65,6 +82,9 @@ func (o *Client) buildChatCompletionParams(
ret = openai.ChatCompletionNewParams{
Model: shared.ChatModel(opts.Model),
Messages: messages,
StreamOptions: openai.ChatCompletionStreamOptionsParam{
IncludeUsage: openai.Bool(true),
},
}
if !opts.Raw {

View File

@@ -108,7 +108,7 @@ func (o *Client) ListModels() (ret []string, err error) {
}
func (o *Client) SendStream(
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string,
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate,
) (err error) {
// Use Responses API for OpenAI, Chat Completions API for other providers
if o.supportsResponsesAPI() {
@@ -118,7 +118,7 @@ func (o *Client) SendStream(
}
func (o *Client) sendStreamResponses(
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string,
msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate,
) (err error) {
defer close(channel)
@@ -128,7 +128,10 @@ func (o *Client) sendStreamResponses(
event := stream.Current()
switch event.Type {
case string(constant.ResponseOutputTextDelta("").Default()):
channel <- event.AsResponseOutputTextDelta().Delta
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: event.AsResponseOutputTextDelta().Delta,
}
case string(constant.ResponseOutputTextDone("").Default()):
// The Responses API sends the full text again in the
// final "done" event. Since we've already streamed all
@@ -138,7 +141,10 @@ func (o *Client) sendStreamResponses(
}
}
if stream.Err() == nil {
channel <- "\n"
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: "\n",
}
}
return stream.Err()
}

View File

@@ -123,7 +123,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
return content.String(), nil
}
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
if c.client == nil {
if err := c.Configure(); err != nil {
close(channel) // Ensure channel is closed on error
@@ -196,7 +196,21 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
content = resp.Choices[0].Message.Content
}
if content != "" {
channel <- content
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: content,
}
}
}
if resp.Usage.TotalTokens != 0 {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(resp.Usage.PromptTokens),
OutputTokens: int(resp.Usage.CompletionTokens),
TotalTokens: int(resp.Usage.TotalTokens),
},
}
}
}
@@ -205,9 +219,14 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
if lastResponse != nil {
citations := lastResponse.GetCitations()
if len(citations) > 0 {
channel <- "\n\n# CITATIONS\n\n"
var citationsText strings.Builder
citationsText.WriteString("\n\n# CITATIONS\n\n")
for i, citation := range citations {
channel <- fmt.Sprintf("- [%d] %s\n", i+1, citation)
citationsText.WriteString(fmt.Sprintf("- [%d] %s\n", i+1, citation))
}
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: citationsText.String(),
}
}
}

View File

@@ -12,7 +12,7 @@ import (
type Vendor interface {
plugins.Plugin
ListModels() ([]string, error)
SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error
SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error
Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error)
NeedsRawMode(modelName string) bool
}

View File

@@ -20,7 +20,7 @@ func (v *stubVendor) Configure() error { return nil }
func (v *stubVendor) Setup() error { return nil }
func (v *stubVendor) SetupFillEnvFileContent(*bytes.Buffer) {}
func (v *stubVendor) ListModels() ([]string, error) { return nil, nil }
func (v *stubVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error {
func (v *stubVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error {
return nil
}
func (v *stubVendor) Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) {

View File

@@ -107,7 +107,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
return strings.Join(textParts, ""), nil
}
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error {
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
if c.client == nil {
close(channel)
return fmt.Errorf("VertexAI client not initialized")
@@ -133,8 +133,34 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
// Process stream
for stream.Next() {
event := stream.Current()
// Handle Content
if event.Delta.Text != "" {
channel <- event.Delta.Text
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: event.Delta.Text,
}
}
// Handle Usage
if event.Message.Usage.InputTokens != 0 || event.Message.Usage.OutputTokens != 0 {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(event.Message.Usage.InputTokens),
OutputTokens: int(event.Message.Usage.OutputTokens),
TotalTokens: int(event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens),
},
}
} else if event.Usage.InputTokens != 0 || event.Usage.OutputTokens != 0 {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(event.Usage.InputTokens),
OutputTokens: int(event.Usage.OutputTokens),
TotalTokens: int(event.Usage.InputTokens + event.Usage.OutputTokens),
},
}
}
}

View File

@@ -40,9 +40,10 @@ type ChatRequest struct {
}
type StreamResponse struct {
Type string `json:"type"` // "content", "error", "complete"
Format string `json:"format"` // "markdown", "mermaid", "plain"
Content string `json:"content"` // The actual content
Type string `json:"type"` // "content", "usage", "error", "complete"
Format string `json:"format,omitempty"` // "markdown", "mermaid", "plain"
Content string `json:"content,omitempty"`
Usage *domain.UsageMetadata `json:"usage,omitempty"`
}
func NewChatHandler(r *gin.Engine, registry *core.PluginRegistry, db *fsdb.Db) *ChatHandler {
@@ -98,7 +99,7 @@ func (h *ChatHandler) HandleChat(c *gin.Context) {
log.Printf("Processing prompt %d: Model=%s Pattern=%s Context=%s",
i+1, prompt.Model, prompt.PatternName, prompt.ContextName)
streamChan := make(chan string)
streamChan := make(chan domain.StreamUpdate)
go func(p PromptRequest) {
defer close(streamChan)
@@ -117,10 +118,10 @@ func (h *ChatHandler) HandleChat(c *gin.Context) {
}
}
chatter, err := h.registry.GetChatter(p.Model, 2048, p.Vendor, "", false, false)
chatter, err := h.registry.GetChatter(p.Model, 2048, p.Vendor, "", true, false)
if err != nil {
log.Printf("Error creating chatter: %v", err)
streamChan <- fmt.Sprintf("Error: %v", err)
streamChan <- domain.StreamUpdate{Type: domain.StreamTypeError, Content: fmt.Sprintf("Error: %v", err)}
return
}
@@ -144,49 +145,44 @@ func (h *ChatHandler) HandleChat(c *gin.Context) {
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
Thinking: request.Thinking,
UpdateChan: streamChan,
Quiet: true,
}
session, err := chatter.Send(chatReq, opts)
_, err = chatter.Send(chatReq, opts)
if err != nil {
log.Printf("Error from chatter.Send: %v", err)
streamChan <- fmt.Sprintf("Error: %v", err)
// Error already sent to streamChan via domain.StreamTypeError if occurred in Send loop
return
}
if session == nil {
log.Printf("No session returned from chatter.Send")
streamChan <- "Error: No response from model"
return
}
lastMsg := session.GetLastMessage()
if lastMsg != nil {
streamChan <- lastMsg.Content
} else {
log.Printf("No message content in session")
streamChan <- "Error: No response content"
}
}(prompt)
for content := range streamChan {
for update := range streamChan {
select {
case <-clientGone:
return
default:
var response StreamResponse
if strings.HasPrefix(content, "Error:") {
switch update.Type {
case domain.StreamTypeContent:
response = StreamResponse{
Type: "content",
Format: detectFormat(update.Content),
Content: update.Content,
}
case domain.StreamTypeUsage:
response = StreamResponse{
Type: "usage",
Usage: update.Usage,
}
case domain.StreamTypeError:
response = StreamResponse{
Type: "error",
Format: "plain",
Content: content,
}
} else {
response = StreamResponse{
Type: "content",
Format: detectFormat(content),
Content: content,
Content: update.Content,
}
}
if err := writeSSEResponse(c.Writer, response); err != nil {
log.Printf("Error writing response: %v", err)
return