mirror of
https://github.com/danielmiessler/Fabric.git
synced 2026-01-09 14:28:01 -05:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
678db0c43e | ||
|
|
765977cd42 | ||
|
|
8017f376b1 | ||
|
|
6f103b2db2 | ||
|
|
19aeebe6f5 | ||
|
|
2d79d3b706 | ||
|
|
4fe501da02 | ||
|
|
2501cbf47e | ||
|
|
d96a1721bb | ||
|
|
c1838d3744 | ||
|
|
643a60a2cf |
18
CHANGELOG.md
18
CHANGELOG.md
@@ -1,5 +1,23 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## v1.4.375 (2026-01-08)
|
||||||
|
|
||||||
|
### PR [#1925](https://github.com/danielmiessler/Fabric/pull/1925) by [ksylvan](https://github.com/ksylvan): docs: update README to document new AI providers and features
|
||||||
|
|
||||||
|
- Docs: update README to document new AI providers and features
|
||||||
|
- List supported native and OpenAI-compatible AI provider integrations
|
||||||
|
- Document dry run mode for previewing prompt construction
|
||||||
|
- Explain Ollama compatibility mode for exposing API endpoints
|
||||||
|
- Detail available prompt strategies like chain-of-thought and reflexion
|
||||||
|
|
||||||
|
### PR [#1926](https://github.com/danielmiessler/Fabric/pull/1926) by [henricook](https://github.com/henricook) and [ksylvan](https://github.com/ksylvan): feat(vertexai): add dynamic model listing and multi-model support
|
||||||
|
|
||||||
|
- Dynamic model listing from Vertex AI Model Garden API
|
||||||
|
- Support for both Gemini (genai SDK) and Claude (Anthropic SDK) models
|
||||||
|
- Curated Gemini model list with web search support for Gemini models
|
||||||
|
- Thinking/extended thinking support for Gemini
|
||||||
|
- TopP parameter support for Claude models
|
||||||
|
|
||||||
## v1.4.374 (2026-01-05)
|
## v1.4.374 (2026-01-05)
|
||||||
|
|
||||||
### PR [#1924](https://github.com/danielmiessler/Fabric/pull/1924) by [ksylvan](https://github.com/ksylvan): Rename `code_helper` to `code2context` across documentation and CLI
|
### PR [#1924](https://github.com/danielmiessler/Fabric/pull/1924) by [ksylvan](https://github.com/ksylvan): Rename `code_helper` to `code2context` across documentation and CLI
|
||||||
|
|||||||
114
README.md
114
README.md
@@ -160,6 +160,7 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
|||||||
- [Docker](#docker)
|
- [Docker](#docker)
|
||||||
- [Environment Variables](#environment-variables)
|
- [Environment Variables](#environment-variables)
|
||||||
- [Setup](#setup)
|
- [Setup](#setup)
|
||||||
|
- [Supported AI Providers](#supported-ai-providers)
|
||||||
- [Per-Pattern Model Mapping](#per-pattern-model-mapping)
|
- [Per-Pattern Model Mapping](#per-pattern-model-mapping)
|
||||||
- [Add aliases for all patterns](#add-aliases-for-all-patterns)
|
- [Add aliases for all patterns](#add-aliases-for-all-patterns)
|
||||||
- [Save your files in markdown using aliases](#save-your-files-in-markdown-using-aliases)
|
- [Save your files in markdown using aliases](#save-your-files-in-markdown-using-aliases)
|
||||||
@@ -172,12 +173,15 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
|||||||
- [Fish Completion](#fish-completion)
|
- [Fish Completion](#fish-completion)
|
||||||
- [Usage](#usage)
|
- [Usage](#usage)
|
||||||
- [Debug Levels](#debug-levels)
|
- [Debug Levels](#debug-levels)
|
||||||
|
- [Dry Run Mode](#dry-run-mode)
|
||||||
- [Extensions](#extensions)
|
- [Extensions](#extensions)
|
||||||
- [REST API Server](#rest-api-server)
|
- [REST API Server](#rest-api-server)
|
||||||
|
- [Ollama Compatibility Mode](#ollama-compatibility-mode)
|
||||||
- [Our approach to prompting](#our-approach-to-prompting)
|
- [Our approach to prompting](#our-approach-to-prompting)
|
||||||
- [Examples](#examples)
|
- [Examples](#examples)
|
||||||
- [Just use the Patterns](#just-use-the-patterns)
|
- [Just use the Patterns](#just-use-the-patterns)
|
||||||
- [Prompt Strategies](#prompt-strategies)
|
- [Prompt Strategies](#prompt-strategies)
|
||||||
|
- [Available Strategies](#available-strategies)
|
||||||
- [Custom Patterns](#custom-patterns)
|
- [Custom Patterns](#custom-patterns)
|
||||||
- [Setting Up Custom Patterns](#setting-up-custom-patterns)
|
- [Setting Up Custom Patterns](#setting-up-custom-patterns)
|
||||||
- [Using Custom Patterns](#using-custom-patterns)
|
- [Using Custom Patterns](#using-custom-patterns)
|
||||||
@@ -186,6 +190,7 @@ Keep in mind that many of these were recorded when Fabric was Python-based, so r
|
|||||||
- [`to_pdf`](#to_pdf)
|
- [`to_pdf`](#to_pdf)
|
||||||
- [`to_pdf` Installation](#to_pdf-installation)
|
- [`to_pdf` Installation](#to_pdf-installation)
|
||||||
- [`code2context`](#code2context)
|
- [`code2context`](#code2context)
|
||||||
|
- [`generate_changelog`](#generate_changelog)
|
||||||
- [pbpaste](#pbpaste)
|
- [pbpaste](#pbpaste)
|
||||||
- [Web Interface (Fabric Web App)](#web-interface-fabric-web-app)
|
- [Web Interface (Fabric Web App)](#web-interface-fabric-web-app)
|
||||||
- [Meta](#meta)
|
- [Meta](#meta)
|
||||||
@@ -349,6 +354,43 @@ fabric --setup
|
|||||||
|
|
||||||
If everything works you are good to go.
|
If everything works you are good to go.
|
||||||
|
|
||||||
|
### Supported AI Providers
|
||||||
|
|
||||||
|
Fabric supports a wide range of AI providers:
|
||||||
|
|
||||||
|
**Native Integrations:**
|
||||||
|
|
||||||
|
- OpenAI
|
||||||
|
- Anthropic (Claude)
|
||||||
|
- Google Gemini
|
||||||
|
- Ollama (local models)
|
||||||
|
- Azure OpenAI
|
||||||
|
- Amazon Bedrock
|
||||||
|
- Vertex AI
|
||||||
|
- LM Studio
|
||||||
|
- Perplexity
|
||||||
|
|
||||||
|
**OpenAI-Compatible Providers:**
|
||||||
|
|
||||||
|
- Abacus
|
||||||
|
- AIML
|
||||||
|
- Cerebras
|
||||||
|
- DeepSeek
|
||||||
|
- GitHub Models
|
||||||
|
- GrokAI
|
||||||
|
- Groq
|
||||||
|
- Langdock
|
||||||
|
- LiteLLM
|
||||||
|
- MiniMax
|
||||||
|
- Mistral
|
||||||
|
- OpenRouter
|
||||||
|
- SiliconCloud
|
||||||
|
- Together
|
||||||
|
- Venice AI
|
||||||
|
- Z AI
|
||||||
|
|
||||||
|
Run `fabric --setup` to configure your preferred provider(s), or use `fabric --listvendors` to see all available vendors.
|
||||||
|
|
||||||
### Per-Pattern Model Mapping
|
### Per-Pattern Model Mapping
|
||||||
|
|
||||||
You can configure specific models for individual patterns using environment variables
|
You can configure specific models for individual patterns using environment variables
|
||||||
@@ -720,6 +762,16 @@ Use the `--debug` flag to control runtime logging:
|
|||||||
- `2`: detailed debugging
|
- `2`: detailed debugging
|
||||||
- `3`: trace level
|
- `3`: trace level
|
||||||
|
|
||||||
|
### Dry Run Mode
|
||||||
|
|
||||||
|
Use `--dry-run` to preview what would be sent to the AI model without making an API call:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "test input" | fabric --dry-run -p summarize
|
||||||
|
```
|
||||||
|
|
||||||
|
This is useful for debugging patterns, checking prompt construction, and verifying input formatting before using API credits.
|
||||||
|
|
||||||
### Extensions
|
### Extensions
|
||||||
|
|
||||||
Fabric supports extensions that can be called within patterns. See the [Extension Guide](internal/plugins/template/Examples/README.md) for complete documentation.
|
Fabric supports extensions that can be called within patterns. See the [Extension Guide](internal/plugins/template/Examples/README.md) for complete documentation.
|
||||||
@@ -745,6 +797,22 @@ The server provides endpoints for:
|
|||||||
|
|
||||||
For complete endpoint documentation, authentication setup, and usage examples, see [REST API Documentation](docs/rest-api.md).
|
For complete endpoint documentation, authentication setup, and usage examples, see [REST API Documentation](docs/rest-api.md).
|
||||||
|
|
||||||
|
### Ollama Compatibility Mode
|
||||||
|
|
||||||
|
Fabric can serve as a drop-in replacement for Ollama by exposing Ollama-compatible API endpoints. Start the server with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
fabric --serve --serveOllama
|
||||||
|
```
|
||||||
|
|
||||||
|
This enables the following Ollama-compatible endpoints:
|
||||||
|
|
||||||
|
- `GET /api/tags` - List available patterns as models
|
||||||
|
- `POST /api/chat` - Chat completions
|
||||||
|
- `GET /api/version` - Server version
|
||||||
|
|
||||||
|
Applications configured to use the Ollama API can point to your Fabric server instead, allowing you to use any of Fabric's supported AI providers through the Ollama interface. Patterns appear as models (e.g., `summarize:latest`).
|
||||||
|
|
||||||
## Our approach to prompting
|
## Our approach to prompting
|
||||||
|
|
||||||
Fabric _Patterns_ are different than most prompts you'll see.
|
Fabric _Patterns_ are different than most prompts you'll see.
|
||||||
@@ -825,6 +893,34 @@ LLM in the chat session.
|
|||||||
|
|
||||||
Use `fabric -S` and select the option to install the strategies in your `~/.config/fabric` directory.
|
Use `fabric -S` and select the option to install the strategies in your `~/.config/fabric` directory.
|
||||||
|
|
||||||
|
#### Available Strategies
|
||||||
|
|
||||||
|
Fabric includes several prompt strategies:
|
||||||
|
|
||||||
|
- `cot` - Chain-of-Thought: Step-by-step reasoning
|
||||||
|
- `cod` - Chain-of-Draft: Iterative drafting with minimal notes (5 words max per step)
|
||||||
|
- `tot` - Tree-of-Thought: Generate multiple reasoning paths and select the best one
|
||||||
|
- `aot` - Atom-of-Thought: Break problems into smallest independent atomic sub-problems
|
||||||
|
- `ltm` - Least-to-Most: Solve problems from easiest to hardest sub-problems
|
||||||
|
- `self-consistent` - Self-Consistency: Multiple reasoning paths with consensus
|
||||||
|
- `self-refine` - Self-Refinement: Answer, critique, and refine
|
||||||
|
- `reflexion` - Reflexion: Answer, critique briefly, and provide refined answer
|
||||||
|
- `standard` - Standard: Direct answer without explanation
|
||||||
|
|
||||||
|
Use the `--strategy` flag to apply a strategy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
echo "Analyze this code" | fabric --strategy cot -p analyze_code
|
||||||
|
```
|
||||||
|
|
||||||
|
List all available strategies with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
fabric --liststrategies
|
||||||
|
```
|
||||||
|
|
||||||
|
Strategies are stored as JSON files in `~/.config/fabric/strategies/`. See the default strategies for the format specification.
|
||||||
|
|
||||||
## Custom Patterns
|
## Custom Patterns
|
||||||
|
|
||||||
You may want to use Fabric to create your own custom Patterns—but not share them with others. No problem!
|
You may want to use Fabric to create your own custom Patterns—but not share them with others. No problem!
|
||||||
@@ -918,6 +1014,24 @@ Install it first using:
|
|||||||
go install github.com/danielmiessler/fabric/cmd/code2context@latest
|
go install github.com/danielmiessler/fabric/cmd/code2context@latest
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `generate_changelog`
|
||||||
|
|
||||||
|
`generate_changelog` generates changelogs from git commit history and GitHub pull requests. It walks through your repository's git history, extracts PR information, and produces well-formatted markdown changelogs.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
generate_changelog --help
|
||||||
|
```
|
||||||
|
|
||||||
|
Features include SQLite caching for fast incremental updates, GitHub GraphQL API integration for efficient PR fetching, and optional AI-enhanced summaries using Fabric.
|
||||||
|
|
||||||
|
Install it using:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go install github.com/danielmiessler/fabric/cmd/generate_changelog@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [generate_changelog README](./cmd/generate_changelog/README.md) for detailed usage and options.
|
||||||
|
|
||||||
## pbpaste
|
## pbpaste
|
||||||
|
|
||||||
The [examples](#examples) use the macOS program `pbpaste` to paste content from the clipboard to pipe into `fabric` as the input. `pbpaste` is not available on Windows or Linux, but there are alternatives.
|
The [examples](#examples) use the macOS program `pbpaste` to paste content from the clipboard to pipe into `fabric` as the input. `pbpaste` is not available on Windows or Linux, but there are alternatives.
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
var version = "v1.4.374"
|
var version = "v1.4.375"
|
||||||
|
|||||||
Binary file not shown.
@@ -10,9 +10,9 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/internal/chat"
|
"github.com/danielmiessler/fabric/internal/chat"
|
||||||
"github.com/danielmiessler/fabric/internal/plugins"
|
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/internal/domain"
|
"github.com/danielmiessler/fabric/internal/domain"
|
||||||
|
"github.com/danielmiessler/fabric/internal/plugins"
|
||||||
|
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
|
||||||
"google.golang.org/genai"
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,10 +29,6 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
citationHeader = "\n\n## Sources\n\n"
|
|
||||||
citationSeparator = "\n"
|
|
||||||
citationFormat = "- [%s](%s)"
|
|
||||||
|
|
||||||
errInvalidLocationFormat = "invalid search location format %q: must be timezone (e.g., 'America/Los_Angeles') or language code (e.g., 'en-US')"
|
errInvalidLocationFormat = "invalid search location format %q: must be timezone (e.g., 'America/Los_Angeles') or language code (e.g., 'en-US')"
|
||||||
locationSeparator = "/"
|
locationSeparator = "/"
|
||||||
langCodeSeparator = "_"
|
langCodeSeparator = "_"
|
||||||
@@ -111,7 +107,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert messages to new SDK format
|
// Convert messages to new SDK format
|
||||||
contents := o.convertMessages(msgs)
|
contents := geminicommon.ConvertMessages(msgs)
|
||||||
|
|
||||||
cfg, err := o.buildGenerateContentConfig(opts)
|
cfg, err := o.buildGenerateContentConfig(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -125,7 +121,7 @@ func (o *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract text from response
|
// Extract text from response
|
||||||
ret = o.extractTextFromResponse(response)
|
ret = geminicommon.ExtractTextWithCitations(response)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -142,7 +138,7 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Convert messages to new SDK format
|
// Convert messages to new SDK format
|
||||||
contents := o.convertMessages(msgs)
|
contents := geminicommon.ConvertMessages(msgs)
|
||||||
|
|
||||||
cfg, err := o.buildGenerateContentConfig(opts)
|
cfg, err := o.buildGenerateContentConfig(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -161,7 +157,7 @@ func (o *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
text := o.extractTextFromResponse(response)
|
text := geminicommon.ExtractTextWithCitations(response)
|
||||||
if text != "" {
|
if text != "" {
|
||||||
channel <- domain.StreamUpdate{
|
channel <- domain.StreamUpdate{
|
||||||
Type: domain.StreamTypeContent,
|
Type: domain.StreamTypeContent,
|
||||||
@@ -218,10 +214,14 @@ func parseThinkingConfig(level domain.ThinkingLevel) (*genai.ThinkingConfig, boo
|
|||||||
func (o *Client) buildGenerateContentConfig(opts *domain.ChatOptions) (*genai.GenerateContentConfig, error) {
|
func (o *Client) buildGenerateContentConfig(opts *domain.ChatOptions) (*genai.GenerateContentConfig, error) {
|
||||||
temperature := float32(opts.Temperature)
|
temperature := float32(opts.Temperature)
|
||||||
topP := float32(opts.TopP)
|
topP := float32(opts.TopP)
|
||||||
|
var maxTokens int32
|
||||||
|
if opts.MaxTokens > 0 {
|
||||||
|
maxTokens = int32(opts.MaxTokens)
|
||||||
|
}
|
||||||
cfg := &genai.GenerateContentConfig{
|
cfg := &genai.GenerateContentConfig{
|
||||||
Temperature: &temperature,
|
Temperature: &temperature,
|
||||||
TopP: &topP,
|
TopP: &topP,
|
||||||
MaxOutputTokens: int32(opts.ModelContextLength),
|
MaxOutputTokens: maxTokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.Search {
|
if opts.Search {
|
||||||
@@ -452,113 +452,3 @@ func (o *Client) generateWAVFile(pcmData []byte) ([]byte, error) {
|
|||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertMessages converts fabric chat messages to genai Content format
|
|
||||||
func (o *Client) convertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
|
|
||||||
var contents []*genai.Content
|
|
||||||
|
|
||||||
for _, msg := range msgs {
|
|
||||||
content := &genai.Content{Parts: []*genai.Part{}}
|
|
||||||
|
|
||||||
switch msg.Role {
|
|
||||||
case chat.ChatMessageRoleAssistant:
|
|
||||||
content.Role = "model"
|
|
||||||
case chat.ChatMessageRoleUser:
|
|
||||||
content.Role = "user"
|
|
||||||
case chat.ChatMessageRoleSystem, chat.ChatMessageRoleDeveloper, chat.ChatMessageRoleFunction, chat.ChatMessageRoleTool:
|
|
||||||
// Gemini's API only accepts "user" and "model" roles.
|
|
||||||
// Map all other roles to "user" to preserve instruction context.
|
|
||||||
content.Role = "user"
|
|
||||||
default:
|
|
||||||
content.Role = "user"
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.TrimSpace(msg.Content) != "" {
|
|
||||||
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle multi-content messages (images, etc.)
|
|
||||||
for _, part := range msg.MultiContent {
|
|
||||||
switch part.Type {
|
|
||||||
case chat.ChatMessagePartTypeText:
|
|
||||||
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
|
|
||||||
case chat.ChatMessagePartTypeImageURL:
|
|
||||||
// TODO: Handle image URLs if needed
|
|
||||||
// This would require downloading and converting to inline data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
contents = append(contents, content)
|
|
||||||
}
|
|
||||||
|
|
||||||
return contents
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractTextFromResponse extracts text content from the response and appends
|
|
||||||
// any web citations in a standardized format.
|
|
||||||
func (o *Client) extractTextFromResponse(response *genai.GenerateContentResponse) string {
|
|
||||||
if response == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
text := o.extractTextParts(response)
|
|
||||||
citations := o.extractCitations(response)
|
|
||||||
if len(citations) > 0 {
|
|
||||||
return text + citationHeader + strings.Join(citations, citationSeparator)
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *Client) extractTextParts(response *genai.GenerateContentResponse) string {
|
|
||||||
var builder strings.Builder
|
|
||||||
for _, candidate := range response.Candidates {
|
|
||||||
if candidate == nil || candidate.Content == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, part := range candidate.Content.Parts {
|
|
||||||
if part != nil && part.Text != "" {
|
|
||||||
builder.WriteString(part.Text)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return builder.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *Client) extractCitations(response *genai.GenerateContentResponse) []string {
|
|
||||||
if response == nil || len(response.Candidates) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
citationMap := make(map[string]bool)
|
|
||||||
var citations []string
|
|
||||||
for _, candidate := range response.Candidates {
|
|
||||||
if candidate == nil || candidate.GroundingMetadata == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
chunks := candidate.GroundingMetadata.GroundingChunks
|
|
||||||
if len(chunks) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, chunk := range chunks {
|
|
||||||
if chunk == nil || chunk.Web == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
uri := chunk.Web.URI
|
|
||||||
title := chunk.Web.Title
|
|
||||||
if uri == "" || title == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var keyBuilder strings.Builder
|
|
||||||
keyBuilder.WriteString(uri)
|
|
||||||
keyBuilder.WriteByte('|')
|
|
||||||
keyBuilder.WriteString(title)
|
|
||||||
key := keyBuilder.String()
|
|
||||||
if !citationMap[key] {
|
|
||||||
citationMap[key] = true
|
|
||||||
citationText := fmt.Sprintf(citationFormat, title, uri)
|
|
||||||
citations = append(citations, citationText)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return citations
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"google.golang.org/genai"
|
|
||||||
|
|
||||||
"github.com/danielmiessler/fabric/internal/chat"
|
"github.com/danielmiessler/fabric/internal/chat"
|
||||||
"github.com/danielmiessler/fabric/internal/domain"
|
"github.com/danielmiessler/fabric/internal/domain"
|
||||||
|
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Test buildModelNameFull method
|
// Test buildModelNameFull method
|
||||||
@@ -31,9 +31,8 @@ func TestBuildModelNameFull(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test extractTextFromResponse method
|
// Test ExtractTextWithCitations from geminicommon
|
||||||
func TestExtractTextFromResponse(t *testing.T) {
|
func TestExtractTextFromResponse(t *testing.T) {
|
||||||
client := &Client{}
|
|
||||||
response := &genai.GenerateContentResponse{
|
response := &genai.GenerateContentResponse{
|
||||||
Candidates: []*genai.Candidate{
|
Candidates: []*genai.Candidate{
|
||||||
{
|
{
|
||||||
@@ -48,7 +47,7 @@ func TestExtractTextFromResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expected := "Hello, world!"
|
expected := "Hello, world!"
|
||||||
|
|
||||||
result := client.extractTextFromResponse(response)
|
result := geminicommon.ExtractTextWithCitations(response)
|
||||||
|
|
||||||
if result != expected {
|
if result != expected {
|
||||||
t.Errorf("Expected %v, got %v", expected, result)
|
t.Errorf("Expected %v, got %v", expected, result)
|
||||||
@@ -56,14 +55,12 @@ func TestExtractTextFromResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractTextFromResponse_Nil(t *testing.T) {
|
func TestExtractTextFromResponse_Nil(t *testing.T) {
|
||||||
client := &Client{}
|
if got := geminicommon.ExtractTextWithCitations(nil); got != "" {
|
||||||
if got := client.extractTextFromResponse(nil); got != "" {
|
|
||||||
t.Fatalf("expected empty string, got %q", got)
|
t.Fatalf("expected empty string, got %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
|
func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
|
||||||
client := &Client{}
|
|
||||||
response := &genai.GenerateContentResponse{
|
response := &genai.GenerateContentResponse{
|
||||||
Candidates: []*genai.Candidate{
|
Candidates: []*genai.Candidate{
|
||||||
{
|
{
|
||||||
@@ -72,7 +69,7 @@ func TestExtractTextFromResponse_EmptyGroundingChunks(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if got := client.extractTextFromResponse(response); got != "Hello" {
|
if got := geminicommon.ExtractTextWithCitations(response); got != "Hello" {
|
||||||
t.Fatalf("expected 'Hello', got %q", got)
|
t.Fatalf("expected 'Hello', got %q", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -162,7 +159,6 @@ func TestBuildGenerateContentConfig_ThinkingTokens(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCitationFormatting(t *testing.T) {
|
func TestCitationFormatting(t *testing.T) {
|
||||||
client := &Client{}
|
|
||||||
response := &genai.GenerateContentResponse{
|
response := &genai.GenerateContentResponse{
|
||||||
Candidates: []*genai.Candidate{
|
Candidates: []*genai.Candidate{
|
||||||
{
|
{
|
||||||
@@ -178,7 +174,7 @@ func TestCitationFormatting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := client.extractTextFromResponse(response)
|
result := geminicommon.ExtractTextWithCitations(response)
|
||||||
if !strings.Contains(result, "## Sources") {
|
if !strings.Contains(result, "## Sources") {
|
||||||
t.Fatalf("expected sources section in result: %s", result)
|
t.Fatalf("expected sources section in result: %s", result)
|
||||||
}
|
}
|
||||||
@@ -189,14 +185,13 @@ func TestCitationFormatting(t *testing.T) {
|
|||||||
|
|
||||||
// Test convertMessages handles role mapping correctly
|
// Test convertMessages handles role mapping correctly
|
||||||
func TestConvertMessagesRoles(t *testing.T) {
|
func TestConvertMessagesRoles(t *testing.T) {
|
||||||
client := &Client{}
|
|
||||||
msgs := []*chat.ChatCompletionMessage{
|
msgs := []*chat.ChatCompletionMessage{
|
||||||
{Role: chat.ChatMessageRoleUser, Content: "user"},
|
{Role: chat.ChatMessageRoleUser, Content: "user"},
|
||||||
{Role: chat.ChatMessageRoleAssistant, Content: "assistant"},
|
{Role: chat.ChatMessageRoleAssistant, Content: "assistant"},
|
||||||
{Role: chat.ChatMessageRoleSystem, Content: "system"},
|
{Role: chat.ChatMessageRoleSystem, Content: "system"},
|
||||||
}
|
}
|
||||||
|
|
||||||
contents := client.convertMessages(msgs)
|
contents := geminicommon.ConvertMessages(msgs)
|
||||||
|
|
||||||
expected := []string{"user", "model", "user"}
|
expected := []string{"user", "model", "user"}
|
||||||
|
|
||||||
|
|||||||
130
internal/plugins/ai/geminicommon/geminicommon.go
Normal file
130
internal/plugins/ai/geminicommon/geminicommon.go
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
// Package geminicommon provides shared utilities for Gemini API integrations.
|
||||||
|
// Used by both the standalone Gemini provider (API key auth) and VertexAI provider (ADC auth).
|
||||||
|
package geminicommon
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/internal/chat"
|
||||||
|
"google.golang.org/genai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Citation formatting constants
|
||||||
|
const (
|
||||||
|
CitationHeader = "\n\n## Sources\n\n"
|
||||||
|
CitationSeparator = "\n"
|
||||||
|
CitationFormat = "- [%s](%s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertMessages converts fabric chat messages to genai Content format.
|
||||||
|
// Gemini's API only accepts "user" and "model" roles, so other roles are mapped to "user".
|
||||||
|
func ConvertMessages(msgs []*chat.ChatCompletionMessage) []*genai.Content {
|
||||||
|
var contents []*genai.Content
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
content := &genai.Content{Parts: []*genai.Part{}}
|
||||||
|
|
||||||
|
switch msg.Role {
|
||||||
|
case chat.ChatMessageRoleAssistant:
|
||||||
|
content.Role = "model"
|
||||||
|
case chat.ChatMessageRoleUser:
|
||||||
|
content.Role = "user"
|
||||||
|
case chat.ChatMessageRoleSystem, chat.ChatMessageRoleDeveloper, chat.ChatMessageRoleFunction, chat.ChatMessageRoleTool:
|
||||||
|
// Gemini's API only accepts "user" and "model" roles.
|
||||||
|
// Map all other roles to "user" to preserve instruction context.
|
||||||
|
content.Role = "user"
|
||||||
|
default:
|
||||||
|
content.Role = "user"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(msg.Content) != "" {
|
||||||
|
content.Parts = append(content.Parts, &genai.Part{Text: msg.Content})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle multi-content messages (images, etc.)
|
||||||
|
for _, part := range msg.MultiContent {
|
||||||
|
switch part.Type {
|
||||||
|
case chat.ChatMessagePartTypeText:
|
||||||
|
content.Parts = append(content.Parts, &genai.Part{Text: part.Text})
|
||||||
|
case chat.ChatMessagePartTypeImageURL:
|
||||||
|
// TODO: Handle image URLs if needed
|
||||||
|
// This would require downloading and converting to inline data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
contents = append(contents, content)
|
||||||
|
}
|
||||||
|
|
||||||
|
return contents
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractText extracts just the text parts from a Gemini response.
|
||||||
|
func ExtractText(response *genai.GenerateContentResponse) string {
|
||||||
|
if response == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var builder strings.Builder
|
||||||
|
for _, candidate := range response.Candidates {
|
||||||
|
if candidate == nil || candidate.Content == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if part != nil && part.Text != "" {
|
||||||
|
builder.WriteString(part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return builder.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractTextWithCitations extracts text content from the response and appends
|
||||||
|
// any web citations in a standardized format.
|
||||||
|
func ExtractTextWithCitations(response *genai.GenerateContentResponse) string {
|
||||||
|
if response == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
text := ExtractText(response)
|
||||||
|
citations := ExtractCitations(response)
|
||||||
|
if len(citations) > 0 {
|
||||||
|
return text + CitationHeader + strings.Join(citations, CitationSeparator)
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractCitations extracts web citations from grounding metadata.
|
||||||
|
func ExtractCitations(response *genai.GenerateContentResponse) []string {
|
||||||
|
if response == nil || len(response.Candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
citationMap := make(map[string]bool)
|
||||||
|
var citations []string
|
||||||
|
for _, candidate := range response.Candidates {
|
||||||
|
if candidate == nil || candidate.GroundingMetadata == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chunks := candidate.GroundingMetadata.GroundingChunks
|
||||||
|
if len(chunks) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
if chunk == nil || chunk.Web == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
uri := chunk.Web.URI
|
||||||
|
title := chunk.Web.Title
|
||||||
|
if uri == "" || title == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := uri + "|" + title
|
||||||
|
if !citationMap[key] {
|
||||||
|
citationMap[key] = true
|
||||||
|
citations = append(citations, fmt.Sprintf(CitationFormat, title, uri))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return citations
|
||||||
|
}
|
||||||
237
internal/plugins/ai/vertexai/models.go
Normal file
237
internal/plugins/ai/vertexai/models.go
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
debuglog "github.com/danielmiessler/fabric/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// API limits
|
||||||
|
maxResponseSize = 10 * 1024 * 1024 // 10MB
|
||||||
|
errorResponseLimit = 1024 // 1KB for error messages
|
||||||
|
|
||||||
|
// Default region for Model Garden API (global doesn't work for this endpoint)
|
||||||
|
defaultModelGardenRegion = "us-central1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Supported Model Garden publishers (others can be added when SDK support is implemented)
|
||||||
|
var publishers = []string{"google", "anthropic"}
|
||||||
|
|
||||||
|
// publisherModelsResponse represents the API response from publishers.models.list
|
||||||
|
type publisherModelsResponse struct {
|
||||||
|
PublisherModels []publisherModel `json:"publisherModels"`
|
||||||
|
NextPageToken string `json:"nextPageToken"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// publisherModel represents a single model in the API response
|
||||||
|
type publisherModel struct {
|
||||||
|
Name string `json:"name"` // Format: publishers/{publisher}/models/{model}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchModelsPage makes a single API request and returns the parsed response.
|
||||||
|
// Extracted to ensure proper cleanup of HTTP response bodies in pagination loops.
|
||||||
|
func fetchModelsPage(ctx context.Context, httpClient *http.Client, url, projectID, publisher string) (*publisherModelsResponse, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
// Set quota project header required by Vertex AI API
|
||||||
|
req.Header.Set("x-goog-user-project", projectID)
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, errorResponseLimit))
|
||||||
|
debuglog.Debug(debuglog.Basic, "API error for %s: status %d, url: %s, body: %s\n", publisher, resp.StatusCode, url, string(bodyBytes))
|
||||||
|
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize+1))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(bodyBytes) > maxResponseSize {
|
||||||
|
return nil, fmt.Errorf("response too large (>%d bytes)", maxResponseSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
var response publisherModelsResponse
|
||||||
|
if err := json.Unmarshal(bodyBytes, &response); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// listPublisherModels fetches models from a specific publisher via the Model Garden API
|
||||||
|
func listPublisherModels(ctx context.Context, httpClient *http.Client, region, projectID, publisher string) ([]string, error) {
|
||||||
|
// Use default region if global or empty (Model Garden API requires a specific region)
|
||||||
|
if region == "" || region == "global" {
|
||||||
|
region = defaultModelGardenRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1beta1/publishers/%s/models", region, publisher)
|
||||||
|
|
||||||
|
var allModels []string
|
||||||
|
pageToken := ""
|
||||||
|
|
||||||
|
for {
|
||||||
|
url := baseURL
|
||||||
|
if pageToken != "" {
|
||||||
|
url = fmt.Sprintf("%s?pageToken=%s", baseURL, pageToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := fetchModelsPage(ctx, httpClient, url, projectID, publisher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract model names, stripping the publishers/{publisher}/models/ prefix
|
||||||
|
for _, model := range response.PublisherModels {
|
||||||
|
modelName := extractModelName(model.Name)
|
||||||
|
if modelName != "" {
|
||||||
|
allModels = append(allModels, modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for more pages
|
||||||
|
if response.NextPageToken == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pageToken = response.NextPageToken
|
||||||
|
}
|
||||||
|
|
||||||
|
debuglog.Debug(debuglog.Detailed, "Listed %d models from publisher %s\n", len(allModels), publisher)
|
||||||
|
return allModels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractModelName extracts the model name from the full resource path
|
||||||
|
// Input: "publishers/google/models/gemini-2.0-flash"
|
||||||
|
// Output: "gemini-2.0-flash"
|
||||||
|
func extractModelName(fullName string) string {
|
||||||
|
parts := strings.Split(fullName, "/")
|
||||||
|
if len(parts) >= 4 && parts[0] == "publishers" && parts[2] == "models" {
|
||||||
|
return parts[3]
|
||||||
|
}
|
||||||
|
// Fallback: return the last segment
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return parts[len(parts)-1]
|
||||||
|
}
|
||||||
|
return fullName
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortModels sorts models by priority: Gemini > Claude > Others
|
||||||
|
// Within each group, models are sorted alphabetically
|
||||||
|
func sortModels(models []string) []string {
|
||||||
|
sort.Slice(models, func(i, j int) bool {
|
||||||
|
pi := modelPriority(models[i])
|
||||||
|
pj := modelPriority(models[j])
|
||||||
|
if pi != pj {
|
||||||
|
return pi < pj
|
||||||
|
}
|
||||||
|
// Same priority: sort alphabetically (case-insensitive)
|
||||||
|
return strings.ToLower(models[i]) < strings.ToLower(models[j])
|
||||||
|
})
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelPriority returns the sort priority for a model (lower = higher priority)
|
||||||
|
func modelPriority(model string) int {
|
||||||
|
lower := strings.ToLower(model)
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(lower, "gemini"):
|
||||||
|
return 1
|
||||||
|
case strings.HasPrefix(lower, "claude"):
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// knownGeminiModels is a curated list of Gemini models available on Vertex AI.
|
||||||
|
// Vertex AI doesn't provide a list API for Gemini models - they must be known ahead of time.
|
||||||
|
// This list is based on Google Cloud documentation as of January 2025.
|
||||||
|
// See: https://docs.cloud.google.com/vertex-ai/generative-ai/docs/models
|
||||||
|
var knownGeminiModels = []string{
|
||||||
|
// Gemini 3 (Preview)
|
||||||
|
"gemini-3-pro-preview",
|
||||||
|
"gemini-3-flash-preview",
|
||||||
|
// Gemini 2.5 (GA)
|
||||||
|
"gemini-2.5-pro",
|
||||||
|
"gemini-2.5-flash",
|
||||||
|
"gemini-2.5-flash-lite",
|
||||||
|
// Gemini 2.0 (GA)
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"gemini-2.0-flash-lite",
|
||||||
|
}
|
||||||
|
|
||||||
|
// getKnownGeminiModels returns the curated list of Gemini models available on Vertex AI.
|
||||||
|
// Unlike third-party models which can be listed via the Model Garden API,
|
||||||
|
// Gemini models must be known ahead of time as there's no list endpoint for them.
|
||||||
|
func getKnownGeminiModels() []string {
|
||||||
|
return knownGeminiModels
|
||||||
|
}
|
||||||
|
|
||||||
|
// isGeminiModel returns true if the model is a Gemini model
|
||||||
|
func isGeminiModel(modelName string) bool {
|
||||||
|
return strings.HasPrefix(strings.ToLower(modelName), "gemini")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isConversationalModel returns true if the model is suitable for text generation/chat
|
||||||
|
// Filters out image generation, embeddings, and other non-conversational models
|
||||||
|
func isConversationalModel(modelName string) bool {
|
||||||
|
lower := strings.ToLower(modelName)
|
||||||
|
|
||||||
|
// Exclude patterns for non-conversational models
|
||||||
|
excludePatterns := []string{
|
||||||
|
"imagen", // Image generation models
|
||||||
|
"imagegeneration",
|
||||||
|
"imagetext",
|
||||||
|
"image-segmentation",
|
||||||
|
"embedding", // Embedding models
|
||||||
|
"textembedding",
|
||||||
|
"multimodalembedding",
|
||||||
|
"text-bison", // Legacy completion models (not chat)
|
||||||
|
"text-unicorn",
|
||||||
|
"code-bison", // Legacy code models
|
||||||
|
"code-gecko",
|
||||||
|
"codechat-bison", // Deprecated chat model
|
||||||
|
"chat-bison", // Deprecated chat model
|
||||||
|
"veo", // Video generation
|
||||||
|
"chirp", // Audio/speech models
|
||||||
|
"medlm", // Medical models (restricted)
|
||||||
|
"medical",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pattern := range excludePatterns {
|
||||||
|
if strings.Contains(lower, pattern) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterConversationalModels returns only models suitable for text generation/chat
|
||||||
|
func filterConversationalModels(models []string) []string {
|
||||||
|
var filtered []string
|
||||||
|
for _, model := range models {
|
||||||
|
if isConversationalModel(model) {
|
||||||
|
filtered = append(filtered, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
@@ -9,13 +9,18 @@ import (
|
|||||||
"github.com/anthropics/anthropic-sdk-go/vertex"
|
"github.com/anthropics/anthropic-sdk-go/vertex"
|
||||||
"github.com/danielmiessler/fabric/internal/chat"
|
"github.com/danielmiessler/fabric/internal/chat"
|
||||||
"github.com/danielmiessler/fabric/internal/domain"
|
"github.com/danielmiessler/fabric/internal/domain"
|
||||||
|
debuglog "github.com/danielmiessler/fabric/internal/log"
|
||||||
"github.com/danielmiessler/fabric/internal/plugins"
|
"github.com/danielmiessler/fabric/internal/plugins"
|
||||||
|
"github.com/danielmiessler/fabric/internal/plugins/ai/geminicommon"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
|
"google.golang.org/genai"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
cloudPlatformScope = "https://www.googleapis.com/auth/cloud-platform"
|
||||||
defaultRegion = "global"
|
defaultRegion = "global"
|
||||||
maxTokens = 4096
|
defaultMaxTokens = 4096
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewClient creates a new Vertex AI client for accessing Claude models via Google Cloud
|
// NewClient creates a new Vertex AI client for accessing Claude models via Google Cloud
|
||||||
@@ -59,17 +64,78 @@ func (c *Client) configure() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) ListModels() ([]string, error) {
|
func (c *Client) ListModels() ([]string, error) {
|
||||||
// Return Claude models available on Vertex AI
|
ctx := context.Background()
|
||||||
return []string{
|
|
||||||
string(anthropic.ModelClaudeSonnet4_5),
|
// Get ADC credentials for API authentication
|
||||||
string(anthropic.ModelClaudeOpus4_5),
|
creds, err := google.FindDefaultCredentials(ctx, cloudPlatformScope)
|
||||||
string(anthropic.ModelClaudeHaiku4_5),
|
if err != nil {
|
||||||
string(anthropic.ModelClaude3_7SonnetLatest),
|
return nil, fmt.Errorf("failed to get Google credentials (ensure ADC is configured): %w", err)
|
||||||
string(anthropic.ModelClaude3_5HaikuLatest),
|
}
|
||||||
}, nil
|
httpClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||||
|
|
||||||
|
// Query all publishers in parallel for better performance
|
||||||
|
type result struct {
|
||||||
|
models []string
|
||||||
|
err error
|
||||||
|
publisher string
|
||||||
|
}
|
||||||
|
// +1 for known Gemini models (no API to list them)
|
||||||
|
results := make(chan result, len(publishers)+1)
|
||||||
|
|
||||||
|
// Query Model Garden API for third-party models
|
||||||
|
for _, pub := range publishers {
|
||||||
|
go func(publisher string) {
|
||||||
|
models, err := listPublisherModels(ctx, httpClient, c.Region.Value, c.ProjectID.Value, publisher)
|
||||||
|
results <- result{models: models, err: err, publisher: publisher}
|
||||||
|
}(pub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add known Gemini models (Vertex AI doesn't have a list API for Gemini)
|
||||||
|
go func() {
|
||||||
|
results <- result{models: getKnownGeminiModels(), err: nil, publisher: "gemini"}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Collect results from all sources
|
||||||
|
var allModels []string
|
||||||
|
for range len(publishers) + 1 {
|
||||||
|
r := <-results
|
||||||
|
if r.err != nil {
|
||||||
|
// Log warning but continue - some sources may not be available
|
||||||
|
debuglog.Debug(debuglog.Basic, "Failed to list %s models: %v\n", r.publisher, r.err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
allModels = append(allModels, r.models...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allModels) == 0 {
|
||||||
|
return nil, fmt.Errorf("no models found from any publisher")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter to only conversational models and sort
|
||||||
|
filtered := filterConversationalModels(allModels)
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return nil, fmt.Errorf("no conversational models found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sortModels(filtered), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
||||||
|
if isGeminiModel(opts.Model) {
|
||||||
|
return c.sendGemini(ctx, msgs, opts)
|
||||||
|
}
|
||||||
|
return c.sendClaude(ctx, msgs, opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMaxTokens returns the max output tokens to use for a request
|
||||||
|
func getMaxTokens(opts *domain.ChatOptions) int64 {
|
||||||
|
if opts.MaxTokens > 0 {
|
||||||
|
return int64(opts.MaxTokens)
|
||||||
|
}
|
||||||
|
return int64(defaultMaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendClaude(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
||||||
if c.client == nil {
|
if c.client == nil {
|
||||||
return "", fmt.Errorf("VertexAI client not initialized")
|
return "", fmt.Errorf("VertexAI client not initialized")
|
||||||
}
|
}
|
||||||
@@ -80,14 +146,22 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
|||||||
return "", fmt.Errorf("no valid messages to send")
|
return "", fmt.Errorf("no valid messages to send")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the request
|
// Build request params
|
||||||
response, err := c.client.Messages.New(ctx, anthropic.MessageNewParams{
|
params := anthropic.MessageNewParams{
|
||||||
Model: anthropic.Model(opts.Model),
|
Model: anthropic.Model(opts.Model),
|
||||||
MaxTokens: int64(maxTokens),
|
MaxTokens: getMaxTokens(opts),
|
||||||
Messages: anthropicMessages,
|
Messages: anthropicMessages,
|
||||||
Temperature: anthropic.Opt(opts.Temperature),
|
}
|
||||||
})
|
|
||||||
|
|
||||||
|
// Only set one of Temperature or TopP as some models don't allow both
|
||||||
|
// (following anthropic.go pattern)
|
||||||
|
if opts.TopP != domain.DefaultTopP {
|
||||||
|
params.TopP = anthropic.Opt(opts.TopP)
|
||||||
|
} else {
|
||||||
|
params.Temperature = anthropic.Opt(opts.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := c.client.Messages.New(ctx, params)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
@@ -108,6 +182,13 @@ func (c *Client) Send(ctx context.Context, msgs []*chat.ChatCompletionMessage, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
||||||
|
if isGeminiModel(opts.Model) {
|
||||||
|
return c.sendStreamGemini(msgs, opts, channel)
|
||||||
|
}
|
||||||
|
return c.sendStreamClaude(msgs, opts, channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendStreamClaude(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
||||||
if c.client == nil {
|
if c.client == nil {
|
||||||
close(channel)
|
close(channel)
|
||||||
return fmt.Errorf("VertexAI client not initialized")
|
return fmt.Errorf("VertexAI client not initialized")
|
||||||
@@ -122,13 +203,22 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
|||||||
return fmt.Errorf("no valid messages to send")
|
return fmt.Errorf("no valid messages to send")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build request params
|
||||||
|
params := anthropic.MessageNewParams{
|
||||||
|
Model: anthropic.Model(opts.Model),
|
||||||
|
MaxTokens: getMaxTokens(opts),
|
||||||
|
Messages: anthropicMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only set one of Temperature or TopP as some models don't allow both
|
||||||
|
if opts.TopP != domain.DefaultTopP {
|
||||||
|
params.TopP = anthropic.Opt(opts.TopP)
|
||||||
|
} else {
|
||||||
|
params.Temperature = anthropic.Opt(opts.Temperature)
|
||||||
|
}
|
||||||
|
|
||||||
// Create streaming request
|
// Create streaming request
|
||||||
stream := c.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
|
stream := c.client.Messages.NewStreaming(ctx, params)
|
||||||
Model: anthropic.Model(opts.Model),
|
|
||||||
MaxTokens: int64(maxTokens),
|
|
||||||
Messages: anthropicMessages,
|
|
||||||
Temperature: anthropic.Opt(opts.Temperature),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Process stream
|
// Process stream
|
||||||
for stream.Next() {
|
for stream.Next() {
|
||||||
@@ -167,6 +257,144 @@ func (c *Client) SendStream(msgs []*chat.ChatCompletionMessage, opts *domain.Cha
|
|||||||
return stream.Err()
|
return stream.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Gemini methods using genai SDK with Vertex AI backend
|
||||||
|
|
||||||
|
// getGeminiRegion returns the appropriate region for a Gemini model.
|
||||||
|
// Preview models are often only available on the global endpoint.
|
||||||
|
func (c *Client) getGeminiRegion(model string) string {
|
||||||
|
if strings.Contains(strings.ToLower(model), "preview") {
|
||||||
|
return "global"
|
||||||
|
}
|
||||||
|
return c.Region.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendGemini(ctx context.Context, msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions) (string, error) {
|
||||||
|
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||||
|
Project: c.ProjectID.Value,
|
||||||
|
Location: c.getGeminiRegion(opts.Model),
|
||||||
|
Backend: genai.BackendVertexAI,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create Gemini client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := geminicommon.ConvertMessages(msgs)
|
||||||
|
if len(contents) == 0 {
|
||||||
|
return "", fmt.Errorf("no valid messages to send")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := c.buildGeminiConfig(opts)
|
||||||
|
|
||||||
|
response, err := client.Models.GenerateContent(ctx, opts.Model, contents, config)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return geminicommon.ExtractTextWithCitations(response), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildGeminiConfig creates the generation config for Gemini models
|
||||||
|
// following the gemini.go pattern for feature parity
|
||||||
|
func (c *Client) buildGeminiConfig(opts *domain.ChatOptions) *genai.GenerateContentConfig {
|
||||||
|
temperature := float32(opts.Temperature)
|
||||||
|
topP := float32(opts.TopP)
|
||||||
|
config := &genai.GenerateContentConfig{
|
||||||
|
Temperature: &temperature,
|
||||||
|
TopP: &topP,
|
||||||
|
MaxOutputTokens: int32(getMaxTokens(opts)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add web search support
|
||||||
|
if opts.Search {
|
||||||
|
config.Tools = []*genai.Tool{{GoogleSearch: &genai.GoogleSearch{}}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add thinking support
|
||||||
|
if tc := parseGeminiThinking(opts.Thinking); tc != nil {
|
||||||
|
config.ThinkingConfig = tc
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseGeminiThinking converts thinking level to Gemini thinking config
|
||||||
|
func parseGeminiThinking(level domain.ThinkingLevel) *genai.ThinkingConfig {
|
||||||
|
lower := strings.ToLower(strings.TrimSpace(string(level)))
|
||||||
|
switch domain.ThinkingLevel(lower) {
|
||||||
|
case "", domain.ThinkingOff:
|
||||||
|
return nil
|
||||||
|
case domain.ThinkingLow, domain.ThinkingMedium, domain.ThinkingHigh:
|
||||||
|
if budget, ok := domain.ThinkingBudgets[domain.ThinkingLevel(lower)]; ok {
|
||||||
|
b := int32(budget)
|
||||||
|
return &genai.ThinkingConfig{IncludeThoughts: true, ThinkingBudget: &b}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Try parsing as integer token count
|
||||||
|
var tokens int
|
||||||
|
if _, err := fmt.Sscanf(lower, "%d", &tokens); err == nil && tokens > 0 {
|
||||||
|
t := int32(tokens)
|
||||||
|
return &genai.ThinkingConfig{IncludeThoughts: true, ThinkingBudget: &t}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) sendStreamGemini(msgs []*chat.ChatCompletionMessage, opts *domain.ChatOptions, channel chan domain.StreamUpdate) error {
|
||||||
|
defer close(channel)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
client, err := genai.NewClient(ctx, &genai.ClientConfig{
|
||||||
|
Project: c.ProjectID.Value,
|
||||||
|
Location: c.getGeminiRegion(opts.Model),
|
||||||
|
Backend: genai.BackendVertexAI,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create Gemini client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := geminicommon.ConvertMessages(msgs)
|
||||||
|
if len(contents) == 0 {
|
||||||
|
return fmt.Errorf("no valid messages to send")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := c.buildGeminiConfig(opts)
|
||||||
|
|
||||||
|
stream := client.Models.GenerateContentStream(ctx, opts.Model, contents, config)
|
||||||
|
|
||||||
|
for response, err := range stream {
|
||||||
|
if err != nil {
|
||||||
|
channel <- domain.StreamUpdate{
|
||||||
|
Type: domain.StreamTypeError,
|
||||||
|
Content: fmt.Sprintf("Error: %v", err),
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
text := geminicommon.ExtractText(response)
|
||||||
|
if text != "" {
|
||||||
|
channel <- domain.StreamUpdate{
|
||||||
|
Type: domain.StreamTypeContent,
|
||||||
|
Content: text,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.UsageMetadata != nil {
|
||||||
|
channel <- domain.StreamUpdate{
|
||||||
|
Type: domain.StreamTypeUsage,
|
||||||
|
Usage: &domain.UsageMetadata{
|
||||||
|
InputTokens: int(response.UsageMetadata.PromptTokenCount),
|
||||||
|
OutputTokens: int(response.UsageMetadata.CandidatesTokenCount),
|
||||||
|
TotalTokens: int(response.UsageMetadata.TotalTokenCount),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude message conversion
|
||||||
|
|
||||||
func (c *Client) toMessages(msgs []*chat.ChatCompletionMessage) []anthropic.MessageParam {
|
func (c *Client) toMessages(msgs []*chat.ChatCompletionMessage) []anthropic.MessageParam {
|
||||||
// Convert messages to Anthropic format with proper role handling
|
// Convert messages to Anthropic format with proper role handling
|
||||||
// - System messages become part of the first user message
|
// - System messages become part of the first user message
|
||||||
|
|||||||
442
internal/plugins/ai/vertexai/vertexai_test.go
Normal file
442
internal/plugins/ai/vertexai/vertexai_test.go
Normal file
@@ -0,0 +1,442 @@
|
|||||||
|
package vertexai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/danielmiessler/fabric/internal/domain"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractModelName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "standard format",
|
||||||
|
input: "publishers/google/models/gemini-2.0-flash",
|
||||||
|
expected: "gemini-2.0-flash",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anthropic model",
|
||||||
|
input: "publishers/anthropic/models/claude-sonnet-4-5",
|
||||||
|
expected: "claude-sonnet-4-5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with version",
|
||||||
|
input: "publishers/anthropic/models/claude-3-opus@20240229",
|
||||||
|
expected: "claude-3-opus@20240229",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "just model name",
|
||||||
|
input: "gemini-pro",
|
||||||
|
expected: "gemini-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractModelName(tt.input)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSortModels(t *testing.T) {
|
||||||
|
input := []string{
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"gemini-pro",
|
||||||
|
"claude-opus-4",
|
||||||
|
"unknown-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := sortModels(input)
|
||||||
|
|
||||||
|
// Verify order: Gemini first, then Claude, then others (alphabetically within each group)
|
||||||
|
expected := []string{
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"gemini-pro",
|
||||||
|
"claude-opus-4",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"unknown-model",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelPriority(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
model string
|
||||||
|
priority int
|
||||||
|
}{
|
||||||
|
{"gemini-2.0-flash", 1},
|
||||||
|
{"Gemini-Pro", 1},
|
||||||
|
{"claude-sonnet-4-5", 2},
|
||||||
|
{"CLAUDE-OPUS", 2},
|
||||||
|
{"some-other-model", 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
result := modelPriority(tt.model)
|
||||||
|
assert.Equal(t, tt.priority, result, "priority for %s", tt.model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListPublisherModels_Success(t *testing.T) {
|
||||||
|
// Create mock server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, http.MethodGet, r.Method)
|
||||||
|
assert.Contains(t, r.URL.Path, "/v1/publishers/google/models")
|
||||||
|
|
||||||
|
response := publisherModelsResponse{
|
||||||
|
PublisherModels: []publisherModel{
|
||||||
|
{Name: "publishers/google/models/gemini-2.0-flash"},
|
||||||
|
{Name: "publishers/google/models/gemini-pro"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Note: This test would need to mock the actual API endpoint
|
||||||
|
// For now, we just verify the mock server works
|
||||||
|
resp, err := http.Get(server.URL + "/v1/publishers/google/models")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var response publisherModelsResponse
|
||||||
|
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, response.PublisherModels, 2)
|
||||||
|
assert.Equal(t, "publishers/google/models/gemini-2.0-flash", response.PublisherModels[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListPublisherModels_Pagination(t *testing.T) {
|
||||||
|
callCount := 0
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
callCount++
|
||||||
|
|
||||||
|
var response publisherModelsResponse
|
||||||
|
if callCount == 1 {
|
||||||
|
response = publisherModelsResponse{
|
||||||
|
PublisherModels: []publisherModel{
|
||||||
|
{Name: "publishers/google/models/gemini-flash"},
|
||||||
|
},
|
||||||
|
NextPageToken: "page2",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response = publisherModelsResponse{
|
||||||
|
PublisherModels: []publisherModel{
|
||||||
|
{Name: "publishers/google/models/gemini-pro"},
|
||||||
|
},
|
||||||
|
NextPageToken: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(response)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Verify the server handles pagination correctly
|
||||||
|
resp, err := http.Get(server.URL + "/page1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
resp, err = http.Get(server.URL + "/page2")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, callCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListPublisherModels_ErrorResponse(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
w.Write([]byte(`{"error": "access denied"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(server.URL + "/v1/publishers/google/models")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewClient(t *testing.T) {
|
||||||
|
client := NewClient()
|
||||||
|
|
||||||
|
assert.NotNil(t, client)
|
||||||
|
assert.Equal(t, "VertexAI", client.Name)
|
||||||
|
assert.NotNil(t, client.ProjectID)
|
||||||
|
assert.NotNil(t, client.Region)
|
||||||
|
assert.Equal(t, "global", client.Region.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPublishersListComplete(t *testing.T) {
|
||||||
|
// Verify supported publishers are in the list
|
||||||
|
expectedPublishers := []string{"google", "anthropic"}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedPublishers, publishers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsConversationalModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
model string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// Conversational models (should return true)
|
||||||
|
{"gemini-2.0-flash", true},
|
||||||
|
{"gemini-2.5-pro", true},
|
||||||
|
{"claude-sonnet-4-5", true},
|
||||||
|
{"claude-opus-4", true},
|
||||||
|
{"deepseek-v3", true},
|
||||||
|
{"llama-3.1-405b", true},
|
||||||
|
{"mistral-large", true},
|
||||||
|
|
||||||
|
// Non-conversational models (should return false)
|
||||||
|
{"imagen-3.0-capability-002", false},
|
||||||
|
{"imagen-4.0-fast-generate-001", false},
|
||||||
|
{"imagegeneration", false},
|
||||||
|
{"imagetext", false},
|
||||||
|
{"image-segmentation-001", false},
|
||||||
|
{"textembedding-gecko", false},
|
||||||
|
{"multimodalembedding", false},
|
||||||
|
{"text-embedding-004", false},
|
||||||
|
{"text-bison", false},
|
||||||
|
{"text-unicorn", false},
|
||||||
|
{"code-bison", false},
|
||||||
|
{"code-gecko", false},
|
||||||
|
{"codechat-bison", false},
|
||||||
|
{"chat-bison", false},
|
||||||
|
{"veo-001", false},
|
||||||
|
{"chirp", false},
|
||||||
|
{"medlm-medium", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
result := isConversationalModel(tt.model)
|
||||||
|
assert.Equal(t, tt.expected, result, "isConversationalModel(%s)", tt.model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterConversationalModels(t *testing.T) {
|
||||||
|
input := []string{
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"imagen-3.0-capability-002",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"textembedding-gecko",
|
||||||
|
"deepseek-v3",
|
||||||
|
"chat-bison",
|
||||||
|
"llama-3.1-405b",
|
||||||
|
"code-bison",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := filterConversationalModels(input)
|
||||||
|
|
||||||
|
expected := []string{
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
"claude-sonnet-4-5",
|
||||||
|
"deepseek-v3",
|
||||||
|
"llama-3.1-405b",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterConversationalModels_EmptyInput(t *testing.T) {
|
||||||
|
result := filterConversationalModels([]string{})
|
||||||
|
assert.Empty(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterConversationalModels_AllFiltered(t *testing.T) {
|
||||||
|
input := []string{
|
||||||
|
"imagen-3.0",
|
||||||
|
"textembedding-gecko",
|
||||||
|
"chat-bison",
|
||||||
|
}
|
||||||
|
|
||||||
|
result := filterConversationalModels(input)
|
||||||
|
assert.Empty(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsGeminiModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
model string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"gemini-2.5-pro", true},
|
||||||
|
{"gemini-3-pro-preview", true},
|
||||||
|
{"Gemini-2.0-flash", true},
|
||||||
|
{"GEMINI-flash", true},
|
||||||
|
{"claude-sonnet-4-5", false},
|
||||||
|
{"claude-opus-4", false},
|
||||||
|
{"deepseek-v3", false},
|
||||||
|
{"llama-3.1-405b", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.model, func(t *testing.T) {
|
||||||
|
result := isGeminiModel(tt.model)
|
||||||
|
assert.Equal(t, tt.expected, result, "isGeminiModel(%s)", tt.model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMaxTokens(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts *domain.ChatOptions
|
||||||
|
expected int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "MaxTokens specified",
|
||||||
|
opts: &domain.ChatOptions{MaxTokens: 8192},
|
||||||
|
expected: 8192,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Default when MaxTokens is 0",
|
||||||
|
opts: &domain.ChatOptions{MaxTokens: 0},
|
||||||
|
expected: int64(defaultMaxTokens),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Default when MaxTokens is negative",
|
||||||
|
opts: &domain.ChatOptions{MaxTokens: -1},
|
||||||
|
expected: int64(defaultMaxTokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := getMaxTokens(tt.opts)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGeminiThinking(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
level domain.ThinkingLevel
|
||||||
|
expectNil bool
|
||||||
|
expectedBudget int32
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty string returns nil",
|
||||||
|
level: "",
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "off returns nil",
|
||||||
|
level: domain.ThinkingOff,
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "low thinking",
|
||||||
|
level: domain.ThinkingLow,
|
||||||
|
expectNil: false,
|
||||||
|
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingLow]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "medium thinking",
|
||||||
|
level: domain.ThinkingMedium,
|
||||||
|
expectNil: false,
|
||||||
|
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingMedium]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "high thinking",
|
||||||
|
level: domain.ThinkingHigh,
|
||||||
|
expectNil: false,
|
||||||
|
expectedBudget: int32(domain.ThinkingBudgets[domain.ThinkingHigh]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric string",
|
||||||
|
level: "5000",
|
||||||
|
expectNil: false,
|
||||||
|
expectedBudget: 5000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid string returns nil",
|
||||||
|
level: "invalid",
|
||||||
|
expectNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := parseGeminiThinking(tt.level)
|
||||||
|
if tt.expectNil {
|
||||||
|
assert.Nil(t, result)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, result)
|
||||||
|
assert.True(t, result.IncludeThoughts)
|
||||||
|
assert.Equal(t, tt.expectedBudget, *result.ThinkingBudget)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildGeminiConfig(t *testing.T) {
|
||||||
|
client := &Client{}
|
||||||
|
|
||||||
|
t.Run("basic config with temperature and TopP", func(t *testing.T) {
|
||||||
|
opts := &domain.ChatOptions{
|
||||||
|
Temperature: 0.7,
|
||||||
|
TopP: 0.9,
|
||||||
|
MaxTokens: 8192,
|
||||||
|
}
|
||||||
|
config := client.buildGeminiConfig(opts)
|
||||||
|
|
||||||
|
assert.NotNil(t, config)
|
||||||
|
assert.Equal(t, float32(0.7), *config.Temperature)
|
||||||
|
assert.Equal(t, float32(0.9), *config.TopP)
|
||||||
|
assert.Equal(t, int32(8192), config.MaxOutputTokens)
|
||||||
|
assert.Nil(t, config.Tools)
|
||||||
|
assert.Nil(t, config.ThinkingConfig)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("config with search enabled", func(t *testing.T) {
|
||||||
|
opts := &domain.ChatOptions{
|
||||||
|
Temperature: 0.5,
|
||||||
|
TopP: 0.8,
|
||||||
|
Search: true,
|
||||||
|
}
|
||||||
|
config := client.buildGeminiConfig(opts)
|
||||||
|
|
||||||
|
assert.NotNil(t, config.Tools)
|
||||||
|
assert.Len(t, config.Tools, 1)
|
||||||
|
assert.NotNil(t, config.Tools[0].GoogleSearch)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("config with thinking enabled", func(t *testing.T) {
|
||||||
|
opts := &domain.ChatOptions{
|
||||||
|
Temperature: 0.5,
|
||||||
|
TopP: 0.8,
|
||||||
|
Thinking: domain.ThinkingHigh,
|
||||||
|
}
|
||||||
|
config := client.buildGeminiConfig(opts)
|
||||||
|
|
||||||
|
assert.NotNil(t, config.ThinkingConfig)
|
||||||
|
assert.True(t, config.ThinkingConfig.IncludeThoughts)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1 +1 @@
|
|||||||
"1.4.374"
|
"1.4.375"
|
||||||
|
|||||||
Reference in New Issue
Block a user