feat: add web search tool support for Anthropic models

## CHANGES

- Add --search flag to enable web search
- Add --search-location for timezone-based results
- Pass search options through ChatOptions struct
- Implement web search tool in Anthropic client
- Format search citations with sources section
- Add comprehensive tests for search functionality
- Remove plugin-level web search configuration
This commit is contained in:
Kayvan Sylvan
2025-07-03 22:40:39 -07:00
parent 095890a556
commit d232222787
5 changed files with 234 additions and 20 deletions

View File

@@ -74,6 +74,8 @@ type Flags struct {
ListStrategies bool `long:"liststrategies" description:"List all strategies"`
ListVendors bool `long:"listvendors" description:"List all vendors"`
ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"`
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic)"`
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
}
var debug = false
@@ -263,6 +265,8 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
Raw: o.Raw,
Seed: o.Seed,
ModelContextLength: o.ModelContextLength,
Search: o.Search,
SearchLocation: o.SearchLocation,
}
return
}

View File

@@ -26,6 +26,8 @@ type ChatOptions struct {
Seed int
ModelContextLength int
MaxTokens int
Search bool
SearchLocation string
}
// NormalizeMessages remove empty messages and ensure messages order user-assist-user

View File

@@ -5,8 +5,6 @@ import (
"fmt"
"strings"
"github.com/samber/lo"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/danielmiessler/fabric/chat"
@@ -29,9 +27,6 @@ func NewClient() (ret *Client) {
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
ret.ApiBaseURL.Value = defaultBaseUrl
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
ret.UseWebTool = ret.AddSetupQuestionBool("Web Search Tool Enabled", false)
ret.WebToolLocation = ret.AddSetupQuestionCustom("Web Search Tool Location", false,
"Enter your approximate timezone location for web search (e.g., 'America/Los_Angeles', see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).")
ret.maxTokens = 4096
ret.defaultRequiredUserMessage = "Hi"
@@ -49,10 +44,8 @@ func NewClient() (ret *Client) {
type Client struct {
*plugins.PluginBase
ApiBaseURL *plugins.SetupQuestion
ApiKey *plugins.SetupQuestion
UseWebTool *plugins.SetupQuestion
WebToolLocation *plugins.SetupQuestion
ApiBaseURL *plugins.SetupQuestion
ApiKey *plugins.SetupQuestion
maxTokens int
defaultRequiredUserMessage string
@@ -127,7 +120,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
Messages: msgs,
}
if plugins.ParseBoolElseFalse(an.UseWebTool.Value) {
if opts.Search {
// Build the web-search tool definition:
webTool := anthropic.WebSearchTool20250305Param{
Name: "web_search", // string literal instead of constant
@@ -138,9 +131,9 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
// MaxUses: anthropic.Opt[int64](5),
}
if an.WebToolLocation.Value != "" {
if opts.SearchLocation != "" {
webTool.UserLocation.Type = "approximate"
webTool.UserLocation.Timezone = anthropic.Opt(an.WebToolLocation.Value)
webTool.UserLocation.Timezone = anthropic.Opt(opts.SearchLocation)
}
// Wrap it in the union:
@@ -165,13 +158,37 @@ func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage,
return
}
texts := lo.FilterMap(message.Content, func(block anthropic.ContentBlockUnion, _ int) (ret string, ok bool) {
if ok = block.Type == "text" && block.Text != ""; ok {
ret = block.Text
var textParts []string
var citations []string
citationMap := make(map[string]bool) // To avoid duplicate citations
for _, block := range message.Content {
if block.Type == "text" && block.Text != "" {
textParts = append(textParts, block.Text)
// Extract citations from this text block
for _, citation := range block.Citations {
if citation.Type == "web_search_result_location" {
citationKey := citation.URL + "|" + citation.Title
if !citationMap[citationKey] {
citationMap[citationKey] = true
citationText := fmt.Sprintf("- [%s](%s)", citation.Title, citation.URL)
if citation.CitedText != "" {
citationText += fmt.Sprintf(" - \"%s\"", citation.CitedText)
}
citations = append(citations, citationText)
}
}
}
}
return
})
ret = strings.Join(texts, "")
}
ret = strings.Join(textParts, "")
// Append citations if any were found
if len(citations) > 0 {
ret += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
}
return
}

View File

@@ -1,7 +1,11 @@
package anthropic
import (
"strings"
"testing"
"github.com/anthropics/anthropic-sdk-go"
"github.com/danielmiessler/fabric/common"
)
// Test generated using Keploy
@@ -63,3 +67,192 @@ func TestClient_ListModels_ReturnsCorrectModels(t *testing.T) {
}
}
}
func TestBuildMessageParams_WithoutSearch(t *testing.T) {
client := NewClient()
opts := &common.ChatOptions{
Model: "claude-3-5-sonnet-latest",
Temperature: 0.7,
Search: false,
}
messages := []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")),
}
params := client.buildMessageParams(messages, opts)
if params.Tools != nil {
t.Error("Expected no tools when search is disabled, got tools")
}
if params.Model != anthropic.Model(opts.Model) {
t.Errorf("Expected model %s, got %s", opts.Model, params.Model)
}
if params.Temperature.Value != opts.Temperature {
t.Errorf("Expected temperature %f, got %f", opts.Temperature, params.Temperature.Value)
}
}
func TestBuildMessageParams_WithSearch(t *testing.T) {
client := NewClient()
opts := &common.ChatOptions{
Model: "claude-3-5-sonnet-latest",
Temperature: 0.7,
Search: true,
}
messages := []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather today?")),
}
params := client.buildMessageParams(messages, opts)
if params.Tools == nil {
t.Fatal("Expected tools when search is enabled, got nil")
}
if len(params.Tools) != 1 {
t.Errorf("Expected 1 tool, got %d", len(params.Tools))
}
webTool := params.Tools[0].OfWebSearchTool20250305
if webTool == nil {
t.Fatal("Expected web search tool, got nil")
}
if webTool.Name != "web_search" {
t.Errorf("Expected tool name 'web_search', got %s", webTool.Name)
}
if webTool.Type != "web_search_20250305" {
t.Errorf("Expected tool type 'web_search_20250305', got %s", webTool.Type)
}
}
func TestBuildMessageParams_WithSearchAndLocation(t *testing.T) {
client := NewClient()
opts := &common.ChatOptions{
Model: "claude-3-5-sonnet-latest",
Temperature: 0.7,
Search: true,
SearchLocation: "America/Los_Angeles",
}
messages := []anthropic.MessageParam{
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in San Francisco?")),
}
params := client.buildMessageParams(messages, opts)
if params.Tools == nil {
t.Fatal("Expected tools when search is enabled, got nil")
}
webTool := params.Tools[0].OfWebSearchTool20250305
if webTool == nil {
t.Fatal("Expected web search tool, got nil")
}
if webTool.UserLocation.Type != "approximate" {
t.Errorf("Expected location type 'approximate', got %s", webTool.UserLocation.Type)
}
if webTool.UserLocation.Timezone.Value != opts.SearchLocation {
t.Errorf("Expected timezone %s, got %s", opts.SearchLocation, webTool.UserLocation.Timezone.Value)
}
}
func TestCitationFormatting(t *testing.T) {
// Test the citation formatting logic by creating a mock message with citations
message := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{
{
Type: "text",
Text: "Based on recent research, artificial intelligence is advancing rapidly.",
Citations: []anthropic.TextCitationUnion{
{
Type: "web_search_result_location",
URL: "https://example.com/ai-research",
Title: "AI Research Advances 2025",
CitedText: "artificial intelligence is advancing rapidly",
},
{
Type: "web_search_result_location",
URL: "https://another-source.com/tech-news",
Title: "Technology News Today",
CitedText: "recent developments in AI",
},
},
},
{
Type: "text",
Text: " Machine learning models are becoming more sophisticated.",
Citations: []anthropic.TextCitationUnion{
{
Type: "web_search_result_location",
URL: "https://example.com/ai-research", // Duplicate URL should be deduplicated
Title: "AI Research Advances 2025",
CitedText: "machine learning models",
},
},
},
},
}
// Extract text and citations using the same logic as the Send method
var textParts []string
var citations []string
citationMap := make(map[string]bool)
for _, block := range message.Content {
if block.Type == "text" && block.Text != "" {
textParts = append(textParts, block.Text)
for _, citation := range block.Citations {
if citation.Type == "web_search_result_location" {
citationKey := citation.URL + "|" + citation.Title
if !citationMap[citationKey] {
citationMap[citationKey] = true
citationText := "- [" + citation.Title + "](" + citation.URL + ")"
if citation.CitedText != "" {
citationText += " - \"" + citation.CitedText + "\""
}
citations = append(citations, citationText)
}
}
}
}
}
result := strings.Join(textParts, "")
if len(citations) > 0 {
result += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
}
// Verify the result contains the expected text
expectedText := "Based on recent research, artificial intelligence is advancing rapidly. Machine learning models are becoming more sophisticated."
if !strings.Contains(result, expectedText) {
t.Errorf("Expected result to contain text: %s", expectedText)
}
// Verify citations are included
if !strings.Contains(result, "## Sources") {
t.Error("Expected result to contain Sources section")
}
if !strings.Contains(result, "[AI Research Advances 2025](https://example.com/ai-research)") {
t.Error("Expected result to contain first citation")
}
if !strings.Contains(result, "[Technology News Today](https://another-source.com/tech-news)") {
t.Error("Expected result to contain second citation")
}
// Verify deduplication - should only have 2 unique citations, not 3
citationCount := strings.Count(result, "- [")
if citationCount != 2 {
t.Errorf("Expected 2 unique citations, got %d", citationCount)
}
}

View File

@@ -152,7 +152,6 @@ func (o *Setting) FillEnvFileContent(buffer *bytes.Buffer) {
}
buffer.WriteString("\n")
}
return
}
func ParseBoolElseFalse(val string) (ret bool) {
@@ -279,7 +278,6 @@ func (o Settings) FillEnvFileContent(buffer *bytes.Buffer) {
for _, setting := range o {
setting.FillEnvFileContent(buffer)
}
return
}
type SetupQuestions []*SetupQuestion