mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-10 06:48:04 -05:00
refactor: extract TTS methods and add audio validation with security limits
## CHANGES - Extract text extraction logic into separate method - Add GenAI client creation helper function - Split TTS generation into focused helper methods - Add audio data size validation with security limits - Implement MIME type validation for audio responses - Add WAV file generation input validation checks - Pre-allocate buffer capacity for better performance - Define audio constants for reusable configuration - Add comprehensive error handling for edge cases - Validate generated WAV data before returning results
This commit is contained in:
@@ -14,6 +14,18 @@ import (
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
// WAV audio constants
|
||||
const (
|
||||
DefaultChannels = 1
|
||||
DefaultSampleRate = 24000
|
||||
DefaultBitsPerSample = 16
|
||||
WAVHeaderSize = 44
|
||||
RIFFHeaderSize = 36
|
||||
MaxAudioDataSize = 100 * 1024 * 1024 // 100MB limit for security
|
||||
MinAudioDataSize = 44 // Minimum viable audio data
|
||||
AudioDataPrefix = "FABRIC_AUDIO_DATA:"
|
||||
)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Gemini"
|
||||
ret = &Client{}
|
||||
@@ -157,30 +169,42 @@ func (o *Client) isTTSModel(modelName string) bool {
|
||||
strings.Contains(lowerModel, "text-to-speech")
|
||||
}
|
||||
|
||||
// generateTTSAudio handles TTS audio generation using the new SDK
|
||||
func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
|
||||
// Extract the text to convert to speech from the user messages
|
||||
var textToSpeak string
|
||||
// extractTextForTTS extracts text content from chat messages for TTS generation
|
||||
func (o *Client) extractTextForTTS(msgs []*chat.ChatCompletionMessage) (string, error) {
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
if msgs[i].Role == chat.ChatMessageRoleUser && msgs[i].Content != "" {
|
||||
textToSpeak = msgs[i].Content
|
||||
break
|
||||
return msgs[i].Content, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("no text content found for TTS generation")
|
||||
}
|
||||
|
||||
if textToSpeak == "" {
|
||||
err = fmt.Errorf("no text content found for TTS generation")
|
||||
return
|
||||
}
|
||||
|
||||
var client *genai.Client
|
||||
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
|
||||
// createGenaiClient creates a new GenAI client for TTS operations
|
||||
func (o *Client) createGenaiClient(ctx context.Context) (*genai.Client, error) {
|
||||
return genai.NewClient(ctx, &genai.ClientConfig{
|
||||
APIKey: o.ApiKey.Value,
|
||||
Backend: genai.BackendGeminiAPI,
|
||||
}); err != nil {
|
||||
return
|
||||
})
|
||||
}
|
||||
|
||||
// generateTTSAudio handles TTS audio generation using the new SDK
|
||||
func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
|
||||
textToSpeak, err := o.extractTextForTTS(msgs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
client, err := o.createGenaiClient(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return o.performTTSGeneration(ctx, client, textToSpeak, opts)
|
||||
}
|
||||
|
||||
// performTTSGeneration performs the actual TTS generation and audio processing
|
||||
func (o *Client) performTTSGeneration(ctx context.Context, client *genai.Client, textToSpeak string, opts *domain.ChatOptions) (string, error) {
|
||||
|
||||
// Create content for TTS
|
||||
contents := []*genai.Content{{
|
||||
Parts: []*genai.Part{{Text: textToSpeak}},
|
||||
@@ -208,8 +232,15 @@ func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompleti
|
||||
if len(response.Candidates) > 0 && response.Candidates[0].Content != nil && len(response.Candidates[0].Content.Parts) > 0 {
|
||||
part := response.Candidates[0].Content.Parts[0]
|
||||
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
|
||||
// The data is already in binary format, not base64
|
||||
// Validate audio data format and size
|
||||
if part.InlineData.MIMEType != "" && !strings.HasPrefix(part.InlineData.MIMEType, "audio/") {
|
||||
return "", fmt.Errorf("unexpected data type: %s, expected audio data", part.InlineData.MIMEType)
|
||||
}
|
||||
|
||||
pcmData := part.InlineData.Data
|
||||
if len(pcmData) < MinAudioDataSize {
|
||||
return "", fmt.Errorf("audio data too small: %d bytes, minimum required: %d", len(pcmData), MinAudioDataSize)
|
||||
}
|
||||
|
||||
// Generate WAV file with proper headers and return the binary data
|
||||
wavData, err := o.generateWAVFile(pcmData)
|
||||
@@ -217,10 +248,14 @@ func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompleti
|
||||
return "", fmt.Errorf("failed to generate WAV file: %w", err)
|
||||
}
|
||||
|
||||
// Validate generated WAV data
|
||||
if len(wavData) < WAVHeaderSize {
|
||||
return "", fmt.Errorf("generated WAV data is invalid: %d bytes, minimum required: %d", len(wavData), WAVHeaderSize)
|
||||
}
|
||||
|
||||
// Store the binary audio data in a special format that the CLI can detect
|
||||
// We'll encode it as a special marker followed by the binary data
|
||||
ret = fmt.Sprintf("FABRIC_AUDIO_DATA:%s", string(wavData))
|
||||
return ret, nil
|
||||
// Use more efficient string concatenation
|
||||
return AudioDataPrefix + string(wavData), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -229,18 +264,28 @@ func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompleti
|
||||
|
||||
// generateWAVFile creates WAV data from PCM data with proper headers
|
||||
func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
|
||||
// Validate input size to prevent potential security issues
|
||||
if len(pcmData) == 0 {
|
||||
return nil, fmt.Errorf("empty PCM data provided")
|
||||
}
|
||||
if len(pcmData) > MaxAudioDataSize {
|
||||
return nil, fmt.Errorf("PCM data too large: %d bytes, maximum allowed: %d", len(pcmData), MaxAudioDataSize)
|
||||
}
|
||||
|
||||
// WAV file parameters (Gemini TTS default specs)
|
||||
channels := 1
|
||||
sampleRate := 24000
|
||||
bitsPerSample := 16
|
||||
channels := DefaultChannels
|
||||
sampleRate := DefaultSampleRate
|
||||
bitsPerSample := DefaultBitsPerSample
|
||||
|
||||
// Calculate required values
|
||||
byteRate := sampleRate * channels * bitsPerSample / 8
|
||||
blockAlign := channels * bitsPerSample / 8
|
||||
dataLen := uint32(len(pcmData))
|
||||
riffSize := 36 + dataLen
|
||||
riffSize := RIFFHeaderSize + dataLen
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
// Pre-allocate buffer with known size for better performance
|
||||
totalSize := int(riffSize + 8) // +8 for RIFF header
|
||||
buf := bytes.NewBuffer(make([]byte, 0, totalSize))
|
||||
|
||||
// RIFF header
|
||||
buf.WriteString("RIFF")
|
||||
@@ -264,8 +309,13 @@ func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
|
||||
// Write PCM data to buffer
|
||||
buf.Write(pcmData)
|
||||
|
||||
// Return the complete WAV data
|
||||
return buf.Bytes(), nil
|
||||
// Validate generated WAV data
|
||||
result := buf.Bytes()
|
||||
if len(result) < WAVHeaderSize {
|
||||
return nil, fmt.Errorf("generated WAV data is invalid: %d bytes, minimum required: %d", len(result), WAVHeaderSize)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// convertMessages converts fabric chat messages to genai Content format
|
||||
|
||||
Reference in New Issue
Block a user