mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 14:28:01 -05:00
Merge pull request #1926 from henricook/feature/vertexai-dynamic-model-listing
feat(vertexai): add dynamic model listing and multi-model support
This commit is contained in:
7
cmd/generate_changelog/incoming/1926.txt
Normal file
7
cmd/generate_changelog/incoming/1926.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
### PR [#1926](https://github.com/danielmiessler/Fabric/pull/1926) by [henricook](https://github.com/henricook) and [ksylvan](https://github.com/ksylvan): feat(vertexai): add dynamic model listing and multi-model support
|
||||
|
||||
- Dynamic model listing from Vertex AI Model Garden API
|
||||
- Support for both Gemini (genai SDK) and Claude (Anthropic SDK) models
|
||||
- Curated Gemini model list with web search support for Gemini models
|
||||
- Thinking/extended thinking support for Gemini
|
||||
- TopP parameter support for Claude models
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
"github.com/danielmiessler/fabric/internal/plugins"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
"github.com/danielmiessler/fabric/internal/plugins"
|
||||
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
@@ -29,10 +29,6 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
citationHeader = "\n\n## Sources\n\n"
|
||||
citationSeparator = "\n"
|
||||
citationFormat = "- [%s](%s)"
|
||||
|
||||
errInvalidLocationFormat = "invalid search location format %q: must be timezone (e.g., 'America/Los_Angeles') or language code (e.g., 'en-US')"
|
||||
locationSeparator = "/"
|
||||
langCodeSeparator = "_"
|
||||
@@ -111,7 +107,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
// Convert messages to new SDK format
|
||||
contents := o.convertMessages(msgs)
|
||||
contents := geminicommon.ConvertMessages(msgs)
|
||||
|
||||
cfg, err := o.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
@@ -125,7 +121,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
// Extract text from response
|
||||
ret = o.extractTextFromResponse(response)
|
||||
ret = geminicommon.ExtractTextWithCitations(response)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,7 +138,7 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
||||
}
|
||||
|
||||
// Convert messages to new SDK format
|
||||
contents := o.convertMessages(msgs)
|
||||
contents := geminicommon.ConvertMessages(msgs)
|
||||
|
||||
cfg, err := o.buildGenerateContentConfig(opts)
|
||||
if err != nil {
|
||||
@@ -161,7 +157,7 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
||||
return err
|
||||
}
|
||||
|
||||
text := o.extractTextFromResponse(response)
|
||||
text := geminicommon.ExtractTextWithCitations(response)
|
||||
if text != "" {
|
||||
channel <- domain.StreamUpdate{
|
||||
Type: domain.StreamTypeContent,
|
||||
@@ -218,10 +214,14 @@ func parseThinkingConfig(level domain.ThinkingLevel) (*genai.ThinkingConfig, boo
|
||||
func (o *Client) buildGenerateContentConfig(opts *domain.ChatOptions) (*genai.GenerateContentConfig, error) {
|
||||
temperature := float32(opts.Temperature)
|
||||
topP := float32(opts.TopP)
|
||||
var maxTokens int32
|
||||
if opts.MaxTokens > 0 {
|
||||
maxTokens = int32(opts.MaxTokens)
|
||||
}
|
||||
cfg := &genai.GenerateContentConfig{
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
MaxOutputTokens: int32(opts.ModelContextLength),
|
||||
MaxOutputTokens: maxTokens,
|
||||
}
|
||||
|
||||
if opts.Search {
|
||||
@@ -452,113 +452,3 @@ func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// convertMessages converts fabric chat messages to genai Content format
|
||||
func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
|
||||
var contents []*genai.Content
|
||||
|
||||
for _, msg := range msgs {
|
||||
content := &genai.Content{Parts: []*genai.Part{}}
|
||||
|
||||
switch msg.Role {
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
content.Role = "model"
|
||||
case chat.ChatMessageRoleUser:
|
||||
content.Role = "user"
|
||||
case chat.ChatMessageRoleSystem, chat.ChatMessageRoleDeveloper, chat.ChatMessageRoleFunction, chat.ChatMessageRoleTool:
|
||||
// Gemini's API only accepts "user" and "model" roles.
|
||||
// Map all other roles to "user" to preserve instruction context.
|
||||
content.Role = "user"
|
||||
default:
|
||||
content.Role = "user"
|
||||
}
|
||||
|
||||
if strings.TrimSpace(msg.Content) != "" {
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
|
||||
}
|
||||
|
||||
// Handle multi-content messages (images, etc.)
|
||||
for _, part := range msg.MultiContent {
|
||||
switch part.Type {
|
||||
case chat.ChatMessagePartTypeText:
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
|
||||
case chat.ChatMessagePartTypeImageURL:
|
||||
// TODO: Handle image URLs if needed
|
||||
// This would require downloading and converting to inline data
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, content)
|
||||
}
|
||||
|
||||
return contents
|
||||
}
|
||||
|
||||
// extractTextFromResponse extracts text content from the response and appends
|
||||
// any web citations in a standardized format.
|
||||
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := o.extractTextParts(response)
|
||||
citations := o.extractCitations(response)
|
||||
if len(citations) > 0 {
|
||||
return text + citationHeader + strings.Join(citations, citationSeparator)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func (o *Client) extractTextParts(response *genai.GenerateContentResponse) string {
|
||||
var builder strings.Builder
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate == nil || candidate.Content == nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
builder.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func (o *Client) extractCitations(response *genai.GenerateContentResponse) []string {
|
||||
if response == nil || len(response.Candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
citationMap := make(map[string]bool)
|
||||
var citations []string
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate == nil || candidate.GroundingMetadata == nil {
|
||||
continue
|
||||
}
|
||||
chunks := candidate.GroundingMetadata.GroundingChunks
|
||||
if len(chunks) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, chunk := range chunks {
|
||||
if chunk == nil || chunk.Web == nil {
|
||||
continue
|
||||
}
|
||||
uri := chunk.Web.URI
|
||||
title := chunk.Web.Title
|
||||
if uri == "" || title == "" {
|
||||
continue
|
||||
}
|
||||
var keyBuilder strings.Builder
|
||||
keyBuilder.WriteString(uri)
|
||||
keyBuilder.WriteByte('|')
|
||||
keyBuilder.WriteString(title)
|
||||
key := keyBuilder.String()
|
||||
if !citationMap[key] {
|
||||
citationMap[key] = true
|
||||
citationText := fmt.Sprintf(citationFormat, title, uri)
|
||||
citations = append(citations, citationText)
|
||||
}
|
||||
}
|
||||
}
|
||||
return citations
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/genai"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
// Test buildModelNameFull method
|
||||
@@ -31,9 +31,8 @@ func TestBuildModelNameFull(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Test extractTextFromResponse method
|
||||
// Test ExtractTextWithCitations from geminicommon
|
||||
func TestExtractTextFromResponse(t *testing.T) {
|
||||
client := &Client{}
|
||||
response := &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
@@ -48,7 +47,7 @@ func TestExtractTextFromResponse(t *testing.T) {
|
||||
}
|
||||
expected := "Hello, world!"
|
||||
|
||||
result := client.extractTextFromResponse(response)
|
||||
result := geminicommon.ExtractTextWithCitations(response)
|
||||
|
||||
if result != expected {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
@@ -56,14 +55,12 @@ func TestExtractTextFromResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExtractTextFromResponse_Nil(t *testing.T) {
|
||||
client := &Client{}
|
||||
if got := client.extractTextFromResponse(nil); got != "" {
|
||||
if got := geminicommon.ExtractTextWithCitations(nil); got != "" {
|
||||
t.Fatalf("expected empty string, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
|
||||
client := &Client{}
|
||||
response := &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
@@ -72,7 +69,7 @@ func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
|
||||
},
|
||||
},
|
||||
}
|
||||
if got := client.extractTextFromResponse(response); got != "Hello" {
|
||||
if got := geminicommon.ExtractTextWithCitations(response); got != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -162,7 +159,6 @@ func TestBuildGenerateContentConfig_ThinkingTokens(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCitationFormatting(t *testing.T) {
|
||||
client := &Client{}
|
||||
response := &genai.GenerateContentResponse{
|
||||
Candidates: []*genai.Candidate{
|
||||
{
|
||||
@@ -178,7 +174,7 @@ func TestCitationFormatting(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
result := client.extractTextFromResponse(response)
|
||||
result := geminicommon.ExtractTextWithCitations(response)
|
||||
if !strings.Contains(result, "## Sources") {
|
||||
t.Fatalf("expected sources section in result: %s", result)
|
||||
}
|
||||
@@ -189,14 +185,13 @@ func TestCitationFormatting(t *testing.T) {
|
||||
|
||||
// Test convertMessages handles role mapping correctly
|
||||
func TestConvertMessagesRoles(t *testing.T) {
|
||||
client := &Client{}
|
||||
msgs := []*chat.ChatCompletionMessage{
|
||||
{Role: chat.ChatMessageRoleUser, Content: "user"},
|
||||
{Role: chat.ChatMessageRoleAssistant, Content: "assistant"},
|
||||
{Role: chat.ChatMessageRoleSystem, Content: "system"},
|
||||
}
|
||||
|
||||
contents := client.convertMessages(msgs)
|
||||
contents := geminicommon.ConvertMessages(msgs)
|
||||
|
||||
expected := []string{"user", "model", "user"}
|
||||
|
||||
|
||||
130
internal/plugins/ai/geminicommon/geminicommon.go
Normal file
130
internal/plugins/ai/geminicommon/geminicommon.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Package geminicommon provides shared utilities for Gemini API integrations.
|
||||
// Used by both the standalone Gemini provider (API key auth) and VertexAI provider (ADC auth).
|
||||
package geminicommon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
// Citation formatting constants
|
||||
const (
|
||||
CitationHeader = "\n\n## Sources\n\n"
|
||||
CitationSeparator = "\n"
|
||||
CitationFormat = "- [%s](%s)"
|
||||
)
|
||||
|
||||
// ConvertMessages converts fabric chat messages to genai Content format.
|
||||
// Gemini's API only accepts "user" and "model" roles, so other roles are mapped to "user".
|
||||
func ConvertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
|
||||
var contents []*genai.Content
|
||||
|
||||
for _, msg := range msgs {
|
||||
content := &genai.Content{Parts: []*genai.Part{}}
|
||||
|
||||
switch msg.Role {
|
||||
case chat.ChatMessageRoleAssistant:
|
||||
content.Role = "model"
|
||||
case chat.ChatMessageRoleUser:
|
||||
content.Role = "user"
|
||||
case chat.ChatMessageRoleSystem, chat.ChatMessageRoleDeveloper, chat.ChatMessageRoleFunction, chat.ChatMessageRoleTool:
|
||||
// Gemini's API only accepts "user" and "model" roles.
|
||||
// Map all other roles to "user" to preserve instruction context.
|
||||
content.Role = "user"
|
||||
default:
|
||||
content.Role = "user"
|
||||
}
|
||||
|
||||
if strings.TrimSpace(msg.Content) != "" {
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
|
||||
}
|
||||
|
||||
// Handle multi-content messages (images, etc.)
|
||||
for _, part := range msg.MultiContent {
|
||||
switch part.Type {
|
||||
case chat.ChatMessagePartTypeText:
|
||||
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
|
||||
case chat.ChatMessagePartTypeImageURL:
|
||||
// TODO: Handle image URLs if needed
|
||||
// This would require downloading and converting to inline data
|
||||
}
|
||||
}
|
||||
|
||||
contents = append(contents, content)
|
||||
}
|
||||
|
||||
return contents
|
||||
}
|
||||
|
||||
// ExtractText extracts just the text parts from a Gemini response.
|
||||
func ExtractText(response *genai.GenerateContentResponse) string {
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var builder strings.Builder
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate == nil || candidate.Content == nil {
|
||||
continue
|
||||
}
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part != nil && part.Text != "" {
|
||||
builder.WriteString(part.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// ExtractTextWithCitations extracts text content from the response and appends
|
||||
// any web citations in a standardized format.
|
||||
func ExtractTextWithCitations(response *genai.GenerateContentResponse) string {
|
||||
if response == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
text := ExtractText(response)
|
||||
citations := ExtractCitations(response)
|
||||
if len(citations) > 0 {
|
||||
return text + CitationHeader + strings.Join(citations, CitationSeparator)
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// ExtractCitations extracts web citations from grounding metadata.
|
||||
func ExtractCitations(response *genai.GenerateContentResponse) []string {
|
||||
if response == nil || len(response.Candidates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
citationMap := make(map[string]bool)
|
||||
var citations []string
|
||||
for _, candidate := range response.Candidates {
|
||||
if candidate == nil || candidate.GroundingMetadata == nil {
|
||||
continue
|
||||
}
|
||||
chunks := candidate.GroundingMetadata.GroundingChunks
|
||||
if len(chunks) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, chunk := range chunks {
|
||||
if chunk == nil || chunk.Web == nil {
|
||||
continue
|
||||
}
|
||||
uri := chunk.Web.URI
|
||||
title := chunk.Web.Title
|
||||
if uri == "" || title == "" {
|
||||
continue
|
||||
}
|
||||
key := uri + "|" + title
|
||||
if !citationMap[key] {
|
||||
citationMap[key] = true
|
||||
citations = append(citations, fmt.Sprintf(CitationFormat, title, uri))
|
||||
}
|
||||
}
|
||||
}
|
||||
return citations
|
||||
}
|
||||
237
internal/plugins/ai/vertexai/models.go
Normal file
237
internal/plugins/ai/vertexai/models.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
debuglog "github.com/danielmiessler/fabric/internal/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// API limits
|
||||
maxResponseSize = 10 * 1024 * 1024 // 10MB
|
||||
errorResponseLimit = 1024 // 1KB for error messages
|
||||
|
||||
// Default region for Model Garden API (global doesn't work for this endpoint)
|
||||
defaultModelGardenRegion = "us-central1"
|
||||
)
|
||||
|
||||
// Supported Model Garden publishers (others can be added when SDK support is implemented)
|
||||
var publishers = []string{"google", "anthropic"}
|
||||
|
||||
// publisherModelsResponse represents the API response from publishers.models.list
|
||||
type publisherModelsResponse struct {
|
||||
PublisherModels []publisherModel `json:"publisherModels"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
|
||||
// publisherModel represents a single model in the API response
|
||||
type publisherModel struct {
|
||||
Name string `json:"name"` // Format: publishers/{publisher}/models/{model}
|
||||
}
|
||||
|
||||
// fetchModelsPage makes a single API request and returns the parsed response.
|
||||
// Extracted to ensure proper cleanup of HTTP response bodies in pagination loops.
|
||||
func fetchModelsPage(ctx context.Context, httpClient *http.Client, url, projectID, publisher string) (*publisherModelsResponse, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
// Set quota project header required by Vertex AI API
|
||||
req.Header.Set("x-goog-user-project", projectID)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, errorResponseLimit))
|
||||
debuglog.Debug(debuglog.Basic, "API error for %s: status %d, url: %s, body: %s\n", publisher, resp.StatusCode, url, string(bodyBytes))
|
||||
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize+1))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if len(bodyBytes) > maxResponseSize {
|
||||
return nil, fmt.Errorf("response too large (>%d bytes)", maxResponseSize)
|
||||
}
|
||||
|
||||
var response publisherModelsResponse
|
||||
if err := json.Unmarshal(bodyBytes, &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// listPublisherModels fetches models from a specific publisher via the Model Garden API
|
||||
func listPublisherModels(ctx context.Context, httpClient *http.Client, region, projectID, publisher string) ([]string, error) {
|
||||
// Use default region if global or empty (Model Garden API requires a specific region)
|
||||
if region == "" || region == "global" {
|
||||
region = defaultModelGardenRegion
|
||||
}
|
||||
|
||||
baseURL := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/publishers/%s/models", region, publisher)
|
||||
|
||||
var allModels []string
|
||||
pageToken := ""
|
||||
|
||||
for {
|
||||
url := baseURL
|
||||
if pageToken != "" {
|
||||
url = fmt.Sprintf("%s?pageToken=%s", baseURL, pageToken)
|
||||
}
|
||||
|
||||
response, err := fetchModelsPage(ctx, httpClient, url, projectID, publisher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract model names, stripping the publishers/{publisher}/models/ prefix
|
||||
for _, model := range response.PublisherModels {
|
||||
modelName := extractModelName(model.Name)
|
||||
if modelName != "" {
|
||||
allModels = append(allModels, modelName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for more pages
|
||||
if response.NextPageToken == "" {
|
||||
break
|
||||
}
|
||||
pageToken = response.NextPageToken
|
||||
}
|
||||
|
||||
debuglog.Debug(debuglog.Detailed, "Listed %d models from publisher %s\n", len(allModels), publisher)
|
||||
return allModels, nil
|
||||
}
|
||||
|
||||
// extractModelName extracts the model name from the full resource path
|
||||
// Input: "publishers/google/models/gemini-2.0-flash"
|
||||
// Output: "gemini-2.0-flash"
|
||||
func extractModelName(fullName string) string {
|
||||
parts := strings.Split(fullName, "/")
|
||||
if len(parts) >= 4 && parts[0] == "publishers" && parts[2] == "models" {
|
||||
return parts[3]
|
||||
}
|
||||
// Fallback: return the last segment
|
||||
if len(parts) > 0 {
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
return fullName
|
||||
}
|
||||
|
||||
// sortModels sorts models by priority: Gemini > Claude > Others
|
||||
// Within each group, models are sorted alphabetically
|
||||
func sortModels(models []string) []string {
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
pi := modelPriority(models[i])
|
||||
pj := modelPriority(models[j])
|
||||
if pi != pj {
|
||||
return pi < pj
|
||||
}
|
||||
// Same priority: sort alphabetically (case-insensitive)
|
||||
return strings.ToLower(models[i]) < strings.ToLower(models[j])
|
||||
})
|
||||
return models
|
||||
}
|
||||
|
||||
// modelPriority returns the sort priority for a model (lower = higher priority)
|
||||
func modelPriority(model string) int {
|
||||
lower := strings.ToLower(model)
|
||||
switch {
|
||||
case strings.HasPrefix(lower, "gemini"):
|
||||
return 1
|
||||
case strings.HasPrefix(lower, "claude"):
|
||||
return 2
|
||||
default:
|
||||
return 3
|
||||
}
|
||||
}
|
||||
|
||||
// knownGeminiModels is a curated list of Gemini models available on Vertex AI.
|
||||
// Vertex AI doesn't provide a list API for Gemini models - they must be known ahead of time.
|
||||
// This list is based on Google Cloud documentation as of January 2025.
|
||||
// See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/models
|
||||
var knownGeminiModels = []string{
|
||||
// Gemini 3 (Preview)
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
// Gemini 2.5 (GA)
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-lite",
|
||||
// Gemini 2.0 (GA)
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.0-flash-lite",
|
||||
}
|
||||
|
||||
// getKnownGeminiModels returns the curated list of Gemini models available on Vertex AI.
|
||||
// Unlike third-party models which can be listed via the Model Garden API,
|
||||
// Gemini models must be known ahead of time as there's no list endpoint for them.
|
||||
func getKnownGeminiModels() []string {
|
||||
return knownGeminiModels
|
||||
}
|
||||
|
||||
// isGeminiModel returns true if the model is a Gemini model
|
||||
func isGeminiModel(modelName string) bool {
|
||||
return strings.HasPrefix(strings.ToLower(modelName), "gemini")
|
||||
}
|
||||
|
||||
// isConversationalModel returns true if the model is suitable for text generation/chat
|
||||
// Filters out image generation, embeddings, and other non-conversational models
|
||||
func isConversationalModel(modelName string) bool {
|
||||
lower := strings.ToLower(modelName)
|
||||
|
||||
// Exclude patterns for non-conversational models
|
||||
excludePatterns := []string{
|
||||
"imagen", // Image generation models
|
||||
"imagegeneration",
|
||||
"imagetext",
|
||||
"image-segmentation",
|
||||
"embedding", // Embedding models
|
||||
"textembedding",
|
||||
"multimodalembedding",
|
||||
"text-bison", // Legacy completion models (not chat)
|
||||
"text-unicorn",
|
||||
"code-bison", // Legacy code models
|
||||
"code-gecko",
|
||||
"codechat-bison", // Deprecated chat model
|
||||
"chat-bison", // Deprecated chat model
|
||||
"veo", // Video generation
|
||||
"chirp", // Audio/speech models
|
||||
"medlm", // Medical models (restricted)
|
||||
"medical",
|
||||
}
|
||||
|
||||
for _, pattern := range excludePatterns {
|
||||
if strings.Contains(lower, pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// filterConversationalModels returns only models suitable for text generation/chat
|
||||
func filterConversationalModels(models []string) []string {
|
||||
var filtered []string
|
||||
for _, model := range models {
|
||||
if isConversationalModel(model) {
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
@@ -9,13 +9,18 @@ import (
|
||||
"github.com/anthropics/anthropic-sdk-go/vertex"
|
||||
"github.com/danielmiessler/fabric/internal/chat"
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
debuglog "github.com/danielmiessler/fabric/internal/log"
|
||||
"github.com/danielmiessler/fabric/internal/plugins"
|
||||
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
const (
|
||||
cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||
defaultRegion = "global"
|
||||
maxTokens = 4096
|
||||
defaultMaxTokens = 4096
|
||||
)
|
||||
|
||||
// NewClient creates a new Vertex AI client for accessing Claude models via Google Cloud
|
||||
@@ -59,17 +64,78 @@ func (c *Client) configure() error {
|
||||
}
|
||||
|
||||
func (c *Client) ListModels() ([]string, error) {
|
||||
// Return Claude models available on Vertex AI
|
||||
return []string{
|
||||
string(anthropic.ModelClaudeSonnet4_5),
|
||||
string(anthropic.ModelClaudeOpus4_5),
|
||||
string(anthropic.ModelClaudeHaiku4_5),
|
||||
string(anthropic.ModelClaude3_7SonnetLatest),
|
||||
string(anthropic.ModelClaude3_5HaikuLatest),
|
||||
}, nil
|
||||
ctx := context.Background()
|
||||
|
||||
// Get ADC credentials for API authentication
|
||||
creds, err := google.FindDefaultCredentials(ctx, cloudPlatformScope)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Google credentials (ensure ADC is configured): %w", err)
|
||||
}
|
||||
httpClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
|
||||
// Query all publishers in parallel for better performance
|
||||
type result struct {
|
||||
models []string
|
||||
err error
|
||||
publisher string
|
||||
}
|
||||
// +1 for known Gemini models (no API to list them)
|
||||
results := make(chan result, len(publishers)+1)
|
||||
|
||||
// Query Model Garden API for third-party models
|
||||
for _, pub := range publishers {
|
||||
go func(publisher string) {
|
||||
models, err := listPublisherModels(ctx, httpClient, c.Region.Value, c.ProjectID.Value, publisher)
|
||||
results <- result{models: models, err: err, publisher: publisher}
|
||||
}(pub)
|
||||
}
|
||||
|
||||
// Add known Gemini models (Vertex AI doesn't have a list API for Gemini)
|
||||
go func() {
|
||||
results <- result{models: getKnownGeminiModels(), err: nil, publisher: "gemini"}
|
||||
}()
|
||||
|
||||
// Collect results from all sources
|
||||
var allModels []string
|
||||
for range len(publishers) + 1 {
|
||||
r := <-results
|
||||
if r.err != nil {
|
||||
// Log warning but continue - some sources may not be available
|
||||
debuglog.Debug(debuglog.Basic, "Failed to list %s models: %v\n", r.publisher, r.err)
|
||||
continue
|
||||
}
|
||||
allModels = append(allModels, r.models...)
|
||||
}
|
||||
|
||||
if len(allModels) == 0 {
|
||||
return nil, fmt.Errorf("no models found from any publisher")
|
||||
}
|
||||
|
||||
// Filter to only conversational models and sort
|
||||
filtered := filterConversationalModels(allModels)
|
||||
if len(filtered) == 0 {
|
||||
return nil, fmt.Errorf("no conversational models found")
|
||||
}
|
||||
|
||||
return sortModels(filtered), nil
|
||||
}
|
||||
|
||||
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
||||
if isGeminiModel(opts.Model) {
|
||||
return c.sendGemini(ctx, msgs, opts)
|
||||
}
|
||||
return c.sendClaude(ctx, msgs, opts)
|
||||
}
|
||||
|
||||
// getMaxTokens returns the max output tokens to use for a request
|
||||
func getMaxTokens(opts *domain.ChatOptions) int64 {
|
||||
if opts.MaxTokens > 0 {
|
||||
return int64(opts.MaxTokens)
|
||||
}
|
||||
return int64(defaultMaxTokens)
|
||||
}
|
||||
|
||||
func (c *Client) sendClaude(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
||||
if c.client == nil {
|
||||
return "", fmt.Errorf("VertexAI client not initialized")
|
||||
}
|
||||
@@ -80,14 +146,22 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
return "", fmt.Errorf("no valid messages to send")
|
||||
}
|
||||
|
||||
// Create the request
|
||||
response, err := c.client.Messages.New(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(opts.Model),
|
||||
MaxTokens: int64(maxTokens),
|
||||
Messages: anthropicMessages,
|
||||
Temperature: anthropic.Opt(opts.Temperature),
|
||||
})
|
||||
// Build request params
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(opts.Model),
|
||||
MaxTokens: getMaxTokens(opts),
|
||||
Messages: anthropicMessages,
|
||||
}
|
||||
|
||||
// Only set one of Temperature or TopP as some models don't allow both
|
||||
// (following anthropic.go pattern)
|
||||
if opts.TopP != domain.DefaultTopP {
|
||||
params.TopP = anthropic.Opt(opts.TopP)
|
||||
} else {
|
||||
params.Temperature = anthropic.Opt(opts.Temperature)
|
||||
}
|
||||
|
||||
response, err := c.client.Messages.New(ctx, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -108,6 +182,13 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
||||
}
|
||||
|
||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
||||
if isGeminiModel(opts.Model) {
|
||||
return c.sendStreamGemini(msgs, opts, channel)
|
||||
}
|
||||
return c.sendStreamClaude(msgs, opts, channel)
|
||||
}
|
||||
|
||||
func (c *Client) sendStreamClaude(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
||||
if c.client == nil {
|
||||
close(channel)
|
||||
return fmt.Errorf("VertexAI client not initialized")
|
||||
@@ -122,13 +203,22 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
||||
return fmt.Errorf("no valid messages to send")
|
||||
}
|
||||
|
||||
// Build request params
|
||||
params := anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(opts.Model),
|
||||
MaxTokens: getMaxTokens(opts),
|
||||
Messages: anthropicMessages,
|
||||
}
|
||||
|
||||
// Only set one of Temperature or TopP as some models don't allow both
|
||||
if opts.TopP != domain.DefaultTopP {
|
||||
params.TopP = anthropic.Opt(opts.TopP)
|
||||
} else {
|
||||
params.Temperature = anthropic.Opt(opts.Temperature)
|
||||
}
|
||||
|
||||
// Create streaming request
|
||||
stream := c.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
|
||||
Model: anthropic.Model(opts.Model),
|
||||
MaxTokens: int64(maxTokens),
|
||||
Messages: anthropicMessages,
|
||||
Temperature: anthropic.Opt(opts.Temperature),
|
||||
})
|
||||
stream := c.client.Messages.NewStreaming(ctx, params)
|
||||
|
||||
// Process stream
|
||||
for stream.Next() {
|
||||
@@ -167,6 +257,144 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
||||
return stream.Err()
|
||||
}
|
||||
|
||||
// Gemini methods using genai SDK with Vertex AI backend
|
||||
|
||||
// getGeminiRegion returns the appropriate region for a Gemini model.
|
||||
// Preview models are often only available on the global endpoint.
|
||||
func (c *Client) getGeminiRegion(model string) string {
|
||||
if strings.Contains(strings.ToLower(model), "preview") {
|
||||
return "global"
|
||||
}
|
||||
return c.Region.Value
|
||||
}
|
||||
|
||||
func (c *Client) sendGemini(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
||||
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
Project: c.ProjectID.Value,
|
||||
Location: c.getGeminiRegion(opts.Model),
|
||||
Backend: genai.BackendVertexAI,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create Gemini client: %w", err)
|
||||
}
|
||||
|
||||
contents := geminicommon.ConvertMessages(msgs)
|
||||
if len(contents) == 0 {
|
||||
return "", fmt.Errorf("no valid messages to send")
|
||||
}
|
||||
|
||||
config := c.buildGeminiConfig(opts)
|
||||
|
||||
response, err := client.Models.GenerateContent(ctx, opts.Model, contents, config)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return geminicommon.ExtractTextWithCitations(response), nil
|
||||
}
|
||||
|
||||
// buildGeminiConfig creates the generation config for Gemini models
|
||||
// following the gemini.go pattern for feature parity
|
||||
func (c *Client) buildGeminiConfig(opts *domain.ChatOptions) *genai.GenerateContentConfig {
|
||||
temperature := float32(opts.Temperature)
|
||||
topP := float32(opts.TopP)
|
||||
config := &genai.GenerateContentConfig{
|
||||
Temperature: &temperature,
|
||||
TopP: &topP,
|
||||
MaxOutputTokens: int32(getMaxTokens(opts)),
|
||||
}
|
||||
|
||||
// Add web search support
|
||||
if opts.Search {
|
||||
config.Tools = []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
||||
}
|
||||
|
||||
// Add thinking support
|
||||
if tc := parseGeminiThinking(opts.Thinking); tc != nil {
|
||||
config.ThinkingConfig = tc
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// parseGeminiThinking converts thinking level to Gemini thinking config
|
||||
func parseGeminiThinking(level domain.ThinkingLevel) *genai.ThinkingConfig {
|
||||
lower := strings.ToLower(strings.TrimSpace(string(level)))
|
||||
switch domain.ThinkingLevel(lower) {
|
||||
case "", domain.ThinkingOff:
|
||||
return nil
|
||||
case domain.ThinkingLow, domain.ThinkingMedium, domain.ThinkingHigh:
|
||||
if budget, ok := domain.ThinkingBudgets[domain.ThinkingLevel(lower)]; ok {
|
||||
b := int32(budget)
|
||||
return &genai.ThinkingConfig{IncludeThoughts: true, ThinkingBudget: &b}
|
||||
}
|
||||
default:
|
||||
// Try parsing as integer token count
|
||||
var tokens int
|
||||
if _, err := fmt.Sscanf(lower, "%d", &tokens); err == nil && tokens > 0 {
|
||||
t := int32(tokens)
|
||||
return &genai.ThinkingConfig{IncludeThoughts: true, ThinkingBudget: &t}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) sendStreamGemini(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
||||
defer close(channel)
|
||||
ctx := context.Background()
|
||||
|
||||
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||
Project: c.ProjectID.Value,
|
||||
Location: c.getGeminiRegion(opts.Model),
|
||||
Backend: genai.BackendVertexAI,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Gemini client: %w", err)
|
||||
}
|
||||
|
||||
contents := geminicommon.ConvertMessages(msgs)
|
||||
if len(contents) == 0 {
|
||||
return fmt.Errorf("no valid messages to send")
|
||||
}
|
||||
|
||||
config := c.buildGeminiConfig(opts)
|
||||
|
||||
stream := client.Models.GenerateContentStream(ctx, opts.Model, contents, config)
|
||||
|
||||
for response, err := range stream {
|
||||
if err != nil {
|
||||
channel <- domain.StreamUpdate{
|
||||
Type: domain.StreamTypeError,
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
text := geminicommon.ExtractText(response)
|
||||
if text != "" {
|
||||
channel <- domain.StreamUpdate{
|
||||
Type: domain.StreamTypeContent,
|
||||
Content: text,
|
||||
}
|
||||
}
|
||||
|
||||
if response.UsageMetadata != nil {
|
||||
channel <- domain.StreamUpdate{
|
||||
Type: domain.StreamTypeUsage,
|
||||
Usage: &domain.UsageMetadata{
|
||||
InputTokens: int(response.UsageMetadata.PromptTokenCount),
|
||||
OutputTokens: int(response.UsageMetadata.CandidatesTokenCount),
|
||||
TotalTokens: int(response.UsageMetadata.TotalTokenCount),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Claude message conversion
|
||||
|
||||
func (c *Client) toMessages(msgs []*chat.ChatCompletionMessage) []anthropic.MessageParam {
|
||||
// Convert messages to Anthropic format with proper role handling
|
||||
// - System messages become part of the first user message
|
||||
|
||||
442
internal/plugins/ai/vertexai/vertexai_test.go
Normal file
442
internal/plugins/ai/vertexai/vertexai_test.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/danielmiessler/fabric/internal/domain"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractModelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "standard format",
|
||||
input: "publishers/google/models/gemini-2.0-flash",
|
||||
expected: "gemini-2.0-flash",
|
||||
},
|
||||
{
|
||||
name: "anthropic model",
|
||||
input: "publishers/anthropic/models/claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "model with version",
|
||||
input: "publishers/anthropic/models/claude-3-opus@20240229",
|
||||
expected: "claude-3-opus@20240229",
|
||||
},
|
||||
{
|
||||
name: "just model name",
|
||||
input: "gemini-pro",
|
||||
expected: "gemini-pro",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractModelName(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortModels(t *testing.T) {
|
||||
input := []string{
|
||||
"claude-sonnet-4-5",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-pro",
|
||||
"claude-opus-4",
|
||||
"unknown-model",
|
||||
}
|
||||
|
||||
result := sortModels(input)
|
||||
|
||||
// Verify order: Gemini first, then Claude, then others (alphabetically within each group)
|
||||
expected := []string{
|
||||
"gemini-2.0-flash",
|
||||
"gemini-pro",
|
||||
"claude-opus-4",
|
||||
"claude-sonnet-4-5",
|
||||
"unknown-model",
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestModelPriority(t *testing.T) {
|
||||
tests := []struct {
|
||||
model string
|
||||
priority int
|
||||
}{
|
||||
{"gemini-2.0-flash", 1},
|
||||
{"Gemini-Pro", 1},
|
||||
{"claude-sonnet-4-5", 2},
|
||||
{"CLAUDE-OPUS", 2},
|
||||
{"some-other-model", 3},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
result := modelPriority(tt.model)
|
||||
assert.Equal(t, tt.priority, result, "priority for %s", tt.model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPublisherModels_Success(t *testing.T) {
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodGet, r.Method)
|
||||
assert.Contains(t, r.URL.Path, "/v1/publishers/google/models")
|
||||
|
||||
response := publisherModelsResponse{
|
||||
PublisherModels: []publisherModel{
|
||||
{Name: "publishers/google/models/gemini-2.0-flash"},
|
||||
{Name: "publishers/google/models/gemini-pro"},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Note: This test would need to mock the actual API endpoint
|
||||
// For now, we just verify the mock server works
|
||||
resp, err := http.Get(server.URL + "/v1/publishers/google/models")
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
var response publisherModelsResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, response.PublisherModels, 2)
|
||||
assert.Equal(t, "publishers/google/models/gemini-2.0-flash", response.PublisherModels[0].Name)
|
||||
}
|
||||
|
||||
func TestListPublisherModels_Pagination(t *testing.T) {
|
||||
callCount := 0
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
|
||||
var response publisherModelsResponse
|
||||
if callCount == 1 {
|
||||
response = publisherModelsResponse{
|
||||
PublisherModels: []publisherModel{
|
||||
{Name: "publishers/google/models/gemini-flash"},
|
||||
},
|
||||
NextPageToken: "page2",
|
||||
}
|
||||
} else {
|
||||
response = publisherModelsResponse{
|
||||
PublisherModels: []publisherModel{
|
||||
{Name: "publishers/google/models/gemini-pro"},
|
||||
},
|
||||
NextPageToken: "",
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(response)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Verify the server handles pagination correctly
|
||||
resp, err := http.Get(server.URL + "/page1")
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
|
||||
resp, err = http.Get(server.URL + "/page2")
|
||||
require.NoError(t, err)
|
||||
resp.Body.Close()
|
||||
|
||||
assert.Equal(t, 2, callCount)
|
||||
}
|
||||
|
||||
func TestListPublisherModels_ErrorResponse(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte(`{"error": "access denied"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL + "/v1/publishers/google/models")
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
client := NewClient()
|
||||
|
||||
assert.NotNil(t, client)
|
||||
assert.Equal(t, "VertexAI", client.Name)
|
||||
assert.NotNil(t, client.ProjectID)
|
||||
assert.NotNil(t, client.Region)
|
||||
assert.Equal(t, "global", client.Region.Value)
|
||||
}
|
||||
|
||||
func TestPublishersListComplete(t *testing.T) {
|
||||
// Verify supported publishers are in the list
|
||||
expectedPublishers := []string{"google", "anthropic"}
|
||||
|
||||
assert.Equal(t, expectedPublishers, publishers)
|
||||
}
|
||||
|
||||
func TestIsConversationalModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// Conversational models (should return true)
|
||||
{"gemini-2.0-flash", true},
|
||||
{"gemini-2.5-pro", true},
|
||||
{"claude-sonnet-4-5", true},
|
||||
{"claude-opus-4", true},
|
||||
{"deepseek-v3", true},
|
||||
{"llama-3.1-405b", true},
|
||||
{"mistral-large", true},
|
||||
|
||||
// Non-conversational models (should return false)
|
||||
{"imagen-3.0-capability-002", false},
|
||||
{"imagen-4.0-fast-generate-001", false},
|
||||
{"imagegeneration", false},
|
||||
{"imagetext", false},
|
||||
{"image-segmentation-001", false},
|
||||
{"textembedding-gecko", false},
|
||||
{"multimodalembedding", false},
|
||||
{"text-embedding-004", false},
|
||||
{"text-bison", false},
|
||||
{"text-unicorn", false},
|
||||
{"code-bison", false},
|
||||
{"code-gecko", false},
|
||||
{"codechat-bison", false},
|
||||
{"chat-bison", false},
|
||||
{"veo-001", false},
|
||||
{"chirp", false},
|
||||
{"medlm-medium", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
result := isConversationalModel(tt.model)
|
||||
assert.Equal(t, tt.expected, result, "isConversationalModel(%s)", tt.model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterConversationalModels(t *testing.T) {
|
||||
input := []string{
|
||||
"gemini-2.0-flash",
|
||||
"imagen-3.0-capability-002",
|
||||
"claude-sonnet-4-5",
|
||||
"textembedding-gecko",
|
||||
"deepseek-v3",
|
||||
"chat-bison",
|
||||
"llama-3.1-405b",
|
||||
"code-bison",
|
||||
}
|
||||
|
||||
result := filterConversationalModels(input)
|
||||
|
||||
expected := []string{
|
||||
"gemini-2.0-flash",
|
||||
"claude-sonnet-4-5",
|
||||
"deepseek-v3",
|
||||
"llama-3.1-405b",
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestFilterConversationalModels_EmptyInput(t *testing.T) {
|
||||
result := filterConversationalModels([]string{})
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestFilterConversationalModels_AllFiltered(t *testing.T) {
|
||||
input := []string{
|
||||
"imagen-3.0",
|
||||
"textembedding-gecko",
|
||||
"chat-bison",
|
||||
}
|
||||
|
||||
result := filterConversationalModels(input)
|
||||
assert.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestIsGeminiModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-2.5-pro", true},
|
||||
{"gemini-3-pro-preview", true},
|
||||
{"Gemini-2.0-flash", true},
|
||||
{"GEMINI-flash", true},
|
||||
{"claude-sonnet-4-5", false},
|
||||
{"claude-opus-4", false},
|
||||
{"deepseek-v3", false},
|
||||
{"llama-3.1-405b", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
result := isGeminiModel(tt.model)
|
||||
assert.Equal(t, tt.expected, result, "isGeminiModel(%s)", tt.model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetMaxTokens(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
opts *domain.ChatOptions
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "MaxTokens specified",
|
||||
opts: &domain.ChatOptions{MaxTokens: 8192},
|
||||
expected: 8192,
|
||||
},
|
||||
{
|
||||
name: "Default when MaxTokens is 0",
|
||||
opts: &domain.ChatOptions{MaxTokens: 0},
|
||||
expected: int64(defaultMaxTokens),
|
||||
},
|
||||
{
|
||||
name: "Default when MaxTokens is negative",
|
||||
opts: &domain.ChatOptions{MaxTokens: -1},
|
||||
expected: int64(defaultMaxTokens),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getMaxTokens(tt.opts)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGeminiThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
level domain.ThinkingLevel
|
||||
expectNil bool
|
||||
expectedBudget int32
|
||||
}{
|
||||
{
|
||||
name: "empty string returns nil",
|
||||
level: "",
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "off returns nil",
|
||||
level: domain.ThinkingOff,
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "low thinking",
|
||||
level: domain.ThinkingLow,
|
||||
expectNil: false,
|
||||
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingLow]),
|
||||
},
|
||||
{
|
||||
name: "medium thinking",
|
||||
level: domain.ThinkingMedium,
|
||||
expectNil: false,
|
||||
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingMedium]),
|
||||
},
|
||||
{
|
||||
name: "high thinking",
|
||||
level: domain.ThinkingHigh,
|
||||
expectNil: false,
|
||||
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingHigh]),
|
||||
},
|
||||
{
|
||||
name: "numeric string",
|
||||
level: "5000",
|
||||
expectNil: false,
|
||||
expectedBudget: 5000,
|
||||
},
|
||||
{
|
||||
name: "invalid string returns nil",
|
||||
level: "invalid",
|
||||
expectNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := parseGeminiThinking(tt.level)
|
||||
if tt.expectNil {
|
||||
assert.Nil(t, result)
|
||||
} else {
|
||||
require.NotNil(t, result)
|
||||
assert.True(t, result.IncludeThoughts)
|
||||
assert.Equal(t, tt.expectedBudget, *result.ThinkingBudget)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGeminiConfig(t *testing.T) {
|
||||
client := &Client{}
|
||||
|
||||
t.Run("basic config with temperature and TopP", func(t *testing.T) {
|
||||
opts := &domain.ChatOptions{
|
||||
Temperature: 0.7,
|
||||
TopP: 0.9,
|
||||
MaxTokens: 8192,
|
||||
}
|
||||
config := client.buildGeminiConfig(opts)
|
||||
|
||||
assert.NotNil(t, config)
|
||||
assert.Equal(t, float32(0.7), *config.Temperature)
|
||||
assert.Equal(t, float32(0.9), *config.TopP)
|
||||
assert.Equal(t, int32(8192), config.MaxOutputTokens)
|
||||
assert.Nil(t, config.Tools)
|
||||
assert.Nil(t, config.ThinkingConfig)
|
||||
})
|
||||
|
||||
t.Run("config with search enabled", func(t *testing.T) {
|
||||
opts := &domain.ChatOptions{
|
||||
Temperature: 0.5,
|
||||
TopP: 0.8,
|
||||
Search: true,
|
||||
}
|
||||
config := client.buildGeminiConfig(opts)
|
||||
|
||||
assert.NotNil(t, config.Tools)
|
||||
assert.Len(t, config.Tools, 1)
|
||||
assert.NotNil(t, config.Tools[0].GoogleSearch)
|
||||
})
|
||||
|
||||
t.Run("config with thinking enabled", func(t *testing.T) {
|
||||
opts := &domain.ChatOptions{
|
||||
Temperature: 0.5,
|
||||
TopP: 0.8,
|
||||
Thinking: domain.ThinkingHigh,
|
||||
}
|
||||
config := client.buildGeminiConfig(opts)
|
||||
|
||||
assert.NotNil(t, config.ThinkingConfig)
|
||||
assert.True(t, config.ThinkingConfig.IncludeThoughts)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user