mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 14:28:01 -05:00
feat(gemini): enable web search, citations, and search-location validation
CHANGES - Enable Gemini models to use web search tool - Validate search-location timezone or language code formats - Normalize language codes from underscores to hyphenated form - Inject Google Search tool when --search flag enabled - Append deduplicated web citations under standardized Sources section - Improve robustness for nil candidates and content parts - Factor generation config builder for reuse in streaming - Update CLI help and completions to include Gemini
This commit is contained in:
@@ -536,7 +536,7 @@ Application Options:
|
||||
--liststrategies List all strategies
|
||||
--listvendors List all vendors
|
||||
--shell-complete-list Output raw list without headers/formatting (for shell completion)
|
||||
--search Enable web search tool for supported models (Anthropic, OpenAI)
|
||||
--search Enable web search tool for supported models (Anthropic, OpenAI, Gemini)
|
||||
--search-location= Set location for web search results (e.g., 'America/Los_Angeles')
|
||||
--image-file= Save generated image to specified file path (e.g., 'output.png')
|
||||
--image-size= Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)
|
||||
|
||||
@@ -98,7 +98,7 @@ _fabric() {
|
||||
'(--api-key)--api-key[API key used to secure server routes]:api-key:' \
|
||||
'(--config)--config[Path to YAML config file]:config file:_files -g "*.yaml *.yml"' \
|
||||
'(--version)--version[Print current version]' \
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI)]' \
|
||||
'(--search)--search[Enable web search tool for supported models (Anthropic, OpenAI, Gemini)]' \
|
||||
'(--search-location)--search-location[Set location for web search results]:location:' \
|
||||
'(--image-file)--image-file[Save generated image to specified file path]:image file:_files -g "*.png *.webp *.jpeg *.jpg"' \
|
||||
'(--image-size)--image-size[Image dimensions]:size:(1024x1024 1536x1024 1024x1536 auto)' \
|
||||
|
||||
@@ -99,7 +99,7 @@ complete -c fabric -l yt-dlp-args -d "Additional arguments to pass to yt-dlp (e.
|
||||
complete -c fabric -l readability -d "Convert HTML input into a clean, readable view"
|
||||
complete -c fabric -l input-has-vars -d "Apply variables to user input"
|
||||
complete -c fabric -l dry-run -d "Show what would be sent to the model without actually sending it"
|
||||
complete -c fabric -l search -d "Enable web search tool for supported models (Anthropic, OpenAI)"
|
||||
complete -c fabric -l search -d "Enable web search tool for supported models (Anthropic, OpenAI, Gemini)"
|
||||
complete -c fabric -l serve -d "Serve the Fabric Rest API"
|
||||
complete -c fabric -l serveOllama -d "Serve the Fabric Rest API with ollama endpoints"
|
||||
complete -c fabric -l version -d "Print current version"
|
||||
|
||||
@@ -79,7 +79,7 @@ type Flags struct {
|
||||
ListStrategies bool `long:"liststrategies" description:"List all strategies"`
|
||||
ListVendors bool `long:"listvendors" description:"List all vendors"`
|
||||
ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"`
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI)"`
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic, OpenAI, Gemini)"`
|
||||
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
|
||||
ImageFile string `long:"image-file" description:"Save generated image to specified file path (e.g., 'output.png')"`
|
||||
ImageSize string `long:"image-size" description:"Image dimensions: 1024x1024, 1536x1024, 1024x1536, auto (default: auto)"`
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
@@ -26,6 +27,24 @@ const (
|
||||
AudioDataPrefix = "FABRIC_AUDIO_DATA:"
|
||||
)
|
||||
|
||||
const (
|
||||
citationHeader = "\n\n## Sources\n\n"
|
||||
citationSeparator = "\n"
|
||||
citationFormat = "- [%s](%s)"
|
||||
|
||||
errInvalidLocationFormat = "invalid search location format %q: must be timezone (e.g., 'America/Los_Angeles') or language code (e.g., 'en-US')"
|
||||
locationSeparator = "/"
|
||||
langCodeSeparator = "_"
|
||||
langCodeNormalizedSep = "-"
|
||||
|
||||
modelPrefix = "models/"
|
||||
modelTypeTTS = "tts"
|
||||
modelTypePreviewTTS = "preview-tts"
|
||||
modelTypeTextToSpeech = "text-to-speech"
|
||||
)
|
||||
|
||||
var langCodeRegex = regexp.MustCompile(`^[a-z]{2}(-[A-Z]{2})?$`)
|
||||
|
||||
func NewClient() (ret *Client) {
|
||||
vendorName := "Gemini"
|
||||
ret = &Client{}
|
||||
@@ -93,14 +112,13 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
// Convert messages to new SDK format
|
||||
contents := o.convertMessages(msgs)
|
||||
|
||||
// Generate content
|
||||
temperature := float32(opts.Temperature)
|
||||
topP := float32(opts.TopP)
|
||||
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
MaxOutputTokens: int32(opts.ModelContextLength),
|
||||
})
|
||||
cfg, err := o.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Generate content with optional tools
|
||||
response, err := client.Models.GenerateContent(ctx, o.buildModelNameFull(opts.Model), contents, cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -123,14 +141,13 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
||||
// Convert messages to new SDK format
|
||||
contents := o.convertMessages(msgs)
|
||||
|
||||
// Generate streaming content
|
||||
temperature := float32(opts.Temperature)
|
||||
topP := float32(opts.TopP)
|
||||
stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, &genai.GenerateContentConfig{
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
MaxOutputTokens: int32(opts.ModelContextLength),
|
||||
})
|
||||
cfg, err := o.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Generate streaming content with optional tools
|
||||
stream := client.Models.GenerateContentStream(ctx, o.buildModelNameFull(opts.Model), contents, cfg)
|
||||
|
||||
for response, err := range stream {
|
||||
if err != nil {
|
||||
@@ -153,20 +170,86 @@ func (o *Client) NeedsRawMode(modelName string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// buildGenerateContentConfig constructs the generation config with optional tools.
|
||||
// When search is enabled it injects the Google Search tool. The optional search
|
||||
// location accepts either:
|
||||
// - A timezone in the format "Continent/City" (e.g., "America/Los_Angeles")
|
||||
// - An ISO language code "ll" or "ll-CC" (e.g., "en" or "en-US")
|
||||
//
|
||||
// Underscores are normalized to hyphens. Returns an error if the location is
|
||||
// invalid.
|
||||
func (o *Client) buildGenerateContentConfig(opts *domain.ChatOptions) (*genai.GenerateContentConfig, error) {
|
||||
temperature := float32(opts.Temperature)
|
||||
topP := float32(opts.TopP)
|
||||
cfg := &genai.GenerateContentConfig{
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
MaxOutputTokens: int32(opts.ModelContextLength),
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
cfg.Tools = []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
||||
if loc := opts.SearchLocation; loc != "" {
|
||||
if isValidLocationFormat(loc) {
|
||||
loc = normalizeLocation(loc)
|
||||
cfg.ToolConfig = &genai.ToolConfig{
|
||||
RetrievalConfig: &genai.RetrievalConfig{LanguageCode: loc},
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf(errInvalidLocationFormat, loc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// buildModelNameFull adds the "models/" prefix for API calls
|
||||
func (o *Client) buildModelNameFull(modelName string) string {
|
||||
if strings.HasPrefix(modelName, "models/") {
|
||||
if strings.HasPrefix(modelName, modelPrefix) {
|
||||
return modelName
|
||||
}
|
||||
return "models/" + modelName
|
||||
return modelPrefix + modelName
|
||||
}
|
||||
|
||||
func isValidLocationFormat(location string) bool {
|
||||
if strings.Contains(location, locationSeparator) {
|
||||
parts := strings.Split(location, locationSeparator)
|
||||
return len(parts) == 2 && parts[0] != "" && parts[1] != ""
|
||||
}
|
||||
return isValidLanguageCode(location)
|
||||
}
|
||||
|
||||
func normalizeLocation(location string) string {
|
||||
if strings.Contains(location, locationSeparator) {
|
||||
return location
|
||||
}
|
||||
return strings.Replace(location, langCodeSeparator, langCodeNormalizedSep, 1)
|
||||
}
|
||||
|
||||
// isValidLanguageCode reports whether the input is an ISO 639-1 language code
|
||||
// optionally followed by an ISO 3166-1 country code. Underscores are
|
||||
// normalized to hyphens before validation.
|
||||
func isValidLanguageCode(code string) bool {
|
||||
normalized := strings.Replace(code, langCodeSeparator, langCodeNormalizedSep, 1)
|
||||
parts := strings.Split(normalized, langCodeNormalizedSep)
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
return langCodeRegex.MatchString(strings.ToLower(parts[0]))
|
||||
case 2:
|
||||
formatted := strings.ToLower(parts[0]) + langCodeNormalizedSep + strings.ToUpper(parts[1])
|
||||
return langCodeRegex.MatchString(formatted)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isTTSModel checks if the model is a text-to-speech model
|
||||
func (o *Client) isTTSModel(modelName string) bool {
|
||||
lowerModel := strings.ToLower(modelName)
|
||||
return strings.Contains(lowerModel, "tts") ||
|
||||
strings.Contains(lowerModel, "preview-tts") ||
|
||||
strings.Contains(lowerModel, "text-to-speech")
|
||||
return strings.Contains(lowerModel, modelTypeTTS) ||
|
||||
strings.Contains(lowerModel, modelTypePreviewTTS) ||
|
||||
strings.Contains(lowerModel, modelTypeTextToSpeech)
|
||||
}
|
||||
|
||||
// extractTextForTTS extracts text content from chat messages for TTS generation
|
||||
@@ -370,19 +453,71 @@ func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Co
|
||||
return contents
|
||||
}
|
||||
|
||||
// extractTextFromResponse extracts text content from the response
|
||||
// extractTextFromResponse extracts text content from the response and appends
|
||||
// any web citations in a standardized format.
|
||||
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
|
||||
var result strings.Builder
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := o.extractTextParts(response)
|
||||
citations := o.extractCitations(response)
|
||||
if len(citations) > 0 {
|
||||
return text + citationHeader + strings.Join(citations, citationSeparator)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func (o *Client) extractTextParts(response *genai.GenerateContentResponse) string {
|
||||
var builder strings.Builder
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate.Content != nil {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.Text != "" {
|
||||
result.WriteString(part.Text)
|
||||
}
|
||||
if candidate == nil || candidate.Content == nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
builder.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result.String()
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (o *Client) extractCitations(response *genai.GenerateContentResponse) []string {
|
||||
if response == nil || len(response.Candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
citationMap := make(map[string]bool)
|
||||
var citations []string
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate == nil || candidate.GroundingMetadata == nil {
|
||||
continue
|
||||
}
|
||||
chunks := candidate.GroundingMetadata.GroundingChunks
|
||||
if len(chunks) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, chunk := range chunks {
|
||||
if chunk == nil || chunk.Web == nil {
|
||||
continue
|
||||
}
|
||||
uri := chunk.Web.URI
|
||||
title := chunk.Web.Title
|
||||
if uri == "" || title == "" {
|
||||
continue
|
||||
}
|
||||
var keyBuilder strings.Builder
|
||||
keyBuilder.WriteString(uri)
|
||||
keyBuilder.WriteByte('|')
|
||||
keyBuilder.WriteString(title)
|
||||
key := keyBuilder.String()
|
||||
if !citationMap[key] {
|
||||
citationMap[key] = true
|
||||
citationText := fmt.Sprintf(citationFormat, title, uri)
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
return citations
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
)
|
||||
|
||||
// Test buildModelNameFull method
|
||||
@@ -53,6 +55,106 @@ func TestExtractTextFromResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTextFromResponse_Nil(t *testing.T) {
|
||||
client := &Client{}
|
||||
if got := client.extractTextFromResponse(nil); got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
|
||||
client := &Client{}
|
||||
response := &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
Content: &genai.Content{Parts: []*genai.Part{{Text: "Hello"}}},
|
||||
GroundingMetadata: &genai.GroundingMetadata{GroundingChunks: nil},
|
||||
},
|
||||
},
|
||||
}
|
||||
if got := client.extractTextFromResponse(response); got != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGenerateContentConfig_WithSearch(t *testing.T) {
|
||||
client := &Client{}
|
||||
opts := &domain.ChatOptions{Search: true}
|
||||
|
||||
cfg, err := client.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.Tools == nil || len(cfg.Tools) != 1 || cfg.Tools[0].GoogleSearch == nil {
|
||||
t.Errorf("expected google search tool to be included")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGenerateContentConfig_WithSearchAndLocation(t *testing.T) {
|
||||
client := &Client{}
|
||||
opts := &domain.ChatOptions{Search: true, SearchLocation: "America/Los_Angeles"}
|
||||
|
||||
cfg, err := client.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.ToolConfig == nil || cfg.ToolConfig.RetrievalConfig == nil {
|
||||
t.Fatalf("expected retrieval config when search location provided")
|
||||
}
|
||||
if cfg.ToolConfig.RetrievalConfig.LanguageCode != opts.SearchLocation {
|
||||
t.Errorf("expected language code %s, got %s", opts.SearchLocation, cfg.ToolConfig.RetrievalConfig.LanguageCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGenerateContentConfig_InvalidLocation(t *testing.T) {
|
||||
client := &Client{}
|
||||
opts := &domain.ChatOptions{Search: true, SearchLocation: "invalid"}
|
||||
|
||||
_, err := client.buildGenerateContentConfig(opts)
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for invalid location")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGenerateContentConfig_LanguageCodeNormalization(t *testing.T) {
|
||||
client := &Client{}
|
||||
opts := &domain.ChatOptions{Search: true, SearchLocation: "en_US"}
|
||||
|
||||
cfg, err := client.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if cfg.ToolConfig == nil || cfg.ToolConfig.RetrievalConfig.LanguageCode != "en-US" {
|
||||
t.Fatalf("expected normalized language code 'en-US', got %+v", cfg.ToolConfig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
client := &Client{}
|
||||
response := &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
Content: &genai.Content{Parts: []*genai.Part{{Text: "Based on recent research, AI is advancing rapidly."}}},
|
||||
GroundingMetadata: &genai.GroundingMetadata{
|
||||
GroundingChunks: []*genai.GroundingChunk{
|
||||
{Web: &genai.GroundingChunkWeb{URI: "https://example.com/ai", Title: "AI Research"}},
|
||||
{Web: &genai.GroundingChunkWeb{URI: "https://news.com/tech", Title: "Tech News"}},
|
||||
{Web: &genai.GroundingChunkWeb{URI: "https://example.com/ai", Title: "AI Research"}}, // duplicate
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := client.extractTextFromResponse(response)
|
||||
if !strings.Contains(result, "## Sources") {
|
||||
t.Fatalf("expected sources section in result: %s", result)
|
||||
}
|
||||
if strings.Count(result, "- [") != 2 {
|
||||
t.Errorf("expected 2 unique citations, got %d", strings.Count(result, "- ["))
|
||||
}
|
||||
}
|
||||
|
||||
// Test convertMessages handles role mapping correctly
|
||||
func TestConvertMessagesRoles(t *testing.T) {
|
||||
client := &Client{}
|
||||
|
||||
Reference in New Issue
Block a user