mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-07 22:54:06 -05:00
Merge branch 'main' into integration-test
This commit is contained in:
59
cmd/root.go
59
cmd/root.go
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/fsnotify/fsnotify"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
@@ -385,12 +386,13 @@ func NewCommand(opts ...Option) *Command {
|
||||
}
|
||||
|
||||
type ToolsFile struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
Prompts server.PromptConfigs `yaml:"prompts"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
Prompts server.PromptConfigs `yaml:"prompts"`
|
||||
}
|
||||
|
||||
// parseEnv replaces environment variables ${ENV_NAME} with their values.
|
||||
@@ -439,11 +441,12 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
// All resource names (sources, authServices, tools, toolsets) must be unique across all files.
|
||||
func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
merged := ToolsFile{
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: make(server.PromptConfigs),
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: make(server.EmbeddingModelConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: make(server.PromptConfigs),
|
||||
}
|
||||
|
||||
var conflicts []string
|
||||
@@ -479,6 +482,15 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge embeddingModels
|
||||
for name, model := range file.EmbeddingModels {
|
||||
if _, exists := merged.EmbeddingModels[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.EmbeddingModels[name] = model
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge tools
|
||||
for name, tool := range file.Tools {
|
||||
if _, exists := merged.Tools[name]; exists {
|
||||
@@ -583,14 +595,14 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -598,7 +610,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
|
||||
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
|
||||
func validateReloadEdits(
|
||||
ctx context.Context, toolsFile ToolsFile,
|
||||
) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
|
||||
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
|
||||
) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -616,22 +628,23 @@ func validateReloadEdits(
|
||||
defer span.End()
|
||||
|
||||
reloadedConfig := server.ServerConfig{
|
||||
Version: versionString,
|
||||
SourceConfigs: toolsFile.Sources,
|
||||
AuthServiceConfigs: toolsFile.AuthServices,
|
||||
ToolConfigs: toolsFile.Tools,
|
||||
ToolsetConfigs: toolsFile.Toolsets,
|
||||
PromptConfigs: toolsFile.Prompts,
|
||||
Version: versionString,
|
||||
SourceConfigs: toolsFile.Sources,
|
||||
AuthServiceConfigs: toolsFile.AuthServices,
|
||||
EmbeddingModelConfigs: toolsFile.EmbeddingModels,
|
||||
ToolConfigs: toolsFile.Tools,
|
||||
ToolsetConfigs: toolsFile.Toolsets,
|
||||
PromptConfigs: toolsFile.Prompts,
|
||||
}
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
// watchChanges checks for changes in the provided yaml tools file(s) or folder.
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
@@ -1830,9 +1831,10 @@ func TestFileLoadingErrors(t *testing.T) {
|
||||
|
||||
func TestMergeToolsFiles(t *testing.T) {
|
||||
file1 := ToolsFile{
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
|
||||
}
|
||||
file2 := ToolsFile{
|
||||
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
||||
@@ -1854,11 +1856,12 @@ func TestMergeToolsFiles(t *testing.T) {
|
||||
name: "merge two distinct files",
|
||||
files: []ToolsFile{file1, file2},
|
||||
want: ToolsFile{
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
|
||||
Prompts: server.PromptConfigs{},
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
|
||||
Prompts: server.PromptConfigs{},
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -1871,22 +1874,24 @@ func TestMergeToolsFiles(t *testing.T) {
|
||||
name: "merge single file",
|
||||
files: []ToolsFile{file1},
|
||||
want: ToolsFile{
|
||||
Sources: file1.Sources,
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
Tools: file1.Tools,
|
||||
Toolsets: file1.Toolsets,
|
||||
Prompts: server.PromptConfigs{},
|
||||
Sources: file1.Sources,
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
|
||||
Tools: file1.Tools,
|
||||
Toolsets: file1.Toolsets,
|
||||
Prompts: server.PromptConfigs{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge empty list",
|
||||
files: []ToolsFile{},
|
||||
want: ToolsFile{
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: server.PromptConfigs{},
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: make(server.EmbeddingModelConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: server.PromptConfigs{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
84
docs/en/resources/embeddingModels/_index.md
Normal file
84
docs/en/resources/embeddingModels/_index.md
Normal file
@@ -0,0 +1,84 @@
|
||||
---
|
||||
title: "EmbeddingModels"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
EmbeddingModels represent services that transform text into vector embeddings for semantic search.
|
||||
---
|
||||
|
||||
EmbeddingModels represent services that generate vector representations of text
|
||||
data. In the MCP Toolbox, these models enable **Semantic Queries**,
|
||||
allowing [Tools](../tools/) to automatically convert human-readable text into
|
||||
numerical vectors before using them in a query.
|
||||
|
||||
This is primarily used in two scenarios:
|
||||
|
||||
- **Vector Ingestion**: Converting a text parameter into a vector string during
|
||||
an `INSERT` operation.
|
||||
|
||||
- **Semantic Search**: Converting a natural language query into a vector to
|
||||
perform similarity searches.
|
||||
|
||||
## Example
|
||||
|
||||
The following configuration defines an embedding model and applies it to
|
||||
specific tool parameters.
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your API keys into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
### Step 1 - Define an Embedding Model
|
||||
|
||||
Define an embedding model in the `embeddingModels` section:
|
||||
|
||||
```yaml
|
||||
embeddingModels:
|
||||
gemini-model: # Name of the embedding model
|
||||
kind: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: ${GOOGLE_API_KEY}
|
||||
dimension: 768
|
||||
|
||||
```
|
||||
|
||||
### Step 2 - Embed Tool Parameters
|
||||
|
||||
Use the defined embedding model, embed your query parameters using the
|
||||
`embeddedBy` field. Only string-typed
|
||||
parameters can be embedded:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
# Vector ingestion tool
|
||||
insert_embedding:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
statement: |
|
||||
INSERT INTO documents (content, embedding)
|
||||
VALUES ($1, $2);
|
||||
parameters:
|
||||
- name: content
|
||||
type: string
|
||||
- name: vector_string
|
||||
type: string
|
||||
description: The text to be vectorized and stored.
|
||||
embeddedBy: gemini-model # refers to the name of a defined embedding model
|
||||
|
||||
# Semantic search tool
|
||||
search_embedding:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
statement: |
|
||||
SELECT id, content, embedding <-> $1 AS distance
|
||||
FROM documents
|
||||
ORDER BY distance LIMIT 1
|
||||
parameters:
|
||||
- name: semantic_search_string
|
||||
type: string
|
||||
description: The search query that will be converted to a vector.
|
||||
embeddedBy: gemini-model # refers to the name of a defined embedding model
|
||||
```
|
||||
|
||||
## Kinds of Embedding Models
|
||||
73
docs/en/resources/embeddingModels/gemini.md
Normal file
73
docs/en/resources/embeddingModels/gemini.md
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: "Gemini Embedding"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Use Google's Gemini models to generate high-performance text embeddings for vector databases.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
Google Gemini provides state-of-the-art embedding models that convert text into
|
||||
high-dimensional vectors.
|
||||
|
||||
### Authentication
|
||||
|
||||
Toolbox uses your [Application Default Credentials
|
||||
(ADC)][adc] to authorize with the
|
||||
Gemini API client.
|
||||
|
||||
Optionally, you can use an [API key][api-key] obtain an API
|
||||
Key from the [Google AI Studio][ai-studio].
|
||||
|
||||
We recommend using an API key for testing and using application default
|
||||
credentials for production.
|
||||
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[api-key]: https://ai.google.dev/gemini-api/docs/api-key#api-keys
|
||||
[ai-studio]: https://aistudio.google.com/app/apikey
|
||||
|
||||
## Behavior
|
||||
|
||||
### Automatic Vectorization
|
||||
|
||||
When a tool parameter is configured with `embeddedBy: <your-gemini-model-name>`,
|
||||
the Toolbox intercepts the raw text input from the client and sends it to the
|
||||
Gemini API. The resulting numerical array is then formatted before being passed
|
||||
to your database source.
|
||||
|
||||
### Dimension Matching
|
||||
|
||||
The `dimension` field must match the expected size of your database column
|
||||
(e.g., a `vector(768)` column in PostgreSQL). This setting is supported by newer
|
||||
models since 2024 only. You cannot set this value if using the earlier model
|
||||
(`models/embedding-001`). Check out [available Gemini models][modellist] for more
|
||||
information.
|
||||
|
||||
[modellist]:
|
||||
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
embeddingModels:
|
||||
gemini-model:
|
||||
kind: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: ${GOOGLE_API_KEY}
|
||||
dimension: 768
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|--------------------------------------------------------------|
|
||||
| kind | string | true | Must be `gemini`. |
|
||||
| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). |
|
||||
| apiKey | string | false | Your API Key from Google AI Studio. |
|
||||
| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). |
|
||||
1
go.mod
1
go.mod
@@ -59,6 +59,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
google.golang.org/api v0.256.0
|
||||
google.golang.org/genai v1.37.0
|
||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8
|
||||
google.golang.org/protobuf v1.36.10
|
||||
modernc.org/sqlite v1.40.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1869,6 +1869,8 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/genai v1.37.0 h1:dgp71k1wQ+/+APdZrN3LFgAGnVnr5IdTF1Oj0Dg+BQc=
|
||||
google.golang.org/genai v1.37.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
|
||||
29
internal/embeddingmodels/embeddingmodels.go
Normal file
29
internal/embeddingmodels/embeddingmodels.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package embeddingmodels
|
||||
|
||||
import "context"
|
||||
|
||||
// EmbeddingModelConfig is the interface for configuring embedding models.
|
||||
type EmbeddingModelConfig interface {
|
||||
EmbeddingModelConfigKind() string
|
||||
Initialize(context.Context) (EmbeddingModel, error)
|
||||
}
|
||||
|
||||
type EmbeddingModel interface {
|
||||
EmbeddingModelKind() string
|
||||
ToConfig() EmbeddingModelConfig
|
||||
EmbedParameters(context.Context, []string) ([][]float32, error)
|
||||
}
|
||||
122
internal/embeddingmodels/gemini/gemini.go
Normal file
122
internal/embeddingmodels/gemini/gemini.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
const EmbeddingModelKind string = "gemini"
|
||||
|
||||
// validate interface
|
||||
var _ embeddingmodels.EmbeddingModelConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Model string `yaml:"model" validate:"required"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Dimension int32 `yaml:"dimension"`
|
||||
}
|
||||
|
||||
// Returns the embedding model kind
|
||||
func (cfg Config) EmbeddingModelConfigKind() string {
|
||||
return EmbeddingModelKind
|
||||
}
|
||||
|
||||
// Initialize a Gemini embedding model
|
||||
func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) {
|
||||
// Get client configs
|
||||
configs := &genai.ClientConfig{}
|
||||
if cfg.ApiKey != "" {
|
||||
configs.APIKey = cfg.ApiKey
|
||||
}
|
||||
|
||||
// Create new Gemini API client
|
||||
client, err := genai.NewClient(ctx, configs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create Gemini API client")
|
||||
}
|
||||
|
||||
m := &EmbeddingModel{
|
||||
Config: cfg,
|
||||
Client: client,
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var _ embeddingmodels.EmbeddingModel = EmbeddingModel{}
|
||||
|
||||
type EmbeddingModel struct {
|
||||
Client *genai.Client
|
||||
Config
|
||||
}
|
||||
|
||||
// Returns the embedding model kind
|
||||
func (m EmbeddingModel) EmbeddingModelKind() string {
|
||||
return EmbeddingModelKind
|
||||
}
|
||||
|
||||
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
|
||||
return m.Config
|
||||
}
|
||||
|
||||
func (m EmbeddingModel) EmbedParameters(ctx context.Context, parameters []string) ([][]float32, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
contents := convertStringsToContents(parameters)
|
||||
|
||||
embedConfig := &genai.EmbedContentConfig{
|
||||
TaskType: "SEMANTIC_SIMILARITY",
|
||||
}
|
||||
|
||||
if m.Dimension > 0 {
|
||||
embedConfig.OutputDimensionality = genai.Ptr(m.Dimension)
|
||||
}
|
||||
|
||||
result, err := m.Client.Models.EmbedContent(ctx, m.Model, contents, embedConfig)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, "Error calling EmbedContent for model %s: %v", m.Model, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, 0, len(result.Embeddings))
|
||||
for _, embedding := range result.Embeddings {
|
||||
embeddings = append(embeddings, embedding.Values)
|
||||
}
|
||||
|
||||
logger.InfoContext(ctx, "Successfully embedded %d text parameters using model %s", len(parameters), m.Model)
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// convertStringsToContents takes a slice of strings and converts it into a slice of *genai.Content objects.
|
||||
func convertStringsToContents(texts []string) []*genai.Content {
|
||||
contents := make([]*genai.Content, 0, len(texts))
|
||||
|
||||
for _, text := range texts {
|
||||
content := genai.NewContentFromText(text, "")
|
||||
contents = append(contents, content)
|
||||
}
|
||||
return contents
|
||||
}
|
||||
130
internal/embeddingmodels/gemini/gemini_test.go
Normal file
130
internal/embeddingmodels/gemini/gemini_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gemini_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlGemini(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.EmbeddingModelConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
embeddingModels:
|
||||
my-gemini-model:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
`,
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"my-gemini-model": gemini.Config{
|
||||
Name: "my-gemini-model",
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Model: "text-embedding-004",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "full example with optional fields",
|
||||
in: `
|
||||
embeddingModels:
|
||||
complex-gemini:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
apiKey: "test-api-key"
|
||||
dimension: 768
|
||||
`,
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"complex-gemini": gemini.Config{
|
||||
Name: "complex-gemini",
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Model: "text-embedding-004",
|
||||
ApiKey: "test-api-key",
|
||||
Dimension: 768,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Models) {
|
||||
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestFailParseFromYamlGemini(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "missing required model field",
|
||||
in: `
|
||||
embeddingModels:
|
||||
bad-model:
|
||||
kind: gemini
|
||||
`,
|
||||
// Removed the specific model name from the prefix to match your output
|
||||
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
desc: "unknown field",
|
||||
in: `
|
||||
embeddingModels:
|
||||
bad-field:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
invalid_param: true
|
||||
`,
|
||||
// Updated to match the specific line-starting format of your error output
|
||||
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
}{}
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
if err.Error() != tc.err {
|
||||
t.Fatalf("unexpected error:\ngot: %q\nwant: %q", err.Error(), tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,8 +19,8 @@ sources:
|
||||
region: ${CLOUD_SQL_MYSQL_REGION}
|
||||
instance: ${CLOUD_SQL_MYSQL_INSTANCE}
|
||||
database: ${CLOUD_SQL_MYSQL_DATABASE}
|
||||
user: ${CLOUD_SQL_MYSQL_USER}
|
||||
password: ${CLOUD_SQL_MYSQL_PASSWORD}
|
||||
user: ${CLOUD_SQL_MYSQL_USER:}
|
||||
password: ${CLOUD_SQL_MYSQL_PASSWORD:}
|
||||
ipType: ${CLOUD_SQL_MYSQL_IP_TYPE:PUBLIC}
|
||||
tools:
|
||||
execute_sql:
|
||||
|
||||
@@ -276,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets)
|
||||
resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets)
|
||||
|
||||
server := Server{
|
||||
version: fakeVersionString,
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -38,6 +40,8 @@ type ServerConfig struct {
|
||||
SourceConfigs SourceConfigs
|
||||
// AuthServiceConfigs defines what sources of authentication are available for tools.
|
||||
AuthServiceConfigs AuthServiceConfigs
|
||||
// EmbeddingModelConfigs defines a models used to embed parameters.
|
||||
EmbeddingModelConfigs EmbeddingModelConfigs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs ToolConfigs
|
||||
// ToolsetConfigs defines what tools are available.
|
||||
@@ -205,6 +209,50 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
|
||||
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
|
||||
|
||||
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(EmbeddingModelConfigs)
|
||||
// Parse the 'kind' fields for each embedding model
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case gemini.EmbeddingModelKind:
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolConfigs is a type used to allow unmarshal of the tool configs
|
||||
type ToolConfigs map[string]tools.ToolConfig
|
||||
|
||||
|
||||
@@ -1107,7 +1107,7 @@ func TestStdioSession(t *testing.T) {
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets)
|
||||
resourceManager := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsets, promptsMap, promptsets)
|
||||
|
||||
server := &Server{
|
||||
version: fakeVersionString,
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -25,30 +26,33 @@ import (
|
||||
|
||||
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
|
||||
type ResourceManager struct {
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
prompts map[string]prompts.Prompt
|
||||
promptsets map[string]prompts.Promptset
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
embeddingModels map[string]embeddingmodels.EmbeddingModel
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
prompts map[string]prompts.Prompt
|
||||
promptsets map[string]prompts.Promptset
|
||||
}
|
||||
|
||||
func NewResourceManager(
|
||||
sourcesMap map[string]sources.Source,
|
||||
authServicesMap map[string]auth.AuthService,
|
||||
embeddingModelsMap map[string]embeddingmodels.EmbeddingModel,
|
||||
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
|
||||
promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset,
|
||||
|
||||
) *ResourceManager {
|
||||
resourceMgr := &ResourceManager{
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
prompts: promptsMap,
|
||||
promptsets: promptsetsMap,
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
embeddingModels: embeddingModelsMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
prompts: promptsMap,
|
||||
promptsets: promptsetsMap,
|
||||
}
|
||||
|
||||
return resourceMgr
|
||||
@@ -68,6 +72,13 @@ func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthServi
|
||||
return authService, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetEmbeddingModel(embeddingModelName string) (embeddingmodels.EmbeddingModel, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
model, ok := r.embeddingModels[embeddingModelName]
|
||||
return model, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
@@ -96,11 +107,12 @@ func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset,
|
||||
return promptset, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.sources = sourcesMap
|
||||
r.authServices = authServicesMap
|
||||
r.embeddingModels = embeddingModelsMap
|
||||
r.tools = toolsMap
|
||||
r.toolsets = toolsetsMap
|
||||
r.prompts = promptsMap
|
||||
@@ -117,6 +129,16 @@ func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.EmbeddingModel {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
copiedMap := make(map[string]embeddingmodels.EmbeddingModel, len(r.embeddingModels))
|
||||
for k, v := range r.embeddingModels {
|
||||
copiedMap[k] = v
|
||||
}
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -36,6 +37,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
@@ -54,7 +56,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
Prompts: []*prompts.Prompt{},
|
||||
},
|
||||
}
|
||||
resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
resMgr := resources.NewResourceManager(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
|
||||
gotSource, _ := resMgr.GetSource("example-source")
|
||||
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
|
||||
@@ -95,7 +97,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
resMgr.SetResources(updateSource, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
gotSource, _ = resMgr.GetSource("example-source2")
|
||||
if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" {
|
||||
t.Errorf("error updating server, sources (-want +got):\n%s", diff)
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/go-chi/httplog/v2"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
@@ -56,6 +57,7 @@ type Server struct {
|
||||
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
map[string]sources.Source,
|
||||
map[string]auth.AuthService,
|
||||
map[string]embeddingmodels.EmbeddingModel,
|
||||
map[string]tools.Tool,
|
||||
map[string]tools.Toolset,
|
||||
map[string]prompts.Prompt,
|
||||
@@ -91,7 +93,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return s, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
@@ -119,7 +121,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return a, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
authServicesMap[name] = a
|
||||
}
|
||||
@@ -129,6 +131,34 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices: %s", len(authServicesMap), strings.Join(authServiceNames, ", ")))
|
||||
|
||||
// Initialize and validate embedding models from configs.
|
||||
embeddingModelsMap := make(map[string]embeddingmodels.EmbeddingModel)
|
||||
for name, ec := range cfg.EmbeddingModelConfigs {
|
||||
em, err := func() (embeddingmodels.EmbeddingModel, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/embeddingmodel/init",
|
||||
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
|
||||
trace.WithAttributes(attribute.String("model_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
em, err := ec.Initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize embedding model %q: %w", name, err)
|
||||
}
|
||||
return em, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
embeddingModelsMap[name] = em
|
||||
}
|
||||
embeddingModelNames := make([]string, 0, len(embeddingModelsMap))
|
||||
for name := range embeddingModelsMap {
|
||||
embeddingModelNames = append(embeddingModelNames, name)
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d embeddingModels: %s", len(embeddingModelsMap), strings.Join(embeddingModelNames, ", ")))
|
||||
|
||||
// initialize and validate the tools from configs
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
@@ -147,7 +177,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return t, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
toolsMap[name] = t
|
||||
}
|
||||
@@ -184,7 +214,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return t, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
@@ -216,7 +246,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return p, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
promptsMap[name] = p
|
||||
}
|
||||
@@ -253,7 +283,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return p, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
promptsetsMap[name] = p
|
||||
}
|
||||
@@ -267,7 +297,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d promptsets: %s", len(promptsetsMap), strings.Join(promptsetNames, ", ")))
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
@@ -320,7 +350,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
httpLogger := httplog.NewLogger("httplog", httpOpts)
|
||||
r.Use(httplog.RequestLogger(httpLogger))
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize configs: %w", err)
|
||||
}
|
||||
@@ -330,7 +360,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
|
||||
s := &Server{
|
||||
version: cfg.Version,
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
@@ -144,6 +145,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
@@ -162,7 +164,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
Prompts: []*prompts.Prompt{},
|
||||
},
|
||||
}
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
if err != nil {
|
||||
t.Errorf("error updating server: %s", err)
|
||||
}
|
||||
|
||||
@@ -333,7 +333,7 @@ func TestInvoke(t *testing.T) {
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
||||
|
||||
Reference in New Issue
Block a user