From d23222278731929cabf1367e676fcb892610c50b Mon Sep 17 00:00:00 2001 From: Kayvan Sylvan Date: Thu, 3 Jul 2025 22:40:39 -0700 Subject: [PATCH] feat: add web search tool support for Anthropic models ## CHANGES - Add --search flag to enable web search - Add --search-location for timezone-based results - Pass search options through ChatOptions struct - Implement web search tool in Anthropic client - Format search citations with sources section - Add comprehensive tests for search functionality - Remove plugin-level web search configuration --- cli/flags.go | 4 + common/domain.go | 2 + plugins/ai/anthropic/anthropic.go | 53 ++++--- plugins/ai/anthropic/anthropic_test.go | 193 +++++++++++++++++++++++++ plugins/plugin.go | 2 - 5 files changed, 234 insertions(+), 20 deletions(-) diff --git a/cli/flags.go b/cli/flags.go index 37d3cbd9..0d92054f 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -74,6 +74,8 @@ type Flags struct { ListStrategies bool `long:"liststrategies" description:"List all strategies"` ListVendors bool `long:"listvendors" description:"List all vendors"` ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"` + Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic)"` + SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"` } var debug = false @@ -263,6 +265,8 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) { Raw: o.Raw, Seed: o.Seed, ModelContextLength: o.ModelContextLength, + Search: o.Search, + SearchLocation: o.SearchLocation, } return } diff --git a/common/domain.go b/common/domain.go index c0554ce5..44357e0b 100644 --- a/common/domain.go +++ b/common/domain.go @@ -26,6 +26,8 @@ type ChatOptions struct { Seed int ModelContextLength int MaxTokens int + Search bool + SearchLocation string } // NormalizeMessages remove empty messages and ensure messages order user-assist-user diff --git a/plugins/ai/anthropic/anthropic.go b/plugins/ai/anthropic/anthropic.go index 0e7d947a..ee8b6bd6 100644 --- a/plugins/ai/anthropic/anthropic.go +++ b/plugins/ai/anthropic/anthropic.go @@ -5,8 +5,6 @@ import ( "fmt" "strings" - "github.com/samber/lo" - "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" "github.com/danielmiessler/fabric/chat" @@ -29,9 +27,6 @@ func NewClient() (ret *Client) { ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false) ret.ApiBaseURL.Value = defaultBaseUrl ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true) - ret.UseWebTool = ret.AddSetupQuestionBool("Web Search Tool Enabled", false) - ret.WebToolLocation = ret.AddSetupQuestionCustom("Web Search Tool Location", false, - "Enter your approximate timezone location for web search (e.g., 'America/Los_Angeles', see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).") ret.maxTokens = 4096 ret.defaultRequiredUserMessage = "Hi" @@ -49,10 +44,8 @@ func NewClient() (ret *Client) { type Client struct { *plugins.PluginBase - ApiBaseURL *plugins.SetupQuestion - ApiKey *plugins.SetupQuestion - UseWebTool *plugins.SetupQuestion - WebToolLocation *plugins.SetupQuestion + ApiBaseURL *plugins.SetupQuestion + ApiKey *plugins.SetupQuestion maxTokens int defaultRequiredUserMessage string @@ -127,7 +120,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common Messages: msgs, } - if plugins.ParseBoolElseFalse(an.UseWebTool.Value) { + if opts.Search { // Build the web-search tool definition: webTool := anthropic.WebSearchTool20250305Param{ Name: "web_search", // string literal instead of constant @@ -138,9 +131,9 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common // MaxUses: anthropic.Opt[int64](5), } - if an.WebToolLocation.Value != "" { + if opts.SearchLocation != "" { webTool.UserLocation.Type = "approximate" - webTool.UserLocation.Timezone = anthropic.Opt(an.WebToolLocation.Value) + webTool.UserLocation.Timezone = anthropic.Opt(opts.SearchLocation) } // Wrap it in the union: @@ -165,13 +158,37 @@ func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, return } - texts := lo.FilterMap(message.Content, func(block anthropic.ContentBlockUnion, _ int) (ret string, ok bool) { - if ok = block.Type == "text" && block.Text != ""; ok { - ret = block.Text + var textParts []string + var citations []string + citationMap := make(map[string]bool) // To avoid duplicate citations + + for _, block := range message.Content { + if block.Type == "text" && block.Text != "" { + textParts = append(textParts, block.Text) + + // Extract citations from this text block + for _, citation := range block.Citations { + if citation.Type == "web_search_result_location" { + citationKey := citation.URL + "|" + citation.Title + if !citationMap[citationKey] { + citationMap[citationKey] = true + citationText := fmt.Sprintf("- [%s](%s)", citation.Title, citation.URL) + if citation.CitedText != "" { + citationText += fmt.Sprintf(" - \"%s\"", citation.CitedText) + } + citations = append(citations, citationText) + } + } + } } - return - }) - ret = strings.Join(texts, "") + } + + ret = strings.Join(textParts, "") + + // Append citations if any were found + if len(citations) > 0 { + ret += "\n\n## Sources\n\n" + strings.Join(citations, "\n") + } return } diff --git a/plugins/ai/anthropic/anthropic_test.go b/plugins/ai/anthropic/anthropic_test.go index e89ea3db..d4497386 100644 --- a/plugins/ai/anthropic/anthropic_test.go +++ b/plugins/ai/anthropic/anthropic_test.go @@ -1,7 +1,11 @@ package anthropic import ( + "strings" "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/danielmiessler/fabric/common" ) // Test generated using Keploy @@ -63,3 +67,192 @@ func TestClient_ListModels_ReturnsCorrectModels(t *testing.T) { } } } + +func TestBuildMessageParams_WithoutSearch(t *testing.T) { + client := NewClient() + opts := &common.ChatOptions{ + Model: "claude-3-5-sonnet-latest", + Temperature: 0.7, + Search: false, + } + + messages := []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")), + } + + params := client.buildMessageParams(messages, opts) + + if params.Tools != nil { + t.Error("Expected no tools when search is disabled, got tools") + } + + if params.Model != anthropic.Model(opts.Model) { + t.Errorf("Expected model %s, got %s", opts.Model, params.Model) + } + + if params.Temperature.Value != opts.Temperature { + t.Errorf("Expected temperature %f, got %f", opts.Temperature, params.Temperature.Value) + } +} + +func TestBuildMessageParams_WithSearch(t *testing.T) { + client := NewClient() + opts := &common.ChatOptions{ + Model: "claude-3-5-sonnet-latest", + Temperature: 0.7, + Search: true, + } + + messages := []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather today?")), + } + + params := client.buildMessageParams(messages, opts) + + if params.Tools == nil { + t.Fatal("Expected tools when search is enabled, got nil") + } + + if len(params.Tools) != 1 { + t.Errorf("Expected 1 tool, got %d", len(params.Tools)) + } + + webTool := params.Tools[0].OfWebSearchTool20250305 + if webTool == nil { + t.Fatal("Expected web search tool, got nil") + } + + if webTool.Name != "web_search" { + t.Errorf("Expected tool name 'web_search', got %s", webTool.Name) + } + + if webTool.Type != "web_search_20250305" { + t.Errorf("Expected tool type 'web_search_20250305', got %s", webTool.Type) + } +} + +func TestBuildMessageParams_WithSearchAndLocation(t *testing.T) { + client := NewClient() + opts := &common.ChatOptions{ + Model: "claude-3-5-sonnet-latest", + Temperature: 0.7, + Search: true, + SearchLocation: "America/Los_Angeles", + } + + messages := []anthropic.MessageParam{ + anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in San Francisco?")), + } + + params := client.buildMessageParams(messages, opts) + + if params.Tools == nil { + t.Fatal("Expected tools when search is enabled, got nil") + } + + webTool := params.Tools[0].OfWebSearchTool20250305 + if webTool == nil { + t.Fatal("Expected web search tool, got nil") + } + + if webTool.UserLocation.Type != "approximate" { + t.Errorf("Expected location type 'approximate', got %s", webTool.UserLocation.Type) + } + + if webTool.UserLocation.Timezone.Value != opts.SearchLocation { + t.Errorf("Expected timezone %s, got %s", opts.SearchLocation, webTool.UserLocation.Timezone.Value) + } +} + +func TestCitationFormatting(t *testing.T) { + // Test the citation formatting logic by creating a mock message with citations + message := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{ + { + Type: "text", + Text: "Based on recent research, artificial intelligence is advancing rapidly.", + Citations: []anthropic.TextCitationUnion{ + { + Type: "web_search_result_location", + URL: "https://example.com/ai-research", + Title: "AI Research Advances 2025", + CitedText: "artificial intelligence is advancing rapidly", + }, + { + Type: "web_search_result_location", + URL: "https://another-source.com/tech-news", + Title: "Technology News Today", + CitedText: "recent developments in AI", + }, + }, + }, + { + Type: "text", + Text: " Machine learning models are becoming more sophisticated.", + Citations: []anthropic.TextCitationUnion{ + { + Type: "web_search_result_location", + URL: "https://example.com/ai-research", // Duplicate URL should be deduplicated + Title: "AI Research Advances 2025", + CitedText: "machine learning models", + }, + }, + }, + }, + } + + // Extract text and citations using the same logic as the Send method + var textParts []string + var citations []string + citationMap := make(map[string]bool) + + for _, block := range message.Content { + if block.Type == "text" && block.Text != "" { + textParts = append(textParts, block.Text) + + for _, citation := range block.Citations { + if citation.Type == "web_search_result_location" { + citationKey := citation.URL + "|" + citation.Title + if !citationMap[citationKey] { + citationMap[citationKey] = true + citationText := "- [" + citation.Title + "](" + citation.URL + ")" + if citation.CitedText != "" { + citationText += " - \"" + citation.CitedText + "\"" + } + citations = append(citations, citationText) + } + } + } + } + } + + result := strings.Join(textParts, "") + if len(citations) > 0 { + result += "\n\n## Sources\n\n" + strings.Join(citations, "\n") + } + + // Verify the result contains the expected text + expectedText := "Based on recent research, artificial intelligence is advancing rapidly. Machine learning models are becoming more sophisticated." + if !strings.Contains(result, expectedText) { + t.Errorf("Expected result to contain text: %s", expectedText) + } + + // Verify citations are included + if !strings.Contains(result, "## Sources") { + t.Error("Expected result to contain Sources section") + } + + if !strings.Contains(result, "[AI Research Advances 2025](https://example.com/ai-research)") { + t.Error("Expected result to contain first citation") + } + + if !strings.Contains(result, "[Technology News Today](https://another-source.com/tech-news)") { + t.Error("Expected result to contain second citation") + } + + // Verify deduplication - should only have 2 unique citations, not 3 + citationCount := strings.Count(result, "- [") + if citationCount != 2 { + t.Errorf("Expected 2 unique citations, got %d", citationCount) + } +} diff --git a/plugins/plugin.go b/plugins/plugin.go index fbb85a1b..2561f9a8 100644 --- a/plugins/plugin.go +++ b/plugins/plugin.go @@ -152,7 +152,6 @@ func (o *Setting) FillEnvFileContent(buffer *bytes.Buffer) { } buffer.WriteString("\n") } - return } func ParseBoolElseFalse(val string) (ret bool) { @@ -279,7 +278,6 @@ func (o Settings) FillEnvFileContent(buffer *bytes.Buffer) { for _, setting := range o { setting.FillEnvFileContent(buffer) } - return } type SetupQuestions []*SetupQuestion