feat(gemini): enable web search, citations, and search-location validation

CHANGES
- Enable Gemini models to use web search tool
- Validate search-location timezone or language code formats
- Normalize language codes from underscores to hyphenated form
- Inject Google Search tool when --search flag enabled
- Append deduplicated web citations under standardized Sources section
- Improve robustness for nil candidates and content parts
- Factor generation config builder for reuse in streaming
- Update CLI help and completions to include Gemini
This commit is contained in:
Kayvan Sylvan
2025-08-10 19:56:02 -07:00
parent f33d27f836
commit 558e7f877d
6 changed files with 271 additions and 34 deletions

View File

@@ -536,7 +536,7 @@ Application Options:
--liststrategies List all strategies
--listvendors List all vendors
--shell-complete-list Output raw list without headers/formatting (for shell completion)
--search Enable web search tool for supported models (Anthropic, OpenAI)
--search Enable web search tool for supported models (Anthropic, OpenAI, Gemini)
--search-location= Set location for web search results (e.g., 'America/Los_Angeles')
--image-file= Save generated image to specified file path (e.g., 'output.png')
--image-size= Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)

View File

@@ -98,7 +98,7 @@ _fabric() {
'(--api-key)--api-key[API key used to secure server routes]:api-key:' \
'(--config)--config[Path to YAML config file]:config file:_files -g "*.yaml *.yml"' \
'(--version)--version[Print current version]' \
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI, Gemini)]' \
'(--search-location)--search-location[Set location for web search results]:location:' \
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
'(--image-size)--image-size[Image dimensions]:size:(1024x1024 1536x1024 1024x1536 auto)' \

View File

@@ -99,7 +99,7 @@ complete -c fabric -l yt-dlp-args -d "Additional arguments to pass to yt-dlp (e.
complete -c fabric -l readability -d "Convert HTML input into a clean, readable view"
complete -c fabric -l input-has-vars -d "Apply variables to user input"
complete -c fabric -l dry-run -d "Show what would be sent to the model without actually sending it"
complete -c fabric -l search -d "Enable web search tool for supported models (Anthropic, OpenAI)"
complete -c fabric -l search -d "Enable web search tool for supported models (Anthropic, OpenAI, Gemini)"
complete -c fabric -l serve -d "Serve the Fabric Rest API"
complete -c fabric -l serveOllama -d "Serve the Fabric Rest API with ollama endpoints"
complete -c fabric -l version -d "Print current version"

View File

@@ -79,7 +79,7 @@ type Flags struct {
ListStrategies bool `long:"liststrategies" description:"List all strategies"`
ListVendors bool `long:"listvendors" description:"List all vendors"`
ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"`
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI)"`
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI, Gemini)"`
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
ImageFile string `long:"image-file" description:"Save generated image to specified file path (e.g., 'output.png')"`
ImageSize string `long:"image-size" description:"Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)"`

View File

@@ -5,6 +5,7 @@ import (
"context"
"encoding/binary"
"fmt"
"regexp"
"strings"
"github.com/danielmiessler/fabric/internal/chat"
@@ -26,6 +27,24 @@ const (
AudioDataPrefix = "FABRIC_AUDIO_DATA:"
)
const (
citationHeader = "\n\n## Sources\n\n"
citationSeparator = "\n"
citationFormat = "- [%s](%s)"
errInvalidLocationFormat = "invalid search location format %q: must be timezone (e.g., 'America/Los_Angeles') or language code (e.g., 'en-US')"
locationSeparator = "/"
langCodeSeparator = "_"
langCodeNormalizedSep = "-"
modelPrefix = "models/"
modelTypeTTS = "tts"
modelTypePreviewTTS = "preview-tts"
modelTypeTextToSpeech = "text-to-speech"
)
var langCodeRegex = regexp.MustCompile(`^[a-z]{2}(-[A-Z]{2})?$`)
func NewClient() (ret *Client) {
vendorName := "Gemini"
ret = &Client{}
@@ -93,14 +112,13 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
// Generate content
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
})
cfg, err := o.buildGenerateContentConfig(opts)
if err != nil {
return "", err
}
// Generate content with optional tools
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, cfg)
if err != nil {
return "", err
}
@@ -123,14 +141,13 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
// Generate streaming content
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
})
cfg, err := o.buildGenerateContentConfig(opts)
if err != nil {
return err
}
// Generate streaming content with optional tools
stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, cfg)
for response, err := range stream {
if err != nil {
@@ -153,20 +170,86 @@ func (o *Client) NeedsRawMode(modelName string) bool {
return false
}
// buildGenerateContentConfig constructs the generation config with optional tools.
// When search is enabled it injects the Google Search tool. The optional search
// location accepts either:
// - A timezone in the format "Continent/City" (e.g., "America/Los_Angeles")
// - An ISO language code "ll" or "ll-CC" (e.g., "en" or "en-US")
//
// Underscores are normalized to hyphens. Returns an error if the location is
// invalid.
func (o *Client) buildGenerateContentConfig(opts *domain.ChatOptions) (*genai.GenerateContentConfig, error) {
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
cfg := &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
}
if opts.Search {
cfg.Tools = []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
if loc := opts.SearchLocation; loc != "" {
if isValidLocationFormat(loc) {
loc = normalizeLocation(loc)
cfg.ToolConfig = &genai.ToolConfig{
RetrievalConfig: &genai.RetrievalConfig{LanguageCode: loc},
}
} else {
return nil, fmt.Errorf(errInvalidLocationFormat, loc)
}
}
}
return cfg, nil
}
// buildModelNameFull adds the "models/" prefix for API calls
func (o *Client) buildModelNameFull(modelName string) string {
if strings.HasPrefix(modelName, "models/") {
if strings.HasPrefix(modelName, modelPrefix) {
return modelName
}
return "models/" + modelName
return modelPrefix + modelName
}
func isValidLocationFormat(location string) bool {
if strings.Contains(location, locationSeparator) {
parts := strings.Split(location, locationSeparator)
return len(parts) == 2 && parts[0] != "" && parts[1] != ""
}
return isValidLanguageCode(location)
}
func normalizeLocation(location string) string {
if strings.Contains(location, locationSeparator) {
return location
}
return strings.Replace(location, langCodeSeparator, langCodeNormalizedSep, 1)
}
// isValidLanguageCode reports whether the input is an ISO 639-1 language code
// optionally followed by an ISO 3166-1 country code. Underscores are
// normalized to hyphens before validation.
func isValidLanguageCode(code string) bool {
normalized := strings.Replace(code, langCodeSeparator, langCodeNormalizedSep, 1)
parts := strings.Split(normalized, langCodeNormalizedSep)
switch len(parts) {
case 1:
return langCodeRegex.MatchString(strings.ToLower(parts[0]))
case 2:
formatted := strings.ToLower(parts[0]) + langCodeNormalizedSep + strings.ToUpper(parts[1])
return langCodeRegex.MatchString(formatted)
default:
return false
}
}
// isTTSModel checks if the model is a text-to-speech model
func (o *Client) isTTSModel(modelName string) bool {
lowerModel := strings.ToLower(modelName)
return strings.Contains(lowerModel, "tts") ||
strings.Contains(lowerModel, "preview-tts") ||
strings.Contains(lowerModel, "text-to-speech")
return strings.Contains(lowerModel, modelTypeTTS) ||
strings.Contains(lowerModel, modelTypePreviewTTS) ||
strings.Contains(lowerModel, modelTypeTextToSpeech)
}
// extractTextForTTS extracts text content from chat messages for TTS generation
@@ -370,19 +453,71 @@ func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Co
return contents
}
// extractTextFromResponse extracts text content from the response
// extractTextFromResponse extracts text content from the response and appends
// any web citations in a standardized format.
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
var result strings.Builder
if response == nil {
return ""
}
text := o.extractTextParts(response)
citations := o.extractCitations(response)
if len(citations) > 0 {
return text + citationHeader + strings.Join(citations, citationSeparator)
}
return text
}
func (o *Client) extractTextParts(response *genai.GenerateContentResponse) string {
var builder strings.Builder
for _, candidate := range response.Candidates {
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if part.Text != "" {
result.WriteString(part.Text)
}
if candidate == nil || candidate.Content == nil {
continue
}
for _, part := range candidate.Content.Parts {
if part != nil && part.Text != "" {
builder.WriteString(part.Text)
}
}
}
return result.String()
return builder.String()
}
func (o *Client) extractCitations(response *genai.GenerateContentResponse) []string {
if response == nil || len(response.Candidates) == 0 {
return nil
}
citationMap := make(map[string]bool)
var citations []string
for _, candidate := range response.Candidates {
if candidate == nil || candidate.GroundingMetadata == nil {
continue
}
chunks := candidate.GroundingMetadata.GroundingChunks
if len(chunks) == 0 {
continue
}
for _, chunk := range chunks {
if chunk == nil || chunk.Web == nil {
continue
}
uri := chunk.Web.URI
title := chunk.Web.Title
if uri == "" || title == "" {
continue
}
var keyBuilder strings.Builder
keyBuilder.WriteString(uri)
keyBuilder.WriteByte('|')
keyBuilder.WriteString(title)
key := keyBuilder.String()
if !citationMap[key] {
citationMap[key] = true
citationText := fmt.Sprintf(citationFormat, title, uri)
citations = append(citations, citationText)
}
}
}
return citations
}

View File

@@ -1,11 +1,13 @@
package gemini
import (
"strings"
"testing"
"google.golang.org/genai"
"github.com/danielmiessler/fabric/internal/chat"
"github.com/danielmiessler/fabric/internal/domain"
)
// Test buildModelNameFull method
@@ -53,6 +55,106 @@ func TestExtractTextFromResponse(t *testing.T) {
}
}
func TestExtractTextFromResponse_Nil(t *testing.T) {
client := &Client{}
if got := client.extractTextFromResponse(nil); got != "" {
t.Fatalf("expected empty string, got %q", got)
}
}
func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
client := &Client{}
response := &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{
{
Content: &genai.Content{Parts: []*genai.Part{{Text: "Hello"}}},
GroundingMetadata: &genai.GroundingMetadata{GroundingChunks: nil},
},
},
}
if got := client.extractTextFromResponse(response); got != "Hello" {
t.Fatalf("expected 'Hello', got %q", got)
}
}
func TestBuildGenerateContentConfig_WithSearch(t *testing.T) {
client := &Client{}
opts := &domain.ChatOptions{Search: true}
cfg, err := client.buildGenerateContentConfig(opts)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Tools == nil || len(cfg.Tools) != 1 || cfg.Tools[0].GoogleSearch == nil {
t.Errorf("expected google search tool to be included")
}
}
func TestBuildGenerateContentConfig_WithSearchAndLocation(t *testing.T) {
client := &Client{}
opts := &domain.ChatOptions{Search: true, SearchLocation: "America/Los_Angeles"}
cfg, err := client.buildGenerateContentConfig(opts)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.ToolConfig == nil || cfg.ToolConfig.RetrievalConfig == nil {
t.Fatalf("expected retrieval config when search location provided")
}
if cfg.ToolConfig.RetrievalConfig.LanguageCode != opts.SearchLocation {
t.Errorf("expected language code %s, got %s", opts.SearchLocation, cfg.ToolConfig.RetrievalConfig.LanguageCode)
}
}
func TestBuildGenerateContentConfig_InvalidLocation(t *testing.T) {
client := &Client{}
opts := &domain.ChatOptions{Search: true, SearchLocation: "invalid"}
_, err := client.buildGenerateContentConfig(opts)
if err == nil {
t.Fatalf("expected error for invalid location")
}
}
func TestBuildGenerateContentConfig_LanguageCodeNormalization(t *testing.T) {
client := &Client{}
opts := &domain.ChatOptions{Search: true, SearchLocation: "en_US"}
cfg, err := client.buildGenerateContentConfig(opts)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.ToolConfig == nil || cfg.ToolConfig.RetrievalConfig.LanguageCode != "en-US" {
t.Fatalf("expected normalized language code 'en-US', got %+v", cfg.ToolConfig)
}
}
func TestCitationFormatting(t *testing.T) {
client := &Client{}
response := &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{
{
Content: &genai.Content{Parts: []*genai.Part{{Text: "Based on recent research, AI is advancing rapidly."}}},
GroundingMetadata: &genai.GroundingMetadata{
GroundingChunks: []*genai.GroundingChunk{
{Web: &genai.GroundingChunkWeb{URI: "https://example.com/ai", Title: "AI Research"}},
{Web: &genai.GroundingChunkWeb{URI: "https://news.com/tech", Title: "Tech News"}},
{Web: &genai.GroundingChunkWeb{URI: "https://example.com/ai", Title: "AI Research"}}, // duplicate
},
},
},
},
}
result := client.extractTextFromResponse(response)
if !strings.Contains(result, "## Sources") {
t.Fatalf("expected sources section in result: %s", result)
}
if strings.Count(result, "- [") != 2 {
t.Errorf("expected 2 unique citations, got %d", strings.Count(result, "- ["))
}
}
// Test convertMessages handles role mapping correctly
func TestConvertMessagesRoles(t *testing.T) {
client := &Client{}