Merge pull request #1926 from henricook/feature/vertexai-dynamic-model-listing

feat(vertexai): add dynamic model listing and multi-model support
This commit is contained in:
Kayvan Sylvan
2026-01-08 11:33:33 -08:00
committed by GitHub
7 changed files with 1085 additions and 156 deletions

View File

@@ -0,0 +1,7 @@
### PR [#1926](https://github.com/danielmiessler/Fabric/pull/1926) by [henricook](https://github.com/henricook) and [ksylvan](https://github.com/ksylvan): feat(vertexai): add dynamic model listing and multi-model support
- Dynamic model listing from Vertex AI Model Garden API
- Support for both Gemini (genai SDK) and Claude (Anthropic SDK) models
- Curated Gemini model list with web search support for Gemini models
- Thinking/extended thinking support for Gemini
- TopP parameter support for Claude models

View File

@@ -10,9 +10,9 @@ import (
"strings"
"github.com/danielmiessler/fabric/internal/chat"
"github.com/danielmiessler/fabric/internal/plugins"
"github.com/danielmiessler/fabric/internal/domain"
"github.com/danielmiessler/fabric/internal/plugins"
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
"google.golang.org/genai"
)
@@ -29,10 +29,6 @@ const (
)
const (
citationHeader = "\n\n## Sources\n\n"
citationSeparator = "\n"
citationFormat = "- [%s](%s)"
errInvalidLocationFormat = "invalid search location format %q: must be timezone (e.g., 'America/Los_Angeles') or language code (e.g., 'en-US')"
locationSeparator = "/"
langCodeSeparator = "_"
@@ -111,7 +107,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
}
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
contents := geminicommon.ConvertMessages(msgs)
cfg, err := o.buildGenerateContentConfig(opts)
if err != nil {
@@ -125,7 +121,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
}
// Extract text from response
ret = o.extractTextFromResponse(response)
ret = geminicommon.ExtractTextWithCitations(response)
return
}
@@ -142,7 +138,7 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
}
// Convert messages to new SDK format
contents := o.convertMessages(msgs)
contents := geminicommon.ConvertMessages(msgs)
cfg, err := o.buildGenerateContentConfig(opts)
if err != nil {
@@ -161,7 +157,7 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
return err
}
text := o.extractTextFromResponse(response)
text := geminicommon.ExtractTextWithCitations(response)
if text != "" {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
@@ -218,10 +214,14 @@ func parseThinkingConfig(level domain.ThinkingLevel) (*genai.ThinkingConfig, boo
func (o *Client) buildGenerateContentConfig(opts *domain.ChatOptions) (*genai.GenerateContentConfig, error) {
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
var maxTokens int32
if opts.MaxTokens > 0 {
maxTokens = int32(opts.MaxTokens)
}
cfg := &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(opts.ModelContextLength),
MaxOutputTokens: maxTokens,
}
if opts.Search {
@@ -452,113 +452,3 @@ func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
return result, nil
}
// convertMessages converts fabric chat messages to genai Content format
func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
var contents []*genai.Content
for _, msg := range msgs {
content := &genai.Content{Parts: []*genai.Part{}}
switch msg.Role {
case chat.ChatMessageRoleAssistant:
content.Role = "model"
case chat.ChatMessageRoleUser:
content.Role = "user"
case chat.ChatMessageRoleSystem, chat.ChatMessageRoleDeveloper, chat.ChatMessageRoleFunction, chat.ChatMessageRoleTool:
// Gemini's API only accepts "user" and "model" roles.
// Map all other roles to "user" to preserve instruction context.
content.Role = "user"
default:
content.Role = "user"
}
if strings.TrimSpace(msg.Content) != "" {
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
}
// Handle multi-content messages (images, etc.)
for _, part := range msg.MultiContent {
switch part.Type {
case chat.ChatMessagePartTypeText:
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
case chat.ChatMessagePartTypeImageURL:
// TODO: Handle image URLs if needed
// This would require downloading and converting to inline data
}
}
contents = append(contents, content)
}
return contents
}
// extractTextFromResponse extracts text content from the response and appends
// any web citations in a standardized format.
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
if response == nil {
return ""
}
text := o.extractTextParts(response)
citations := o.extractCitations(response)
if len(citations) > 0 {
return text + citationHeader + strings.Join(citations, citationSeparator)
}
return text
}
func (o *Client) extractTextParts(response *genai.GenerateContentResponse) string {
var builder strings.Builder
for _, candidate := range response.Candidates {
if candidate == nil || candidate.Content == nil {
continue
}
for _, part := range candidate.Content.Parts {
if part != nil && part.Text != "" {
builder.WriteString(part.Text)
}
}
}
return builder.String()
}
func (o *Client) extractCitations(response *genai.GenerateContentResponse) []string {
if response == nil || len(response.Candidates) == 0 {
return nil
}
citationMap := make(map[string]bool)
var citations []string
for _, candidate := range response.Candidates {
if candidate == nil || candidate.GroundingMetadata == nil {
continue
}
chunks := candidate.GroundingMetadata.GroundingChunks
if len(chunks) == 0 {
continue
}
for _, chunk := range chunks {
if chunk == nil || chunk.Web == nil {
continue
}
uri := chunk.Web.URI
title := chunk.Web.Title
if uri == "" || title == "" {
continue
}
var keyBuilder strings.Builder
keyBuilder.WriteString(uri)
keyBuilder.WriteByte('|')
keyBuilder.WriteString(title)
key := keyBuilder.String()
if !citationMap[key] {
citationMap[key] = true
citationText := fmt.Sprintf(citationFormat, title, uri)
citations = append(citations, citationText)
}
}
}
return citations
}

View File

@@ -4,10 +4,10 @@ import (
"strings"
"testing"
"google.golang.org/genai"
"github.com/danielmiessler/fabric/internal/chat"
"github.com/danielmiessler/fabric/internal/domain"
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
"google.golang.org/genai"
)
// Test buildModelNameFull method
@@ -31,9 +31,8 @@ func TestBuildModelNameFull(t *testing.T) {
}
}
// Test extractTextFromResponse method
// Test ExtractTextWithCitations from geminicommon
func TestExtractTextFromResponse(t *testing.T) {
client := &Client{}
response := &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{
{
@@ -48,7 +47,7 @@ func TestExtractTextFromResponse(t *testing.T) {
}
expected := "Hello, world!"
result := client.extractTextFromResponse(response)
result := geminicommon.ExtractTextWithCitations(response)
if result != expected {
t.Errorf("Expected %v, got %v", expected, result)
@@ -56,14 +55,12 @@ func TestExtractTextFromResponse(t *testing.T) {
}
func TestExtractTextFromResponse_Nil(t *testing.T) {
client := &Client{}
if got := client.extractTextFromResponse(nil); got != "" {
if got := geminicommon.ExtractTextWithCitations(nil); got != "" {
t.Fatalf("expected empty string, got %q", got)
}
}
func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
client := &Client{}
response := &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{
{
@@ -72,7 +69,7 @@ func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
},
},
}
if got := client.extractTextFromResponse(response); got != "Hello" {
if got := geminicommon.ExtractTextWithCitations(response); got != "Hello" {
t.Fatalf("expected 'Hello', got %q", got)
}
}
@@ -162,7 +159,6 @@ func TestBuildGenerateContentConfig_ThinkingTokens(t *testing.T) {
}
func TestCitationFormatting(t *testing.T) {
client := &Client{}
response := &genai.GenerateContentResponse{
Candidates: []*genai.Candidate{
{
@@ -178,7 +174,7 @@ func TestCitationFormatting(t *testing.T) {
},
}
result := client.extractTextFromResponse(response)
result := geminicommon.ExtractTextWithCitations(response)
if !strings.Contains(result, "## Sources") {
t.Fatalf("expected sources section in result: %s", result)
}
@@ -189,14 +185,13 @@ func TestCitationFormatting(t *testing.T) {
// Test convertMessages handles role mapping correctly
func TestConvertMessagesRoles(t *testing.T) {
client := &Client{}
msgs := []*chat.ChatCompletionMessage{
{Role: chat.ChatMessageRoleUser, Content: "user"},
{Role: chat.ChatMessageRoleAssistant, Content: "assistant"},
{Role: chat.ChatMessageRoleSystem, Content: "system"},
}
contents := client.convertMessages(msgs)
contents := geminicommon.ConvertMessages(msgs)
expected := []string{"user", "model", "user"}

View File

@@ -0,0 +1,130 @@
// Package geminicommon provides shared utilities for Gemini API integrations.
// Used by both the standalone Gemini provider (API key auth) and VertexAI provider (ADC auth).
package geminicommon
import (
"fmt"
"strings"
"github.com/danielmiessler/fabric/internal/chat"
"google.golang.org/genai"
)
// Citation formatting constants
const (
CitationHeader = "\n\n## Sources\n\n"
CitationSeparator = "\n"
CitationFormat = "- [%s](%s)"
)
// ConvertMessages converts fabric chat messages to genai Content format.
// Gemini's API only accepts "user" and "model" roles, so other roles are mapped to "user".
func ConvertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
var contents []*genai.Content
for _, msg := range msgs {
content := &genai.Content{Parts: []*genai.Part{}}
switch msg.Role {
case chat.ChatMessageRoleAssistant:
content.Role = "model"
case chat.ChatMessageRoleUser:
content.Role = "user"
case chat.ChatMessageRoleSystem, chat.ChatMessageRoleDeveloper, chat.ChatMessageRoleFunction, chat.ChatMessageRoleTool:
// Gemini's API only accepts "user" and "model" roles.
// Map all other roles to "user" to preserve instruction context.
content.Role = "user"
default:
content.Role = "user"
}
if strings.TrimSpace(msg.Content) != "" {
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
}
// Handle multi-content messages (images, etc.)
for _, part := range msg.MultiContent {
switch part.Type {
case chat.ChatMessagePartTypeText:
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
case chat.ChatMessagePartTypeImageURL:
// TODO: Handle image URLs if needed
// This would require downloading and converting to inline data
}
}
contents = append(contents, content)
}
return contents
}
// ExtractText extracts just the text parts from a Gemini response.
func ExtractText(response *genai.GenerateContentResponse) string {
if response == nil {
return ""
}
var builder strings.Builder
for _, candidate := range response.Candidates {
if candidate == nil || candidate.Content == nil {
continue
}
for _, part := range candidate.Content.Parts {
if part != nil && part.Text != "" {
builder.WriteString(part.Text)
}
}
}
return builder.String()
}
// ExtractTextWithCitations extracts text content from the response and appends
// any web citations in a standardized format.
func ExtractTextWithCitations(response *genai.GenerateContentResponse) string {
if response == nil {
return ""
}
text := ExtractText(response)
citations := ExtractCitations(response)
if len(citations) > 0 {
return text + CitationHeader + strings.Join(citations, CitationSeparator)
}
return text
}
// ExtractCitations extracts web citations from grounding metadata.
func ExtractCitations(response *genai.GenerateContentResponse) []string {
if response == nil || len(response.Candidates) == 0 {
return nil
}
citationMap := make(map[string]bool)
var citations []string
for _, candidate := range response.Candidates {
if candidate == nil || candidate.GroundingMetadata == nil {
continue
}
chunks := candidate.GroundingMetadata.GroundingChunks
if len(chunks) == 0 {
continue
}
for _, chunk := range chunks {
if chunk == nil || chunk.Web == nil {
continue
}
uri := chunk.Web.URI
title := chunk.Web.Title
if uri == "" || title == "" {
continue
}
key := uri + "|" + title
if !citationMap[key] {
citationMap[key] = true
citations = append(citations, fmt.Sprintf(CitationFormat, title, uri))
}
}
}
return citations
}

View File

@@ -0,0 +1,237 @@
package vertexai
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strings"
debuglog "github.com/danielmiessler/fabric/internal/log"
)
const (
// API limits
maxResponseSize = 10 * 1024 * 1024 // 10MB
errorResponseLimit = 1024 // 1KB for error messages
// Default region for Model Garden API (global doesn't work for this endpoint)
defaultModelGardenRegion = "us-central1"
)
// Supported Model Garden publishers (others can be added when SDK support is implemented)
var publishers = []string{"google", "anthropic"}
// publisherModelsResponse represents the API response from publishers.models.list
type publisherModelsResponse struct {
PublisherModels []publisherModel `json:"publisherModels"`
NextPageToken string `json:"nextPageToken"`
}
// publisherModel represents a single model in the API response
type publisherModel struct {
Name string `json:"name"` // Format: publishers/{publisher}/models/{model}
}
// fetchModelsPage makes a single API request and returns the parsed response.
// Extracted to ensure proper cleanup of HTTP response bodies in pagination loops.
func fetchModelsPage(ctx context.Context, httpClient *http.Client, url, projectID, publisher string) (*publisherModelsResponse, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
// Set quota project header required by Vertex AI API
req.Header.Set("x-goog-user-project", projectID)
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, errorResponseLimit))
debuglog.Debug(debuglog.Basic, "API error for %s: status %d, url: %s, body: %s\n", publisher, resp.StatusCode, url, string(bodyBytes))
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize+1))
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if len(bodyBytes) > maxResponseSize {
return nil, fmt.Errorf("response too large (>%d bytes)", maxResponseSize)
}
var response publisherModelsResponse
if err := json.Unmarshal(bodyBytes, &response); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &response, nil
}
// listPublisherModels fetches models from a specific publisher via the Model Garden API
func listPublisherModels(ctx context.Context, httpClient *http.Client, region, projectID, publisher string) ([]string, error) {
// Use default region if global or empty (Model Garden API requires a specific region)
if region == "" || region == "global" {
region = defaultModelGardenRegion
}
baseURL := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/publishers/%s/models", region, publisher)
var allModels []string
pageToken := ""
for {
url := baseURL
if pageToken != "" {
url = fmt.Sprintf("%s?pageToken=%s", baseURL, pageToken)
}
response, err := fetchModelsPage(ctx, httpClient, url, projectID, publisher)
if err != nil {
return nil, err
}
// Extract model names, stripping the publishers/{publisher}/models/ prefix
for _, model := range response.PublisherModels {
modelName := extractModelName(model.Name)
if modelName != "" {
allModels = append(allModels, modelName)
}
}
// Check for more pages
if response.NextPageToken == "" {
break
}
pageToken = response.NextPageToken
}
debuglog.Debug(debuglog.Detailed, "Listed %d models from publisher %s\n", len(allModels), publisher)
return allModels, nil
}
// extractModelName extracts the model name from the full resource path
// Input: "publishers/google/models/gemini-2.0-flash"
// Output: "gemini-2.0-flash"
func extractModelName(fullName string) string {
parts := strings.Split(fullName, "/")
if len(parts) >= 4 && parts[0] == "publishers" && parts[2] == "models" {
return parts[3]
}
// Fallback: return the last segment
if len(parts) > 0 {
return parts[len(parts)-1]
}
return fullName
}
// sortModels sorts models by priority: Gemini > Claude > Others
// Within each group, models are sorted alphabetically
func sortModels(models []string) []string {
sort.Slice(models, func(i, j int) bool {
pi := modelPriority(models[i])
pj := modelPriority(models[j])
if pi != pj {
return pi < pj
}
// Same priority: sort alphabetically (case-insensitive)
return strings.ToLower(models[i]) < strings.ToLower(models[j])
})
return models
}
// modelPriority returns the sort priority for a model (lower = higher priority)
func modelPriority(model string) int {
lower := strings.ToLower(model)
switch {
case strings.HasPrefix(lower, "gemini"):
return 1
case strings.HasPrefix(lower, "claude"):
return 2
default:
return 3
}
}
// knownGeminiModels is a curated list of Gemini models available on Vertex AI.
// Vertex AI doesn't provide a list API for Gemini models - they must be known ahead of time.
// This list is based on Google Cloud documentation as of January 2025.
// See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/models
var knownGeminiModels = []string{
// Gemini 3 (Preview)
"gemini-3-pro-preview",
"gemini-3-flash-preview",
// Gemini 2.5 (GA)
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
// Gemini 2.0 (GA)
"gemini-2.0-flash",
"gemini-2.0-flash-lite",
}
// getKnownGeminiModels returns the curated list of Gemini models available on Vertex AI.
// Unlike third-party models which can be listed via the Model Garden API,
// Gemini models must be known ahead of time as there's no list endpoint for them.
func getKnownGeminiModels() []string {
return knownGeminiModels
}
// isGeminiModel returns true if the model is a Gemini model
func isGeminiModel(modelName string) bool {
return strings.HasPrefix(strings.ToLower(modelName), "gemini")
}
// isConversationalModel returns true if the model is suitable for text generation/chat
// Filters out image generation, embeddings, and other non-conversational models
func isConversationalModel(modelName string) bool {
lower := strings.ToLower(modelName)
// Exclude patterns for non-conversational models
excludePatterns := []string{
"imagen", // Image generation models
"imagegeneration",
"imagetext",
"image-segmentation",
"embedding", // Embedding models
"textembedding",
"multimodalembedding",
"text-bison", // Legacy completion models (not chat)
"text-unicorn",
"code-bison", // Legacy code models
"code-gecko",
"codechat-bison", // Deprecated chat model
"chat-bison", // Deprecated chat model
"veo", // Video generation
"chirp", // Audio/speech models
"medlm", // Medical models (restricted)
"medical",
}
for _, pattern := range excludePatterns {
if strings.Contains(lower, pattern) {
return false
}
}
return true
}
// filterConversationalModels returns only models suitable for text generation/chat
func filterConversationalModels(models []string) []string {
var filtered []string
for _, model := range models {
if isConversationalModel(model) {
filtered = append(filtered, model)
}
}
return filtered
}

View File

@@ -9,13 +9,18 @@ import (
"github.com/anthropics/anthropic-sdk-go/vertex"
"github.com/danielmiessler/fabric/internal/chat"
"github.com/danielmiessler/fabric/internal/domain"
debuglog "github.com/danielmiessler/fabric/internal/log"
"github.com/danielmiessler/fabric/internal/plugins"
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/genai"
)
const (
cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
defaultRegion = "global"
maxTokens = 4096
defaultMaxTokens = 4096
)
// NewClient creates a new Vertex AI client for accessing Claude models via Google Cloud
@@ -59,17 +64,78 @@ func (c *Client) configure() error {
}
func (c *Client) ListModels() ([]string, error) {
// Return Claude models available on Vertex AI
return []string{
string(anthropic.ModelClaudeSonnet4_5),
string(anthropic.ModelClaudeOpus4_5),
string(anthropic.ModelClaudeHaiku4_5),
string(anthropic.ModelClaude3_7SonnetLatest),
string(anthropic.ModelClaude3_5HaikuLatest),
}, nil
ctx := context.Background()
// Get ADC credentials for API authentication
creds, err := google.FindDefaultCredentials(ctx, cloudPlatformScope)
if err != nil {
return nil, fmt.Errorf("failed to get Google credentials (ensure ADC is configured): %w", err)
}
httpClient := oauth2.NewClient(ctx, creds.TokenSource)
// Query all publishers in parallel for better performance
type result struct {
models []string
err error
publisher string
}
// +1 for known Gemini models (no API to list them)
results := make(chan result, len(publishers)+1)
// Query Model Garden API for third-party models
for _, pub := range publishers {
go func(publisher string) {
models, err := listPublisherModels(ctx, httpClient, c.Region.Value, c.ProjectID.Value, publisher)
results <- result{models: models, err: err, publisher: publisher}
}(pub)
}
// Add known Gemini models (Vertex AI doesn't have a list API for Gemini)
go func() {
results <- result{models: getKnownGeminiModels(), err: nil, publisher: "gemini"}
}()
// Collect results from all sources
var allModels []string
for range len(publishers) + 1 {
r := <-results
if r.err != nil {
// Log warning but continue - some sources may not be available
debuglog.Debug(debuglog.Basic, "Failed to list %s models: %v\n", r.publisher, r.err)
continue
}
allModels = append(allModels, r.models...)
}
if len(allModels) == 0 {
return nil, fmt.Errorf("no models found from any publisher")
}
// Filter to only conversational models and sort
filtered := filterConversationalModels(allModels)
if len(filtered) == 0 {
return nil, fmt.Errorf("no conversational models found")
}
return sortModels(filtered), nil
}
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
if isGeminiModel(opts.Model) {
return c.sendGemini(ctx, msgs, opts)
}
return c.sendClaude(ctx, msgs, opts)
}
// getMaxTokens returns the max output tokens to use for a request
func getMaxTokens(opts *domain.ChatOptions) int64 {
if opts.MaxTokens > 0 {
return int64(opts.MaxTokens)
}
return int64(defaultMaxTokens)
}
func (c *Client) sendClaude(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
if c.client == nil {
return "", fmt.Errorf("VertexAI client not initialized")
}
@@ -80,14 +146,22 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
return "", fmt.Errorf("no valid messages to send")
}
// Create the request
response, err := c.client.Messages.New(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(opts.Model),
MaxTokens: int64(maxTokens),
Messages: anthropicMessages,
Temperature: anthropic.Opt(opts.Temperature),
})
// Build request params
params := anthropic.MessageNewParams{
Model: anthropic.Model(opts.Model),
MaxTokens: getMaxTokens(opts),
Messages: anthropicMessages,
}
// Only set one of Temperature or TopP as some models don't allow both
// (following anthropic.go pattern)
if opts.TopP != domain.DefaultTopP {
params.TopP = anthropic.Opt(opts.TopP)
} else {
params.Temperature = anthropic.Opt(opts.Temperature)
}
response, err := c.client.Messages.New(ctx, params)
if err != nil {
return "", err
}
@@ -108,6 +182,13 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
}
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
if isGeminiModel(opts.Model) {
return c.sendStreamGemini(msgs, opts, channel)
}
return c.sendStreamClaude(msgs, opts, channel)
}
func (c *Client) sendStreamClaude(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
if c.client == nil {
close(channel)
return fmt.Errorf("VertexAI client not initialized")
@@ -122,13 +203,22 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
return fmt.Errorf("no valid messages to send")
}
// Build request params
params := anthropic.MessageNewParams{
Model: anthropic.Model(opts.Model),
MaxTokens: getMaxTokens(opts),
Messages: anthropicMessages,
}
// Only set one of Temperature or TopP as some models don't allow both
if opts.TopP != domain.DefaultTopP {
params.TopP = anthropic.Opt(opts.TopP)
} else {
params.Temperature = anthropic.Opt(opts.Temperature)
}
// Create streaming request
stream := c.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
Model: anthropic.Model(opts.Model),
MaxTokens: int64(maxTokens),
Messages: anthropicMessages,
Temperature: anthropic.Opt(opts.Temperature),
})
stream := c.client.Messages.NewStreaming(ctx, params)
// Process stream
for stream.Next() {
@@ -167,6 +257,144 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
return stream.Err()
}
// Gemini methods using genai SDK with Vertex AI backend
// getGeminiRegion returns the appropriate region for a Gemini model.
// Preview models are often only available on the global endpoint.
func (c *Client) getGeminiRegion(model string) string {
if strings.Contains(strings.ToLower(model), "preview") {
return "global"
}
return c.Region.Value
}
func (c *Client) sendGemini(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
client, err := genai.NewClient(ctx, &genai.ClientConfig{
Project: c.ProjectID.Value,
Location: c.getGeminiRegion(opts.Model),
Backend: genai.BackendVertexAI,
})
if err != nil {
return "", fmt.Errorf("failed to create Gemini client: %w", err)
}
contents := geminicommon.ConvertMessages(msgs)
if len(contents) == 0 {
return "", fmt.Errorf("no valid messages to send")
}
config := c.buildGeminiConfig(opts)
response, err := client.Models.GenerateContent(ctx, opts.Model, contents, config)
if err != nil {
return "", err
}
return geminicommon.ExtractTextWithCitations(response), nil
}
// buildGeminiConfig creates the generation config for Gemini models
// following the gemini.go pattern for feature parity
func (c *Client) buildGeminiConfig(opts *domain.ChatOptions) *genai.GenerateContentConfig {
temperature := float32(opts.Temperature)
topP := float32(opts.TopP)
config := &genai.GenerateContentConfig{
Temperature: &temperature,
TopP: &topP,
MaxOutputTokens: int32(getMaxTokens(opts)),
}
// Add web search support
if opts.Search {
config.Tools = []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
}
// Add thinking support
if tc := parseGeminiThinking(opts.Thinking); tc != nil {
config.ThinkingConfig = tc
}
return config
}
// parseGeminiThinking converts thinking level to Gemini thinking config
func parseGeminiThinking(level domain.ThinkingLevel) *genai.ThinkingConfig {
lower := strings.ToLower(strings.TrimSpace(string(level)))
switch domain.ThinkingLevel(lower) {
case "", domain.ThinkingOff:
return nil
case domain.ThinkingLow, domain.ThinkingMedium, domain.ThinkingHigh:
if budget, ok := domain.ThinkingBudgets[domain.ThinkingLevel(lower)]; ok {
b := int32(budget)
return &genai.ThinkingConfig{IncludeThoughts: true, ThinkingBudget: &b}
}
default:
// Try parsing as integer token count
var tokens int
if _, err := fmt.Sscanf(lower, "%d", &tokens); err == nil && tokens > 0 {
t := int32(tokens)
return &genai.ThinkingConfig{IncludeThoughts: true, ThinkingBudget: &t}
}
}
return nil
}
func (c *Client) sendStreamGemini(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
defer close(channel)
ctx := context.Background()
client, err := genai.NewClient(ctx, &genai.ClientConfig{
Project: c.ProjectID.Value,
Location: c.getGeminiRegion(opts.Model),
Backend: genai.BackendVertexAI,
})
if err != nil {
return fmt.Errorf("failed to create Gemini client: %w", err)
}
contents := geminicommon.ConvertMessages(msgs)
if len(contents) == 0 {
return fmt.Errorf("no valid messages to send")
}
config := c.buildGeminiConfig(opts)
stream := client.Models.GenerateContentStream(ctx, opts.Model, contents, config)
for response, err := range stream {
if err != nil {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeError,
Content: fmt.Sprintf("Error: %v", err),
}
return err
}
text := geminicommon.ExtractText(response)
if text != "" {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeContent,
Content: text,
}
}
if response.UsageMetadata != nil {
channel <- domain.StreamUpdate{
Type: domain.StreamTypeUsage,
Usage: &domain.UsageMetadata{
InputTokens: int(response.UsageMetadata.PromptTokenCount),
OutputTokens: int(response.UsageMetadata.CandidatesTokenCount),
TotalTokens: int(response.UsageMetadata.TotalTokenCount),
},
}
}
}
return nil
}
// Claude message conversion
func (c *Client) toMessages(msgs []*chat.ChatCompletionMessage) []anthropic.MessageParam {
// Convert messages to Anthropic format with proper role handling
// - System messages become part of the first user message

View File

@@ -0,0 +1,442 @@
package vertexai
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/danielmiessler/fabric/internal/domain"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractModelName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "standard format",
input: "publishers/google/models/gemini-2.0-flash",
expected: "gemini-2.0-flash",
},
{
name: "anthropic model",
input: "publishers/anthropic/models/claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "model with version",
input: "publishers/anthropic/models/claude-3-opus@20240229",
expected: "claude-3-opus@20240229",
},
{
name: "just model name",
input: "gemini-pro",
expected: "gemini-pro",
},
{
name: "empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractModelName(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestSortModels(t *testing.T) {
input := []string{
"claude-sonnet-4-5",
"gemini-2.0-flash",
"gemini-pro",
"claude-opus-4",
"unknown-model",
}
result := sortModels(input)
// Verify order: Gemini first, then Claude, then others (alphabetically within each group)
expected := []string{
"gemini-2.0-flash",
"gemini-pro",
"claude-opus-4",
"claude-sonnet-4-5",
"unknown-model",
}
assert.Equal(t, expected, result)
}
func TestModelPriority(t *testing.T) {
tests := []struct {
model string
priority int
}{
{"gemini-2.0-flash", 1},
{"Gemini-Pro", 1},
{"claude-sonnet-4-5", 2},
{"CLAUDE-OPUS", 2},
{"some-other-model", 3},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
result := modelPriority(tt.model)
assert.Equal(t, tt.priority, result, "priority for %s", tt.model)
})
}
}
func TestListPublisherModels_Success(t *testing.T) {
// Create mock server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodGet, r.Method)
assert.Contains(t, r.URL.Path, "/v1/publishers/google/models")
response := publisherModelsResponse{
PublisherModels: []publisherModel{
{Name: "publishers/google/models/gemini-2.0-flash"},
{Name: "publishers/google/models/gemini-pro"},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
// Note: This test would need to mock the actual API endpoint
// For now, we just verify the mock server works
resp, err := http.Get(server.URL + "/v1/publishers/google/models")
require.NoError(t, err)
defer resp.Body.Close()
var response publisherModelsResponse
err = json.NewDecoder(resp.Body).Decode(&response)
require.NoError(t, err)
assert.Len(t, response.PublisherModels, 2)
assert.Equal(t, "publishers/google/models/gemini-2.0-flash", response.PublisherModels[0].Name)
}
func TestListPublisherModels_Pagination(t *testing.T) {
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
var response publisherModelsResponse
if callCount == 1 {
response = publisherModelsResponse{
PublisherModels: []publisherModel{
{Name: "publishers/google/models/gemini-flash"},
},
NextPageToken: "page2",
}
} else {
response = publisherModelsResponse{
PublisherModels: []publisherModel{
{Name: "publishers/google/models/gemini-pro"},
},
NextPageToken: "",
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}))
defer server.Close()
// Verify the server handles pagination correctly
resp, err := http.Get(server.URL + "/page1")
require.NoError(t, err)
resp.Body.Close()
resp, err = http.Get(server.URL + "/page2")
require.NoError(t, err)
resp.Body.Close()
assert.Equal(t, 2, callCount)
}
func TestListPublisherModels_ErrorResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte(`{"error": "access denied"}`))
}))
defer server.Close()
resp, err := http.Get(server.URL + "/v1/publishers/google/models")
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
}
func TestNewClient(t *testing.T) {
client := NewClient()
assert.NotNil(t, client)
assert.Equal(t, "VertexAI", client.Name)
assert.NotNil(t, client.ProjectID)
assert.NotNil(t, client.Region)
assert.Equal(t, "global", client.Region.Value)
}
func TestPublishersListComplete(t *testing.T) {
// Verify supported publishers are in the list
expectedPublishers := []string{"google", "anthropic"}
assert.Equal(t, expectedPublishers, publishers)
}
func TestIsConversationalModel(t *testing.T) {
tests := []struct {
model string
expected bool
}{
// Conversational models (should return true)
{"gemini-2.0-flash", true},
{"gemini-2.5-pro", true},
{"claude-sonnet-4-5", true},
{"claude-opus-4", true},
{"deepseek-v3", true},
{"llama-3.1-405b", true},
{"mistral-large", true},
// Non-conversational models (should return false)
{"imagen-3.0-capability-002", false},
{"imagen-4.0-fast-generate-001", false},
{"imagegeneration", false},
{"imagetext", false},
{"image-segmentation-001", false},
{"textembedding-gecko", false},
{"multimodalembedding", false},
{"text-embedding-004", false},
{"text-bison", false},
{"text-unicorn", false},
{"code-bison", false},
{"code-gecko", false},
{"codechat-bison", false},
{"chat-bison", false},
{"veo-001", false},
{"chirp", false},
{"medlm-medium", false},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
result := isConversationalModel(tt.model)
assert.Equal(t, tt.expected, result, "isConversationalModel(%s)", tt.model)
})
}
}
func TestFilterConversationalModels(t *testing.T) {
input := []string{
"gemini-2.0-flash",
"imagen-3.0-capability-002",
"claude-sonnet-4-5",
"textembedding-gecko",
"deepseek-v3",
"chat-bison",
"llama-3.1-405b",
"code-bison",
}
result := filterConversationalModels(input)
expected := []string{
"gemini-2.0-flash",
"claude-sonnet-4-5",
"deepseek-v3",
"llama-3.1-405b",
}
assert.Equal(t, expected, result)
}
func TestFilterConversationalModels_EmptyInput(t *testing.T) {
result := filterConversationalModels([]string{})
assert.Empty(t, result)
}
func TestFilterConversationalModels_AllFiltered(t *testing.T) {
input := []string{
"imagen-3.0",
"textembedding-gecko",
"chat-bison",
}
result := filterConversationalModels(input)
assert.Empty(t, result)
}
func TestIsGeminiModel(t *testing.T) {
tests := []struct {
model string
expected bool
}{
{"gemini-2.5-pro", true},
{"gemini-3-pro-preview", true},
{"Gemini-2.0-flash", true},
{"GEMINI-flash", true},
{"claude-sonnet-4-5", false},
{"claude-opus-4", false},
{"deepseek-v3", false},
{"llama-3.1-405b", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.model, func(t *testing.T) {
result := isGeminiModel(tt.model)
assert.Equal(t, tt.expected, result, "isGeminiModel(%s)", tt.model)
})
}
}
func TestGetMaxTokens(t *testing.T) {
tests := []struct {
name string
opts *domain.ChatOptions
expected int64
}{
{
name: "MaxTokens specified",
opts: &domain.ChatOptions{MaxTokens: 8192},
expected: 8192,
},
{
name: "Default when MaxTokens is 0",
opts: &domain.ChatOptions{MaxTokens: 0},
expected: int64(defaultMaxTokens),
},
{
name: "Default when MaxTokens is negative",
opts: &domain.ChatOptions{MaxTokens: -1},
expected: int64(defaultMaxTokens),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getMaxTokens(tt.opts)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseGeminiThinking(t *testing.T) {
tests := []struct {
name string
level domain.ThinkingLevel
expectNil bool
expectedBudget int32
}{
{
name: "empty string returns nil",
level: "",
expectNil: true,
},
{
name: "off returns nil",
level: domain.ThinkingOff,
expectNil: true,
},
{
name: "low thinking",
level: domain.ThinkingLow,
expectNil: false,
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingLow]),
},
{
name: "medium thinking",
level: domain.ThinkingMedium,
expectNil: false,
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingMedium]),
},
{
name: "high thinking",
level: domain.ThinkingHigh,
expectNil: false,
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingHigh]),
},
{
name: "numeric string",
level: "5000",
expectNil: false,
expectedBudget: 5000,
},
{
name: "invalid string returns nil",
level: "invalid",
expectNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseGeminiThinking(tt.level)
if tt.expectNil {
assert.Nil(t, result)
} else {
require.NotNil(t, result)
assert.True(t, result.IncludeThoughts)
assert.Equal(t, tt.expectedBudget, *result.ThinkingBudget)
}
})
}
}
func TestBuildGeminiConfig(t *testing.T) {
client := &Client{}
t.Run("basic config with temperature and TopP", func(t *testing.T) {
opts := &domain.ChatOptions{
Temperature: 0.7,
TopP: 0.9,
MaxTokens: 8192,
}
config := client.buildGeminiConfig(opts)
assert.NotNil(t, config)
assert.Equal(t, float32(0.7), *config.Temperature)
assert.Equal(t, float32(0.9), *config.TopP)
assert.Equal(t, int32(8192), config.MaxOutputTokens)
assert.Nil(t, config.Tools)
assert.Nil(t, config.ThinkingConfig)
})
t.Run("config with search enabled", func(t *testing.T) {
opts := &domain.ChatOptions{
Temperature: 0.5,
TopP: 0.8,
Search: true,
}
config := client.buildGeminiConfig(opts)
assert.NotNil(t, config.Tools)
assert.Len(t, config.Tools, 1)
assert.NotNil(t, config.Tools[0].GoogleSearch)
})
t.Run("config with thinking enabled", func(t *testing.T) {
opts := &domain.ChatOptions{
Temperature: 0.5,
TopP: 0.8,
Thinking: domain.ThinkingHigh,
}
config := client.buildGeminiConfig(opts)
assert.NotNil(t, config.ThinkingConfig)
assert.True(t, config.ThinkingConfig.IncludeThoughts)
})
}