From 9c62f313ff5edf0a3b5b8a3e996eba078fba4095 Mon Sep 17 00:00:00 2001 From: Wenxin Du <117315983+duwenxin99@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:34:54 -0500 Subject: [PATCH] feat: Add `embeddingModel` support (#2121) First part of the implementation to support semantic search in tools. Second part: https://github.com/googleapis/genai-toolbox/pull/2151 --- cmd/root.go | 59 ++++---- cmd/root_test.go | 41 +++--- docs/en/resources/embeddingModels/_index.md | 84 +++++++++++ docs/en/resources/embeddingModels/gemini.md | 73 ++++++++++ go.mod | 1 + go.sum | 2 + internal/embeddingmodels/embeddingmodels.go | 29 ++++ internal/embeddingmodels/gemini/gemini.go | 122 ++++++++++++++++ .../embeddingmodels/gemini/gemini_test.go | 130 ++++++++++++++++++ internal/server/common_test.go | 2 +- internal/server/config.go | 48 +++++++ internal/server/mcp_test.go | 2 +- internal/server/resources/resources.go | 52 +++++-- internal/server/resources/resources_test.go | 6 +- internal/server/server.go | 48 +++++-- internal/server/server_test.go | 4 +- internal/tools/cloudgda/cloudgda_test.go | 2 +- 17 files changed, 634 insertions(+), 71 deletions(-) create mode 100644 docs/en/resources/embeddingModels/_index.md create mode 100644 docs/en/resources/embeddingModels/gemini.md create mode 100644 internal/embeddingmodels/embeddingmodels.go create mode 100644 internal/embeddingmodels/gemini/gemini.go create mode 100644 internal/embeddingmodels/gemini/gemini_test.go diff --git a/cmd/root.go b/cmd/root.go index e0bb46c642..ca9a2edc3f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -33,6 +33,7 @@ import ( "github.com/fsnotify/fsnotify" yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" "github.com/googleapis/genai-toolbox/internal/prompts" @@ -385,12 +386,13 @@ func NewCommand(opts ...Option) *Command { } type ToolsFile struct { - Sources server.SourceConfigs `yaml:"sources"` - AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility. - AuthServices server.AuthServiceConfigs `yaml:"authServices"` - Tools server.ToolConfigs `yaml:"tools"` - Toolsets server.ToolsetConfigs `yaml:"toolsets"` - Prompts server.PromptConfigs `yaml:"prompts"` + Sources server.SourceConfigs `yaml:"sources"` + AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility. + AuthServices server.AuthServiceConfigs `yaml:"authServices"` + EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"` + Tools server.ToolConfigs `yaml:"tools"` + Toolsets server.ToolsetConfigs `yaml:"toolsets"` + Prompts server.PromptConfigs `yaml:"prompts"` } // parseEnv replaces environment variables ${ENV_NAME} with their values. @@ -439,11 +441,12 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) { // All resource names (sources, authServices, tools, toolsets) must be unique across all files. func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { merged := ToolsFile{ - Sources: make(server.SourceConfigs), - AuthServices: make(server.AuthServiceConfigs), - Tools: make(server.ToolConfigs), - Toolsets: make(server.ToolsetConfigs), - Prompts: make(server.PromptConfigs), + Sources: make(server.SourceConfigs), + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: make(server.EmbeddingModelConfigs), + Tools: make(server.ToolConfigs), + Toolsets: make(server.ToolsetConfigs), + Prompts: make(server.PromptConfigs), } var conflicts []string @@ -479,6 +482,15 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) { } } + // Check for conflicts and merge embeddingModels + for name, model := range file.EmbeddingModels { + if _, exists := merged.EmbeddingModels[name]; exists { + conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1)) + } else { + merged.EmbeddingModels[name] = model + } + } + // Check for conflicts and merge tools for name, tool := range file.Tools { if _, exists := merged.Tools[name]; exists { @@ -583,14 +595,14 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser panic(err) } - sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile) if err != nil { errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err) logger.WarnContext(ctx, errMsg.Error()) return err } - s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) + s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) return nil } @@ -598,7 +610,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser // validateReloadEdits checks that the reloaded tools file configs can initialized without failing func validateReloadEdits( ctx context.Context, toolsFile ToolsFile, -) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error, +) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error, ) { logger, err := util.LoggerFromContext(ctx) if err != nil { @@ -616,22 +628,23 @@ func validateReloadEdits( defer span.End() reloadedConfig := server.ServerConfig{ - Version: versionString, - SourceConfigs: toolsFile.Sources, - AuthServiceConfigs: toolsFile.AuthServices, - ToolConfigs: toolsFile.Tools, - ToolsetConfigs: toolsFile.Toolsets, - PromptConfigs: toolsFile.Prompts, + Version: versionString, + SourceConfigs: toolsFile.Sources, + AuthServiceConfigs: toolsFile.AuthServices, + EmbeddingModelConfigs: toolsFile.EmbeddingModels, + ToolConfigs: toolsFile.Tools, + ToolsetConfigs: toolsFile.Toolsets, + PromptConfigs: toolsFile.Prompts, } - sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig) if err != nil { errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err) logger.WarnContext(ctx, errMsg.Error()) - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } - return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil + return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil } // watchChanges checks for changes in the provided yaml tools file(s) or folder. diff --git a/cmd/root_test.go b/cmd/root_test.go index 6036c9c478..bb2b7e4949 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -32,6 +32,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/auth/google" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prebuiltconfigs" "github.com/googleapis/genai-toolbox/internal/prompts" @@ -1830,9 +1831,10 @@ func TestFileLoadingErrors(t *testing.T) { func TestMergeToolsFiles(t *testing.T) { file1 := ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}}, - Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}}, + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}}, + Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}}, + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, } file2 := ToolsFile{ AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, @@ -1854,11 +1856,12 @@ func TestMergeToolsFiles(t *testing.T) { name: "merge two distinct files", files: []ToolsFile{file1, file2}, want: ToolsFile{ - Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, - AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, - Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}}, - Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}}, - Prompts: server.PromptConfigs{}, + Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}}, + AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}}, + Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}}, + Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}}, + Prompts: server.PromptConfigs{}, + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, }, wantErr: false, }, @@ -1871,22 +1874,24 @@ func TestMergeToolsFiles(t *testing.T) { name: "merge single file", files: []ToolsFile{file1}, want: ToolsFile{ - Sources: file1.Sources, - AuthServices: make(server.AuthServiceConfigs), - Tools: file1.Tools, - Toolsets: file1.Toolsets, - Prompts: server.PromptConfigs{}, + Sources: file1.Sources, + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}}, + Tools: file1.Tools, + Toolsets: file1.Toolsets, + Prompts: server.PromptConfigs{}, }, }, { name: "merge empty list", files: []ToolsFile{}, want: ToolsFile{ - Sources: make(server.SourceConfigs), - AuthServices: make(server.AuthServiceConfigs), - Tools: make(server.ToolConfigs), - Toolsets: make(server.ToolsetConfigs), - Prompts: server.PromptConfigs{}, + Sources: make(server.SourceConfigs), + AuthServices: make(server.AuthServiceConfigs), + EmbeddingModels: make(server.EmbeddingModelConfigs), + Tools: make(server.ToolConfigs), + Toolsets: make(server.ToolsetConfigs), + Prompts: server.PromptConfigs{}, }, }, } diff --git a/docs/en/resources/embeddingModels/_index.md b/docs/en/resources/embeddingModels/_index.md new file mode 100644 index 0000000000..84ddfb9222 --- /dev/null +++ b/docs/en/resources/embeddingModels/_index.md @@ -0,0 +1,84 @@ +--- +title: "EmbeddingModels" +type: docs +weight: 2 +description: > + EmbeddingModels represent services that transform text into vector embeddings for semantic search. +--- + +EmbeddingModels represent services that generate vector representations of text +data. In the MCP Toolbox, these models enable **Semantic Queries**, +allowing [Tools](../tools/) to automatically convert human-readable text into +numerical vectors before using them in a query. + +This is primarily used in two scenarios: + +- **Vector Ingestion**: Converting a text parameter into a vector string during + an `INSERT` operation. + +- **Semantic Search**: Converting a natural language query into a vector to + perform similarity searches. + +## Example + +The following configuration defines an embedding model and applies it to +specific tool parameters. + +{{< notice tip >}} +Use environment variable replacement with the format ${ENV_NAME} +instead of hardcoding your API keys into the configuration file. +{{< /notice >}} + +### Step 1 - Define an Embedding Model + +Define an embedding model in the `embeddingModels` section: + +```yaml +embeddingModels: + gemini-model: # Name of the embedding model + kind: gemini + model: gemini-embedding-001 + apiKey: ${GOOGLE_API_KEY} + dimension: 768 + +``` + +### Step 2 - Embed Tool Parameters + +Use the defined embedding model, embed your query parameters using the +`embeddedBy` field. Only string-typed +parameters can be embedded: + +```yaml +tools: + # Vector ingestion tool + insert_embedding: + kind: postgres-sql + source: my-pg-instance + statement: | + INSERT INTO documents (content, embedding) + VALUES ($1, $2); + parameters: + - name: content + type: string + - name: vector_string + type: string + description: The text to be vectorized and stored. + embeddedBy: gemini-model # refers to the name of a defined embedding model + + # Semantic search tool + search_embedding: + kind: postgres-sql + source: my-pg-instance + statement: | + SELECT id, content, embedding <-> $1 AS distance + FROM documents + ORDER BY distance LIMIT 1 + parameters: + - name: semantic_search_string + type: string + description: The search query that will be converted to a vector. + embeddedBy: gemini-model # refers to the name of a defined embedding model +``` + +## Kinds of Embedding Models diff --git a/docs/en/resources/embeddingModels/gemini.md b/docs/en/resources/embeddingModels/gemini.md new file mode 100644 index 0000000000..8fcf8cf358 --- /dev/null +++ b/docs/en/resources/embeddingModels/gemini.md @@ -0,0 +1,73 @@ +--- +title: "Gemini Embedding" +type: docs +weight: 1 +description: > + Use Google's Gemini models to generate high-performance text embeddings for vector databases. +--- + +## About + +Google Gemini provides state-of-the-art embedding models that convert text into +high-dimensional vectors. + +### Authentication + +Toolbox uses your [Application Default Credentials +(ADC)][adc] to authorize with the +Gemini API client. + +Optionally, you can use an [API key][api-key] obtain an API +Key from the [Google AI Studio][ai-studio]. + +We recommend using an API key for testing and using application default +credentials for production. + +[adc]: https://cloud.google.com/docs/authentication#adc +[api-key]: https://ai.google.dev/gemini-api/docs/api-key#api-keys +[ai-studio]: https://aistudio.google.com/app/apikey + +## Behavior + +### Automatic Vectorization + +When a tool parameter is configured with `embeddedBy: `, +the Toolbox intercepts the raw text input from the client and sends it to the +Gemini API. The resulting numerical array is then formatted before being passed +to your database source. + +### Dimension Matching + +The `dimension` field must match the expected size of your database column +(e.g., a `vector(768)` column in PostgreSQL). This setting is supported by newer +models since 2024 only. You cannot set this value if using the earlier model +(`models/embedding-001`). Check out [available Gemini models][modellist] for more +information. + +[modellist]: + https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models + +## Example + +```yaml +embeddingModels: + gemini-model: + kind: gemini + model: gemini-embedding-001 + apiKey: ${GOOGLE_API_KEY} + dimension: 768 +``` + +{{< notice tip >}} +Use environment variable replacement with the format ${ENV_NAME} +instead of hardcoding your secrets into the configuration file. +{{< /notice >}} + +## Reference + +| **field** | **type** | **required** | **description** | +|-----------|:--------:|:------------:|--------------------------------------------------------------| +| kind | string | true | Must be `gemini`. | +| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). | +| apiKey | string | false | Your API Key from Google AI Studio. | +| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). | diff --git a/go.mod b/go.mod index a2d598c060..c351fd27f1 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( go.opentelemetry.io/otel/trace v1.38.0 golang.org/x/oauth2 v0.33.0 google.golang.org/api v0.256.0 + google.golang.org/genai v1.37.0 google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 google.golang.org/protobuf v1.36.10 modernc.org/sqlite v1.40.0 diff --git a/go.sum b/go.sum index eeac2b4fd4..e52edfb83a 100644 --- a/go.sum +++ b/go.sum @@ -1869,6 +1869,8 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genai v1.37.0 h1:dgp71k1wQ+/+APdZrN3LFgAGnVnr5IdTF1Oj0Dg+BQc= +google.golang.org/genai v1.37.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= diff --git a/internal/embeddingmodels/embeddingmodels.go b/internal/embeddingmodels/embeddingmodels.go new file mode 100644 index 0000000000..f079bc8a2d --- /dev/null +++ b/internal/embeddingmodels/embeddingmodels.go @@ -0,0 +1,29 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package embeddingmodels + +import "context" + +// EmbeddingModelConfig is the interface for configuring embedding models. +type EmbeddingModelConfig interface { + EmbeddingModelConfigKind() string + Initialize(context.Context) (EmbeddingModel, error) +} + +type EmbeddingModel interface { + EmbeddingModelKind() string + ToConfig() EmbeddingModelConfig + EmbedParameters(context.Context, []string) ([][]float32, error) +} diff --git a/internal/embeddingmodels/gemini/gemini.go b/internal/embeddingmodels/gemini/gemini.go new file mode 100644 index 0000000000..1a31060e7b --- /dev/null +++ b/internal/embeddingmodels/gemini/gemini.go @@ -0,0 +1,122 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gemini + +import ( + "context" + "fmt" + + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/util" + "google.golang.org/genai" +) + +const EmbeddingModelKind string = "gemini" + +// validate interface +var _ embeddingmodels.EmbeddingModelConfig = Config{} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Kind string `yaml:"kind" validate:"required"` + Model string `yaml:"model" validate:"required"` + ApiKey string `yaml:"apiKey"` + Dimension int32 `yaml:"dimension"` +} + +// Returns the embedding model kind +func (cfg Config) EmbeddingModelConfigKind() string { + return EmbeddingModelKind +} + +// Initialize a Gemini embedding model +func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) { + // Get client configs + configs := &genai.ClientConfig{} + if cfg.ApiKey != "" { + configs.APIKey = cfg.ApiKey + } + + // Create new Gemini API client + client, err := genai.NewClient(ctx, configs) + if err != nil { + return nil, fmt.Errorf("unable to create Gemini API client") + } + + m := &EmbeddingModel{ + Config: cfg, + Client: client, + } + return m, nil +} + +var _ embeddingmodels.EmbeddingModel = EmbeddingModel{} + +type EmbeddingModel struct { + Client *genai.Client + Config +} + +// Returns the embedding model kind +func (m EmbeddingModel) EmbeddingModelKind() string { + return EmbeddingModelKind +} + +func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig { + return m.Config +} + +func (m EmbeddingModel) EmbedParameters(ctx context.Context, parameters []string) ([][]float32, error) { + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("unable to get logger from ctx: %s", err) + } + + contents := convertStringsToContents(parameters) + + embedConfig := &genai.EmbedContentConfig{ + TaskType: "SEMANTIC_SIMILARITY", + } + + if m.Dimension > 0 { + embedConfig.OutputDimensionality = genai.Ptr(m.Dimension) + } + + result, err := m.Client.Models.EmbedContent(ctx, m.Model, contents, embedConfig) + if err != nil { + logger.ErrorContext(ctx, "Error calling EmbedContent for model %s: %v", m.Model, err) + return nil, err + } + + embeddings := make([][]float32, 0, len(result.Embeddings)) + for _, embedding := range result.Embeddings { + embeddings = append(embeddings, embedding.Values) + } + + logger.InfoContext(ctx, "Successfully embedded %d text parameters using model %s", len(parameters), m.Model) + + return embeddings, nil +} + +// convertStringsToContents takes a slice of strings and converts it into a slice of *genai.Content objects. +func convertStringsToContents(texts []string) []*genai.Content { + contents := make([]*genai.Content, 0, len(texts)) + + for _, text := range texts { + content := genai.NewContentFromText(text, "") + contents = append(contents, content) + } + return contents +} diff --git a/internal/embeddingmodels/gemini/gemini_test.go b/internal/embeddingmodels/gemini/gemini_test.go new file mode 100644 index 0000000000..c4c3b8dcdc --- /dev/null +++ b/internal/embeddingmodels/gemini/gemini_test.go @@ -0,0 +1,130 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gemini_test + +import ( + "testing" + + yaml "github.com/goccy/go-yaml" + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYamlGemini(t *testing.T) { + tcs := []struct { + desc string + in string + want server.EmbeddingModelConfigs + }{ + { + desc: "basic example", + in: ` + embeddingModels: + my-gemini-model: + kind: gemini + model: text-embedding-004 + `, + want: map[string]embeddingmodels.EmbeddingModelConfig{ + "my-gemini-model": gemini.Config{ + Name: "my-gemini-model", + Kind: gemini.EmbeddingModelKind, + Model: "text-embedding-004", + }, + }, + }, + { + desc: "full example with optional fields", + in: ` + embeddingModels: + complex-gemini: + kind: gemini + model: text-embedding-004 + apiKey: "test-api-key" + dimension: 768 + `, + want: map[string]embeddingmodels.EmbeddingModelConfig{ + "complex-gemini": gemini.Config{ + Name: "complex-gemini", + Kind: gemini.EmbeddingModelKind, + Model: "text-embedding-004", + ApiKey: "test-api-key", + Dimension: 768, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Models server.EmbeddingModelConfigs `yaml:"embeddingModels"` + }{} + // Parse contents + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if !cmp.Equal(tc.want, got.Models) { + t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models)) + } + }) + } +} +func TestFailParseFromYamlGemini(t *testing.T) { + tcs := []struct { + desc string + in string + err string + }{ + { + desc: "missing required model field", + in: ` + embeddingModels: + bad-model: + kind: gemini + `, + // Removed the specific model name from the prefix to match your output + err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag", + }, + { + desc: "unknown field", + in: ` + embeddingModels: + bad-field: + kind: gemini + model: text-embedding-004 + invalid_param: true + `, + // Updated to match the specific line-starting format of your error output + err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := struct { + Models server.EmbeddingModelConfigs `yaml:"embeddingModels"` + }{} + err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got) + if err == nil { + t.Fatalf("expect parsing to fail") + } + if err.Error() != tc.err { + t.Fatalf("unexpected error:\ngot: %q\nwant: %q", err.Error(), tc.err) + } + }) + } +} diff --git a/internal/server/common_test.go b/internal/server/common_test.go index 3953e1c7bc..5419351a41 100644 --- a/internal/server/common_test.go +++ b/internal/server/common_test.go @@ -276,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools sseManager := newSseManager(ctx) - resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets) + resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets) server := Server{ version: fakeVersionString, diff --git a/internal/server/config.go b/internal/server/config.go index fa0f1952a7..8beb32c839 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -21,6 +21,8 @@ import ( yaml "github.com/goccy/go-yaml" "github.com/googleapis/genai-toolbox/internal/auth" "github.com/googleapis/genai-toolbox/internal/auth/google" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -38,6 +40,8 @@ type ServerConfig struct { SourceConfigs SourceConfigs // AuthServiceConfigs defines what sources of authentication are available for tools. AuthServiceConfigs AuthServiceConfigs + // EmbeddingModelConfigs defines a models used to embed parameters. + EmbeddingModelConfigs EmbeddingModelConfigs // ToolConfigs defines what tools are available. ToolConfigs ToolConfigs // ToolsetConfigs defines what tools are available. @@ -205,6 +209,50 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i return nil } +// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map +type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig + +// validate interface +var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{} + +func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error { + *c = make(EmbeddingModelConfigs) + // Parse the 'kind' fields for each embedding model + var raw map[string]util.DelayedUnmarshaler + if err := unmarshal(&raw); err != nil { + return err + } + + for name, u := range raw { + // Unmarshal to a general type that ensure it capture all fields + var v map[string]any + if err := u.Unmarshal(&v); err != nil { + return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err) + } + + kind, ok := v["kind"] + if !ok { + return fmt.Errorf("missing 'kind' field for embedding model %q", name) + } + + dec, err := util.NewStrictDecoder(v) + if err != nil { + return fmt.Errorf("error creating decoder: %w", err) + } + switch kind { + case gemini.EmbeddingModelKind: + actual := gemini.Config{Name: name} + if err := dec.DecodeContext(ctx, &actual); err != nil { + return fmt.Errorf("unable to parse as %q: %w", kind, err) + } + (*c)[name] = actual + default: + return fmt.Errorf("%q is not a valid kind of auth source", kind) + } + } + return nil +} + // ToolConfigs is a type used to allow unmarshal of the tool configs type ToolConfigs map[string]tools.ToolConfig diff --git a/internal/server/mcp_test.go b/internal/server/mcp_test.go index 90b8676098..ff6ffffe84 100644 --- a/internal/server/mcp_test.go +++ b/internal/server/mcp_test.go @@ -1107,7 +1107,7 @@ func TestStdioSession(t *testing.T) { sseManager := newSseManager(ctx) - resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets) + resourceManager := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsets, promptsMap, promptsets) server := &Server{ version: fakeVersionString, diff --git a/internal/server/resources/resources.go b/internal/server/resources/resources.go index 0cea0b7eaa..b41e160a39 100644 --- a/internal/server/resources/resources.go +++ b/internal/server/resources/resources.go @@ -18,6 +18,7 @@ import ( "sync" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/sources" "github.com/googleapis/genai-toolbox/internal/tools" @@ -25,30 +26,33 @@ import ( // ResourceManager contains available resources for the server. Should be initialized with NewResourceManager(). type ResourceManager struct { - mu sync.RWMutex - sources map[string]sources.Source - authServices map[string]auth.AuthService - tools map[string]tools.Tool - toolsets map[string]tools.Toolset - prompts map[string]prompts.Prompt - promptsets map[string]prompts.Promptset + mu sync.RWMutex + sources map[string]sources.Source + authServices map[string]auth.AuthService + embeddingModels map[string]embeddingmodels.EmbeddingModel + tools map[string]tools.Tool + toolsets map[string]tools.Toolset + prompts map[string]prompts.Prompt + promptsets map[string]prompts.Promptset } func NewResourceManager( sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, + embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset, ) *ResourceManager { resourceMgr := &ResourceManager{ - mu: sync.RWMutex{}, - sources: sourcesMap, - authServices: authServicesMap, - tools: toolsMap, - toolsets: toolsetsMap, - prompts: promptsMap, - promptsets: promptsetsMap, + mu: sync.RWMutex{}, + sources: sourcesMap, + authServices: authServicesMap, + embeddingModels: embeddingModelsMap, + tools: toolsMap, + toolsets: toolsetsMap, + prompts: promptsMap, + promptsets: promptsetsMap, } return resourceMgr @@ -68,6 +72,13 @@ func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthServi return authService, ok } +func (r *ResourceManager) GetEmbeddingModel(embeddingModelName string) (embeddingmodels.EmbeddingModel, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + model, ok := r.embeddingModels[embeddingModelName] + return model, ok +} + func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) { r.mu.RLock() defer r.mu.RUnlock() @@ -96,11 +107,12 @@ func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset, return promptset, ok } -func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) { +func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) { r.mu.Lock() defer r.mu.Unlock() r.sources = sourcesMap r.authServices = authServicesMap + r.embeddingModels = embeddingModelsMap r.tools = toolsMap r.toolsets = toolsetsMap r.prompts = promptsMap @@ -117,6 +129,16 @@ func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService { return copiedMap } +func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.EmbeddingModel { + r.mu.RLock() + defer r.mu.RUnlock() + copiedMap := make(map[string]embeddingmodels.EmbeddingModel, len(r.embeddingModels)) + for k, v := range r.embeddingModels { + copiedMap[k] = v + } + return copiedMap +} + func (r *ResourceManager) GetToolsMap() map[string]tools.Tool { r.mu.RLock() defer r.mu.RUnlock() diff --git a/internal/server/resources/resources_test.go b/internal/server/resources/resources_test.go index b746abf3fc..ad4bfdd326 100644 --- a/internal/server/resources/resources_test.go +++ b/internal/server/resources/resources_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/resources" "github.com/googleapis/genai-toolbox/internal/sources" @@ -36,6 +37,7 @@ func TestUpdateServer(t *testing.T) { }, } newAuth := map[string]auth.AuthService{"example-auth": nil} + newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil} newTools := map[string]tools.Tool{"example-tool": nil} newToolsets := map[string]tools.Toolset{ "example-toolset": { @@ -54,7 +56,7 @@ func TestUpdateServer(t *testing.T) { Prompts: []*prompts.Prompt{}, }, } - resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets) + resMgr := resources.NewResourceManager(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets) gotSource, _ := resMgr.GetSource("example-source") if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" { @@ -95,7 +97,7 @@ func TestUpdateServer(t *testing.T) { }, } - resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets) + resMgr.SetResources(updateSource, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets) gotSource, _ = resMgr.GetSource("example-source2") if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" { t.Errorf("error updating server, sources (-want +got):\n%s", diff) diff --git a/internal/server/server.go b/internal/server/server.go index 4d2f600bd1..10d146143c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -30,6 +30,7 @@ import ( "github.com/go-chi/cors" "github.com/go-chi/httplog/v2" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server/resources" @@ -56,6 +57,7 @@ type Server struct { func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( map[string]sources.Source, map[string]auth.AuthService, + map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, @@ -91,7 +93,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return s, nil }() if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } sourcesMap[name] = s } @@ -119,7 +121,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return a, nil }() if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } authServicesMap[name] = a } @@ -129,6 +131,34 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( } l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices: %s", len(authServicesMap), strings.Join(authServiceNames, ", "))) + // Initialize and validate embedding models from configs. + embeddingModelsMap := make(map[string]embeddingmodels.EmbeddingModel) + for name, ec := range cfg.EmbeddingModelConfigs { + em, err := func() (embeddingmodels.EmbeddingModel, error) { + _, span := instrumentation.Tracer.Start( + ctx, + "toolbox/server/embeddingmodel/init", + trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())), + trace.WithAttributes(attribute.String("model_name", name)), + ) + defer span.End() + em, err := ec.Initialize(ctx) + if err != nil { + return nil, fmt.Errorf("unable to initialize embedding model %q: %w", name, err) + } + return em, nil + }() + if err != nil { + return nil, nil, nil, nil, nil, nil, nil, err + } + embeddingModelsMap[name] = em + } + embeddingModelNames := make([]string, 0, len(embeddingModelsMap)) + for name := range embeddingModelsMap { + embeddingModelNames = append(embeddingModelNames, name) + } + l.InfoContext(ctx, fmt.Sprintf("Initialized %d embeddingModels: %s", len(embeddingModelsMap), strings.Join(embeddingModelNames, ", "))) + // initialize and validate the tools from configs toolsMap := make(map[string]tools.Tool) for name, tc := range cfg.ToolConfigs { @@ -147,7 +177,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return t, nil }() if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } toolsMap[name] = t } @@ -184,7 +214,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return t, err }() if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } toolsetsMap[name] = t } @@ -216,7 +246,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return p, nil }() if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } promptsMap[name] = p } @@ -253,7 +283,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( return p, err }() if err != nil { - return nil, nil, nil, nil, nil, nil, err + return nil, nil, nil, nil, nil, nil, nil, err } promptsetsMap[name] = p } @@ -267,7 +297,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) ( } l.InfoContext(ctx, fmt.Sprintf("Initialized %d promptsets: %s", len(promptsetsMap), strings.Join(promptsetNames, ", "))) - return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil + return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil } // NewServer returns a Server object based on provided Config. @@ -320,7 +350,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { httpLogger := httplog.NewLogger("httplog", httpOpts) r.Use(httplog.RequestLogger(httpLogger)) - sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg) + sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg) if err != nil { return nil, fmt.Errorf("unable to initialize configs: %w", err) } @@ -330,7 +360,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) { sseManager := newSseManager(ctx) - resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) + resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap) s := &Server{ version: cfg.Version, diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 1d11379d9c..dc221e9a5e 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -25,6 +25,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/googleapis/genai-toolbox/internal/auth" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" "github.com/googleapis/genai-toolbox/internal/log" "github.com/googleapis/genai-toolbox/internal/prompts" "github.com/googleapis/genai-toolbox/internal/server" @@ -144,6 +145,7 @@ func TestUpdateServer(t *testing.T) { }, } newAuth := map[string]auth.AuthService{"example-auth": nil} + newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil} newTools := map[string]tools.Tool{"example-tool": nil} newToolsets := map[string]tools.Toolset{ "example-toolset": { @@ -162,7 +164,7 @@ func TestUpdateServer(t *testing.T) { Prompts: []*prompts.Prompt{}, }, } - s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets) + s.ResourceMgr.SetResources(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets) if err != nil { t.Errorf("error updating server: %s", err) } diff --git a/internal/tools/cloudgda/cloudgda_test.go b/internal/tools/cloudgda/cloudgda_test.go index 0d57032904..34e8848cb8 100644 --- a/internal/tools/cloudgda/cloudgda_test.go +++ b/internal/tools/cloudgda/cloudgda_test.go @@ -333,7 +333,7 @@ func TestInvoke(t *testing.T) { {Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"}, } - resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil) + resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil) // Invoke the tool result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client