mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-07 22:54:06 -05:00
Merge branch 'main' into docs/toolbox-adk-migration
This commit is contained in:
60
cmd/root.go
60
cmd/root.go
@@ -33,6 +33,7 @@ import (
|
||||
"github.com/fsnotify/fsnotify"
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
@@ -197,6 +198,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats"
|
||||
@@ -385,12 +387,13 @@ func NewCommand(opts ...Option) *Command {
|
||||
}
|
||||
|
||||
type ToolsFile struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
Prompts server.PromptConfigs `yaml:"prompts"`
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
|
||||
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
|
||||
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
|
||||
Prompts server.PromptConfigs `yaml:"prompts"`
|
||||
}
|
||||
|
||||
// parseEnv replaces environment variables ${ENV_NAME} with their values.
|
||||
@@ -439,11 +442,12 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
|
||||
// All resource names (sources, authServices, tools, toolsets) must be unique across all files.
|
||||
func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
merged := ToolsFile{
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: make(server.PromptConfigs),
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: make(server.EmbeddingModelConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: make(server.PromptConfigs),
|
||||
}
|
||||
|
||||
var conflicts []string
|
||||
@@ -479,6 +483,15 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge embeddingModels
|
||||
for name, model := range file.EmbeddingModels {
|
||||
if _, exists := merged.EmbeddingModels[name]; exists {
|
||||
conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1))
|
||||
} else {
|
||||
merged.EmbeddingModels[name] = model
|
||||
}
|
||||
}
|
||||
|
||||
// Check for conflicts and merge tools
|
||||
for name, tool := range file.Tools {
|
||||
if _, exists := merged.Tools[name]; exists {
|
||||
@@ -583,14 +596,14 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -598,7 +611,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
|
||||
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
|
||||
func validateReloadEdits(
|
||||
ctx context.Context, toolsFile ToolsFile,
|
||||
) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
|
||||
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
|
||||
) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
@@ -616,22 +629,23 @@ func validateReloadEdits(
|
||||
defer span.End()
|
||||
|
||||
reloadedConfig := server.ServerConfig{
|
||||
Version: versionString,
|
||||
SourceConfigs: toolsFile.Sources,
|
||||
AuthServiceConfigs: toolsFile.AuthServices,
|
||||
ToolConfigs: toolsFile.Tools,
|
||||
ToolsetConfigs: toolsFile.Toolsets,
|
||||
PromptConfigs: toolsFile.Prompts,
|
||||
Version: versionString,
|
||||
SourceConfigs: toolsFile.Sources,
|
||||
AuthServiceConfigs: toolsFile.AuthServices,
|
||||
EmbeddingModelConfigs: toolsFile.EmbeddingModels,
|
||||
ToolConfigs: toolsFile.Tools,
|
||||
ToolsetConfigs: toolsFile.Toolsets,
|
||||
PromptConfigs: toolsFile.Prompts,
|
||||
}
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
|
||||
if err != nil {
|
||||
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
|
||||
logger.WarnContext(ctx, errMsg.Error())
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
// watchChanges checks for changes in the provided yaml tools file(s) or folder.
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
@@ -1503,7 +1504,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"alloydb_postgres_database_tools": tools.ToolsetConfig{
|
||||
Name: "alloydb_postgres_database_tools",
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"},
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1533,7 +1534,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_postgres_database_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_postgres_database_tools",
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"},
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1633,7 +1634,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"postgres_database_tools": tools.ToolsetConfig{
|
||||
Name: "postgres_database_tools",
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"},
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1830,9 +1831,10 @@ func TestFileLoadingErrors(t *testing.T) {
|
||||
|
||||
func TestMergeToolsFiles(t *testing.T) {
|
||||
file1 := ToolsFile{
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
|
||||
}
|
||||
file2 := ToolsFile{
|
||||
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
||||
@@ -1854,11 +1856,12 @@ func TestMergeToolsFiles(t *testing.T) {
|
||||
name: "merge two distinct files",
|
||||
files: []ToolsFile{file1, file2},
|
||||
want: ToolsFile{
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
|
||||
Prompts: server.PromptConfigs{},
|
||||
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
|
||||
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
|
||||
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
|
||||
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
|
||||
Prompts: server.PromptConfigs{},
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
@@ -1871,22 +1874,24 @@ func TestMergeToolsFiles(t *testing.T) {
|
||||
name: "merge single file",
|
||||
files: []ToolsFile{file1},
|
||||
want: ToolsFile{
|
||||
Sources: file1.Sources,
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
Tools: file1.Tools,
|
||||
Toolsets: file1.Toolsets,
|
||||
Prompts: server.PromptConfigs{},
|
||||
Sources: file1.Sources,
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
|
||||
Tools: file1.Tools,
|
||||
Toolsets: file1.Toolsets,
|
||||
Prompts: server.PromptConfigs{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge empty list",
|
||||
files: []ToolsFile{},
|
||||
want: ToolsFile{
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: server.PromptConfigs{},
|
||||
Sources: make(server.SourceConfigs),
|
||||
AuthServices: make(server.AuthServiceConfigs),
|
||||
EmbeddingModels: make(server.EmbeddingModelConfigs),
|
||||
Tools: make(server.ToolConfigs),
|
||||
Toolsets: make(server.ToolsetConfigs),
|
||||
Prompts: server.PromptConfigs{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
84
docs/en/resources/embeddingModels/_index.md
Normal file
84
docs/en/resources/embeddingModels/_index.md
Normal file
@@ -0,0 +1,84 @@
|
||||
---
|
||||
title: "EmbeddingModels"
|
||||
type: docs
|
||||
weight: 2
|
||||
description: >
|
||||
EmbeddingModels represent services that transform text into vector embeddings for semantic search.
|
||||
---
|
||||
|
||||
EmbeddingModels represent services that generate vector representations of text
|
||||
data. In the MCP Toolbox, these models enable **Semantic Queries**,
|
||||
allowing [Tools](../tools/) to automatically convert human-readable text into
|
||||
numerical vectors before using them in a query.
|
||||
|
||||
This is primarily used in two scenarios:
|
||||
|
||||
- **Vector Ingestion**: Converting a text parameter into a vector string during
|
||||
an `INSERT` operation.
|
||||
|
||||
- **Semantic Search**: Converting a natural language query into a vector to
|
||||
perform similarity searches.
|
||||
|
||||
## Example
|
||||
|
||||
The following configuration defines an embedding model and applies it to
|
||||
specific tool parameters.
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your API keys into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
### Step 1 - Define an Embedding Model
|
||||
|
||||
Define an embedding model in the `embeddingModels` section:
|
||||
|
||||
```yaml
|
||||
embeddingModels:
|
||||
gemini-model: # Name of the embedding model
|
||||
kind: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: ${GOOGLE_API_KEY}
|
||||
dimension: 768
|
||||
|
||||
```
|
||||
|
||||
### Step 2 - Embed Tool Parameters
|
||||
|
||||
Use the defined embedding model, embed your query parameters using the
|
||||
`embeddedBy` field. Only string-typed
|
||||
parameters can be embedded:
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
# Vector ingestion tool
|
||||
insert_embedding:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
statement: |
|
||||
INSERT INTO documents (content, embedding)
|
||||
VALUES ($1, $2);
|
||||
parameters:
|
||||
- name: content
|
||||
type: string
|
||||
- name: vector_string
|
||||
type: string
|
||||
description: The text to be vectorized and stored.
|
||||
embeddedBy: gemini-model # refers to the name of a defined embedding model
|
||||
|
||||
# Semantic search tool
|
||||
search_embedding:
|
||||
kind: postgres-sql
|
||||
source: my-pg-instance
|
||||
statement: |
|
||||
SELECT id, content, embedding <-> $1 AS distance
|
||||
FROM documents
|
||||
ORDER BY distance LIMIT 1
|
||||
parameters:
|
||||
- name: semantic_search_string
|
||||
type: string
|
||||
description: The search query that will be converted to a vector.
|
||||
embeddedBy: gemini-model # refers to the name of a defined embedding model
|
||||
```
|
||||
|
||||
## Kinds of Embedding Models
|
||||
73
docs/en/resources/embeddingModels/gemini.md
Normal file
73
docs/en/resources/embeddingModels/gemini.md
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
title: "Gemini Embedding"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Use Google's Gemini models to generate high-performance text embeddings for vector databases.
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
Google Gemini provides state-of-the-art embedding models that convert text into
|
||||
high-dimensional vectors.
|
||||
|
||||
### Authentication
|
||||
|
||||
Toolbox uses your [Application Default Credentials
|
||||
(ADC)][adc] to authorize with the
|
||||
Gemini API client.
|
||||
|
||||
Optionally, you can use an [API key][api-key] obtain an API
|
||||
Key from the [Google AI Studio][ai-studio].
|
||||
|
||||
We recommend using an API key for testing and using application default
|
||||
credentials for production.
|
||||
|
||||
[adc]: https://cloud.google.com/docs/authentication#adc
|
||||
[api-key]: https://ai.google.dev/gemini-api/docs/api-key#api-keys
|
||||
[ai-studio]: https://aistudio.google.com/app/apikey
|
||||
|
||||
## Behavior
|
||||
|
||||
### Automatic Vectorization
|
||||
|
||||
When a tool parameter is configured with `embeddedBy: <your-gemini-model-name>`,
|
||||
the Toolbox intercepts the raw text input from the client and sends it to the
|
||||
Gemini API. The resulting numerical array is then formatted before being passed
|
||||
to your database source.
|
||||
|
||||
### Dimension Matching
|
||||
|
||||
The `dimension` field must match the expected size of your database column
|
||||
(e.g., a `vector(768)` column in PostgreSQL). This setting is supported by newer
|
||||
models since 2024 only. You cannot set this value if using the earlier model
|
||||
(`models/embedding-001`). Check out [available Gemini models][modellist] for more
|
||||
information.
|
||||
|
||||
[modellist]:
|
||||
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
embeddingModels:
|
||||
gemini-model:
|
||||
kind: gemini
|
||||
model: gemini-embedding-001
|
||||
apiKey: ${GOOGLE_API_KEY}
|
||||
dimension: 768
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
Use environment variable replacement with the format ${ENV_NAME}
|
||||
instead of hardcoding your secrets into the configuration file.
|
||||
{{< /notice >}}
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-----------|:--------:|:------------:|--------------------------------------------------------------|
|
||||
| kind | string | true | Must be `gemini`. |
|
||||
| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). |
|
||||
| apiKey | string | false | Your API Key from Google AI Studio. |
|
||||
| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). |
|
||||
@@ -94,7 +94,10 @@ cluster][alloydb-free-trial].
|
||||
instance.
|
||||
|
||||
- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md)
|
||||
Lists all the user-created roles in PostgreSQL database..
|
||||
Lists all the user-created roles in PostgreSQL database.
|
||||
|
||||
- [`postgres-list-stored-procedure`](../tools/postgres/postgres-list-stored-procedure.md)
|
||||
Lists all the stored procedure in PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
|
||||
@@ -91,7 +91,10 @@ to a database by following these instructions][csql-pg-quickstart].
|
||||
instance.
|
||||
|
||||
- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md)
|
||||
Lists all the user-created roles in PostgreSQL database..
|
||||
Lists all the user-created roles in PostgreSQL database.
|
||||
|
||||
- [`postgres-list-stored-procedure`](../tools/postgres/postgres-list-stored-procedure.md)
|
||||
Lists all the stored procedure in PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
|
||||
@@ -85,7 +85,10 @@ reputation for reliability, feature robustness, and performance.
|
||||
server.
|
||||
|
||||
- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md)
|
||||
Lists all the user-created roles in PostgreSQL database..
|
||||
Lists all the user-created roles in PostgreSQL database.
|
||||
|
||||
- [`postgres-list-stored-procedure`](../tools/postgres/postgres-list-stored-procedure.md)
|
||||
Lists all the stored procedure in PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
---
|
||||
title: "postgres-list-stored-procedure"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The "postgres-list-stored-procedure" tool retrieves metadata for stored procedures in PostgreSQL, including procedure definitions, owners, languages, and descriptions.
|
||||
aliases:
|
||||
- /resources/tools/postgres-list-stored-procedure
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `postgres-list-stored-procedure` tool queries PostgreSQL system catalogs (`pg_proc`, `pg_namespace`, `pg_roles`, and `pg_language`) to retrieve comprehensive metadata about stored procedures in the database. It filters for procedures (kind = 'p') and provides the full procedure definition along with ownership and language information.
|
||||
|
||||
Compatible sources:
|
||||
|
||||
- [alloydb-postgres](../../sources/alloydb-pg.md)
|
||||
- [cloud-sql-postgres](../../sources/cloud-sql-pg.md)
|
||||
- [postgres](../../sources/postgres.md)
|
||||
|
||||
The tool returns a JSON array where each element represents a stored procedure with its schema, name, owner, language, complete definition, and optional description. Results are sorted by schema name and procedure name, with a default limit of 20 procedures.
|
||||
|
||||
## Parameters
|
||||
|
||||
| parameter | type | required | default | description |
|
||||
|--------------|---------|----------|---------|-------------|
|
||||
| role_name | string | false | null | Optional: The owner name to filter stored procedures by (supports partial matching) |
|
||||
| schema_name | string | false | null | Optional: The schema name to filter stored procedures by (supports partial matching) |
|
||||
| limit | integer | false | 20 | Optional: The maximum number of stored procedures to return |
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_stored_procedure:
|
||||
kind: postgres-list-stored-procedure
|
||||
source: postgres-source
|
||||
description: "Retrieves stored procedure metadata including definitions and owners."
|
||||
```
|
||||
|
||||
### Example Requests
|
||||
|
||||
**List all stored procedures (default limit 20):**
|
||||
```json
|
||||
{}
|
||||
```
|
||||
|
||||
**Filter by specific owner (role):**
|
||||
```json
|
||||
{
|
||||
"role_name": "app_user"
|
||||
}
|
||||
```
|
||||
|
||||
**Filter by schema:**
|
||||
```json
|
||||
{
|
||||
"schema_name": "public"
|
||||
}
|
||||
```
|
||||
|
||||
**Filter by owner and schema with custom limit:**
|
||||
```json
|
||||
{
|
||||
"role_name": "postgres",
|
||||
"schema_name": "public",
|
||||
"limit": 50
|
||||
}
|
||||
```
|
||||
|
||||
**Filter by partial schema name:**
|
||||
```json
|
||||
{
|
||||
"schema_name": "audit"
|
||||
}
|
||||
```
|
||||
|
||||
### Example Response
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"schema_name": "public",
|
||||
"name": "process_payment",
|
||||
"owner": "postgres",
|
||||
"language": "plpgsql",
|
||||
"definition": "CREATE OR REPLACE PROCEDURE public.process_payment(p_order_id integer, p_amount numeric)\n LANGUAGE plpgsql\nAS $procedure$\nBEGIN\n UPDATE orders SET status = 'paid', amount = p_amount WHERE id = p_order_id;\n INSERT INTO payment_log (order_id, amount, timestamp) VALUES (p_order_id, p_amount, now());\n COMMIT;\nEND\n$procedure$",
|
||||
"description": "Processes payment for an order and logs the transaction"
|
||||
},
|
||||
{
|
||||
"schema_name": "public",
|
||||
"name": "cleanup_old_records",
|
||||
"owner": "postgres",
|
||||
"language": "plpgsql",
|
||||
"definition": "CREATE OR REPLACE PROCEDURE public.cleanup_old_records(p_days_old integer)\n LANGUAGE plpgsql\nAS $procedure$\nDECLARE\n v_deleted integer;\nBEGIN\n DELETE FROM audit_logs WHERE created_at < now() - (p_days_old || ' days')::interval;\n GET DIAGNOSTICS v_deleted = ROW_COUNT;\n RAISE NOTICE 'Deleted % records', v_deleted;\nEND\n$procedure$",
|
||||
"description": "Removes audit log records older than specified days"
|
||||
},
|
||||
{
|
||||
"schema_name": "audit",
|
||||
"name": "audit_table_changes",
|
||||
"owner": "app_user",
|
||||
"language": "plpgsql",
|
||||
"definition": "CREATE OR REPLACE PROCEDURE audit.audit_table_changes()\n LANGUAGE plpgsql\nAS $procedure$\nBEGIN\n INSERT INTO audit.change_log (table_name, operation, changed_at) VALUES (TG_TABLE_NAME, TG_OP, now());\nEND\n$procedure$",
|
||||
"description": null
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Output Fields Reference
|
||||
|
||||
| field | type | description |
|
||||
|-------------|---------|-------------|
|
||||
| schema_name | string | Name of the schema containing the stored procedure. |
|
||||
| name | string | Name of the stored procedure. |
|
||||
| owner | string | PostgreSQL role/user who owns the stored procedure. |
|
||||
| language | string | Programming language in which the procedure is written (e.g., plpgsql, sql, c). |
|
||||
| definition | string | Complete SQL definition of the stored procedure, including the CREATE PROCEDURE statement. |
|
||||
| description | string | Optional description or comment for the procedure (may be null if no comment is set). |
|
||||
|
||||
## Use Cases
|
||||
|
||||
- **Code review and auditing**: Export procedure definitions for version control or compliance audits.
|
||||
- **Documentation generation**: Automatically extract procedure metadata and descriptions for documentation.
|
||||
- **Permission auditing**: Identify procedures owned by specific users or in specific schemas.
|
||||
- **Migration planning**: Retrieve all procedure definitions when planning database migrations.
|
||||
- **Dependency analysis**: Review procedure definitions to understand dependencies and call chains.
|
||||
- **Security assessment**: Audit which roles own and can modify stored procedures.
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- The tool filters at the database level using LIKE pattern matching, so partial matches are supported.
|
||||
- Procedure definitions can be large; consider using the `limit` parameter for large databases with many procedures.
|
||||
- Results are ordered by schema name and procedure name for consistent output.
|
||||
- The default limit of 20 procedures is suitable for most use cases; increase as needed.
|
||||
|
||||
## Notes
|
||||
|
||||
- Only stored **procedures** are returned; functions and other callable objects are excluded via the `prokind = 'p'` filter.
|
||||
- Filtering uses `LIKE` pattern matching, so filter values support partial matches (e.g., `role_name: "app"` will match "app_user", "app_admin", etc.).
|
||||
- The `definition` field contains the complete, runnable CREATE PROCEDURE statement.
|
||||
- The `description` field is populated from comments set via PostgreSQL's COMMENT command and may be null.
|
||||
1
go.mod
1
go.mod
@@ -59,6 +59,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
google.golang.org/api v0.256.0
|
||||
google.golang.org/genai v1.37.0
|
||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8
|
||||
google.golang.org/protobuf v1.36.10
|
||||
modernc.org/sqlite v1.40.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1869,6 +1869,8 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww
|
||||
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
google.golang.org/genai v1.37.0 h1:dgp71k1wQ+/+APdZrN3LFgAGnVnr5IdTF1Oj0Dg+BQc=
|
||||
google.golang.org/genai v1.37.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
|
||||
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
|
||||
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
|
||||
|
||||
29
internal/embeddingmodels/embeddingmodels.go
Normal file
29
internal/embeddingmodels/embeddingmodels.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package embeddingmodels
|
||||
|
||||
import "context"
|
||||
|
||||
// EmbeddingModelConfig is the interface for configuring embedding models.
|
||||
type EmbeddingModelConfig interface {
|
||||
EmbeddingModelConfigKind() string
|
||||
Initialize(context.Context) (EmbeddingModel, error)
|
||||
}
|
||||
|
||||
type EmbeddingModel interface {
|
||||
EmbeddingModelKind() string
|
||||
ToConfig() EmbeddingModelConfig
|
||||
EmbedParameters(context.Context, []string) ([][]float32, error)
|
||||
}
|
||||
122
internal/embeddingmodels/gemini/gemini.go
Normal file
122
internal/embeddingmodels/gemini/gemini.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"google.golang.org/genai"
|
||||
)
|
||||
|
||||
const EmbeddingModelKind string = "gemini"
|
||||
|
||||
// validate interface
|
||||
var _ embeddingmodels.EmbeddingModelConfig = Config{}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Model string `yaml:"model" validate:"required"`
|
||||
ApiKey string `yaml:"apiKey"`
|
||||
Dimension int32 `yaml:"dimension"`
|
||||
}
|
||||
|
||||
// Returns the embedding model kind
|
||||
func (cfg Config) EmbeddingModelConfigKind() string {
|
||||
return EmbeddingModelKind
|
||||
}
|
||||
|
||||
// Initialize a Gemini embedding model
|
||||
func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) {
|
||||
// Get client configs
|
||||
configs := &genai.ClientConfig{}
|
||||
if cfg.ApiKey != "" {
|
||||
configs.APIKey = cfg.ApiKey
|
||||
}
|
||||
|
||||
// Create new Gemini API client
|
||||
client, err := genai.NewClient(ctx, configs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create Gemini API client")
|
||||
}
|
||||
|
||||
m := &EmbeddingModel{
|
||||
Config: cfg,
|
||||
Client: client,
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
var _ embeddingmodels.EmbeddingModel = EmbeddingModel{}
|
||||
|
||||
type EmbeddingModel struct {
|
||||
Client *genai.Client
|
||||
Config
|
||||
}
|
||||
|
||||
// Returns the embedding model kind
|
||||
func (m EmbeddingModel) EmbeddingModelKind() string {
|
||||
return EmbeddingModelKind
|
||||
}
|
||||
|
||||
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
|
||||
return m.Config
|
||||
}
|
||||
|
||||
func (m EmbeddingModel) EmbedParameters(ctx context.Context, parameters []string) ([][]float32, error) {
|
||||
logger, err := util.LoggerFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
|
||||
}
|
||||
|
||||
contents := convertStringsToContents(parameters)
|
||||
|
||||
embedConfig := &genai.EmbedContentConfig{
|
||||
TaskType: "SEMANTIC_SIMILARITY",
|
||||
}
|
||||
|
||||
if m.Dimension > 0 {
|
||||
embedConfig.OutputDimensionality = genai.Ptr(m.Dimension)
|
||||
}
|
||||
|
||||
result, err := m.Client.Models.EmbedContent(ctx, m.Model, contents, embedConfig)
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, "Error calling EmbedContent for model %s: %v", m.Model, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
embeddings := make([][]float32, 0, len(result.Embeddings))
|
||||
for _, embedding := range result.Embeddings {
|
||||
embeddings = append(embeddings, embedding.Values)
|
||||
}
|
||||
|
||||
logger.InfoContext(ctx, "Successfully embedded %d text parameters using model %s", len(parameters), m.Model)
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
// convertStringsToContents takes a slice of strings and converts it into a slice of *genai.Content objects.
|
||||
func convertStringsToContents(texts []string) []*genai.Content {
|
||||
contents := make([]*genai.Content, 0, len(texts))
|
||||
|
||||
for _, text := range texts {
|
||||
content := genai.NewContentFromText(text, "")
|
||||
contents = append(contents, content)
|
||||
}
|
||||
return contents
|
||||
}
|
||||
130
internal/embeddingmodels/gemini/gemini_test.go
Normal file
130
internal/embeddingmodels/gemini/gemini_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package gemini_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlGemini(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.EmbeddingModelConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
embeddingModels:
|
||||
my-gemini-model:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
`,
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"my-gemini-model": gemini.Config{
|
||||
Name: "my-gemini-model",
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Model: "text-embedding-004",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "full example with optional fields",
|
||||
in: `
|
||||
embeddingModels:
|
||||
complex-gemini:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
apiKey: "test-api-key"
|
||||
dimension: 768
|
||||
`,
|
||||
want: map[string]embeddingmodels.EmbeddingModelConfig{
|
||||
"complex-gemini": gemini.Config{
|
||||
Name: "complex-gemini",
|
||||
Kind: gemini.EmbeddingModelKind,
|
||||
Model: "text-embedding-004",
|
||||
ApiKey: "test-api-key",
|
||||
Dimension: 768,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Models) {
|
||||
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestFailParseFromYamlGemini(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "missing required model field",
|
||||
in: `
|
||||
embeddingModels:
|
||||
bad-model:
|
||||
kind: gemini
|
||||
`,
|
||||
// Removed the specific model name from the prefix to match your output
|
||||
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
desc: "unknown field",
|
||||
in: `
|
||||
embeddingModels:
|
||||
bad-field:
|
||||
kind: gemini
|
||||
model: text-embedding-004
|
||||
invalid_param: true
|
||||
`,
|
||||
// Updated to match the specific line-starting format of your error output
|
||||
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
|
||||
}{}
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
if err.Error() != tc.err {
|
||||
t.Fatalf("unexpected error:\ngot: %q\nwant: %q", err.Error(), tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -224,6 +224,10 @@ tools:
|
||||
kind: postgres-list-roles
|
||||
source: alloydb-pg-source
|
||||
|
||||
list_stored_procedure:
|
||||
kind: postgres-list-stored-procedure
|
||||
source: alloydb-pg-source
|
||||
|
||||
toolsets:
|
||||
alloydb_postgres_database_tools:
|
||||
- execute_sql
|
||||
@@ -254,3 +258,4 @@ toolsets:
|
||||
- list_database_stats
|
||||
- list_roles
|
||||
- list_table_stats
|
||||
- list_stored_procedure
|
||||
|
||||
@@ -19,8 +19,8 @@ sources:
|
||||
region: ${CLOUD_SQL_MYSQL_REGION}
|
||||
instance: ${CLOUD_SQL_MYSQL_INSTANCE}
|
||||
database: ${CLOUD_SQL_MYSQL_DATABASE}
|
||||
user: ${CLOUD_SQL_MYSQL_USER}
|
||||
password: ${CLOUD_SQL_MYSQL_PASSWORD}
|
||||
user: ${CLOUD_SQL_MYSQL_USER:}
|
||||
password: ${CLOUD_SQL_MYSQL_PASSWORD:}
|
||||
ipType: ${CLOUD_SQL_MYSQL_IP_TYPE:PUBLIC}
|
||||
tools:
|
||||
execute_sql:
|
||||
|
||||
@@ -226,6 +226,10 @@ tools:
|
||||
kind: postgres-list-roles
|
||||
source: cloudsql-pg-source
|
||||
|
||||
list_stored_procedure:
|
||||
kind: postgres-list-stored-procedure
|
||||
source: cloudsql-pg-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_postgres_database_tools:
|
||||
- execute_sql
|
||||
@@ -256,3 +260,4 @@ toolsets:
|
||||
- list_database_stats
|
||||
- list_roles
|
||||
- list_table_stats
|
||||
- list_stored_procedure
|
||||
|
||||
@@ -225,6 +225,10 @@ tools:
|
||||
kind: postgres-list-roles
|
||||
source: postgresql-source
|
||||
|
||||
list_stored_procedure:
|
||||
kind: postgres-list-stored-procedure
|
||||
source: postgresql-source
|
||||
|
||||
toolsets:
|
||||
postgres_database_tools:
|
||||
- execute_sql
|
||||
@@ -255,3 +259,4 @@ toolsets:
|
||||
- list_database_stats
|
||||
- list_roles
|
||||
- list_table_stats
|
||||
- list_stored_procedure
|
||||
|
||||
@@ -276,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets)
|
||||
resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets)
|
||||
|
||||
server := Server{
|
||||
version: fakeVersionString,
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth/google"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -38,6 +40,8 @@ type ServerConfig struct {
|
||||
SourceConfigs SourceConfigs
|
||||
// AuthServiceConfigs defines what sources of authentication are available for tools.
|
||||
AuthServiceConfigs AuthServiceConfigs
|
||||
// EmbeddingModelConfigs defines a models used to embed parameters.
|
||||
EmbeddingModelConfigs EmbeddingModelConfigs
|
||||
// ToolConfigs defines what tools are available.
|
||||
ToolConfigs ToolConfigs
|
||||
// ToolsetConfigs defines what tools are available.
|
||||
@@ -205,6 +209,50 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
|
||||
return nil
|
||||
}
|
||||
|
||||
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
|
||||
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
|
||||
|
||||
// validate interface
|
||||
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
|
||||
|
||||
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
|
||||
*c = make(EmbeddingModelConfigs)
|
||||
// Parse the 'kind' fields for each embedding model
|
||||
var raw map[string]util.DelayedUnmarshaler
|
||||
if err := unmarshal(&raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for name, u := range raw {
|
||||
// Unmarshal to a general type that ensure it capture all fields
|
||||
var v map[string]any
|
||||
if err := u.Unmarshal(&v); err != nil {
|
||||
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
|
||||
}
|
||||
|
||||
kind, ok := v["kind"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
|
||||
}
|
||||
|
||||
dec, err := util.NewStrictDecoder(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating decoder: %w", err)
|
||||
}
|
||||
switch kind {
|
||||
case gemini.EmbeddingModelKind:
|
||||
actual := gemini.Config{Name: name}
|
||||
if err := dec.DecodeContext(ctx, &actual); err != nil {
|
||||
return fmt.Errorf("unable to parse as %q: %w", kind, err)
|
||||
}
|
||||
(*c)[name] = actual
|
||||
default:
|
||||
return fmt.Errorf("%q is not a valid kind of auth source", kind)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToolConfigs is a type used to allow unmarshal of the tool configs
|
||||
type ToolConfigs map[string]tools.ToolConfig
|
||||
|
||||
|
||||
@@ -1107,7 +1107,7 @@ func TestStdioSession(t *testing.T) {
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets)
|
||||
resourceManager := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsets, promptsMap, promptsets)
|
||||
|
||||
server := &Server{
|
||||
version: fakeVersionString,
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
@@ -25,30 +26,33 @@ import (
|
||||
|
||||
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
|
||||
type ResourceManager struct {
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
prompts map[string]prompts.Prompt
|
||||
promptsets map[string]prompts.Promptset
|
||||
mu sync.RWMutex
|
||||
sources map[string]sources.Source
|
||||
authServices map[string]auth.AuthService
|
||||
embeddingModels map[string]embeddingmodels.EmbeddingModel
|
||||
tools map[string]tools.Tool
|
||||
toolsets map[string]tools.Toolset
|
||||
prompts map[string]prompts.Prompt
|
||||
promptsets map[string]prompts.Promptset
|
||||
}
|
||||
|
||||
func NewResourceManager(
|
||||
sourcesMap map[string]sources.Source,
|
||||
authServicesMap map[string]auth.AuthService,
|
||||
embeddingModelsMap map[string]embeddingmodels.EmbeddingModel,
|
||||
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
|
||||
promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset,
|
||||
|
||||
) *ResourceManager {
|
||||
resourceMgr := &ResourceManager{
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
prompts: promptsMap,
|
||||
promptsets: promptsetsMap,
|
||||
mu: sync.RWMutex{},
|
||||
sources: sourcesMap,
|
||||
authServices: authServicesMap,
|
||||
embeddingModels: embeddingModelsMap,
|
||||
tools: toolsMap,
|
||||
toolsets: toolsetsMap,
|
||||
prompts: promptsMap,
|
||||
promptsets: promptsetsMap,
|
||||
}
|
||||
|
||||
return resourceMgr
|
||||
@@ -68,6 +72,13 @@ func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthServi
|
||||
return authService, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetEmbeddingModel(embeddingModelName string) (embeddingmodels.EmbeddingModel, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
model, ok := r.embeddingModels[embeddingModelName]
|
||||
return model, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
@@ -96,11 +107,12 @@ func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset,
|
||||
return promptset, ok
|
||||
}
|
||||
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
|
||||
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.sources = sourcesMap
|
||||
r.authServices = authServicesMap
|
||||
r.embeddingModels = embeddingModelsMap
|
||||
r.tools = toolsMap
|
||||
r.toolsets = toolsetsMap
|
||||
r.prompts = promptsMap
|
||||
@@ -117,6 +129,16 @@ func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.EmbeddingModel {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
copiedMap := make(map[string]embeddingmodels.EmbeddingModel, len(r.embeddingModels))
|
||||
for k, v := range r.embeddingModels {
|
||||
copiedMap[k] = v
|
||||
}
|
||||
return copiedMap
|
||||
}
|
||||
|
||||
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
@@ -36,6 +37,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
@@ -54,7 +56,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
Prompts: []*prompts.Prompt{},
|
||||
},
|
||||
}
|
||||
resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
resMgr := resources.NewResourceManager(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
|
||||
gotSource, _ := resMgr.GetSource("example-source")
|
||||
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
|
||||
@@ -95,7 +97,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
resMgr.SetResources(updateSource, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
gotSource, _ = resMgr.GetSource("example-source2")
|
||||
if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" {
|
||||
t.Errorf("error updating server, sources (-want +got):\n%s", diff)
|
||||
|
||||
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/go-chi/httplog/v2"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server/resources"
|
||||
@@ -56,6 +57,7 @@ type Server struct {
|
||||
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
map[string]sources.Source,
|
||||
map[string]auth.AuthService,
|
||||
map[string]embeddingmodels.EmbeddingModel,
|
||||
map[string]tools.Tool,
|
||||
map[string]tools.Toolset,
|
||||
map[string]prompts.Prompt,
|
||||
@@ -91,7 +93,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return s, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
sourcesMap[name] = s
|
||||
}
|
||||
@@ -119,7 +121,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return a, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
authServicesMap[name] = a
|
||||
}
|
||||
@@ -129,6 +131,34 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices: %s", len(authServicesMap), strings.Join(authServiceNames, ", ")))
|
||||
|
||||
// Initialize and validate embedding models from configs.
|
||||
embeddingModelsMap := make(map[string]embeddingmodels.EmbeddingModel)
|
||||
for name, ec := range cfg.EmbeddingModelConfigs {
|
||||
em, err := func() (embeddingmodels.EmbeddingModel, error) {
|
||||
_, span := instrumentation.Tracer.Start(
|
||||
ctx,
|
||||
"toolbox/server/embeddingmodel/init",
|
||||
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
|
||||
trace.WithAttributes(attribute.String("model_name", name)),
|
||||
)
|
||||
defer span.End()
|
||||
em, err := ec.Initialize(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize embedding model %q: %w", name, err)
|
||||
}
|
||||
return em, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
embeddingModelsMap[name] = em
|
||||
}
|
||||
embeddingModelNames := make([]string, 0, len(embeddingModelsMap))
|
||||
for name := range embeddingModelsMap {
|
||||
embeddingModelNames = append(embeddingModelNames, name)
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d embeddingModels: %s", len(embeddingModelsMap), strings.Join(embeddingModelNames, ", ")))
|
||||
|
||||
// initialize and validate the tools from configs
|
||||
toolsMap := make(map[string]tools.Tool)
|
||||
for name, tc := range cfg.ToolConfigs {
|
||||
@@ -147,7 +177,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return t, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
toolsMap[name] = t
|
||||
}
|
||||
@@ -184,7 +214,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return t, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
toolsetsMap[name] = t
|
||||
}
|
||||
@@ -216,7 +246,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return p, nil
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
promptsMap[name] = p
|
||||
}
|
||||
@@ -253,7 +283,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
return p, err
|
||||
}()
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, nil, nil, err
|
||||
return nil, nil, nil, nil, nil, nil, nil, err
|
||||
}
|
||||
promptsetsMap[name] = p
|
||||
}
|
||||
@@ -267,7 +297,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
|
||||
}
|
||||
l.InfoContext(ctx, fmt.Sprintf("Initialized %d promptsets: %s", len(promptsetsMap), strings.Join(promptsetNames, ", ")))
|
||||
|
||||
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
|
||||
}
|
||||
|
||||
// NewServer returns a Server object based on provided Config.
|
||||
@@ -320,7 +350,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
httpLogger := httplog.NewLogger("httplog", httpOpts)
|
||||
r.Use(httplog.RequestLogger(httpLogger))
|
||||
|
||||
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to initialize configs: %w", err)
|
||||
}
|
||||
@@ -330,7 +360,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
|
||||
|
||||
sseManager := newSseManager(ctx)
|
||||
|
||||
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
|
||||
|
||||
s := &Server{
|
||||
version: cfg.Version,
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/auth"
|
||||
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
|
||||
"github.com/googleapis/genai-toolbox/internal/log"
|
||||
"github.com/googleapis/genai-toolbox/internal/prompts"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
@@ -144,6 +145,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
},
|
||||
}
|
||||
newAuth := map[string]auth.AuthService{"example-auth": nil}
|
||||
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
|
||||
newTools := map[string]tools.Tool{"example-tool": nil}
|
||||
newToolsets := map[string]tools.Toolset{
|
||||
"example-toolset": {
|
||||
@@ -162,7 +164,7 @@ func TestUpdateServer(t *testing.T) {
|
||||
Prompts: []*prompts.Prompt{},
|
||||
},
|
||||
}
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
s.ResourceMgr.SetResources(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
|
||||
if err != nil {
|
||||
t.Errorf("error updating server: %s", err)
|
||||
}
|
||||
|
||||
@@ -333,7 +333,7 @@ func TestInvoke(t *testing.T) {
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
|
||||
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgresliststoredprocedure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const kind string = "postgres-list-stored-procedure"
|
||||
|
||||
const listStoredProcedure = `
|
||||
SELECT
|
||||
n.nspname AS schema_name,
|
||||
p.proname AS name,
|
||||
r.rolname AS owner,
|
||||
l.lanname AS language,
|
||||
pg_catalog.pg_get_functiondef(p.oid) AS definition,
|
||||
pg_catalog.obj_description(p.oid, 'pg_proc') AS description
|
||||
FROM pg_catalog.pg_proc p
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
JOIN pg_catalog.pg_roles r ON r.oid = p.proowner
|
||||
JOIN pg_catalog.pg_language l ON l.oid = p.prolang
|
||||
WHERE
|
||||
p.prokind = 'p' AND
|
||||
($1::text IS NULL OR r.rolname LIKE '%' || $1::text || '%') AND
|
||||
($2::text IS NULL OR n.nspname LIKE '%' || $2::text || '%')
|
||||
ORDER BY n.nspname, p.proname
|
||||
LIMIT
|
||||
COALESCE($3::int, 20);
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
var _ compatibleSource = &cloudsqlpg.Source{}
|
||||
var _ compatibleSource = &postgres.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("role_name", "Optional: The owner name to filter the stored procedures by. Defaults to NULL.", false),
|
||||
parameters.NewStringParameterWithRequired("schema_name", "Optional: The schema name to filter the stored procedures by. Defaults to NULL.", false),
|
||||
parameters.NewIntParameterWithDefault("limit", 20, "Optional: The maximum number of stored procedures to return. Defaults to 20."),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
if cfg.Description == "" {
|
||||
cfg.Description = "Retrieves stored procedure metadata returning schema name, procedure name, procedure owner, language, definition, and description, filtered by optional role name (procedure owner), schema name, and limit (default 20)."
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
name: cfg.Name,
|
||||
kind: cfg.Kind,
|
||||
authRequired: cfg.AuthRequired,
|
||||
allParams: allParameters,
|
||||
pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: paramManifest,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
name string `yaml:"name"`
|
||||
kind string `yaml:"kind"`
|
||||
authRequired []string `yaml:"authRequired"`
|
||||
allParams parameters.Parameters `yaml:"allParams"`
|
||||
pool *pgxpool.Pool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
newParams, err := parameters.GetParams(t.allParams, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
|
||||
results, err := t.pool.Query(ctx, listStoredProcedure, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
var out []map[string]any
|
||||
|
||||
for results.Next() {
|
||||
values, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
rowMap := make(map[string]any)
|
||||
for i, field := range fields {
|
||||
rowMap[string(field.Name)] = values[i]
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.allParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.authRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
|
||||
return "Authorization", nil
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
// Copyright 2026 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgresliststoredprocedure_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure"
|
||||
)
|
||||
|
||||
func TestParseFromYamlPostgresListStoredProcedure(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-list-stored-procedure
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgresliststoredprocedure.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-list-stored-procedure",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-list-stored-procedure
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgresliststoredprocedure.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-list-stored-procedure",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -201,6 +201,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunPostgresListPgSettingsTest(t, ctx, pool)
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -185,6 +185,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunPostgresListPgSettingsTest(t, ctx, pool)
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -213,6 +213,7 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
|
||||
PostgresListPGSettingsToolKind = "postgres-list-pg-settings"
|
||||
PostgresListDatabaseStatsToolKind = "postgres-list-database-stats"
|
||||
PostgresListRolesToolKind = "postgres-list-roles"
|
||||
PostgresListStoredProcedureToolKind = "postgres-list-stored-procedure"
|
||||
)
|
||||
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
@@ -310,6 +311,11 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
|
||||
"kind": PostgresListRolesToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
|
||||
tools["list_stored_procedure"] = map[string]any{
|
||||
"kind": PostgresListStoredProcedureToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -164,4 +164,5 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunPostgresListPgSettingsTest(t, ctx, pool)
|
||||
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListRolesTest(t, ctx, pool)
|
||||
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
|
||||
}
|
||||
|
||||
251
tests/tool.go
251
tests/tool.go
@@ -4390,6 +4390,257 @@ func RunPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *pgxp
|
||||
}
|
||||
}
|
||||
|
||||
// RunPostgresListStoredProcedureTest runs tests for the postgres list-stored-procedure tool
|
||||
func RunPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
type storedProcedureDetails struct {
|
||||
SchemaName string `json:"schema_name"`
|
||||
Name string `json:"name"`
|
||||
Owner string `json:"owner"`
|
||||
Language string `json:"language"`
|
||||
Definition string `json:"definition"`
|
||||
Description any `json:"description"`
|
||||
}
|
||||
|
||||
// Create test schema
|
||||
testSchemaName := "test_proc_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
createSchemaStmt := fmt.Sprintf("CREATE SCHEMA %s", testSchemaName)
|
||||
if _, err := pool.Exec(ctx, createSchemaStmt); err != nil {
|
||||
t.Fatalf("unable to create test schema: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
dropSchemaStmt := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", testSchemaName)
|
||||
if _, err := pool.Exec(ctx, dropSchemaStmt); err != nil {
|
||||
t.Logf("warning: unable to drop test schema: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create test procedures
|
||||
proc1Name := "test_proc_1_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
createProc1Stmt := fmt.Sprintf(`
|
||||
CREATE PROCEDURE %s.%s(p_count INT)
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
BEGIN
|
||||
INSERT INTO test_table VALUES (p_count);
|
||||
COMMIT;
|
||||
END;
|
||||
$$
|
||||
`, testSchemaName, proc1Name)
|
||||
|
||||
if _, err := pool.Exec(ctx, createProc1Stmt); err != nil {
|
||||
t.Fatalf("unable to create test procedure 1: %v", err)
|
||||
}
|
||||
|
||||
// Add a comment/description to the procedure
|
||||
commentStmt := fmt.Sprintf("COMMENT ON PROCEDURE %s.%s(INT) IS 'Test procedure that inserts a record'", testSchemaName, proc1Name)
|
||||
if _, err := pool.Exec(ctx, commentStmt); err != nil {
|
||||
t.Logf("warning: unable to add comment to procedure: %v", err)
|
||||
}
|
||||
|
||||
// Create a second test procedure
|
||||
proc2Name := "test_proc_2_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
createProc2Stmt := fmt.Sprintf(`
|
||||
CREATE PROCEDURE %s.%s()
|
||||
LANGUAGE plpgsql
|
||||
AS $$
|
||||
DECLARE
|
||||
v_count INT;
|
||||
BEGIN
|
||||
SELECT COUNT(*) INTO v_count FROM test_table;
|
||||
RAISE NOTICE 'Total records: %%', v_count;
|
||||
END;
|
||||
$$
|
||||
`, testSchemaName, proc2Name)
|
||||
|
||||
if _, err := pool.Exec(ctx, createProc2Stmt); err != nil {
|
||||
t.Fatalf("unable to create test procedure 2: %v", err)
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
shouldHaveData bool
|
||||
expectedCount int
|
||||
filterByRole string
|
||||
filterBySchema string
|
||||
}{
|
||||
{
|
||||
name: "list stored procedures with no arguments (default limit 20)",
|
||||
requestBody: bytes.NewBufferString(`{}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false, // may or may not have data depending on what's in the database
|
||||
},
|
||||
{
|
||||
name: "list stored procedures filtering by specific schema",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s"}`, testSchemaName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: true,
|
||||
expectedCount: 2,
|
||||
filterBySchema: testSchemaName,
|
||||
},
|
||||
{
|
||||
name: "list stored procedures filtering by procedure owner (postgres)",
|
||||
requestBody: bytes.NewBufferString(`{"role_name": "postgres"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false, // might have procedures owned by postgres
|
||||
},
|
||||
{
|
||||
name: "list stored procedures with custom limit",
|
||||
requestBody: bytes.NewBufferString(`{"limit": 5}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "list stored procedures filtering by schema and role",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "role_name": "postgres"}`, testSchemaName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: true,
|
||||
expectedCount: 2,
|
||||
filterBySchema: testSchemaName,
|
||||
filterByRole: "postgres",
|
||||
},
|
||||
{
|
||||
name: "list stored procedures with non-existent schema",
|
||||
requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema_xyz"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "list stored procedures with non-existent role",
|
||||
requestBody: bytes.NewBufferString(`{"role_name": "non_existent_role_xyz"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "list stored procedures with partial schema name match",
|
||||
requestBody: bytes.NewBufferString(`{"schema_name": "test_proc"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: true,
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list stored procedures with limit 1",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "limit": 1}`, testSchemaName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: true,
|
||||
expectedCount: 1,
|
||||
filterBySchema: testSchemaName,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_stored_procedure/invoke"
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got []storedProcedureDetails
|
||||
if resultString != "null" {
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify expected data presence
|
||||
if tc.shouldHaveData {
|
||||
if len(got) == 0 {
|
||||
t.Fatalf("expected data but got empty result")
|
||||
}
|
||||
|
||||
// If filtering by schema, verify all results are from that schema
|
||||
if tc.filterBySchema != "" {
|
||||
for _, proc := range got {
|
||||
if proc.SchemaName != tc.filterBySchema && !strings.Contains(proc.SchemaName, tc.filterBySchema) {
|
||||
t.Errorf("procedure schema %s does not match filter %s", proc.SchemaName, tc.filterBySchema)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If filtering by role, verify all results are owned by that role
|
||||
if tc.filterByRole != "" {
|
||||
for _, proc := range got {
|
||||
if proc.Owner != tc.filterByRole {
|
||||
t.Errorf("procedure owner %s does not match filter %s", proc.Owner, tc.filterByRole)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify expected count if specified
|
||||
if tc.expectedCount > 0 && len(got) != tc.expectedCount {
|
||||
t.Errorf("expected %d procedures but got %d", tc.expectedCount, len(got))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify result structure and data types
|
||||
for _, proc := range got {
|
||||
// Verify all required fields are present and non-empty
|
||||
if proc.SchemaName == "" {
|
||||
t.Errorf("schema_name should not be empty")
|
||||
}
|
||||
if proc.Name == "" {
|
||||
t.Errorf("procedure name should not be empty")
|
||||
}
|
||||
if proc.Owner == "" {
|
||||
t.Errorf("owner should not be empty")
|
||||
}
|
||||
if proc.Language == "" {
|
||||
t.Errorf("language should not be empty")
|
||||
}
|
||||
if proc.Definition == "" {
|
||||
t.Errorf("definition should not be empty")
|
||||
}
|
||||
|
||||
// Verify definition contains CREATE PROCEDURE
|
||||
if !strings.Contains(proc.Definition, "CREATE PROCEDURE") {
|
||||
t.Logf("warning: definition may not be a valid CREATE PROCEDURE statement: %s", proc.Definition)
|
||||
}
|
||||
|
||||
// Verify language is a valid PostgreSQL language
|
||||
validLanguages := []string{"plpgsql", "sql", "c", "internal", "plperl", "pltcl", "plpython"}
|
||||
found := false
|
||||
for _, lang := range validLanguages {
|
||||
if proc.Language == lang {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Logf("warning: language %s may not be a standard PostgreSQL language", proc.Language)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify results are sorted by schema_name and name
|
||||
if len(got) > 1 {
|
||||
for i := 0; i < len(got)-1; i++ {
|
||||
currentKey := fmt.Sprintf("%s.%s", got[i].SchemaName, got[i].Name)
|
||||
nextKey := fmt.Sprintf("%s.%s", got[i+1].SchemaName, got[i+1].Name)
|
||||
if currentKey > nextKey {
|
||||
t.Logf("warning: results may not be sorted by schema_name and name")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RunRequest is a helper function to send HTTP requests and return the response
|
||||
func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
|
||||
// Send request
|
||||
|
||||
Reference in New Issue
Block a user