mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-10 06:48:04 -05:00
feat: add web search tool support for Anthropic models
## CHANGES - Add --search flag to enable web search - Add --search-location for timezone-based results - Pass search options through ChatOptions struct - Implement web search tool in Anthropic client - Format search citations with sources section - Add comprehensive tests for search functionality - Remove plugin-level web search configuration
This commit is contained in:
@@ -74,6 +74,8 @@ type Flags struct {
|
||||
ListStrategies bool `long:"liststrategies" description:"List all strategies"`
|
||||
ListVendors bool `long:"listvendors" description:"List all vendors"`
|
||||
ShellCompleteOutput bool `long:"shell-complete-list" description:"Output raw list without headers/formatting (for shell completion)"`
|
||||
Search bool `long:"search" description:"Enable web search tool for supported models (Anthropic)"`
|
||||
SearchLocation string `long:"search-location" description:"Set location for web search results (e.g., 'America/Los_Angeles')"`
|
||||
}
|
||||
|
||||
var debug = false
|
||||
@@ -263,6 +265,8 @@ func (o *Flags) BuildChatOptions() (ret *common.ChatOptions) {
|
||||
Raw: o.Raw,
|
||||
Seed: o.Seed,
|
||||
ModelContextLength: o.ModelContextLength,
|
||||
Search: o.Search,
|
||||
SearchLocation: o.SearchLocation,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -26,6 +26,8 @@ type ChatOptions struct {
|
||||
Seed int
|
||||
ModelContextLength int
|
||||
MaxTokens int
|
||||
Search bool
|
||||
SearchLocation string
|
||||
}
|
||||
|
||||
// NormalizeMessages remove empty messages and ensure messages order user-assist-user
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/option"
|
||||
"github.com/danielmiessler/fabric/chat"
|
||||
@@ -29,9 +27,6 @@ func NewClient() (ret *Client) {
|
||||
ret.ApiBaseURL = ret.AddSetupQuestion("API Base URL", false)
|
||||
ret.ApiBaseURL.Value = defaultBaseUrl
|
||||
ret.ApiKey = ret.PluginBase.AddSetupQuestion("API key", true)
|
||||
ret.UseWebTool = ret.AddSetupQuestionBool("Web Search Tool Enabled", false)
|
||||
ret.WebToolLocation = ret.AddSetupQuestionCustom("Web Search Tool Location", false,
|
||||
"Enter your approximate timezone location for web search (e.g., 'America/Los_Angeles', see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).")
|
||||
|
||||
ret.maxTokens = 4096
|
||||
ret.defaultRequiredUserMessage = "Hi"
|
||||
@@ -49,10 +44,8 @@ func NewClient() (ret *Client) {
|
||||
|
||||
type Client struct {
|
||||
*plugins.PluginBase
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
UseWebTool *plugins.SetupQuestion
|
||||
WebToolLocation *plugins.SetupQuestion
|
||||
ApiBaseURL *plugins.SetupQuestion
|
||||
ApiKey *plugins.SetupQuestion
|
||||
|
||||
maxTokens int
|
||||
defaultRequiredUserMessage string
|
||||
@@ -127,7 +120,7 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
Messages: msgs,
|
||||
}
|
||||
|
||||
if plugins.ParseBoolElseFalse(an.UseWebTool.Value) {
|
||||
if opts.Search {
|
||||
// Build the web-search tool definition:
|
||||
webTool := anthropic.WebSearchTool20250305Param{
|
||||
Name: "web_search", // string literal instead of constant
|
||||
@@ -138,9 +131,9 @@ func (an *Client) buildMessageParams(msgs []anthropic.MessageParam, opts *common
|
||||
// MaxUses: anthropic.Opt[int64](5),
|
||||
}
|
||||
|
||||
if an.WebToolLocation.Value != "" {
|
||||
if opts.SearchLocation != "" {
|
||||
webTool.UserLocation.Type = "approximate"
|
||||
webTool.UserLocation.Timezone = anthropic.Opt(an.WebToolLocation.Value)
|
||||
webTool.UserLocation.Timezone = anthropic.Opt(opts.SearchLocation)
|
||||
}
|
||||
|
||||
// Wrap it in the union:
|
||||
@@ -165,13 +158,37 @@ func (an *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage,
|
||||
return
|
||||
}
|
||||
|
||||
texts := lo.FilterMap(message.Content, func(block anthropic.ContentBlockUnion, _ int) (ret string, ok bool) {
|
||||
if ok = block.Type == "text" && block.Text != ""; ok {
|
||||
ret = block.Text
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool) // To avoid duplicate citations
|
||||
|
||||
for _, block := range message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
textParts = append(textParts, block.Text)
|
||||
|
||||
// Extract citations from this text block
|
||||
for _, citation := range block.Citations {
|
||||
if citation.Type == "web_search_result_location" {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := fmt.Sprintf("- [%s](%s)", citation.Title, citation.URL)
|
||||
if citation.CitedText != "" {
|
||||
citationText += fmt.Sprintf(" - \"%s\"", citation.CitedText)
|
||||
}
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
})
|
||||
ret = strings.Join(texts, "")
|
||||
}
|
||||
|
||||
ret = strings.Join(textParts, "")
|
||||
|
||||
// Append citations if any were found
|
||||
if len(citations) > 0 {
|
||||
ret += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/danielmiessler/fabric/common"
|
||||
)
|
||||
|
||||
// Test generated using Keploy
|
||||
@@ -63,3 +67,192 @@ func TestClient_ListModels_ReturnsCorrectModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithoutSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: false,
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("Hello")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools != nil {
|
||||
t.Error("Expected no tools when search is disabled, got tools")
|
||||
}
|
||||
|
||||
if params.Model != anthropic.Model(opts.Model) {
|
||||
t.Errorf("Expected model %s, got %s", opts.Model, params.Model)
|
||||
}
|
||||
|
||||
if params.Temperature.Value != opts.Temperature {
|
||||
t.Errorf("Expected temperature %f, got %f", opts.Temperature, params.Temperature.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithSearch(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather today?")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools == nil {
|
||||
t.Fatal("Expected tools when search is enabled, got nil")
|
||||
}
|
||||
|
||||
if len(params.Tools) != 1 {
|
||||
t.Errorf("Expected 1 tool, got %d", len(params.Tools))
|
||||
}
|
||||
|
||||
webTool := params.Tools[0].OfWebSearchTool20250305
|
||||
if webTool == nil {
|
||||
t.Fatal("Expected web search tool, got nil")
|
||||
}
|
||||
|
||||
if webTool.Name != "web_search" {
|
||||
t.Errorf("Expected tool name 'web_search', got %s", webTool.Name)
|
||||
}
|
||||
|
||||
if webTool.Type != "web_search_20250305" {
|
||||
t.Errorf("Expected tool type 'web_search_20250305', got %s", webTool.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildMessageParams_WithSearchAndLocation(t *testing.T) {
|
||||
client := NewClient()
|
||||
opts := &common.ChatOptions{
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
Temperature: 0.7,
|
||||
Search: true,
|
||||
SearchLocation: "America/Los_Angeles",
|
||||
}
|
||||
|
||||
messages := []anthropic.MessageParam{
|
||||
anthropic.NewUserMessage(anthropic.NewTextBlock("What's the weather in San Francisco?")),
|
||||
}
|
||||
|
||||
params := client.buildMessageParams(messages, opts)
|
||||
|
||||
if params.Tools == nil {
|
||||
t.Fatal("Expected tools when search is enabled, got nil")
|
||||
}
|
||||
|
||||
webTool := params.Tools[0].OfWebSearchTool20250305
|
||||
if webTool == nil {
|
||||
t.Fatal("Expected web search tool, got nil")
|
||||
}
|
||||
|
||||
if webTool.UserLocation.Type != "approximate" {
|
||||
t.Errorf("Expected location type 'approximate', got %s", webTool.UserLocation.Type)
|
||||
}
|
||||
|
||||
if webTool.UserLocation.Timezone.Value != opts.SearchLocation {
|
||||
t.Errorf("Expected timezone %s, got %s", opts.SearchLocation, webTool.UserLocation.Timezone.Value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
// Test the citation formatting logic by creating a mock message with citations
|
||||
message := &anthropic.Message{
|
||||
Content: []anthropic.ContentBlockUnion{
|
||||
{
|
||||
Type: "text",
|
||||
Text: "Based on recent research, artificial intelligence is advancing rapidly.",
|
||||
Citations: []anthropic.TextCitationUnion{
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com/ai-research",
|
||||
Title: "AI Research Advances 2025",
|
||||
CitedText: "artificial intelligence is advancing rapidly",
|
||||
},
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://another-source.com/tech-news",
|
||||
Title: "Technology News Today",
|
||||
CitedText: "recent developments in AI",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "text",
|
||||
Text: " Machine learning models are becoming more sophisticated.",
|
||||
Citations: []anthropic.TextCitationUnion{
|
||||
{
|
||||
Type: "web_search_result_location",
|
||||
URL: "https://example.com/ai-research", // Duplicate URL should be deduplicated
|
||||
Title: "AI Research Advances 2025",
|
||||
CitedText: "machine learning models",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract text and citations using the same logic as the Send method
|
||||
var textParts []string
|
||||
var citations []string
|
||||
citationMap := make(map[string]bool)
|
||||
|
||||
for _, block := range message.Content {
|
||||
if block.Type == "text" && block.Text != "" {
|
||||
textParts = append(textParts, block.Text)
|
||||
|
||||
for _, citation := range block.Citations {
|
||||
if citation.Type == "web_search_result_location" {
|
||||
citationKey := citation.URL + "|" + citation.Title
|
||||
if !citationMap[citationKey] {
|
||||
citationMap[citationKey] = true
|
||||
citationText := "- [" + citation.Title + "](" + citation.URL + ")"
|
||||
if citation.CitedText != "" {
|
||||
citationText += " - \"" + citation.CitedText + "\""
|
||||
}
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result := strings.Join(textParts, "")
|
||||
if len(citations) > 0 {
|
||||
result += "\n\n## Sources\n\n" + strings.Join(citations, "\n")
|
||||
}
|
||||
|
||||
// Verify the result contains the expected text
|
||||
expectedText := "Based on recent research, artificial intelligence is advancing rapidly. Machine learning models are becoming more sophisticated."
|
||||
if !strings.Contains(result, expectedText) {
|
||||
t.Errorf("Expected result to contain text: %s", expectedText)
|
||||
}
|
||||
|
||||
// Verify citations are included
|
||||
if !strings.Contains(result, "## Sources") {
|
||||
t.Error("Expected result to contain Sources section")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "[AI Research Advances 2025](https://example.com/ai-research)") {
|
||||
t.Error("Expected result to contain first citation")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "[Technology News Today](https://another-source.com/tech-news)") {
|
||||
t.Error("Expected result to contain second citation")
|
||||
}
|
||||
|
||||
// Verify deduplication - should only have 2 unique citations, not 3
|
||||
citationCount := strings.Count(result, "- [")
|
||||
if citationCount != 2 {
|
||||
t.Errorf("Expected 2 unique citations, got %d", citationCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,7 +152,6 @@ func (o *Setting) FillEnvFileContent(buffer *bytes.Buffer) {
|
||||
}
|
||||
buffer.WriteString("\n")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ParseBoolElseFalse(val string) (ret bool) {
|
||||
@@ -279,7 +278,6 @@ func (o Settings) FillEnvFileContent(buffer *bytes.Buffer) {
|
||||
for _, setting := range o {
|
||||
setting.FillEnvFileContent(buffer)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type SetupQuestions []*SetupQuestion
|
||||
|
||||
Reference in New Issue
Block a user