chore: update Gemini SDK to new genai library and add TTS audio output support

## CHANGES

- Replace deprecated generative-ai-go with google.golang.org/genai library
- Add TTS model detection and audio output validation
- Implement WAV file generation for TTS audio responses
- Add audio format checking utilities in CLI output
- Update Gemini client to support streaming with new SDK
- Add "Kore" and "subchunk" to VSCode spell checker dictionary
- Remove extra blank line from changelog formatting
- Update dependency imports and remove unused packages
This commit is contained in:
Kayvan Sylvan
2025-07-26 10:54:34 -07:00
parent 92aca524a4
commit 3e75aa260f
11 changed files with 423 additions and 123 deletions

View File

@@ -1,8 +1,9 @@
package gemini
import (
"bytes"
"context"
"errors"
"encoding/binary"
"fmt"
"strings"
@@ -10,13 +11,9 @@ import (
"github.com/danielmiessler/fabric/internal/plugins"
"github.com/danielmiessler/fabric/internal/domain"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
"google.golang.org/genai"
)
const modelsNamePrefix = "models/"
func NewClient() (ret *Client) {
vendorName := "Gemini"
ret = &Client{}
@@ -39,107 +36,104 @@ type Client struct {
func (o *Client) ListModels() (ret []string, err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
defer client.Close()
iter := client.ListModels(ctx)
for {
var resp *genai.ModelInfo
if resp, err = iter.Next(); err != nil {
if errors.Is(err, iterator.Done) {
err = nil
}
break
}
// List available models using the correct API
resp, err := client.Models.List(ctx, &genai.ListModelsConfig{})
if err != nil {
return nil, err
}
name := o.buildModelNameSimple(resp.Name)
ret = append(ret, name)
for _, model := range resp.Items {
// Strip the "models/" prefix for user convenience
modelName := strings.TrimPrefix(model.Name, "models/")
ret = append(ret, modelName)
}
return
}
func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
systemInstruction, messages := toMessages(msgs)
// Check if this is a TTS model request
if o.isTTSModel(opts.Model) {
if !opts.AudioOutput {
err = fmt.Errorf("TTS model '%s' requires audio output. Please specify an audio output file with -o flag ending in .wav", opts.Model)
return
}
// Handle TTS generation
return o.generateTTSAudio(ctx, msgs, opts)
}
// Regular text generation
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
return
}
defer client.Close()
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
var response *genai.GenerateContentResponse
if response, err = model.GenerateContent(ctx, messages...); err != nil {
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
ret = o.extractText(response)
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
// Generate content
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
})
if err != nil {
return "", err
}
// Extract text from response
ret = o.extractTextFromResponse(response)
return
}
func (o *Client) buildModelNameSimple(fullModelName string) string {
return strings.TrimPrefix(fullModelName, modelsNamePrefix)
}
func (o *Client) buildModelNameFull(modelName string) string {
return fmt.Sprintf("%v%v", modelsNamePrefix, modelName)
}
func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan string) (err error) {
ctx := context.Background()
var client *genai.Client
if client, err = genai.NewClient(ctx, option.WithAPIKey(o.ApiKey.Value)); err != nil {
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
defer client.Close()
systemInstruction, messages := toMessages(msgs)
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
model := client.GenerativeModel(o.buildModelNameFull(opts.Model))
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SystemInstruction = systemInstruction
// Generate streaming content
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
})
iter := model.GenerateContentStream(ctx, messages...)
for {
if resp, iterErr := iter.Next(); iterErr == nil {
for _, candidate := range resp.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if text, ok := part.(genai.Text); ok {
channel <- string(text)
}
}
}
}
} else {
if !errors.Is(iterErr, iterator.Done) {
channel <- fmt.Sprintf("%v\n", iterErr)
}
for response, err := range stream {
if err != nil {
channel <- fmt.Sprintf("Error: %v\n", err)
close(channel)
break
}
}
return
}
func (o *Client) extractText(response *genai.GenerateContentResponse) (ret string) {
for _, candidate := range response.Candidates {
if candidate.Content == nil {
break
}
for _, part := range candidate.Content.Parts {
if text, ok := part.(genai.Text); ok {
ret += string(text)
}
text := o.extractTextFromResponse(response)
if text != "" {
channel <- text
}
}
close(channel)
return
}
@@ -147,18 +141,174 @@ func (o *Client) NeedsRawMode(modelName string) bool {
return false
}
func toMessages(msgs []*chat.ChatCompletionMessage) (systemInstruction *genai.Content, messages []genai.Part) {
if len(msgs) >= 2 {
systemInstruction = &genai.Content{
Parts: []genai.Part{
genai.Text(msgs[0].Content),
},
}
for _, msg := range msgs[1:] {
messages = append(messages, genai.Text(msg.Content))
}
} else {
messages = append(messages, genai.Text(msgs[0].Content))
// buildModelNameFull adds the "models/" prefix for API calls
func (o *Client) buildModelNameFull(modelName string) string {
if strings.HasPrefix(modelName, "models/") {
return modelName
}
return
return "models/" + modelName
}
// isTTSModel checks if the model is a text-to-speech model
func (o *Client) isTTSModel(modelName string) bool {
lowerModel := strings.ToLower(modelName)
return strings.Contains(lowerModel, "tts") ||
strings.Contains(lowerModel, "preview-tts") ||
strings.Contains(lowerModel, "text-to-speech")
}
// generateTTSAudio handles TTS audio generation using the new SDK
func (o *Client) generateTTSAudio(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (ret string, err error) {
// Extract the text to convert to speech from the user messages
var textToSpeak string
for i := len(msgs) - 1; i >= 0; i-- {
if msgs[i].Role == chat.ChatMessageRoleUser && msgs[i].Content != "" {
textToSpeak = msgs[i].Content
break
}
}
if textToSpeak == "" {
err = fmt.Errorf("no text content found for TTS generation")
return
}
var client *genai.Client
if client, err = genai.NewClient(ctx, &genai.ClientConfig{
APIKey: o.ApiKey.Value,
Backend: genai.BackendGeminiAPI,
}); err != nil {
return
}
// Create content for TTS
contents := []*genai.Content{{
Parts: []*genai.Part{{Text: textToSpeak}},
}}
// Configure for TTS generation
config := &genai.GenerateContentConfig{
ResponseModalities: []string{"AUDIO"},
SpeechConfig: &genai.SpeechConfig{
VoiceConfig: &genai.VoiceConfig{
PrebuiltVoiceConfig: &genai.PrebuiltVoiceConfig{
VoiceName: "Kore", // Default voice
},
},
},
}
// Generate TTS content
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, config)
if err != nil {
return "", fmt.Errorf("TTS generation failed: %w", err)
}
// Extract and process audio data
if len(response.Candidates) > 0 && response.Candidates[0].Content != nil && len(response.Candidates[0].Content.Parts) > 0 {
part := response.Candidates[0].Content.Parts[0]
if part.InlineData != nil && len(part.InlineData.Data) > 0 {
// The data is already in binary format, not base64
pcmData := part.InlineData.Data
// Generate WAV file with proper headers and return the binary data
wavData, err := o.generateWAVFile(pcmData)
if err != nil {
return "", fmt.Errorf("failed to generate WAV file: %w", err)
}
// Store the binary audio data in a special format that the CLI can detect
// We'll encode it as a special marker followed by the binary data
ret = fmt.Sprintf("FABRIC_AUDIO_DATA:%s", string(wavData))
return ret, nil
}
}
return "", fmt.Errorf("no audio data received from TTS model")
}
// generateWAVFile creates WAV data from PCM data with proper headers
func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
// WAV file parameters (Gemini TTS default specs)
channels := 1
sampleRate := 24000
bitsPerSample := 16
// Calculate required values
byteRate := sampleRate * channels * bitsPerSample / 8
blockAlign := channels * bitsPerSample / 8
dataLen := uint32(len(pcmData))
riffSize := 36 + dataLen
buf := new(bytes.Buffer)
// RIFF header
buf.WriteString("RIFF")
binary.Write(buf, binary.LittleEndian, riffSize)
buf.WriteString("WAVE")
// fmt chunk
buf.WriteString("fmt ")
binary.Write(buf, binary.LittleEndian, uint32(16)) // subchunk1Size
binary.Write(buf, binary.LittleEndian, uint16(1)) // audioFormat = PCM
binary.Write(buf, binary.LittleEndian, uint16(channels)) // numChannels
binary.Write(buf, binary.LittleEndian, uint32(sampleRate)) // sampleRate
binary.Write(buf, binary.LittleEndian, uint32(byteRate)) // byteRate
binary.Write(buf, binary.LittleEndian, uint16(blockAlign)) // blockAlign
binary.Write(buf, binary.LittleEndian, uint16(bitsPerSample)) // bitsPerSample
// data chunk
buf.WriteString("data")
binary.Write(buf, binary.LittleEndian, dataLen)
// Write PCM data to buffer
buf.Write(pcmData)
// Return the complete WAV data
return buf.Bytes(), nil
}
// convertMessages converts fabric chat messages to genai Content format
func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
var contents []*genai.Content
for _, msg := range msgs {
content := &genai.Content{}
if msg.Content != "" {
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
}
// Handle multi-content messages (images, etc.)
for _, part := range msg.MultiContent {
switch part.Type {
case chat.ChatMessagePartTypeText:
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
case chat.ChatMessagePartTypeImageURL:
// TODO: Handle image URLs if needed
// This would require downloading and converting to inline data
}
}
contents = append(contents, content)
}
return contents
}
// extractTextFromResponse extracts text content from the response
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
var result strings.Builder
for _, candidate := range response.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
result.WriteString(part.Text)
}
}
}
}
return result.String()
}

View File

@@ -3,32 +3,40 @@ package gemini
import (
"testing"
"github.com/google/generative-ai-go/genai"
"google.golang.org/genai"
)
// Test generated using Keploy
func TestBuildModelNameSimple(t *testing.T) {
// Test buildModelNameFull method
func TestBuildModelNameFull(t *testing.T) {
client := &Client{}
fullModelName := "models/chat-bison-001"
expected := "chat-bison-001"
result := client.buildModelNameSimple(fullModelName)
tests := []struct {
input string
expected string
}{
{"chat-bison-001", "models/chat-bison-001"},
{"models/chat-bison-001", "models/chat-bison-001"},
{"gemini-2.5-flash-preview-tts", "models/gemini-2.5-flash-preview-tts"},
}
if result != expected {
t.Errorf("Expected %v, got %v", expected, result)
for _, test := range tests {
result := client.buildModelNameFull(test.input)
if result != test.expected {
t.Errorf("For input %v, expected %v, got %v", test.input, test.expected, result)
}
}
}
// Test generated using Keploy
func TestExtractText(t *testing.T) {
// Test extractTextFromResponse method
func TestExtractTextFromResponse(t *testing.T) {
client := &Client{}
response := &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{
{
Content: &genai.Content{
Parts: []genai.Part{
genai.Text("Hello, "),
genai.Text("world!"),
Parts: []*genai.Part{
{Text: "Hello, "},
{Text: "world!"},
},
},
},
@@ -36,9 +44,56 @@ func TestExtractText(t *testing.T) {
}
expected := "Hello, world!"
result := client.extractText(response)
result := client.extractTextFromResponse(response)
if result != expected {
t.Errorf("Expected %v, got %v", expected, result)
}
}
// Test isTTSModel method
func TestIsTTSModel(t *testing.T) {
client := &Client{}
tests := []struct {
modelName string
expected bool
}{
{"gemini-2.5-flash-preview-tts", true},
{"text-to-speech-model", true},
{"TTS-MODEL", true},
{"gemini-pro", false},
{"chat-bison", false},
{"", false},
}
for _, test := range tests {
result := client.isTTSModel(test.modelName)
if result != test.expected {
t.Errorf("For model %v, expected %v, got %v", test.modelName, test.expected, result)
}
}
}
// Test generateWAVFile method (basic test)
func TestGenerateWAVFile(t *testing.T) {
client := &Client{}
// Test with minimal PCM data
pcmData := []byte{0x00, 0x01, 0x02, 0x03}
result, err := client.generateWAVFile(pcmData)
if err != nil {
t.Errorf("generateWAVFile failed: %v", err)
}
// Check that we got some data back
if len(result) == 0 {
t.Error("generateWAVFile returned empty data")
}
// Check that it starts with RIFF header
if len(result) >= 4 && string(result[0:4]) != "RIFF" {
t.Error("Generated WAV data doesn't start with RIFF header")
}
}