diff --git a/.vscode/settings.json b/.vscode/settings.json index a18f86d0..56447929 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -61,6 +61,7 @@ "jessevdk", "Jina", "joho", + "Kore", "ksylvan", "Langdock", "ldflags", @@ -108,6 +109,7 @@ "storer", "Streamlit", "stretchr", + "subchunk", "talkpanel", "Telos", "testpattern", diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d9dfc59..542067d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,6 @@ - Create comprehensive Docker testing environment with 6 scenarios - Add interactive test runner with shell access - ## v1.4.265 (2025-07-25) ### PR [#1647](https://github.com/danielmiessler/Fabric/pull/1647) by [ksylvan](https://github.com/ksylvan): Simplify Workflow with Single Version Retrieval Step diff --git a/cmd/generate_changelog/changelog.db b/cmd/generate_changelog/changelog.db index 6deab265..8b5bc3d0 100644 Binary files a/cmd/generate_changelog/changelog.db and b/cmd/generate_changelog/changelog.db differ diff --git a/go.mod b/go.mod index 6ba13a0d..b8209bce 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/gin-gonic/gin v1.10.1 github.com/go-git/go-git/v5 v5.16.2 github.com/go-shiori/go-readability v0.0.0-20250217085726-9f5bf5ca7612 - github.com/google/generative-ai-go v0.20.1 github.com/google/go-github/v66 v66.0.0 github.com/hasura/go-graphql-client v0.14.4 github.com/jessevdk/go-flags v1.6.1 @@ -35,13 +34,16 @@ require ( gopkg.in/yaml.v3 v3.0.1 ) +require ( + github.com/google/go-cmp v0.7.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect +) + require ( cloud.google.com/go v0.121.2 // indirect - cloud.google.com/go/ai v0.12.1 // indirect cloud.google.com/go/auth v0.16.2 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.7.0 // indirect - cloud.google.com/go/longrunning v0.6.7 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.3.0 // indirect @@ -109,7 +111,6 @@ require ( github.com/ugorji/go/codec v1.2.14 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect go.opentelemetry.io/otel v1.36.0 // indirect go.opentelemetry.io/otel/metric v1.36.0 // indirect @@ -120,7 +121,7 @@ require ( golang.org/x/net v0.41.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.34.0 // indirect - golang.org/x/time v0.12.0 // indirect + google.golang.org/genai v1.17.0 google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect google.golang.org/grpc v1.73.0 // indirect diff --git a/go.sum b/go.sum index adc68afd..19f1ab1a 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,11 @@ cloud.google.com/go v0.121.2 h1:v2qQpN6Dx9x2NmwrqlesOt3Ys4ol5/lFZ6Mg1B7OJCg= cloud.google.com/go v0.121.2/go.mod h1:nRFlrHq39MNVWu+zESP2PosMWA0ryJw8KUBZ2iZpxbw= -cloud.google.com/go/ai v0.12.1 h1:m1n/VjUuHS+pEO/2R4/VbuuEIkgk0w67fDQvFaMngM0= -cloud.google.com/go/ai v0.12.1/go.mod h1:5vIPNe1ZQsVZqCliXIPL4QnhObQQY4d9hAGHdVc4iw4= cloud.google.com/go/auth v0.16.2 h1:QvBAGFPLrDeoiNjyfVunhQ10HKNYuOwZ5noee0M5df4= cloud.google.com/go/auth v0.16.2/go.mod h1:sRBas2Y1fB1vZTdurouM0AzuYQBMZinrUYL8EufhtEA= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeOCw78U8ytSU= cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo= -cloud.google.com/go/longrunning v0.6.7 h1:IGtfDWHhQCgCjwQjV9iiLnUta9LBCo8R9QmAFsS/PrE= -cloud.google.com/go/longrunning v0.6.7/go.mod h1:EAFV3IZAKmM56TyiE6VAP3VoTzhZzySwI/YI1s/nRsY= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= @@ -126,8 +122,6 @@ github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8J github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/google/generative-ai-go v0.20.1 h1:6dEIujpgN2V0PgLhr6c/M1ynRdc7ARtiIDPFzj45uNQ= -github.com/google/generative-ai-go v0.20.1/go.mod h1:TjOnZJmZKzarWbjUJgy+r3Ee7HGBRVLhOIgupnwR4Bg= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -145,6 +139,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.14.2 h1:eBLnkZ9635krYIPD+ag1USrOAI0Nr0QYF3+/3GqO0k0= github.com/googleapis/gax-go/v2 v2.14.2/go.mod h1:ON64QhlJkhVtSqp4v1uaK92VyZ2gmvDQsweuyLV+8+w= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hasura/go-graphql-client v0.14.4 h1:bYU7/+V50T2YBGdNQXt6l4f2cMZPECPUd8cyCR+ixtw= github.com/hasura/go-graphql-client v0.14.4/go.mod h1:jfSZtBER3or+88Q9vFhWHiFMPppfYILRyl+0zsgPIIw= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -249,8 +245,6 @@ github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0 h1:q4XOmH/0opmeuJtPsbFNivyl7bCt7yRBbeEm2sC/XtQ= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.61.0/go.mod h1:snMWehoOh2wsEwnvvwtDyFCxVeDAODenXHtn5vzrKjo= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= @@ -345,8 +339,6 @@ golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= -golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= -golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= @@ -357,6 +349,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.236.0 h1:CAiEiDVtO4D/Qja2IA9VzlFrgPnK3XVMmRoJZlSWbc0= google.golang.org/api v0.236.0/go.mod h1:X1WF9CU2oTc+Jml1tiIxGmWFK/UZezdqEu09gcxZAj4= +google.golang.org/genai v1.17.0 h1:lXYSnWShPYjxTouxRj0zF8RsNmSF+SKo7SQ7dM35NlI= +google.golang.org/genai v1.17.0/go.mod h1:QPj5NGJw+3wEOHg+PrsWwJKvG6UC84ex5FR7qAYsN/M= google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2 h1:1tXaIXCracvtsRxSBsYDiSBN0cuJvM7QYW+MrpIRY78= google.golang.org/genproto v0.0.0-20250505200425-f936aa4a68b2/go.mod h1:49MsLSx0oWMOZqcpB3uL8ZOkAh1+TndpJ8ONoCBWiZk= google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= diff --git a/internal/cli/chat.go b/internal/cli/chat.go index e2ad0081..93a807a8 100644 --- a/internal/cli/chat.go +++ b/internal/cli/chat.go @@ -3,6 +3,7 @@ package cli import ( "fmt" "os" + "path/filepath" "strings" "github.com/danielmiessler/fabric/internal/core" @@ -35,6 +36,40 @@ func handleChatProcessing(currentFlags *Flags, registry *core.PluginRegistry, me if chatOptions, err = currentFlags.BuildChatOptions(); err != nil { return } + + // Check if user is requesting audio output or using a TTS model + isAudioOutput := currentFlags.Output != "" && IsAudioFormat(currentFlags.Output) + isTTSModel := isTTSModel(currentFlags.Model) + + if isTTSModel && !isAudioOutput { + err = fmt.Errorf("TTS model '%s' requires audio output. Please specify an audio output file with -o flag (e.g., -o output.wav)", currentFlags.Model) + return + } + + if isAudioOutput && !isTTSModel { + err = fmt.Errorf("audio output file '%s' specified but model '%s' is not a TTS model. Please use a TTS model like gemini-2.5-flash-preview-tts", currentFlags.Output, currentFlags.Model) + return + } + + // For TTS models, check if output file already exists BEFORE processing + if isTTSModel && isAudioOutput { + outputFile := currentFlags.Output + // Add .wav extension if not provided + if filepath.Ext(outputFile) == "" { + outputFile += ".wav" + } + if _, err = os.Stat(outputFile); err == nil { + err = fmt.Errorf("file %s already exists. Please choose a different filename or remove the existing file", outputFile) + return + } + } + + // Set audio options in chat config + chatOptions.AudioOutput = isAudioOutput + if isAudioOutput { + chatOptions.AudioFormat = "wav" // Default to WAV format + } + if session, err = chatter.Send(chatReq, chatOptions); err != nil { return } @@ -42,8 +77,13 @@ func handleChatProcessing(currentFlags *Flags, registry *core.PluginRegistry, me result := session.GetLastMessage().Content if !currentFlags.Stream || currentFlags.SuppressThink { - // print the result if it was not streamed already or suppress-think disabled streaming output - fmt.Println(result) + // For TTS models with audio output, show a user-friendly message instead of raw data + if isTTSModel && isAudioOutput && strings.HasPrefix(result, "FABRIC_AUDIO_DATA:") { + fmt.Printf("TTS audio generated successfully and saved to: %s\n", currentFlags.Output) + } else { + // print the result if it was not streamed already or suppress-think disabled streaming output + fmt.Println(result) + } } // if the copy flag is set, copy the message to the clipboard @@ -59,8 +99,29 @@ func handleChatProcessing(currentFlags *Flags, registry *core.PluginRegistry, me sessionAsString := session.String() err = CreateOutputFile(sessionAsString, currentFlags.Output) } else { - err = CreateOutputFile(result, currentFlags.Output) + // For TTS models, we need to handle audio output differently + if isTTSModel && isAudioOutput { + // Check if result contains actual audio data + if strings.HasPrefix(result, "FABRIC_AUDIO_DATA:") { + // Extract the binary audio data + audioData := []byte(result[len("FABRIC_AUDIO_DATA:"):]) + err = CreateAudioOutputFile(audioData, currentFlags.Output) + } else { + // Fallback for any error messages or unexpected responses + err = CreateOutputFile(result, currentFlags.Output) + } + } else { + err = CreateOutputFile(result, currentFlags.Output) + } } } return } + +// isTTSModel checks if the model is a text-to-speech model +func isTTSModel(modelName string) bool { + lowerModel := strings.ToLower(modelName) + return strings.Contains(lowerModel, "tts") || + strings.Contains(lowerModel, "preview-tts") || + strings.Contains(lowerModel, "text-to-speech") +} diff --git a/internal/cli/output.go b/internal/cli/output.go index 00793d92..90ae3cfc 100644 --- a/internal/cli/output.go +++ b/internal/cli/output.go @@ -3,6 +3,8 @@ package cli import ( "fmt" "os" + "path/filepath" + "strings" "github.com/atotto/clipboard" ) @@ -28,3 +30,37 @@ func CreateOutputFile(message string, fileName string) (err error) { } return } + +// CreateAudioOutputFile creates a binary file for audio data +func CreateAudioOutputFile(audioData []byte, fileName string) (err error) { + // If no extension is provided, default to .wav + if filepath.Ext(fileName) == "" { + fileName += ".wav" + } + + // File existence check is now done in the CLI layer before TTS generation + var file *os.File + if file, err = os.Create(fileName); err != nil { + err = fmt.Errorf("error creating audio file: %v", err) + return + } + defer file.Close() + + if _, err = file.Write(audioData); err != nil { + err = fmt.Errorf("error writing audio data to file: %v", err) + } + // No redundant output message here - the CLI layer handles success messaging + return +} + +// IsAudioFormat checks if the filename suggests an audio format +func IsAudioFormat(fileName string) bool { + ext := strings.ToLower(filepath.Ext(fileName)) + audioExts := []string{".wav", ".mp3", ".m4a", ".aac", ".ogg", ".flac"} + for _, audioExt := range audioExts { + if ext == audioExt { + return true + } + } + return false +} diff --git a/internal/domain/domain.go b/internal/domain/domain.go index 0daabc32..b37b794e 100644 --- a/internal/domain/domain.go +++ b/internal/domain/domain.go @@ -36,6 +36,8 @@ type ChatOptions struct { SuppressThink bool ThinkStartTag string ThinkEndTag string + AudioOutput bool + AudioFormat string } // NormalizeMessages remove empty messages and ensure messages order user-assist-user diff --git a/internal/plugins/ai/gemini/gemini.go b/internal/plugins/ai/gemini/gemini.go index 1dff3785..1d0cf987 100644 --- a/internal/plugins/ai/gemini/gemini.go +++ b/internal/plugins/ai/gemini/gemini.go @@ -1,8 +1,9 @@ package gemini import ( + "bytes" "context" - "errors" + "encoding/binary" "fmt" "strings" @@ -10,13 +11,9 @@ import ( "github.com/danielmiessler/fabric/internal/plugins" "github.com/danielmiessler/fabric/internal/domain" - "github.com/google/generative-ai-go/genai" - "google.golang.org/api/iterator" - "google.golang.org/api/option" + "google.golang.org/genai" ) -const modelsNamePrefix = "models/" - func NewClient() (ret *Client) { vendorName := "Gemini" ret = &Client{} @@ -39,107 +36,104 @@ type Client struct { func (o *Client) ListModels() (ret []string, err error) { ctx := context.Background() var client *genai.Client - if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { + if client, err = genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: o.ApiKey.Value, + Backend: genai.BackendGeminiAPI, + }); err != nil { return } - defer client.Close() - iter := client.ListModels(ctx) - for { - var resp *genai.ModelInfo - if resp, err = iter.Next(); err != nil { - if errors.Is(err, iterator.Done) { - err = nil - } - break - } + // List available models using the correct API + resp, err := client.Models.List(ctx, &genai.ListModelsConfig{}) + if err != nil { + return nil, err + } - name := o.buildModelNameSimple(resp.Name) - ret = append(ret, name) + for _, model := range resp.Items { + // Strip the "models/" prefix for user convenience + modelName := strings.TrimPrefix(model.Name, "models/") + ret = append(ret, modelName) } return } func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) { - systemInstruction, messages := toMessages(msgs) + // Check if this is a TTS model request + if o.isTTSModel(opts.Model) { + if !opts.AudioOutput { + err = fmt.Errorf("TTS model '%s' requires audio output. Please specify an audio output file with -o flag ending in .wav", opts.Model) + return + } + // Handle TTS generation + return o.generateTTSAudio(ctx, msgs, opts) + } + + // Regular text generation var client *genai.Client - if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { - return - } - defer client.Close() - - model := client.GenerativeModel(o.buildModelNameFull(opts.Model)) - model.SetTemperature(float32(opts.Temperature)) - model.SetTopP(float32(opts.TopP)) - model.SystemInstruction = systemInstruction - - var response *genai.GenerateContentResponse - if response, err = model.GenerateContent(ctx, messages...); err != nil { + if client, err = genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: o.ApiKey.Value, + Backend: genai.BackendGeminiAPI, + }); err != nil { return } - ret = o.extractText(response) + // Convert messages to new SDK format + contents := o.convertMessages(msgs) + + // Generate content + temperature := float32(opts.Temperature) + topP := float32(opts.TopP) + response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{ + Temperature: &temperature, + TopP: &topP, + MaxOutputTokens: int32(opts.ModelContextLength), + }) + if err != nil { + return "", err + } + + // Extract text from response + ret = o.extractTextFromResponse(response) return } -func (o *Client) buildModelNameSimple(fullModelName string) string { - return strings.TrimPrefix(fullModelName, modelsNamePrefix) -} - -func (o *Client) buildModelNameFull(modelName string) string { - return fmt.Sprintf("%v%v", modelsNamePrefix, modelName) -} - func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) { ctx := context.Background() var client *genai.Client - if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil { + if client, err = genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: o.ApiKey.Value, + Backend: genai.BackendGeminiAPI, + }); err != nil { return } - defer client.Close() - systemInstruction, messages := toMessages(msgs) + // Convert messages to new SDK format + contents := o.convertMessages(msgs) - model := client.GenerativeModel(o.buildModelNameFull(opts.Model)) - model.SetTemperature(float32(opts.Temperature)) - model.SetTopP(float32(opts.TopP)) - model.SystemInstruction = systemInstruction + // Generate streaming content + temperature := float32(opts.Temperature) + topP := float32(opts.TopP) + stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{ + Temperature: &temperature, + TopP: &topP, + MaxOutputTokens: int32(opts.ModelContextLength), + }) - iter := model.GenerateContentStream(ctx, messages...) - for { - if resp, iterErr := iter.Next(); iterErr == nil { - for _, candidate := range resp.Candidates { - if candidate.Content != nil { - for _, part := range candidate.Content.Parts { - if text, ok := part.(genai.Text); ok { - channel <- string(text) - } - } - } - } - } else { - if !errors.Is(iterErr, iterator.Done) { - channel <- fmt.Sprintf("%v\n", iterErr) - } + for response, err := range stream { + if err != nil { + channel <- fmt.Sprintf("Error: %v\n", err) close(channel) break } - } - return -} -func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) { - for _, candidate := range response.Candidates { - if candidate.Content == nil { - break - } - for _, part := range candidate.Content.Parts { - if text, ok := part.(genai.Text); ok { - ret += string(text) - } + text := o.extractTextFromResponse(response) + if text != "" { + channel <- text } } + close(channel) + return } @@ -147,18 +141,174 @@ func (o *Client) NeedsRawMode(modelName string) bool { return false } -func toMessages(msgs []*chat.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) { - if len(msgs) >= 2 { - systemInstruction = &genai.Content{ - Parts: []genai.Part{ - genai.Text(msgs[0].Content), - }, - } - for _, msg := range msgs[1:] { - messages = append(messages, genai.Text(msg.Content)) - } - } else { - messages = append(messages, genai.Text(msgs[0].Content)) +// buildModelNameFull adds the "models/" prefix for API calls +func (o *Client) buildModelNameFull(modelName string) string { + if strings.HasPrefix(modelName, "models/") { + return modelName } - return + return "models/" + modelName +} + +// isTTSModel checks if the model is a text-to-speech model +func (o *Client) isTTSModel(modelName string) bool { + lowerModel := strings.ToLower(modelName) + return strings.Contains(lowerModel, "tts") || + strings.Contains(lowerModel, "preview-tts") || + strings.Contains(lowerModel, "text-to-speech") +} + +// generateTTSAudio handles TTS audio generation using the new SDK +func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) { + // Extract the text to convert to speech from the user messages + var textToSpeak string + for i := len(msgs) - 1; i >= 0; i-- { + if msgs[i].Role == chat.ChatMessageRoleUser && msgs[i].Content != "" { + textToSpeak = msgs[i].Content + break + } + } + + if textToSpeak == "" { + err = fmt.Errorf("no text content found for TTS generation") + return + } + + var client *genai.Client + if client, err = genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: o.ApiKey.Value, + Backend: genai.BackendGeminiAPI, + }); err != nil { + return + } + + // Create content for TTS + contents := []*genai.Content{{ + Parts: []*genai.Part{{Text: textToSpeak}}, + }} + + // Configure for TTS generation + config := &genai.GenerateContentConfig{ + ResponseModalities: []string{"AUDIO"}, + SpeechConfig: &genai.SpeechConfig{ + VoiceConfig: &genai.VoiceConfig{ + PrebuiltVoiceConfig: &genai.PrebuiltVoiceConfig{ + VoiceName: "Kore", // Default voice + }, + }, + }, + } + + // Generate TTS content + response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, config) + if err != nil { + return "", fmt.Errorf("TTS generation failed: %w", err) + } + + // Extract and process audio data + if len(response.Candidates) > 0 && response.Candidates[0].Content != nil && len(response.Candidates[0].Content.Parts) > 0 { + part := response.Candidates[0].Content.Parts[0] + if part.InlineData != nil && len(part.InlineData.Data) > 0 { + // The data is already in binary format, not base64 + pcmData := part.InlineData.Data + + // Generate WAV file with proper headers and return the binary data + wavData, err := o.generateWAVFile(pcmData) + if err != nil { + return "", fmt.Errorf("failed to generate WAV file: %w", err) + } + + // Store the binary audio data in a special format that the CLI can detect + // We'll encode it as a special marker followed by the binary data + ret = fmt.Sprintf("FABRIC_AUDIO_DATA:%s", string(wavData)) + return ret, nil + } + } + + return "", fmt.Errorf("no audio data received from TTS model") +} + +// generateWAVFile creates WAV data from PCM data with proper headers +func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) { + // WAV file parameters (Gemini TTS default specs) + channels := 1 + sampleRate := 24000 + bitsPerSample := 16 + + // Calculate required values + byteRate := sampleRate * channels * bitsPerSample / 8 + blockAlign := channels * bitsPerSample / 8 + dataLen := uint32(len(pcmData)) + riffSize := 36 + dataLen + + buf := new(bytes.Buffer) + + // RIFF header + buf.WriteString("RIFF") + binary.Write(buf, binary.LittleEndian, riffSize) + buf.WriteString("WAVE") + + // fmt chunk + buf.WriteString("fmt ") + binary.Write(buf, binary.LittleEndian, uint32(16)) // subchunk1Size + binary.Write(buf, binary.LittleEndian, uint16(1)) // audioFormat = PCM + binary.Write(buf, binary.LittleEndian, uint16(channels)) // numChannels + binary.Write(buf, binary.LittleEndian, uint32(sampleRate)) // sampleRate + binary.Write(buf, binary.LittleEndian, uint32(byteRate)) // byteRate + binary.Write(buf, binary.LittleEndian, uint16(blockAlign)) // blockAlign + binary.Write(buf, binary.LittleEndian, uint16(bitsPerSample)) // bitsPerSample + + // data chunk + buf.WriteString("data") + binary.Write(buf, binary.LittleEndian, dataLen) + + // Write PCM data to buffer + buf.Write(pcmData) + + // Return the complete WAV data + return buf.Bytes(), nil +} + +// convertMessages converts fabric chat messages to genai Content format +func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content { + var contents []*genai.Content + + for _, msg := range msgs { + content := &genai.Content{} + + if msg.Content != "" { + content.Parts = append(content.Parts, &genai.Part{Text: msg.Content}) + } + + // Handle multi-content messages (images, etc.) + for _, part := range msg.MultiContent { + switch part.Type { + case chat.ChatMessagePartTypeText: + content.Parts = append(content.Parts, &genai.Part{Text: part.Text}) + case chat.ChatMessagePartTypeImageURL: + // TODO: Handle image URLs if needed + // This would require downloading and converting to inline data + } + } + + contents = append(contents, content) + } + + return contents +} + +// extractTextFromResponse extracts text content from the response +func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string { + var result strings.Builder + + for _, candidate := range response.Candidates { + if candidate.Content != nil { + for _, part := range candidate.Content.Parts { + if part.Text != "" { + result.WriteString(part.Text) + } + } + } + } + + return result.String() } diff --git a/internal/plugins/ai/gemini/gemini_test.go b/internal/plugins/ai/gemini/gemini_test.go index e3d58572..633aa239 100644 --- a/internal/plugins/ai/gemini/gemini_test.go +++ b/internal/plugins/ai/gemini/gemini_test.go @@ -3,32 +3,40 @@ package gemini import ( "testing" - "github.com/google/generative-ai-go/genai" + "google.golang.org/genai" ) -// Test generated using Keploy -func TestBuildModelNameSimple(t *testing.T) { +// Test buildModelNameFull method +func TestBuildModelNameFull(t *testing.T) { client := &Client{} - fullModelName := "models/chat-bison-001" - expected := "chat-bison-001" - result := client.buildModelNameSimple(fullModelName) + tests := []struct { + input string + expected string + }{ + {"chat-bison-001", "models/chat-bison-001"}, + {"models/chat-bison-001", "models/chat-bison-001"}, + {"gemini-2.5-flash-preview-tts", "models/gemini-2.5-flash-preview-tts"}, + } - if result != expected { - t.Errorf("Expected %v, got %v", expected, result) + for _, test := range tests { + result := client.buildModelNameFull(test.input) + if result != test.expected { + t.Errorf("For input %v, expected %v, got %v", test.input, test.expected, result) + } } } -// Test generated using Keploy -func TestExtractText(t *testing.T) { +// Test extractTextFromResponse method +func TestExtractTextFromResponse(t *testing.T) { client := &Client{} response := &genai.GenerateContentResponse{ Candidates: []*genai.Candidate{ { Content: &genai.Content{ - Parts: []genai.Part{ - genai.Text("Hello, "), - genai.Text("world!"), + Parts: []*genai.Part{ + {Text: "Hello, "}, + {Text: "world!"}, }, }, }, @@ -36,9 +44,56 @@ func TestExtractText(t *testing.T) { } expected := "Hello, world!" - result := client.extractText(response) + result := client.extractTextFromResponse(response) if result != expected { t.Errorf("Expected %v, got %v", expected, result) } } + +// Test isTTSModel method +func TestIsTTSModel(t *testing.T) { + client := &Client{} + + tests := []struct { + modelName string + expected bool + }{ + {"gemini-2.5-flash-preview-tts", true}, + {"text-to-speech-model", true}, + {"TTS-MODEL", true}, + {"gemini-pro", false}, + {"chat-bison", false}, + {"", false}, + } + + for _, test := range tests { + result := client.isTTSModel(test.modelName) + if result != test.expected { + t.Errorf("For model %v, expected %v, got %v", test.modelName, test.expected, result) + } + } +} + +// Test generateWAVFile method (basic test) +func TestGenerateWAVFile(t *testing.T) { + client := &Client{} + + // Test with minimal PCM data + pcmData := []byte{0x00, 0x01, 0x02, 0x03} + + result, err := client.generateWAVFile(pcmData) + if err != nil { + t.Errorf("generateWAVFile failed: %v", err) + } + + // Check that we got some data back + if len(result) == 0 { + t.Error("generateWAVFile returned empty data") + } + + // Check that it starts with RIFF header + if len(result) >= 4 && string(result[0:4]) != "RIFF" { + t.Error("Generated WAV data doesn't start with RIFF header") + } +} diff --git a/scripts/docker-test/README.md b/scripts/docker-test/README.md index cb846c7c..687ac5db 100644 --- a/scripts/docker-test/README.md +++ b/scripts/docker-test/README.md @@ -50,7 +50,7 @@ These files are volume-mounted into the Docker container and persist changes mad The interactive mode (`-i`) provides several options: -``` +```text Available test cases: 1) No APIs configured (no-config)