Files
Fabric/internal/plugins/ai/gemini/gemini.go
Kayvan Sylvan 614b1322d5 feat: add Gemini TTS voice selection and listing functionality
## CHANGES

- Add `--voice` flag for TTS voice selection
- Add `--list-gemini-voices` command for voice discovery
- Implement voice validation for Gemini TTS models
- Update shell completions for voice options
- Add comprehensive Gemini TTS documentation
- Create voice samples directory structure
- Extend spell checker dictionary with voice names
2025-07-26 15:11:30 -07:00

376 lines
11 KiB
Go

package gemini
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"strings"
"github.com/danielmiessler/fabric/internal/chat"
"github.com/danielmiessler/fabric/internal/plugins"
"github.com/danielmiessler/fabric/internal/domain"
"google.golang.org/genai"
)
// WAV audio constants
const (
DefaultChannels = 1
DefaultSampleRate = 24000
DefaultBitsPerSample = 16
WAVHeaderSize = 44
RIFFHeaderSize = 36
MaxAudioDataSize = 100 * 1024 * 1024 // 100MB limit for security
MinAudioDataSize = 44 // Minimum viable audio data
AudioDataPrefix = "FABRIC_AUDIO_DATA:"
)
func NewClient() (ret *Client) {
vendorName := "Gemini"
ret = &Client{}
ret.PluginBase = &plugins.PluginBase{
Name: vendorName,
EnvNamePrefix: plugins.BuildEnvVariablePrefix(vendorName),
}
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
return
}
type Client struct {
*plugins.PluginBase
ApiKey *plugins.SetupQuestion
}
func (o *Client) ListModels() (ret []string, err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
// List available models using the correct API
resp, err := client.Models.List(ctx, &genai.ListModelsConfig{})
if err != nil {
return nil, err
}
for _, model := range resp.Items {
// Strip the "models/" prefix for user convenience
modelName := strings.TrimPrefix(model.Name, "models/")
ret = append(ret, modelName)
}
return
}
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
// Check if this is a TTS model request
if o.isTTSModel(opts.Model) {
if !opts.AudioOutput {
err = fmt.Errorf("TTS model '%s' requires audio output. Please specify an audio output file with -o flag ending in .wav", opts.Model)
return
}
// Handle TTS generation
return o.generateTTSAudio(ctx, msgs, opts)
}
// Regular text generation
var client *genai.Client
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
// Generate content
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
})
if err != nil {
return "", err
}
// Extract text from response
ret = o.extractTextFromResponse(response)
return
}
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
// Generate streaming content
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
})
for response, err := range stream {
if err != nil {
channel <- fmt.Sprintf("Error: %v\n", err)
close(channel)
break
}
text := o.extractTextFromResponse(response)
if text != "" {
channel <- text
}
}
close(channel)
return
}
func (o *Client) NeedsRawMode(modelName string) bool {
return false
}
// buildModelNameFull adds the "models/" prefix for API calls
func (o *Client) buildModelNameFull(modelName string) string {
if strings.HasPrefix(modelName, "models/") {
return modelName
}
return "models/" + modelName
}
// isTTSModel checks if the model is a text-to-speech model
func (o *Client) isTTSModel(modelName string) bool {
lowerModel := strings.ToLower(modelName)
return strings.Contains(lowerModel, "tts") ||
strings.Contains(lowerModel, "preview-tts") ||
strings.Contains(lowerModel, "text-to-speech")
}
// extractTextForTTS extracts text content from chat messages for TTS generation
func (o *Client) extractTextForTTS(msgs []*chat.ChatCompletionMessage) (string, error) {
for i := len(msgs) - 1; i >= 0; i-- {
if msgs[i].Role == chat.ChatMessageRoleUser && msgs[i].Content != "" {
return msgs[i].Content, nil
}
}
return "", fmt.Errorf("no text content found for TTS generation")
}
// createGenaiClient creates a new GenAI client for TTS operations
func (o *Client) createGenaiClient(ctx context.Context) (*genai.Client, error) {
return genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
})
}
// generateTTSAudio handles TTS audio generation using the new SDK
func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
textToSpeak, err := o.extractTextForTTS(msgs)
if err != nil {
return "", err
}
// Validate voice name before making API call
if opts.Voice != "" && !IsValidGeminiVoice(opts.Voice) {
validVoices := GetGeminiVoiceNames()
return "", fmt.Errorf("invalid voice '%s'. Valid voices are: %v", opts.Voice, validVoices)
}
client, err := o.createGenaiClient(ctx)
if err != nil {
return "", err
}
return o.performTTSGeneration(ctx, client, textToSpeak, opts)
}
// performTTSGeneration performs the actual TTS generation and audio processing
func (o *Client) performTTSGeneration(ctx context.Context, client *genai.Client, textToSpeak string, opts *domain.ChatOptions) (string, error) {
// Create content for TTS
contents := []*genai.Content{{
Parts: []*genai.Part{{Text: textToSpeak}},
}}
// Configure for TTS generation
voiceName := opts.Voice
if voiceName == "" {
voiceName = "Kore" // Default voice if none specified
}
config := &genai.GenerateContentConfig{
ResponseModalities: []string{"AUDIO"},
SpeechConfig: &genai.SpeechConfig{
VoiceConfig: &genai.VoiceConfig{
PrebuiltVoiceConfig: &genai.PrebuiltVoiceConfig{
VoiceName: voiceName,
},
},
},
}
// Generate TTS content
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, config)
if err != nil {
return "", fmt.Errorf("TTS generation failed: %w", err)
}
// Extract and process audio data
if len(response.Candidates) > 0 && response.Candidates[0].Content != nil && len(response.Candidates[0].Content.Parts) > 0 {
part := response.Candidates[0].Content.Parts[0]
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
// Validate audio data format and size
if part.InlineData.MIMEType != "" && !strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
return "", fmt.Errorf("unexpected data type: %s, expected audio data", part.InlineData.MIMEType)
}
pcmData := part.InlineData.Data
if len(pcmData) < MinAudioDataSize {
return "", fmt.Errorf("audio data too small: %d bytes, minimum required: %d", len(pcmData), MinAudioDataSize)
}
// Generate WAV file with proper headers and return the binary data
wavData, err := o.generateWAVFile(pcmData)
if err != nil {
return "", fmt.Errorf("failed to generate WAV file: %w", err)
}
// Validate generated WAV data
if len(wavData) < WAVHeaderSize {
return "", fmt.Errorf("generated WAV data is invalid: %d bytes, minimum required: %d", len(wavData), WAVHeaderSize)
}
// Store the binary audio data in a special format that the CLI can detect
// Use more efficient string concatenation
return AudioDataPrefix + string(wavData), nil
}
}
return "", fmt.Errorf("no audio data received from TTS model")
}
// generateWAVFile creates WAV data from PCM data with proper headers
func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
// Validate input size to prevent potential security issues
if len(pcmData) == 0 {
return nil, fmt.Errorf("empty PCM data provided")
}
if len(pcmData) > MaxAudioDataSize {
return nil, fmt.Errorf("PCM data too large: %d bytes, maximum allowed: %d", len(pcmData), MaxAudioDataSize)
}
// WAV file parameters (Gemini TTS default specs)
channels := DefaultChannels
sampleRate := DefaultSampleRate
bitsPerSample := DefaultBitsPerSample
// Calculate required values
byteRate := sampleRate * channels * bitsPerSample / 8
blockAlign := channels * bitsPerSample / 8
dataLen := uint32(len(pcmData))
riffSize := RIFFHeaderSize + dataLen
// Pre-allocate buffer with known size for better performance
totalSize := int(riffSize + 8) // +8 for RIFF header
buf := bytes.NewBuffer(make([]byte, 0, totalSize))
// RIFF header
buf.WriteString("RIFF")
binary.Write(buf, binary.LittleEndian, riffSize)
buf.WriteString("WAVE")
// fmt chunk
buf.WriteString("fmt ")
binary.Write(buf, binary.LittleEndian, uint32(16)) // subchunk1Size
binary.Write(buf, binary.LittleEndian, uint16(1)) // audioFormat = PCM
binary.Write(buf, binary.LittleEndian, uint16(channels)) // numChannels
binary.Write(buf, binary.LittleEndian, uint32(sampleRate)) // sampleRate
binary.Write(buf, binary.LittleEndian, uint32(byteRate)) // byteRate
binary.Write(buf, binary.LittleEndian, uint16(blockAlign)) // blockAlign
binary.Write(buf, binary.LittleEndian, uint16(bitsPerSample)) // bitsPerSample
// data chunk
buf.WriteString("data")
binary.Write(buf, binary.LittleEndian, dataLen)
// Write PCM data to buffer
buf.Write(pcmData)
// Validate generated WAV data
result := buf.Bytes()
if len(result) < WAVHeaderSize {
return nil, fmt.Errorf("generated WAV data is invalid: %d bytes, minimum required: %d", len(result), WAVHeaderSize)
}
return result, nil
}
// convertMessages converts fabric chat messages to genai Content format
func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
var contents []*genai.Content
for _, msg := range msgs {
content := &genai.Content{Parts: []*genai.Part{}}
if msg.Content != "" {
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
}
// Handle multi-content messages (images, etc.)
for _, part := range msg.MultiContent {
switch part.Type {
case chat.ChatMessagePartTypeText:
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
case chat.ChatMessagePartTypeImageURL:
// TODO: Handle image URLs if needed
// This would require downloading and converting to inline data
}
}
contents = append(contents, content)
}
return contents
}
// extractTextFromResponse extracts text content from the response
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
var result strings.Builder
for _, candidate := range response.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
result.WriteString(part.Text)
}
}
}
}
return result.String()
}