diff --git a/README.md b/README.md index 3107d2c4..da6b72ed 100644 --- a/README.md +++ b/README.md @@ -705,6 +705,7 @@ Application Options: --yt-dlp-args= Additional arguments to pass to yt-dlp (e.g. '--cookies-from-browser brave') --thinking= Set reasoning/thinking level (e.g., off, low, medium, high, or numeric tokens for Anthropic or Google Gemini) + --show-metadata Print metadata (input/output tokens) to stderr --debug= Set debug level (0: off, 1: basic, 2: detailed, 3: trace) Help Options: -h, --help Show this help message diff --git a/cmd/generate_changelog/incoming/1912.txt b/cmd/generate_changelog/incoming/1912.txt new file mode 100644 index 00000000..c2dbb159 --- /dev/null +++ b/cmd/generate_changelog/incoming/1912.txt @@ -0,0 +1,7 @@ +### PR [#1912](https://github.com/danielmiessler/Fabric/pull/1912) by [berniegreen](https://github.com/berniegreen): refactor: implement structured streaming and metadata support + +- Feat: add domain types for structured streaming (Phase 1) +- Refactor: update Vendor interface and Chatter for structured streaming (Phase 2) +- Refactor: implement structured streaming in all AI vendors (Phase 3) +- Feat: implement CLI support for metadata display (Phase 4) +- Feat: implement REST API support for metadata streaming (Phase 5) diff --git a/docs/docs.go b/docs/docs.go index 62c61101..e67df397 100644 --- a/docs/docs.go +++ b/docs/docs.go @@ -289,6 +289,20 @@ const docTemplate = `{ "ThinkingHigh" ] }, + "domain.UsageMetadata": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "output_tokens": { + "type": "integer" + }, + "total_tokens": { + "type": "integer" + } + } + }, "fsdb.Pattern": { "type": "object", "properties": { @@ -360,6 +374,9 @@ const docTemplate = `{ "$ref": "#/definitions/restapi.PromptRequest" } }, + "quiet": { + "type": "boolean" + }, "raw": { "type": "boolean" }, @@ -372,6 +389,9 @@ const docTemplate = `{ "seed": { "type": "integer" }, + "showMetadata": { + "type": "boolean" + }, "suppressThink": { "type": "boolean" }, @@ -392,6 +412,9 @@ const docTemplate = `{ "type": "number", "format": "float64" }, + "updateChan": { + "type": "object" + }, "voice": { "type": "string" } @@ -423,6 +446,10 @@ const docTemplate = `{ "patternName": { "type": "string" }, + "sessionName": { + "description": "Session name for multi-turn conversations", + "type": "string" + }, "strategyName": { "description": "Optional strategy name", "type": "string" @@ -446,7 +473,6 @@ const docTemplate = `{ "type": "object", "properties": { "content": { - "description": "The actual content", "type": "string" }, "format": { @@ -454,8 +480,11 @@ const docTemplate = `{ "type": "string" }, "type": { - "description": "\"content\", \"error\", \"complete\"", + "description": "\"content\", \"usage\", \"error\", \"complete\"", "type": "string" + }, + "usage": { + "$ref": "#/definitions/domain.UsageMetadata" } } }, diff --git a/docs/swagger.json b/docs/swagger.json index 547443dd..749c6e06 100644 --- a/docs/swagger.json +++ b/docs/swagger.json @@ -283,6 +283,20 @@ "ThinkingHigh" ] }, + "domain.UsageMetadata": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "output_tokens": { + "type": "integer" + }, + "total_tokens": { + "type": "integer" + } + } + }, "fsdb.Pattern": { "type": "object", "properties": { @@ -354,6 +368,9 @@ "$ref": "#/definitions/restapi.PromptRequest" } }, + "quiet": { + "type": "boolean" + }, "raw": { "type": "boolean" }, @@ -366,6 +383,9 @@ "seed": { "type": "integer" }, + "showMetadata": { + "type": "boolean" + }, "suppressThink": { "type": "boolean" }, @@ -386,6 +406,9 @@ "type": "number", "format": "float64" }, + "updateChan": { + "type": "object" + }, "voice": { "type": "string" } @@ -417,6 +440,10 @@ "patternName": { "type": "string" }, + "sessionName": { + "description": "Session name for multi-turn conversations", + "type": "string" + }, "strategyName": { "description": "Optional strategy name", "type": "string" @@ -440,7 +467,6 @@ "type": "object", "properties": { "content": { - "description": "The actual content", "type": "string" }, "format": { @@ -448,8 +474,11 @@ "type": "string" }, "type": { - "description": "\"content\", \"error\", \"complete\"", + "description": "\"content\", \"usage\", \"error\", \"complete\"", "type": "string" + }, + "usage": { + "$ref": "#/definitions/domain.UsageMetadata" } } }, diff --git a/docs/swagger.yaml b/docs/swagger.yaml index f1629f92..53184584 100644 --- a/docs/swagger.yaml +++ b/docs/swagger.yaml @@ -12,6 +12,15 @@ definitions: - ThinkingLow - ThinkingMedium - ThinkingHigh + domain.UsageMetadata: + properties: + input_tokens: + type: integer + output_tokens: + type: integer + total_tokens: + type: integer + type: object fsdb.Pattern: properties: description: @@ -60,6 +69,8 @@ definitions: items: $ref: '#/definitions/restapi.PromptRequest' type: array + quiet: + type: boolean raw: type: boolean search: @@ -68,6 +79,8 @@ definitions: type: string seed: type: integer + showMetadata: + type: boolean suppressThink: type: boolean temperature: @@ -82,6 +95,8 @@ definitions: topP: format: float64 type: number + updateChan: + type: object voice: type: string type: object @@ -102,6 +117,9 @@ definitions: type: string patternName: type: string + sessionName: + description: Session name for multi-turn conversations + type: string strategyName: description: Optional strategy name type: string @@ -118,14 +136,15 @@ definitions: restapi.StreamResponse: properties: content: - description: The actual content type: string format: description: '"markdown", "mermaid", "plain"' type: string type: - description: '"content", "error", "complete"' + description: '"content", "usage", "error", "complete"' type: string + usage: + $ref: '#/definitions/domain.UsageMetadata' type: object restapi.YouTubeRequest: properties: diff --git a/internal/cli/flags.go b/internal/cli/flags.go index e018bc47..274c856c 100644 --- a/internal/cli/flags.go +++ b/internal/cli/flags.go @@ -104,6 +104,7 @@ type Flags struct { Notification bool `long:"notification" yaml:"notification" description:"Send desktop notification when command completes"` NotificationCommand string `long:"notification-command" yaml:"notificationCommand" description:"Custom command to run for notifications (overrides built-in notifications)"` Thinking domain.ThinkingLevel `long:"thinking" yaml:"thinking" description:"Set reasoning/thinking level (e.g., off, low, medium, high, or numeric tokens for Anthropic or Google Gemini)"` + ShowMetadata bool `long:"show-metadata" description:"Print metadata to stderr"` Debug int `long:"debug" description:"Set debug level (0=off, 1=basic, 2=detailed, 3=trace)" default:"0"` } @@ -459,6 +460,7 @@ func (o *Flags) BuildChatOptions() (ret *domain.ChatOptions, err error) { Voice: o.Voice, Notification: o.Notification || o.NotificationCommand != "", NotificationCommand: o.NotificationCommand, + ShowMetadata: o.ShowMetadata, } return } diff --git a/internal/core/chatter.go b/internal/core/chatter.go index adbd581f..101ed364 100644 --- a/internal/core/chatter.go +++ b/internal/core/chatter.go @@ -64,7 +64,7 @@ func (o *Chatter) Send(request *domain.ChatRequest, opts *domain.ChatOptions) (s message := "" if o.Stream { - responseChan := make(chan string) + responseChan := make(chan domain.StreamUpdate) errChan := make(chan error, 1) done := make(chan struct{}) printedStream := false @@ -76,15 +76,31 @@ func (o *Chatter) Send(request *domain.ChatRequest, opts *domain.ChatOptions) (s } }() - for response := range responseChan { - message += response - if !opts.SuppressThink { - fmt.Print(response) - printedStream = true + for update := range responseChan { + if opts.UpdateChan != nil { + opts.UpdateChan <- update + } + switch update.Type { + case domain.StreamTypeContent: + message += update.Content + if !opts.SuppressThink && !opts.Quiet { + fmt.Print(update.Content) + printedStream = true + } + case domain.StreamTypeUsage: + if opts.ShowMetadata && update.Usage != nil && !opts.Quiet { + fmt.Fprintf(os.Stderr, "\n[Metadata] Input: %d | Output: %d | Total: %d\n", + update.Usage.InputTokens, update.Usage.OutputTokens, update.Usage.TotalTokens) + } + case domain.StreamTypeError: + if !opts.Quiet { + fmt.Fprintf(os.Stderr, "Error: %s\n", update.Content) + } + errChan <- errors.New(update.Content) } } - if printedStream && !opts.SuppressThink && !strings.HasSuffix(message, "\n") { + if printedStream && !opts.SuppressThink && !strings.HasSuffix(message, "\n") && !opts.Quiet { fmt.Println() } diff --git a/internal/core/chatter_test.go b/internal/core/chatter_test.go index bc4e9aba..6dcc1ce0 100644 --- a/internal/core/chatter_test.go +++ b/internal/core/chatter_test.go @@ -14,7 +14,7 @@ import ( // mockVendor implements the ai.Vendor interface for testing type mockVendor struct { sendStreamError error - streamChunks []string + streamChunks []domain.StreamUpdate sendFunc func(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) } @@ -45,7 +45,7 @@ func (m *mockVendor) ListModels() ([]string, error) { return []string{"test-model"}, nil } -func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *domain.ChatOptions, responseChan chan string) error { +func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *domain.ChatOptions, responseChan chan domain.StreamUpdate) error { // Send chunks if provided (for successful streaming test) if m.streamChunks != nil { for _, chunk := range m.streamChunks { @@ -169,7 +169,11 @@ func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) { db := fsdb.NewDb(tempDir) // Create test chunks that should be aggregated - testChunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."} + chunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."} + testChunks := make([]domain.StreamUpdate, len(chunks)) + for i, c := range chunks { + testChunks[i] = domain.StreamUpdate{Type: domain.StreamTypeContent, Content: c} + } expectedMessage := "Hello world! This is a test." // Create a mock vendor that will send chunks successfully @@ -228,3 +232,83 @@ func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) { t.Errorf("Expected aggregated message %q, got %q", expectedMessage, assistantMessage.Content) } } + +func TestChatter_Send_StreamingMetadataPropagation(t *testing.T) { + // Create a temporary database for testing + tempDir := t.TempDir() + db := fsdb.NewDb(tempDir) + + // Create test chunks: one content, one usage metadata + testChunks := []domain.StreamUpdate{ + { + Type: domain.StreamTypeContent, + Content: "Test content", + }, + { + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: 10, + OutputTokens: 5, + TotalTokens: 15, + }, + }, + } + + // Create a mock vendor + mockVendor := &mockVendor{ + sendStreamError: nil, + streamChunks: testChunks, + } + + // Create chatter with streaming enabled + chatter := &Chatter{ + db: db, + Stream: true, + vendor: mockVendor, + model: "test-model", + } + + // Create a test request + request := &domain.ChatRequest{ + Message: &chat.ChatCompletionMessage{ + Role: chat.ChatMessageRoleUser, + Content: "test message", + }, + } + + // Create an update channel to capture stream events + updateChan := make(chan domain.StreamUpdate, 10) + + // Create test options with UpdateChan + opts := &domain.ChatOptions{ + Model: "test-model", + UpdateChan: updateChan, + Quiet: true, // Suppress stdout/stderr + } + + // Call Send + _, err := chatter.Send(request, opts) + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + close(updateChan) + + // Verify we received the metadata event + var usageReceived bool + for update := range updateChan { + if update.Type == domain.StreamTypeUsage { + usageReceived = true + if update.Usage == nil { + t.Error("Expected usage metadata to be non-nil") + } else { + if update.Usage.TotalTokens != 15 { + t.Errorf("Expected 15 total tokens, got %d", update.Usage.TotalTokens) + } + } + } + } + + if !usageReceived { + t.Error("Expected to receive a usage metadata update, but didn't") + } +} diff --git a/internal/core/plugin_registry_test.go b/internal/core/plugin_registry_test.go index 23345985..629882dc 100644 --- a/internal/core/plugin_registry_test.go +++ b/internal/core/plugin_registry_test.go @@ -43,7 +43,7 @@ func (m *testVendor) Configure() error { return nil } func (m *testVendor) Setup() error { return nil } func (m *testVendor) SetupFillEnvFileContent(*bytes.Buffer) {} func (m *testVendor) ListModels() ([]string, error) { return m.models, nil } -func (m *testVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error { +func (m *testVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error { return nil } func (m *testVendor) Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) { diff --git a/internal/domain/domain.go b/internal/domain/domain.go index bd6fbbdf..2179e28b 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -51,6 +51,9 @@ type ChatOptions struct { Voice string Notification bool NotificationCommand string + ShowMetadata bool + Quiet bool + UpdateChan chan StreamUpdate } // NormalizeMessages remove empty messages and ensure messages order user-assist-user diff --git a/internal/domain/stream.go b/internal/domain/stream.go new file mode 100644 index 00000000..91bdcc4f --- /dev/null +++ b/internal/domain/stream.go @@ -0,0 +1,24 @@ +package domain + +// StreamType distinguishes between partial text content and metadata events. +type StreamType string + +const ( + StreamTypeContent StreamType = "content" + StreamTypeUsage StreamType = "usage" + StreamTypeError StreamType = "error" +) + +// StreamUpdate is the unified payload sent through the internal channels. +type StreamUpdate struct { + Type StreamType `json:"type"` + Content string `json:"content,omitempty"` // For text deltas + Usage *UsageMetadata `json:"usage,omitempty"` // For token counts +} + +// UsageMetadata normalizes token counts across different providers. +type UsageMetadata struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} diff --git a/internal/plugins/ai/anthropic/anthropic.go b/internal/plugins/ai/anthropic/anthropic.go index 2176e74e..3c582935 100644 --- a/internal/plugins/ai/anthropic/anthropic.go +++ b/internal/plugins/ai/anthropic/anthropic.go @@ -184,7 +184,7 @@ func parseThinking(level domain.ThinkingLevel) (anthropic.ThinkingConfigParamUni } func (an *Client) SendStream( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { messages := an.toMessages(msgs) if len(messages) == 0 { @@ -210,9 +210,33 @@ func (an *Client) SendStream( for stream.Next() { event := stream.Current() - // directly send any non-empty delta text + // Handle Content if event.Delta.Text != "" { - channel <- event.Delta.Text + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: event.Delta.Text, + } + } + + // Handle Usage + if event.Message.Usage.InputTokens != 0 || event.Message.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Message.Usage.InputTokens), + OutputTokens: int(event.Message.Usage.OutputTokens), + TotalTokens: int(event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens), + }, + } + } else if event.Usage.InputTokens != 0 || event.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Usage.InputTokens), + OutputTokens: int(event.Usage.OutputTokens), + TotalTokens: int(event.Usage.InputTokens + event.Usage.OutputTokens), + }, + } } } diff --git a/internal/plugins/ai/bedrock/bedrock.go b/internal/plugins/ai/bedrock/bedrock.go index 15f47c01..5421b2f6 100644 --- a/internal/plugins/ai/bedrock/bedrock.go +++ b/internal/plugins/ai/bedrock/bedrock.go @@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) { } // SendStream sends the messages to the Bedrock ConverseStream API -func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { // Ensure channel is closed on all exit paths to prevent goroutine leaks defer func() { if r := recover(); r != nil { @@ -186,18 +186,35 @@ func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *dom case *types.ConverseStreamOutputMemberContentBlockDelta: text, ok := v.Value.Delta.(*types.ContentBlockDeltaMemberText) if ok { - channel <- text.Value + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: text.Value, + } } case *types.ConverseStreamOutputMemberMessageStop: - channel <- "\n" + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } return nil // Let defer handle the close + case *types.ConverseStreamOutputMemberMetadata: + if v.Value.Usage != nil { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(*v.Value.Usage.InputTokens), + OutputTokens: int(*v.Value.Usage.OutputTokens), + TotalTokens: int(*v.Value.Usage.TotalTokens), + }, + } + } + // Unused Events case *types.ConverseStreamOutputMemberMessageStart, *types.ConverseStreamOutputMemberContentBlockStart, - *types.ConverseStreamOutputMemberContentBlockStop, - *types.ConverseStreamOutputMemberMetadata: + *types.ConverseStreamOutputMemberContentBlockStop: default: return fmt.Errorf("unknown stream event type: %T", v) diff --git a/internal/plugins/ai/dryrun/dryrun.go b/internal/plugins/ai/dryrun/dryrun.go index 1f36d515..034011c3 100644 --- a/internal/plugins/ai/dryrun/dryrun.go +++ b/internal/plugins/ai/dryrun/dryrun.go @@ -108,12 +108,30 @@ func (c *Client) constructRequest(msgs []*chat.ChatCompletionMessage, opts *doma return builder.String() } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error { defer close(channel) request := c.constructRequest(msgs, opts) - channel <- request - channel <- "\n" - channel <- DryRunResponse + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: request, + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: DryRunResponse, + } + // Simulated usage + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + }, + } return nil } diff --git a/internal/plugins/ai/dryrun/dryrun_test.go b/internal/plugins/ai/dryrun/dryrun_test.go index 0f8825ac..a5db58c3 100644 --- a/internal/plugins/ai/dryrun/dryrun_test.go +++ b/internal/plugins/ai/dryrun/dryrun_test.go @@ -39,7 +39,7 @@ func TestSendStream_SendsMessages(t *testing.T) { opts := &domain.ChatOptions{ Model: "dry-run-model", } - channel := make(chan string) + channel := make(chan domain.StreamUpdate) go func() { err := client.SendStream(msgs, opts, channel) if err != nil { @@ -48,7 +48,7 @@ func TestSendStream_SendsMessages(t *testing.T) { }() var receivedMessages []string for msg := range channel { - receivedMessages = append(receivedMessages, msg) + receivedMessages = append(receivedMessages, msg.Content) } if len(receivedMessages) == 0 { t.Errorf("Expected to receive messages, but got none") diff --git a/internal/plugins/ai/gemini/gemini.go b/internal/plugins/ai/gemini/gemini.go index fe3035d1..40878f4a 100644 --- a/internal/plugins/ai/gemini/gemini.go +++ b/internal/plugins/ai/gemini/gemini.go @@ -129,7 +129,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o return } -func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { ctx := context.Background() defer close(channel) @@ -154,13 +154,30 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha for response, err := range stream { if err != nil { - channel <- fmt.Sprintf("Error: %v\n", err) + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeError, + Content: fmt.Sprintf("Error: %v", err), + } return err } text := o.extractTextFromResponse(response) if text != "" { - channel <- text + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: text, + } + } + + if response.UsageMetadata != nil { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(response.UsageMetadata.PromptTokenCount), + OutputTokens: int(response.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(response.UsageMetadata.TotalTokenCount), + }, + } } } diff --git a/internal/plugins/ai/lmstudio/lmstudio.go b/internal/plugins/ai/lmstudio/lmstudio.go index 8e64b88f..f9cae99f 100644 --- a/internal/plugins/ai/lmstudio/lmstudio.go +++ b/internal/plugins/ai/lmstudio/lmstudio.go @@ -87,13 +87,16 @@ func (c *Client) ListModels() ([]string, error) { return models, nil } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value) payload := map[string]any{ "messages": msgs, "model": opts.Model, "stream": true, // Enable streaming + "stream_options": map[string]any{ + "include_usage": true, + }, } var jsonPayload []byte @@ -144,7 +147,7 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha line = after } - if string(line) == "[DONE]" { + if string(bytes.TrimSpace(line)) == "[DONE]" { break } @@ -153,6 +156,24 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha continue } + // Handle Usage + if usage, ok := result["usage"].(map[string]any); ok { + var metadata domain.UsageMetadata + if val, ok := usage["prompt_tokens"].(float64); ok { + metadata.InputTokens = int(val) + } + if val, ok := usage["completion_tokens"].(float64); ok { + metadata.OutputTokens = int(val) + } + if val, ok := usage["total_tokens"].(float64); ok { + metadata.TotalTokens = int(val) + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &metadata, + } + } + var choices []any var ok bool if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 { @@ -166,7 +187,10 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha var content string if content, _ = delta["content"].(string); content != "" { - channel <- content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: content, + } } } diff --git a/internal/plugins/ai/ollama/ollama.go b/internal/plugins/ai/ollama/ollama.go index 5ab36278..03317dfe 100644 --- a/internal/plugins/ai/ollama/ollama.go +++ b/internal/plugins/ai/ollama/ollama.go @@ -106,7 +106,7 @@ func (o *Client) ListModels() (ret []string, err error) { return } -func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { ctx := context.Background() var req ollamaapi.ChatRequest @@ -115,7 +115,21 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha } respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) { - channel <- resp.Message.Content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: resp.Message.Content, + } + + if resp.Done { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: resp.PromptEvalCount, + OutputTokens: resp.EvalCount, + TotalTokens: resp.PromptEvalCount + resp.EvalCount, + }, + } + } return } diff --git a/internal/plugins/ai/openai/chat_completions.go b/internal/plugins/ai/openai/chat_completions.go index 9ae444a6..5fc65d9f 100644 --- a/internal/plugins/ai/openai/chat_completions.go +++ b/internal/plugins/ai/openai/chat_completions.go @@ -30,7 +30,7 @@ func (o *Client) sendChatCompletions(ctx context.Context, msgs []*chat.ChatCompl // sendStreamChatCompletions sends a streaming request using the Chat Completions API func (o *Client) sendStreamChatCompletions( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { defer close(channel) @@ -39,11 +39,28 @@ func (o *Client) sendStreamChatCompletions( for stream.Next() { chunk := stream.Current() if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { - channel <- chunk.Choices[0].Delta.Content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: chunk.Choices[0].Delta.Content, + } + } + + if chunk.Usage.TotalTokens > 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(chunk.Usage.PromptTokens), + OutputTokens: int(chunk.Usage.CompletionTokens), + TotalTokens: int(chunk.Usage.TotalTokens), + }, + } } } if stream.Err() == nil { - channel <- "\n" + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } } return stream.Err() } @@ -65,6 +82,9 @@ func (o *Client) buildChatCompletionParams( ret = openai.ChatCompletionNewParams{ Model: shared.ChatModel(opts.Model), Messages: messages, + StreamOptions: openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + }, } if !opts.Raw { diff --git a/internal/plugins/ai/openai/openai.go b/internal/plugins/ai/openai/openai.go index db48f856..e364c66c 100644 --- a/internal/plugins/ai/openai/openai.go +++ b/internal/plugins/ai/openai/openai.go @@ -108,7 +108,7 @@ func (o *Client) ListModels() (ret []string, err error) { } func (o *Client) SendStream( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { // Use Responses API for OpenAI, Chat Completions API for other providers if o.supportsResponsesAPI() { @@ -118,7 +118,7 @@ func (o *Client) SendStream( } func (o *Client) sendStreamResponses( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { defer close(channel) @@ -128,7 +128,10 @@ func (o *Client) sendStreamResponses( event := stream.Current() switch event.Type { case string(constant.ResponseOutputTextDelta("").Default()): - channel <- event.AsResponseOutputTextDelta().Delta + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: event.AsResponseOutputTextDelta().Delta, + } case string(constant.ResponseOutputTextDone("").Default()): // The Responses API sends the full text again in the // final "done" event. Since we've already streamed all @@ -138,7 +141,10 @@ func (o *Client) sendStreamResponses( } } if stream.Err() == nil { - channel <- "\n" + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } } return stream.Err() } diff --git a/internal/plugins/ai/perplexity/perplexity.go b/internal/plugins/ai/perplexity/perplexity.go index 3f0a6198..4ec5f2b1 100644 --- a/internal/plugins/ai/perplexity/perplexity.go +++ b/internal/plugins/ai/perplexity/perplexity.go @@ -123,7 +123,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o return content.String(), nil } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error { if c.client == nil { if err := c.Configure(); err != nil { close(channel) // Ensure channel is closed on error @@ -196,7 +196,21 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha content = resp.Choices[0].Message.Content } if content != "" { - channel <- content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: content, + } + } + } + + if resp.Usage.TotalTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(resp.Usage.PromptTokens), + OutputTokens: int(resp.Usage.CompletionTokens), + TotalTokens: int(resp.Usage.TotalTokens), + }, } } } @@ -205,9 +219,14 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha if lastResponse != nil { citations := lastResponse.GetCitations() if len(citations) > 0 { - channel <- "\n\n# CITATIONS\n\n" + var citationsText strings.Builder + citationsText.WriteString("\n\n# CITATIONS\n\n") for i, citation := range citations { - channel <- fmt.Sprintf("- [%d] %s\n", i+1, citation) + citationsText.WriteString(fmt.Sprintf("- [%d] %s\n", i+1, citation)) + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: citationsText.String(), } } } diff --git a/internal/plugins/ai/vendor.go b/internal/plugins/ai/vendor.go index a9c9b929..c3134618 100644 --- a/internal/plugins/ai/vendor.go +++ b/internal/plugins/ai/vendor.go @@ -12,7 +12,7 @@ import ( type Vendor interface { plugins.Plugin ListModels() ([]string, error) - SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error + SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) NeedsRawMode(modelName string) bool } diff --git a/internal/plugins/ai/vendors_test.go b/internal/plugins/ai/vendors_test.go index 42cf40fb..4534712d 100644 --- a/internal/plugins/ai/vendors_test.go +++ b/internal/plugins/ai/vendors_test.go @@ -20,7 +20,7 @@ func (v *stubVendor) Configure() error { return nil } func (v *stubVendor) Setup() error { return nil } func (v *stubVendor) SetupFillEnvFileContent(*bytes.Buffer) {} func (v *stubVendor) ListModels() ([]string, error) { return nil, nil } -func (v *stubVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error { +func (v *stubVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error { return nil } func (v *stubVendor) Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) { diff --git a/internal/plugins/ai/vertexai/vertexai.go b/internal/plugins/ai/vertexai/vertexai.go index 0860be0f..23ed4ac2 100644 --- a/internal/plugins/ai/vertexai/vertexai.go +++ b/internal/plugins/ai/vertexai/vertexai.go @@ -107,7 +107,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o return strings.Join(textParts, ""), nil } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error { if c.client == nil { close(channel) return fmt.Errorf("VertexAI client not initialized") @@ -133,8 +133,34 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha // Process stream for stream.Next() { event := stream.Current() + + // Handle Content if event.Delta.Text != "" { - channel <- event.Delta.Text + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: event.Delta.Text, + } + } + + // Handle Usage + if event.Message.Usage.InputTokens != 0 || event.Message.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Message.Usage.InputTokens), + OutputTokens: int(event.Message.Usage.OutputTokens), + TotalTokens: int(event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens), + }, + } + } else if event.Usage.InputTokens != 0 || event.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Usage.InputTokens), + OutputTokens: int(event.Usage.OutputTokens), + TotalTokens: int(event.Usage.InputTokens + event.Usage.OutputTokens), + }, + } } } diff --git a/internal/server/chat.go b/internal/server/chat.go index fe53ee9e..9854d824 100755 --- a/internal/server/chat.go +++ b/internal/server/chat.go @@ -40,9 +40,10 @@ type ChatRequest struct { } type StreamResponse struct { - Type string `json:"type"` // "content", "error", "complete" - Format string `json:"format"` // "markdown", "mermaid", "plain" - Content string `json:"content"` // The actual content + Type string `json:"type"` // "content", "usage", "error", "complete" + Format string `json:"format,omitempty"` // "markdown", "mermaid", "plain" + Content string `json:"content,omitempty"` + Usage *domain.UsageMetadata `json:"usage,omitempty"` } func NewChatHandler(r *gin.Engine, registry *core.PluginRegistry, db *fsdb.Db) *ChatHandler { @@ -98,7 +99,7 @@ func (h *ChatHandler) HandleChat(c *gin.Context) { log.Printf("Processing prompt %d: Model=%s Pattern=%s Context=%s", i+1, prompt.Model, prompt.PatternName, prompt.ContextName) - streamChan := make(chan string) + streamChan := make(chan domain.StreamUpdate) go func(p PromptRequest) { defer close(streamChan) @@ -117,10 +118,10 @@ func (h *ChatHandler) HandleChat(c *gin.Context) { } } - chatter, err := h.registry.GetChatter(p.Model, 2048, p.Vendor, "", false, false) + chatter, err := h.registry.GetChatter(p.Model, 2048, p.Vendor, "", true, false) if err != nil { log.Printf("Error creating chatter: %v", err) - streamChan <- fmt.Sprintf("Error: %v", err) + streamChan <- domain.StreamUpdate{Type: domain.StreamTypeError, Content: fmt.Sprintf("Error: %v", err)} return } @@ -144,49 +145,44 @@ func (h *ChatHandler) HandleChat(c *gin.Context) { FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, Thinking: request.Thinking, + UpdateChan: streamChan, + Quiet: true, } - session, err := chatter.Send(chatReq, opts) + _, err = chatter.Send(chatReq, opts) if err != nil { log.Printf("Error from chatter.Send: %v", err) - streamChan <- fmt.Sprintf("Error: %v", err) + // Error already sent to streamChan via domain.StreamTypeError if occurred in Send loop return } - - if session == nil { - log.Printf("No session returned from chatter.Send") - streamChan <- "Error: No response from model" - return - } - - lastMsg := session.GetLastMessage() - if lastMsg != nil { - streamChan <- lastMsg.Content - } else { - log.Printf("No message content in session") - streamChan <- "Error: No response content" - } }(prompt) - for content := range streamChan { + for update := range streamChan { select { case <-clientGone: return default: var response StreamResponse - if strings.HasPrefix(content, "Error:") { + switch update.Type { + case domain.StreamTypeContent: + response = StreamResponse{ + Type: "content", + Format: detectFormat(update.Content), + Content: update.Content, + } + case domain.StreamTypeUsage: + response = StreamResponse{ + Type: "usage", + Usage: update.Usage, + } + case domain.StreamTypeError: response = StreamResponse{ Type: "error", Format: "plain", - Content: content, - } - } else { - response = StreamResponse{ - Type: "content", - Format: detectFormat(content), - Content: content, + Content: update.Content, } } + if err := writeSSEResponse(c.Writer, response); err != nil { log.Printf("Error writing response: %v", err) return