Merge branch 'main' into integration-test

This commit is contained in:
manuka rahul
2026-01-06 12:22:16 +05:30
committed by GitHub
18 changed files with 636 additions and 73 deletions

View File

@@ -33,6 +33,7 @@ import (
"github.com/fsnotify/fsnotify"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
"github.com/googleapis/genai-toolbox/internal/prompts"
@@ -385,12 +386,13 @@ func NewCommand(opts ...Option) *Command {
}
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
Prompts server.PromptConfigs `yaml:"prompts"`
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
Prompts server.PromptConfigs `yaml:"prompts"`
}
// parseEnv replaces environment variables ${ENV_NAME} with their values.
@@ -439,11 +441,12 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
// All resource names (sources, authServices, tools, toolsets) must be unique across all files.
func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
merged := ToolsFile{
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: make(server.PromptConfigs),
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
EmbeddingModels: make(server.EmbeddingModelConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: make(server.PromptConfigs),
}
var conflicts []string
@@ -479,6 +482,15 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
}
}
// Check for conflicts and merge embeddingModels
for name, model := range file.EmbeddingModels {
if _, exists := merged.EmbeddingModels[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1))
} else {
merged.EmbeddingModels[name] = model
}
}
// Check for conflicts and merge tools
for name, tool := range file.Tools {
if _, exists := merged.Tools[name]; exists {
@@ -583,14 +595,14 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
panic(err)
}
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
if err != nil {
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
logger.WarnContext(ctx, errMsg.Error())
return err
}
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
return nil
}
@@ -598,7 +610,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
func validateReloadEdits(
ctx context.Context, toolsFile ToolsFile,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -616,22 +628,23 @@ func validateReloadEdits(
defer span.End()
reloadedConfig := server.ServerConfig{
Version: versionString,
SourceConfigs: toolsFile.Sources,
AuthServiceConfigs: toolsFile.AuthServices,
ToolConfigs: toolsFile.Tools,
ToolsetConfigs: toolsFile.Toolsets,
PromptConfigs: toolsFile.Prompts,
Version: versionString,
SourceConfigs: toolsFile.Sources,
AuthServiceConfigs: toolsFile.AuthServices,
EmbeddingModelConfigs: toolsFile.EmbeddingModels,
ToolConfigs: toolsFile.Tools,
ToolsetConfigs: toolsFile.Toolsets,
PromptConfigs: toolsFile.Prompts,
}
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
if err != nil {
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
logger.WarnContext(ctx, errMsg.Error())
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
}
// watchChanges checks for changes in the provided yaml tools file(s) or folder.

View File

@@ -32,6 +32,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth/google"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
"github.com/googleapis/genai-toolbox/internal/prompts"
@@ -1830,9 +1831,10 @@ func TestFileLoadingErrors(t *testing.T) {
func TestMergeToolsFiles(t *testing.T) {
file1 := ToolsFile{
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
}
file2 := ToolsFile{
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
@@ -1854,11 +1856,12 @@ func TestMergeToolsFiles(t *testing.T) {
name: "merge two distinct files",
files: []ToolsFile{file1, file2},
want: ToolsFile{
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
Prompts: server.PromptConfigs{},
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
Prompts: server.PromptConfigs{},
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
},
wantErr: false,
},
@@ -1871,22 +1874,24 @@ func TestMergeToolsFiles(t *testing.T) {
name: "merge single file",
files: []ToolsFile{file1},
want: ToolsFile{
Sources: file1.Sources,
AuthServices: make(server.AuthServiceConfigs),
Tools: file1.Tools,
Toolsets: file1.Toolsets,
Prompts: server.PromptConfigs{},
Sources: file1.Sources,
AuthServices: make(server.AuthServiceConfigs),
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
Tools: file1.Tools,
Toolsets: file1.Toolsets,
Prompts: server.PromptConfigs{},
},
},
{
name: "merge empty list",
files: []ToolsFile{},
want: ToolsFile{
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: server.PromptConfigs{},
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
EmbeddingModels: make(server.EmbeddingModelConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: server.PromptConfigs{},
},
},
}

View File

@@ -0,0 +1,84 @@
---
title: "EmbeddingModels"
type: docs
weight: 2
description: >
EmbeddingModels represent services that transform text into vector embeddings for semantic search.
---
EmbeddingModels represent services that generate vector representations of text
data. In the MCP Toolbox, these models enable **Semantic Queries**,
allowing [Tools](../tools/) to automatically convert human-readable text into
numerical vectors before using them in a query.
This is primarily used in two scenarios:
- **Vector Ingestion**: Converting a text parameter into a vector string during
an `INSERT` operation.
- **Semantic Search**: Converting a natural language query into a vector to
perform similarity searches.
## Example
The following configuration defines an embedding model and applies it to
specific tool parameters.
{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your API keys into the configuration file.
{{< /notice >}}
### Step 1 - Define an Embedding Model
Define an embedding model in the `embeddingModels` section:
```yaml
embeddingModels:
gemini-model: # Name of the embedding model
kind: gemini
model: gemini-embedding-001
apiKey: ${GOOGLE_API_KEY}
dimension: 768
```
### Step 2 - Embed Tool Parameters
Use the defined embedding model, embed your query parameters using the
`embeddedBy` field. Only string-typed
parameters can be embedded:
```yaml
tools:
# Vector ingestion tool
insert_embedding:
kind: postgres-sql
source: my-pg-instance
statement: |
INSERT INTO documents (content, embedding)
VALUES ($1, $2);
parameters:
- name: content
type: string
- name: vector_string
type: string
description: The text to be vectorized and stored.
embeddedBy: gemini-model # refers to the name of a defined embedding model
# Semantic search tool
search_embedding:
kind: postgres-sql
source: my-pg-instance
statement: |
SELECT id, content, embedding <-> $1 AS distance
FROM documents
ORDER BY distance LIMIT 1
parameters:
- name: semantic_search_string
type: string
description: The search query that will be converted to a vector.
embeddedBy: gemini-model # refers to the name of a defined embedding model
```
## Kinds of Embedding Models

View File

@@ -0,0 +1,73 @@
---
title: "Gemini Embedding"
type: docs
weight: 1
description: >
Use Google's Gemini models to generate high-performance text embeddings for vector databases.
---
## About
Google Gemini provides state-of-the-art embedding models that convert text into
high-dimensional vectors.
### Authentication
Toolbox uses your [Application Default Credentials
(ADC)][adc] to authorize with the
Gemini API client.
Optionally, you can use an [API key][api-key] obtain an API
Key from the [Google AI Studio][ai-studio].
We recommend using an API key for testing and using application default
credentials for production.
[adc]: https://cloud.google.com/docs/authentication#adc
[api-key]: https://ai.google.dev/gemini-api/docs/api-key#api-keys
[ai-studio]: https://aistudio.google.com/app/apikey
## Behavior
### Automatic Vectorization
When a tool parameter is configured with `embeddedBy: <your-gemini-model-name>`,
the Toolbox intercepts the raw text input from the client and sends it to the
Gemini API. The resulting numerical array is then formatted before being passed
to your database source.
### Dimension Matching
The `dimension` field must match the expected size of your database column
(e.g., a `vector(768)` column in PostgreSQL). This setting is supported by newer
models since 2024 only. You cannot set this value if using the earlier model
(`models/embedding-001`). Check out [available Gemini models][modellist] for more
information.
[modellist]:
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models
## Example
```yaml
embeddingModels:
gemini-model:
kind: gemini
model: gemini-embedding-001
apiKey: ${GOOGLE_API_KEY}
dimension: 768
```
{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your secrets into the configuration file.
{{< /notice >}}
## Reference
| **field** | **type** | **required** | **description** |
|-----------|:--------:|:------------:|--------------------------------------------------------------|
| kind | string | true | Must be `gemini`. |
| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). |
| apiKey | string | false | Your API Key from Google AI Studio. |
| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). |

1
go.mod
View File

@@ -59,6 +59,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
golang.org/x/oauth2 v0.33.0
google.golang.org/api v0.256.0
google.golang.org/genai v1.37.0
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8
google.golang.org/protobuf v1.36.10
modernc.org/sqlite v1.40.0

2
go.sum
View File

@@ -1869,6 +1869,8 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genai v1.37.0 h1:dgp71k1wQ+/+APdZrN3LFgAGnVnr5IdTF1Oj0Dg+BQc=
google.golang.org/genai v1.37.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=

View File

@@ -0,0 +1,29 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package embeddingmodels
import "context"
// EmbeddingModelConfig is the interface for configuring embedding models.
type EmbeddingModelConfig interface {
EmbeddingModelConfigKind() string
Initialize(context.Context) (EmbeddingModel, error)
}
type EmbeddingModel interface {
EmbeddingModelKind() string
ToConfig() EmbeddingModelConfig
EmbedParameters(context.Context, []string) ([][]float32, error)
}

View File

@@ -0,0 +1,122 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gemini
import (
"context"
"fmt"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util"
"google.golang.org/genai"
)
const EmbeddingModelKind string = "gemini"
// validate interface
var _ embeddingmodels.EmbeddingModelConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Dimension int32 `yaml:"dimension"`
}
// Returns the embedding model kind
func (cfg Config) EmbeddingModelConfigKind() string {
return EmbeddingModelKind
}
// Initialize a Gemini embedding model
func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) {
// Get client configs
configs := &genai.ClientConfig{}
if cfg.ApiKey != "" {
configs.APIKey = cfg.ApiKey
}
// Create new Gemini API client
client, err := genai.NewClient(ctx, configs)
if err != nil {
return nil, fmt.Errorf("unable to create Gemini API client")
}
m := &EmbeddingModel{
Config: cfg,
Client: client,
}
return m, nil
}
var _ embeddingmodels.EmbeddingModel = EmbeddingModel{}
type EmbeddingModel struct {
Client *genai.Client
Config
}
// Returns the embedding model kind
func (m EmbeddingModel) EmbeddingModelKind() string {
return EmbeddingModelKind
}
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
return m.Config
}
func (m EmbeddingModel) EmbedParameters(ctx context.Context, parameters []string) ([][]float32, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
contents := convertStringsToContents(parameters)
embedConfig := &genai.EmbedContentConfig{
TaskType: "SEMANTIC_SIMILARITY",
}
if m.Dimension > 0 {
embedConfig.OutputDimensionality = genai.Ptr(m.Dimension)
}
result, err := m.Client.Models.EmbedContent(ctx, m.Model, contents, embedConfig)
if err != nil {
logger.ErrorContext(ctx, "Error calling EmbedContent for model %s: %v", m.Model, err)
return nil, err
}
embeddings := make([][]float32, 0, len(result.Embeddings))
for _, embedding := range result.Embeddings {
embeddings = append(embeddings, embedding.Values)
}
logger.InfoContext(ctx, "Successfully embedded %d text parameters using model %s", len(parameters), m.Model)
return embeddings, nil
}
// convertStringsToContents takes a slice of strings and converts it into a slice of *genai.Content objects.
func convertStringsToContents(texts []string) []*genai.Content {
contents := make([]*genai.Content, 0, len(texts))
for _, text := range texts {
content := genai.NewContentFromText(text, "")
contents = append(contents, content)
}
return contents
}

View File

@@ -0,0 +1,130 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gemini_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlGemini(t *testing.T) {
tcs := []struct {
desc string
in string
want server.EmbeddingModelConfigs
}{
{
desc: "basic example",
in: `
embeddingModels:
my-gemini-model:
kind: gemini
model: text-embedding-004
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"my-gemini-model": gemini.Config{
Name: "my-gemini-model",
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
},
},
},
{
desc: "full example with optional fields",
in: `
embeddingModels:
complex-gemini:
kind: gemini
model: text-embedding-004
apiKey: "test-api-key"
dimension: 768
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"complex-gemini": gemini.Config{
Name: "complex-gemini",
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
ApiKey: "test-api-key",
Dimension: 768,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Models) {
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
}
})
}
}
func TestFailParseFromYamlGemini(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "missing required model field",
in: `
embeddingModels:
bad-model:
kind: gemini
`,
// Removed the specific model name from the prefix to match your output
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
},
{
desc: "unknown field",
in: `
embeddingModels:
bad-field:
kind: gemini
model: text-embedding-004
invalid_param: true
`,
// Updated to match the specific line-starting format of your error output
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
if err.Error() != tc.err {
t.Fatalf("unexpected error:\ngot: %q\nwant: %q", err.Error(), tc.err)
}
})
}
}

View File

@@ -19,8 +19,8 @@ sources:
region: ${CLOUD_SQL_MYSQL_REGION}
instance: ${CLOUD_SQL_MYSQL_INSTANCE}
database: ${CLOUD_SQL_MYSQL_DATABASE}
user: ${CLOUD_SQL_MYSQL_USER}
password: ${CLOUD_SQL_MYSQL_PASSWORD}
user: ${CLOUD_SQL_MYSQL_USER:}
password: ${CLOUD_SQL_MYSQL_PASSWORD:}
ipType: ${CLOUD_SQL_MYSQL_IP_TYPE:PUBLIC}
tools:
execute_sql:

View File

@@ -276,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
sseManager := newSseManager(ctx)
resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets)
resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets)
server := Server{
version: fakeVersionString,

View File

@@ -21,6 +21,8 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/auth/google"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -38,6 +40,8 @@ type ServerConfig struct {
SourceConfigs SourceConfigs
// AuthServiceConfigs defines what sources of authentication are available for tools.
AuthServiceConfigs AuthServiceConfigs
// EmbeddingModelConfigs defines a models used to embed parameters.
EmbeddingModelConfigs EmbeddingModelConfigs
// ToolConfigs defines what tools are available.
ToolConfigs ToolConfigs
// ToolsetConfigs defines what tools are available.
@@ -205,6 +209,50 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
return nil
}
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(EmbeddingModelConfigs)
// Parse the 'kind' fields for each embedding model
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case gemini.EmbeddingModelKind:
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// ToolConfigs is a type used to allow unmarshal of the tool configs
type ToolConfigs map[string]tools.ToolConfig

View File

@@ -1107,7 +1107,7 @@ func TestStdioSession(t *testing.T) {
sseManager := newSseManager(ctx)
resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets)
resourceManager := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsets, promptsMap, promptsets)
server := &Server{
version: fakeVersionString,

View File

@@ -18,6 +18,7 @@ import (
"sync"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -25,30 +26,33 @@ import (
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
type ResourceManager struct {
mu sync.RWMutex
sources map[string]sources.Source
authServices map[string]auth.AuthService
tools map[string]tools.Tool
toolsets map[string]tools.Toolset
prompts map[string]prompts.Prompt
promptsets map[string]prompts.Promptset
mu sync.RWMutex
sources map[string]sources.Source
authServices map[string]auth.AuthService
embeddingModels map[string]embeddingmodels.EmbeddingModel
tools map[string]tools.Tool
toolsets map[string]tools.Toolset
prompts map[string]prompts.Prompt
promptsets map[string]prompts.Promptset
}
func NewResourceManager(
sourcesMap map[string]sources.Source,
authServicesMap map[string]auth.AuthService,
embeddingModelsMap map[string]embeddingmodels.EmbeddingModel,
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset,
) *ResourceManager {
resourceMgr := &ResourceManager{
mu: sync.RWMutex{},
sources: sourcesMap,
authServices: authServicesMap,
tools: toolsMap,
toolsets: toolsetsMap,
prompts: promptsMap,
promptsets: promptsetsMap,
mu: sync.RWMutex{},
sources: sourcesMap,
authServices: authServicesMap,
embeddingModels: embeddingModelsMap,
tools: toolsMap,
toolsets: toolsetsMap,
prompts: promptsMap,
promptsets: promptsetsMap,
}
return resourceMgr
@@ -68,6 +72,13 @@ func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthServi
return authService, ok
}
func (r *ResourceManager) GetEmbeddingModel(embeddingModelName string) (embeddingmodels.EmbeddingModel, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
model, ok := r.embeddingModels[embeddingModelName]
return model, ok
}
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -96,11 +107,12 @@ func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset,
return promptset, ok
}
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
r.mu.Lock()
defer r.mu.Unlock()
r.sources = sourcesMap
r.authServices = authServicesMap
r.embeddingModels = embeddingModelsMap
r.tools = toolsMap
r.toolsets = toolsetsMap
r.prompts = promptsMap
@@ -117,6 +129,16 @@ func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
return copiedMap
}
func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.EmbeddingModel {
r.mu.RLock()
defer r.mu.RUnlock()
copiedMap := make(map[string]embeddingmodels.EmbeddingModel, len(r.embeddingModels))
for k, v := range r.embeddingModels {
copiedMap[k] = v
}
return copiedMap
}
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
r.mu.RLock()
defer r.mu.RUnlock()

View File

@@ -19,6 +19,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -36,6 +37,7 @@ func TestUpdateServer(t *testing.T) {
},
}
newAuth := map[string]auth.AuthService{"example-auth": nil}
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
newTools := map[string]tools.Tool{"example-tool": nil}
newToolsets := map[string]tools.Toolset{
"example-toolset": {
@@ -54,7 +56,7 @@ func TestUpdateServer(t *testing.T) {
Prompts: []*prompts.Prompt{},
},
}
resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
resMgr := resources.NewResourceManager(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
gotSource, _ := resMgr.GetSource("example-source")
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
@@ -95,7 +97,7 @@ func TestUpdateServer(t *testing.T) {
},
}
resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
resMgr.SetResources(updateSource, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
gotSource, _ = resMgr.GetSource("example-source2")
if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" {
t.Errorf("error updating server, sources (-want +got):\n%s", diff)

View File

@@ -30,6 +30,7 @@ import (
"github.com/go-chi/cors"
"github.com/go-chi/httplog/v2"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/resources"
@@ -56,6 +57,7 @@ type Server struct {
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
map[string]sources.Source,
map[string]auth.AuthService,
map[string]embeddingmodels.EmbeddingModel,
map[string]tools.Tool,
map[string]tools.Toolset,
map[string]prompts.Prompt,
@@ -91,7 +93,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return s, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
sourcesMap[name] = s
}
@@ -119,7 +121,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return a, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
authServicesMap[name] = a
}
@@ -129,6 +131,34 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices: %s", len(authServicesMap), strings.Join(authServiceNames, ", ")))
// Initialize and validate embedding models from configs.
embeddingModelsMap := make(map[string]embeddingmodels.EmbeddingModel)
for name, ec := range cfg.EmbeddingModelConfigs {
em, err := func() (embeddingmodels.EmbeddingModel, error) {
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/embeddingmodel/init",
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
trace.WithAttributes(attribute.String("model_name", name)),
)
defer span.End()
em, err := ec.Initialize(ctx)
if err != nil {
return nil, fmt.Errorf("unable to initialize embedding model %q: %w", name, err)
}
return em, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, nil, err
}
embeddingModelsMap[name] = em
}
embeddingModelNames := make([]string, 0, len(embeddingModelsMap))
for name := range embeddingModelsMap {
embeddingModelNames = append(embeddingModelNames, name)
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d embeddingModels: %s", len(embeddingModelsMap), strings.Join(embeddingModelNames, ", ")))
// initialize and validate the tools from configs
toolsMap := make(map[string]tools.Tool)
for name, tc := range cfg.ToolConfigs {
@@ -147,7 +177,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return t, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
toolsMap[name] = t
}
@@ -184,7 +214,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return t, err
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
toolsetsMap[name] = t
}
@@ -216,7 +246,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return p, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
promptsMap[name] = p
}
@@ -253,7 +283,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return p, err
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
promptsetsMap[name] = p
}
@@ -267,7 +297,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d promptsets: %s", len(promptsetsMap), strings.Join(promptsetNames, ", ")))
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
}
// NewServer returns a Server object based on provided Config.
@@ -320,7 +350,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
httpLogger := httplog.NewLogger("httplog", httpOpts)
r.Use(httplog.RequestLogger(httpLogger))
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("unable to initialize configs: %w", err)
}
@@ -330,7 +360,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
sseManager := newSseManager(ctx)
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
s := &Server{
version: cfg.Version,

View File

@@ -25,6 +25,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server"
@@ -144,6 +145,7 @@ func TestUpdateServer(t *testing.T) {
},
}
newAuth := map[string]auth.AuthService{"example-auth": nil}
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
newTools := map[string]tools.Tool{"example-tool": nil}
newToolsets := map[string]tools.Toolset{
"example-toolset": {
@@ -162,7 +164,7 @@ func TestUpdateServer(t *testing.T) {
Prompts: []*prompts.Prompt{},
},
}
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
s.ResourceMgr.SetResources(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
if err != nil {
t.Errorf("error updating server: %s", err)
}

View File

@@ -333,7 +333,7 @@ func TestInvoke(t *testing.T) {
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client