refactor: extract TTS methods and add audio validation with security limits

## CHANGES

- Extract text extraction logic into separate method
- Add GenAI client creation helper function
- Split TTS generation into focused helper methods
- Add audio data size validation with security limits
- Implement MIME type validation for audio responses
- Add WAV file generation input validation checks
- Pre-allocate buffer capacity for better performance
- Define audio constants for reusable configuration
- Add comprehensive error handling for edge cases
- Validate generated WAV data before returning results
This commit is contained in:
Kayvan Sylvan
2025-07-26 11:29:12 -07:00
parent 5d7137804a
commit 5cdf297d85

View File

@@ -14,6 +14,18 @@ import (
"google.golang.org/genai"
)
// WAV audio constants
const (
DefaultChannels = 1
DefaultSampleRate = 24000
DefaultBitsPerSample = 16
WAVHeaderSize = 44
RIFFHeaderSize = 36
MaxAudioDataSize = 100 * 1024 * 1024 // 100MB limit for security
MinAudioDataSize = 44 // Minimum viable audio data
AudioDataPrefix = "FABRIC_AUDIO_DATA:"
)
func NewClient() (ret *Client) {
vendorName := "Gemini"
ret = &Client{}
@@ -157,30 +169,42 @@ func (o *Client) isTTSModel(modelName string) bool {
strings.Contains(lowerModel, "text-to-speech")
}
// generateTTSAudio handles TTS audio generation using the new SDK
func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
// Extract the text to convert to speech from the user messages
var textToSpeak string
// extractTextForTTS extracts text content from chat messages for TTS generation
func (o *Client) extractTextForTTS(msgs []*chat.ChatCompletionMessage) (string, error) {
for i := len(msgs) - 1; i >= 0; i-- {
if msgs[i].Role == chat.ChatMessageRoleUser && msgs[i].Content != "" {
textToSpeak = msgs[i].Content
break
return msgs[i].Content, nil
}
}
return "", fmt.Errorf("no text content found for TTS generation")
}
if textToSpeak == "" {
err = fmt.Errorf("no text content found for TTS generation")
return
}
var client *genai.Client
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
// createGenaiClient creates a new GenAI client for TTS operations
func (o *Client) createGenaiClient(ctx context.Context) (*genai.Client, error) {
return genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
})
}
// generateTTSAudio handles TTS audio generation using the new SDK
func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
textToSpeak, err := o.extractTextForTTS(msgs)
if err != nil {
return "", err
}
client, err := o.createGenaiClient(ctx)
if err != nil {
return "", err
}
return o.performTTSGeneration(ctx, client, textToSpeak, opts)
}
// performTTSGeneration performs the actual TTS generation and audio processing
func (o *Client) performTTSGeneration(ctx context.Context, client *genai.Client, textToSpeak string, opts *domain.ChatOptions) (string, error) {
// Create content for TTS
contents := []*genai.Content{{
Parts: []*genai.Part{{Text: textToSpeak}},
@@ -208,8 +232,15 @@ func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompleti
if len(response.Candidates) > 0 && response.Candidates[0].Content != nil && len(response.Candidates[0].Content.Parts) > 0 {
part := response.Candidates[0].Content.Parts[0]
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
// The data is already in binary format, not base64
// Validate audio data format and size
if part.InlineData.MIMEType != "" && !strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
return "", fmt.Errorf("unexpected data type: %s, expected audio data", part.InlineData.MIMEType)
}
pcmData := part.InlineData.Data
if len(pcmData) < MinAudioDataSize {
return "", fmt.Errorf("audio data too small: %d bytes, minimum required: %d", len(pcmData), MinAudioDataSize)
}
// Generate WAV file with proper headers and return the binary data
wavData, err := o.generateWAVFile(pcmData)
@@ -217,10 +248,14 @@ func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompleti
return "", fmt.Errorf("failed to generate WAV file: %w", err)
}
// Validate generated WAV data
if len(wavData) < WAVHeaderSize {
return "", fmt.Errorf("generated WAV data is invalid: %d bytes, minimum required: %d", len(wavData), WAVHeaderSize)
}
// Store the binary audio data in a special format that the CLI can detect
// We'll encode it as a special marker followed by the binary data
ret = fmt.Sprintf("FABRIC_AUDIO_DATA:%s", string(wavData))
return ret, nil
// Use more efficient string concatenation
return AudioDataPrefix + string(wavData), nil
}
}
@@ -229,18 +264,28 @@ func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompleti
// generateWAVFile creates WAV data from PCM data with proper headers
func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
// Validate input size to prevent potential security issues
if len(pcmData) == 0 {
return nil, fmt.Errorf("empty PCM data provided")
}
if len(pcmData) > MaxAudioDataSize {
return nil, fmt.Errorf("PCM data too large: %d bytes, maximum allowed: %d", len(pcmData), MaxAudioDataSize)
}
// WAV file parameters (Gemini TTS default specs)
channels := 1
sampleRate := 24000
bitsPerSample := 16
channels := DefaultChannels
sampleRate := DefaultSampleRate
bitsPerSample := DefaultBitsPerSample
// Calculate required values
byteRate := sampleRate * channels * bitsPerSample / 8
blockAlign := channels * bitsPerSample / 8
dataLen := uint32(len(pcmData))
riffSize := 36 + dataLen
riffSize := RIFFHeaderSize + dataLen
buf := new(bytes.Buffer)
// Pre-allocate buffer with known size for better performance
totalSize := int(riffSize + 8) // +8 for RIFF header
buf := bytes.NewBuffer(make([]byte, 0, totalSize))
// RIFF header
buf.WriteString("RIFF")
@@ -264,8 +309,13 @@ func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
// Write PCM data to buffer
buf.Write(pcmData)
// Return the complete WAV data
return buf.Bytes(), nil
// Validate generated WAV data
result := buf.Bytes()
if len(result) < WAVHeaderSize {
return nil, fmt.Errorf("generated WAV data is invalid: %d bytes, minimum required: %d", len(result), WAVHeaderSize)
}
return result, nil
}
// convertMessages converts fabric chat messages to genai Content format