diff --git a/internal/core/chatter_test.go b/internal/core/chatter_test.go index bc4e9aba..4da37377 100644 --- a/internal/core/chatter_test.go +++ b/internal/core/chatter_test.go @@ -14,7 +14,7 @@ import ( // mockVendor implements the ai.Vendor interface for testing type mockVendor struct { sendStreamError error - streamChunks []string + streamChunks []domain.StreamUpdate sendFunc func(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) } @@ -45,7 +45,7 @@ func (m *mockVendor) ListModels() ([]string, error) { return []string{"test-model"}, nil } -func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *domain.ChatOptions, responseChan chan string) error { +func (m *mockVendor) SendStream(messages []*chat.ChatCompletionMessage, opts *domain.ChatOptions, responseChan chan domain.StreamUpdate) error { // Send chunks if provided (for successful streaming test) if m.streamChunks != nil { for _, chunk := range m.streamChunks { @@ -169,7 +169,11 @@ func TestChatter_Send_StreamingSuccessfulAggregation(t *testing.T) { db := fsdb.NewDb(tempDir) // Create test chunks that should be aggregated - testChunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."} + chunks := []string{"Hello", " ", "world", "!", " This", " is", " a", " test."} + testChunks := make([]domain.StreamUpdate, len(chunks)) + for i, c := range chunks { + testChunks[i] = domain.StreamUpdate{Type: domain.StreamTypeContent, Content: c} + } expectedMessage := "Hello world! This is a test." // Create a mock vendor that will send chunks successfully diff --git a/internal/core/plugin_registry_test.go b/internal/core/plugin_registry_test.go index 23345985..629882dc 100644 --- a/internal/core/plugin_registry_test.go +++ b/internal/core/plugin_registry_test.go @@ -43,7 +43,7 @@ func (m *testVendor) Configure() error { return nil } func (m *testVendor) Setup() error { return nil } func (m *testVendor) SetupFillEnvFileContent(*bytes.Buffer) {} func (m *testVendor) ListModels() ([]string, error) { return m.models, nil } -func (m *testVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error { +func (m *testVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error { return nil } func (m *testVendor) Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) { diff --git a/internal/plugins/ai/anthropic/anthropic.go b/internal/plugins/ai/anthropic/anthropic.go index 2176e74e..3c582935 100644 --- a/internal/plugins/ai/anthropic/anthropic.go +++ b/internal/plugins/ai/anthropic/anthropic.go @@ -184,7 +184,7 @@ func parseThinking(level domain.ThinkingLevel) (anthropic.ThinkingConfigParamUni } func (an *Client) SendStream( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { messages := an.toMessages(msgs) if len(messages) == 0 { @@ -210,9 +210,33 @@ func (an *Client) SendStream( for stream.Next() { event := stream.Current() - // directly send any non-empty delta text + // Handle Content if event.Delta.Text != "" { - channel <- event.Delta.Text + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: event.Delta.Text, + } + } + + // Handle Usage + if event.Message.Usage.InputTokens != 0 || event.Message.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Message.Usage.InputTokens), + OutputTokens: int(event.Message.Usage.OutputTokens), + TotalTokens: int(event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens), + }, + } + } else if event.Usage.InputTokens != 0 || event.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Usage.InputTokens), + OutputTokens: int(event.Usage.OutputTokens), + TotalTokens: int(event.Usage.InputTokens + event.Usage.OutputTokens), + }, + } } } diff --git a/internal/plugins/ai/bedrock/bedrock.go b/internal/plugins/ai/bedrock/bedrock.go index 15f47c01..5421b2f6 100644 --- a/internal/plugins/ai/bedrock/bedrock.go +++ b/internal/plugins/ai/bedrock/bedrock.go @@ -154,7 +154,7 @@ func (c *BedrockClient) ListModels() ([]string, error) { } // SendStream sends the messages to the Bedrock ConverseStream API -func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { // Ensure channel is closed on all exit paths to prevent goroutine leaks defer func() { if r := recover(); r != nil { @@ -186,18 +186,35 @@ func (c *BedrockClient) SendStream(msgs []*chat.ChatCompletionMessage, opts *dom case *types.ConverseStreamOutputMemberContentBlockDelta: text, ok := v.Value.Delta.(*types.ContentBlockDeltaMemberText) if ok { - channel <- text.Value + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: text.Value, + } } case *types.ConverseStreamOutputMemberMessageStop: - channel <- "\n" + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } return nil // Let defer handle the close + case *types.ConverseStreamOutputMemberMetadata: + if v.Value.Usage != nil { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(*v.Value.Usage.InputTokens), + OutputTokens: int(*v.Value.Usage.OutputTokens), + TotalTokens: int(*v.Value.Usage.TotalTokens), + }, + } + } + // Unused Events case *types.ConverseStreamOutputMemberMessageStart, *types.ConverseStreamOutputMemberContentBlockStart, - *types.ConverseStreamOutputMemberContentBlockStop, - *types.ConverseStreamOutputMemberMetadata: + *types.ConverseStreamOutputMemberContentBlockStop: default: return fmt.Errorf("unknown stream event type: %T", v) diff --git a/internal/plugins/ai/dryrun/dryrun.go b/internal/plugins/ai/dryrun/dryrun.go index 1f36d515..034011c3 100644 --- a/internal/plugins/ai/dryrun/dryrun.go +++ b/internal/plugins/ai/dryrun/dryrun.go @@ -108,12 +108,30 @@ func (c *Client) constructRequest(msgs []*chat.ChatCompletionMessage, opts *doma return builder.String() } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error { defer close(channel) request := c.constructRequest(msgs, opts) - channel <- request - channel <- "\n" - channel <- DryRunResponse + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: request, + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: DryRunResponse, + } + // Simulated usage + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: 100, + OutputTokens: 50, + TotalTokens: 150, + }, + } return nil } diff --git a/internal/plugins/ai/gemini/gemini.go b/internal/plugins/ai/gemini/gemini.go index fe3035d1..40878f4a 100644 --- a/internal/plugins/ai/gemini/gemini.go +++ b/internal/plugins/ai/gemini/gemini.go @@ -129,7 +129,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o return } -func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { ctx := context.Background() defer close(channel) @@ -154,13 +154,30 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha for response, err := range stream { if err != nil { - channel <- fmt.Sprintf("Error: %v\n", err) + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeError, + Content: fmt.Sprintf("Error: %v", err), + } return err } text := o.extractTextFromResponse(response) if text != "" { - channel <- text + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: text, + } + } + + if response.UsageMetadata != nil { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(response.UsageMetadata.PromptTokenCount), + OutputTokens: int(response.UsageMetadata.CandidatesTokenCount), + TotalTokens: int(response.UsageMetadata.TotalTokenCount), + }, + } } } diff --git a/internal/plugins/ai/lmstudio/lmstudio.go b/internal/plugins/ai/lmstudio/lmstudio.go index 8e64b88f..f9cae99f 100644 --- a/internal/plugins/ai/lmstudio/lmstudio.go +++ b/internal/plugins/ai/lmstudio/lmstudio.go @@ -87,13 +87,16 @@ func (c *Client) ListModels() ([]string, error) { return models, nil } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { url := fmt.Sprintf("%s/chat/completions", c.ApiUrl.Value) payload := map[string]any{ "messages": msgs, "model": opts.Model, "stream": true, // Enable streaming + "stream_options": map[string]any{ + "include_usage": true, + }, } var jsonPayload []byte @@ -144,7 +147,7 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha line = after } - if string(line) == "[DONE]" { + if string(bytes.TrimSpace(line)) == "[DONE]" { break } @@ -153,6 +156,24 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha continue } + // Handle Usage + if usage, ok := result["usage"].(map[string]any); ok { + var metadata domain.UsageMetadata + if val, ok := usage["prompt_tokens"].(float64); ok { + metadata.InputTokens = int(val) + } + if val, ok := usage["completion_tokens"].(float64); ok { + metadata.OutputTokens = int(val) + } + if val, ok := usage["total_tokens"].(float64); ok { + metadata.TotalTokens = int(val) + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &metadata, + } + } + var choices []any var ok bool if choices, ok = result["choices"].([]any); !ok || len(choices) == 0 { @@ -166,7 +187,10 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha var content string if content, _ = delta["content"].(string); content != "" { - channel <- content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: content, + } } } diff --git a/internal/plugins/ai/ollama/ollama.go b/internal/plugins/ai/ollama/ollama.go index 5ab36278..03317dfe 100644 --- a/internal/plugins/ai/ollama/ollama.go +++ b/internal/plugins/ai/ollama/ollama.go @@ -106,7 +106,7 @@ func (o *Client) ListModels() (ret []string, err error) { return } -func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { +func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) (err error) { ctx := context.Background() var req ollamaapi.ChatRequest @@ -115,7 +115,21 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha } respFunc := func(resp ollamaapi.ChatResponse) (streamErr error) { - channel <- resp.Message.Content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: resp.Message.Content, + } + + if resp.Done { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: resp.PromptEvalCount, + OutputTokens: resp.EvalCount, + TotalTokens: resp.PromptEvalCount + resp.EvalCount, + }, + } + } return } diff --git a/internal/plugins/ai/openai/chat_completions.go b/internal/plugins/ai/openai/chat_completions.go index 9ae444a6..5fc65d9f 100644 --- a/internal/plugins/ai/openai/chat_completions.go +++ b/internal/plugins/ai/openai/chat_completions.go @@ -30,7 +30,7 @@ func (o *Client) sendChatCompletions(ctx context.Context, msgs []*chat.ChatCompl // sendStreamChatCompletions sends a streaming request using the Chat Completions API func (o *Client) sendStreamChatCompletions( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { defer close(channel) @@ -39,11 +39,28 @@ func (o *Client) sendStreamChatCompletions( for stream.Next() { chunk := stream.Current() if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { - channel <- chunk.Choices[0].Delta.Content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: chunk.Choices[0].Delta.Content, + } + } + + if chunk.Usage.TotalTokens > 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(chunk.Usage.PromptTokens), + OutputTokens: int(chunk.Usage.CompletionTokens), + TotalTokens: int(chunk.Usage.TotalTokens), + }, + } } } if stream.Err() == nil { - channel <- "\n" + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } } return stream.Err() } @@ -65,6 +82,9 @@ func (o *Client) buildChatCompletionParams( ret = openai.ChatCompletionNewParams{ Model: shared.ChatModel(opts.Model), Messages: messages, + StreamOptions: openai.ChatCompletionStreamOptionsParam{ + IncludeUsage: openai.Bool(true), + }, } if !opts.Raw { diff --git a/internal/plugins/ai/openai/openai.go b/internal/plugins/ai/openai/openai.go index db48f856..e364c66c 100644 --- a/internal/plugins/ai/openai/openai.go +++ b/internal/plugins/ai/openai/openai.go @@ -108,7 +108,7 @@ func (o *Client) ListModels() (ret []string, err error) { } func (o *Client) SendStream( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { // Use Responses API for OpenAI, Chat Completions API for other providers if o.supportsResponsesAPI() { @@ -118,7 +118,7 @@ func (o *Client) SendStream( } func (o *Client) sendStreamResponses( - msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string, + msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate, ) (err error) { defer close(channel) @@ -128,7 +128,10 @@ func (o *Client) sendStreamResponses( event := stream.Current() switch event.Type { case string(constant.ResponseOutputTextDelta("").Default()): - channel <- event.AsResponseOutputTextDelta().Delta + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: event.AsResponseOutputTextDelta().Delta, + } case string(constant.ResponseOutputTextDone("").Default()): // The Responses API sends the full text again in the // final "done" event. Since we've already streamed all @@ -138,7 +141,10 @@ func (o *Client) sendStreamResponses( } } if stream.Err() == nil { - channel <- "\n" + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: "\n", + } } return stream.Err() } diff --git a/internal/plugins/ai/perplexity/perplexity.go b/internal/plugins/ai/perplexity/perplexity.go index 3f0a6198..4ec5f2b1 100644 --- a/internal/plugins/ai/perplexity/perplexity.go +++ b/internal/plugins/ai/perplexity/perplexity.go @@ -123,7 +123,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o return content.String(), nil } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error { if c.client == nil { if err := c.Configure(); err != nil { close(channel) // Ensure channel is closed on error @@ -196,7 +196,21 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha content = resp.Choices[0].Message.Content } if content != "" { - channel <- content + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: content, + } + } + } + + if resp.Usage.TotalTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(resp.Usage.PromptTokens), + OutputTokens: int(resp.Usage.CompletionTokens), + TotalTokens: int(resp.Usage.TotalTokens), + }, } } } @@ -205,9 +219,14 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha if lastResponse != nil { citations := lastResponse.GetCitations() if len(citations) > 0 { - channel <- "\n\n# CITATIONS\n\n" + var citationsText strings.Builder + citationsText.WriteString("\n\n# CITATIONS\n\n") for i, citation := range citations { - channel <- fmt.Sprintf("- [%d] %s\n", i+1, citation) + citationsText.WriteString(fmt.Sprintf("- [%d] %s\n", i+1, citation)) + } + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: citationsText.String(), } } } diff --git a/internal/plugins/ai/vendors_test.go b/internal/plugins/ai/vendors_test.go index 42cf40fb..4534712d 100644 --- a/internal/plugins/ai/vendors_test.go +++ b/internal/plugins/ai/vendors_test.go @@ -20,7 +20,7 @@ func (v *stubVendor) Configure() error { return nil } func (v *stubVendor) Setup() error { return nil } func (v *stubVendor) SetupFillEnvFileContent(*bytes.Buffer) {} func (v *stubVendor) ListModels() ([]string, error) { return nil, nil } -func (v *stubVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan string) error { +func (v *stubVendor) SendStream([]*chat.ChatCompletionMessage, *domain.ChatOptions, chan domain.StreamUpdate) error { return nil } func (v *stubVendor) Send(context.Context, []*chat.ChatCompletionMessage, *domain.ChatOptions) (string, error) { diff --git a/internal/plugins/ai/vertexai/vertexai.go b/internal/plugins/ai/vertexai/vertexai.go index 0860be0f..23ed4ac2 100644 --- a/internal/plugins/ai/vertexai/vertexai.go +++ b/internal/plugins/ai/vertexai/vertexai.go @@ -107,7 +107,7 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o return strings.Join(textParts, ""), nil } -func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) error { +func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error { if c.client == nil { close(channel) return fmt.Errorf("VertexAI client not initialized") @@ -133,8 +133,34 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha // Process stream for stream.Next() { event := stream.Current() + + // Handle Content if event.Delta.Text != "" { - channel <- event.Delta.Text + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeContent, + Content: event.Delta.Text, + } + } + + // Handle Usage + if event.Message.Usage.InputTokens != 0 || event.Message.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Message.Usage.InputTokens), + OutputTokens: int(event.Message.Usage.OutputTokens), + TotalTokens: int(event.Message.Usage.InputTokens + event.Message.Usage.OutputTokens), + }, + } + } else if event.Usage.InputTokens != 0 || event.Usage.OutputTokens != 0 { + channel <- domain.StreamUpdate{ + Type: domain.StreamTypeUsage, + Usage: &domain.UsageMetadata{ + InputTokens: int(event.Usage.InputTokens), + OutputTokens: int(event.Usage.OutputTokens), + TotalTokens: int(event.Usage.InputTokens + event.Usage.OutputTokens), + }, + } } }