Merge branch 'main' into docs/toolbox-adk-migration

This commit is contained in:
Anmol Shukla
2026-01-06 16:46:28 +05:30
committed by GitHub
32 changed files with 1372 additions and 79 deletions

View File

@@ -33,6 +33,7 @@ import (
"github.com/fsnotify/fsnotify"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
"github.com/googleapis/genai-toolbox/internal/prompts"
@@ -197,6 +198,7 @@ import (
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistroles"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablespaces"
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttablestats"
@@ -385,12 +387,13 @@ func NewCommand(opts ...Option) *Command {
}
type ToolsFile struct {
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
Prompts server.PromptConfigs `yaml:"prompts"`
Sources server.SourceConfigs `yaml:"sources"`
AuthSources server.AuthServiceConfigs `yaml:"authSources"` // Deprecated: Kept for compatibility.
AuthServices server.AuthServiceConfigs `yaml:"authServices"`
EmbeddingModels server.EmbeddingModelConfigs `yaml:"embeddingModels"`
Tools server.ToolConfigs `yaml:"tools"`
Toolsets server.ToolsetConfigs `yaml:"toolsets"`
Prompts server.PromptConfigs `yaml:"prompts"`
}
// parseEnv replaces environment variables ${ENV_NAME} with their values.
@@ -439,11 +442,12 @@ func parseToolsFile(ctx context.Context, raw []byte) (ToolsFile, error) {
// All resource names (sources, authServices, tools, toolsets) must be unique across all files.
func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
merged := ToolsFile{
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: make(server.PromptConfigs),
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
EmbeddingModels: make(server.EmbeddingModelConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: make(server.PromptConfigs),
}
var conflicts []string
@@ -479,6 +483,15 @@ func mergeToolsFiles(files ...ToolsFile) (ToolsFile, error) {
}
}
// Check for conflicts and merge embeddingModels
for name, model := range file.EmbeddingModels {
if _, exists := merged.EmbeddingModels[name]; exists {
conflicts = append(conflicts, fmt.Sprintf("embedding model '%s' (file #%d)", name, fileIndex+1))
} else {
merged.EmbeddingModels[name] = model
}
}
// Check for conflicts and merge tools
for name, tool := range file.Tools {
if _, exists := merged.Tools[name]; exists {
@@ -583,14 +596,14 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
panic(err)
}
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := validateReloadEdits(ctx, toolsFile)
if err != nil {
errMsg := fmt.Errorf("unable to validate reloaded edits: %w", err)
logger.WarnContext(ctx, errMsg.Error())
return err
}
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
s.ResourceMgr.SetResources(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
return nil
}
@@ -598,7 +611,7 @@ func handleDynamicReload(ctx context.Context, toolsFile ToolsFile, s *server.Ser
// validateReloadEdits checks that the reloaded tools file configs can initialized without failing
func validateReloadEdits(
ctx context.Context, toolsFile ToolsFile,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
) (map[string]sources.Source, map[string]auth.AuthService, map[string]embeddingmodels.EmbeddingModel, map[string]tools.Tool, map[string]tools.Toolset, map[string]prompts.Prompt, map[string]prompts.Promptset, error,
) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
@@ -616,22 +629,23 @@ func validateReloadEdits(
defer span.End()
reloadedConfig := server.ServerConfig{
Version: versionString,
SourceConfigs: toolsFile.Sources,
AuthServiceConfigs: toolsFile.AuthServices,
ToolConfigs: toolsFile.Tools,
ToolsetConfigs: toolsFile.Toolsets,
PromptConfigs: toolsFile.Prompts,
Version: versionString,
SourceConfigs: toolsFile.Sources,
AuthServiceConfigs: toolsFile.AuthServices,
EmbeddingModelConfigs: toolsFile.EmbeddingModels,
ToolConfigs: toolsFile.Tools,
ToolsetConfigs: toolsFile.Toolsets,
PromptConfigs: toolsFile.Prompts,
}
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := server.InitializeConfigs(ctx, reloadedConfig)
if err != nil {
errMsg := fmt.Errorf("unable to initialize reloaded configs: %w", err)
logger.WarnContext(ctx, errMsg.Error())
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
}
// watchChanges checks for changes in the provided yaml tools file(s) or folder.

View File

@@ -32,6 +32,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth/google"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prebuiltconfigs"
"github.com/googleapis/genai-toolbox/internal/prompts"
@@ -1503,7 +1504,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"alloydb_postgres_database_tools": tools.ToolsetConfig{
Name: "alloydb_postgres_database_tools",
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"},
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"},
},
},
},
@@ -1533,7 +1534,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"cloud_sql_postgres_database_tools": tools.ToolsetConfig{
Name: "cloud_sql_postgres_database_tools",
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"},
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"},
},
},
},
@@ -1633,7 +1634,7 @@ func TestPrebuiltTools(t *testing.T) {
wantToolset: server.ToolsetConfigs{
"postgres_database_tools": tools.ToolsetConfig{
Name: "postgres_database_tools",
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats"},
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality", "list_publication_tables", "list_tablespaces", "list_pg_settings", "list_database_stats", "list_roles", "list_table_stats", "list_stored_procedure"},
},
},
},
@@ -1830,9 +1831,10 @@ func TestFileLoadingErrors(t *testing.T) {
func TestMergeToolsFiles(t *testing.T) {
file1 := ToolsFile{
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}},
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
}
file2 := ToolsFile{
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
@@ -1854,11 +1856,12 @@ func TestMergeToolsFiles(t *testing.T) {
name: "merge two distinct files",
files: []ToolsFile{file1, file2},
want: ToolsFile{
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
Prompts: server.PromptConfigs{},
Sources: server.SourceConfigs{"source1": httpsrc.Config{Name: "source1"}},
AuthServices: server.AuthServiceConfigs{"auth1": google.Config{Name: "auth1"}},
Tools: server.ToolConfigs{"tool1": http.Config{Name: "tool1"}, "tool2": http.Config{Name: "tool2"}},
Toolsets: server.ToolsetConfigs{"set1": tools.ToolsetConfig{Name: "set1"}, "set2": tools.ToolsetConfig{Name: "set2"}},
Prompts: server.PromptConfigs{},
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
},
wantErr: false,
},
@@ -1871,22 +1874,24 @@ func TestMergeToolsFiles(t *testing.T) {
name: "merge single file",
files: []ToolsFile{file1},
want: ToolsFile{
Sources: file1.Sources,
AuthServices: make(server.AuthServiceConfigs),
Tools: file1.Tools,
Toolsets: file1.Toolsets,
Prompts: server.PromptConfigs{},
Sources: file1.Sources,
AuthServices: make(server.AuthServiceConfigs),
EmbeddingModels: server.EmbeddingModelConfigs{"model1": gemini.Config{Name: "gemini-text"}},
Tools: file1.Tools,
Toolsets: file1.Toolsets,
Prompts: server.PromptConfigs{},
},
},
{
name: "merge empty list",
files: []ToolsFile{},
want: ToolsFile{
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: server.PromptConfigs{},
Sources: make(server.SourceConfigs),
AuthServices: make(server.AuthServiceConfigs),
EmbeddingModels: make(server.EmbeddingModelConfigs),
Tools: make(server.ToolConfigs),
Toolsets: make(server.ToolsetConfigs),
Prompts: server.PromptConfigs{},
},
},
}

View File

@@ -0,0 +1,84 @@
---
title: "EmbeddingModels"
type: docs
weight: 2
description: >
EmbeddingModels represent services that transform text into vector embeddings for semantic search.
---
EmbeddingModels represent services that generate vector representations of text
data. In the MCP Toolbox, these models enable **Semantic Queries**,
allowing [Tools](../tools/) to automatically convert human-readable text into
numerical vectors before using them in a query.
This is primarily used in two scenarios:
- **Vector Ingestion**: Converting a text parameter into a vector string during
an `INSERT` operation.
- **Semantic Search**: Converting a natural language query into a vector to
perform similarity searches.
## Example
The following configuration defines an embedding model and applies it to
specific tool parameters.
{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your API keys into the configuration file.
{{< /notice >}}
### Step 1 - Define an Embedding Model
Define an embedding model in the `embeddingModels` section:
```yaml
embeddingModels:
gemini-model: # Name of the embedding model
kind: gemini
model: gemini-embedding-001
apiKey: ${GOOGLE_API_KEY}
dimension: 768
```
### Step 2 - Embed Tool Parameters
Use the defined embedding model, embed your query parameters using the
`embeddedBy` field. Only string-typed
parameters can be embedded:
```yaml
tools:
# Vector ingestion tool
insert_embedding:
kind: postgres-sql
source: my-pg-instance
statement: |
INSERT INTO documents (content, embedding)
VALUES ($1, $2);
parameters:
- name: content
type: string
- name: vector_string
type: string
description: The text to be vectorized and stored.
embeddedBy: gemini-model # refers to the name of a defined embedding model
# Semantic search tool
search_embedding:
kind: postgres-sql
source: my-pg-instance
statement: |
SELECT id, content, embedding <-> $1 AS distance
FROM documents
ORDER BY distance LIMIT 1
parameters:
- name: semantic_search_string
type: string
description: The search query that will be converted to a vector.
embeddedBy: gemini-model # refers to the name of a defined embedding model
```
## Kinds of Embedding Models

View File

@@ -0,0 +1,73 @@
---
title: "Gemini Embedding"
type: docs
weight: 1
description: >
Use Google's Gemini models to generate high-performance text embeddings for vector databases.
---
## About
Google Gemini provides state-of-the-art embedding models that convert text into
high-dimensional vectors.
### Authentication
Toolbox uses your [Application Default Credentials
(ADC)][adc] to authorize with the
Gemini API client.
Optionally, you can use an [API key][api-key] obtain an API
Key from the [Google AI Studio][ai-studio].
We recommend using an API key for testing and using application default
credentials for production.
[adc]: https://cloud.google.com/docs/authentication#adc
[api-key]: https://ai.google.dev/gemini-api/docs/api-key#api-keys
[ai-studio]: https://aistudio.google.com/app/apikey
## Behavior
### Automatic Vectorization
When a tool parameter is configured with `embeddedBy: <your-gemini-model-name>`,
the Toolbox intercepts the raw text input from the client and sends it to the
Gemini API. The resulting numerical array is then formatted before being passed
to your database source.
### Dimension Matching
The `dimension` field must match the expected size of your database column
(e.g., a `vector(768)` column in PostgreSQL). This setting is supported by newer
models since 2024 only. You cannot set this value if using the earlier model
(`models/embedding-001`). Check out [available Gemini models][modellist] for more
information.
[modellist]:
https://docs.cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-text-embeddings#supported-models
## Example
```yaml
embeddingModels:
gemini-model:
kind: gemini
model: gemini-embedding-001
apiKey: ${GOOGLE_API_KEY}
dimension: 768
```
{{< notice tip >}}
Use environment variable replacement with the format ${ENV_NAME}
instead of hardcoding your secrets into the configuration file.
{{< /notice >}}
## Reference
| **field** | **type** | **required** | **description** |
|-----------|:--------:|:------------:|--------------------------------------------------------------|
| kind | string | true | Must be `gemini`. |
| model | string | true | The Gemini model ID to use (e.g., `gemini-embedding-001`). |
| apiKey | string | false | Your API Key from Google AI Studio. |
| dimension | integer | false | The number of dimensions in the output vector (e.g., `768`). |

View File

@@ -94,7 +94,10 @@ cluster][alloydb-free-trial].
instance.
- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md)
Lists all the user-created roles in PostgreSQL database..
Lists all the user-created roles in PostgreSQL database.
- [`postgres-list-stored-procedure`](../tools/postgres/postgres-list-stored-procedure.md)
Lists all the stored procedure in PostgreSQL database.
### Pre-built Configurations

View File

@@ -91,7 +91,10 @@ to a database by following these instructions][csql-pg-quickstart].
instance.
- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md)
Lists all the user-created roles in PostgreSQL database..
Lists all the user-created roles in PostgreSQL database.
- [`postgres-list-stored-procedure`](../tools/postgres/postgres-list-stored-procedure.md)
Lists all the stored procedure in PostgreSQL database.
### Pre-built Configurations

View File

@@ -85,7 +85,10 @@ reputation for reliability, feature robustness, and performance.
server.
- [`postgres-list-roles`](../tools/postgres/postgres-list-roles.md)
Lists all the user-created roles in PostgreSQL database..
Lists all the user-created roles in PostgreSQL database.
- [`postgres-list-stored-procedure`](../tools/postgres/postgres-list-stored-procedure.md)
Lists all the stored procedure in PostgreSQL database.
### Pre-built Configurations

View File

@@ -0,0 +1,141 @@
---
title: "postgres-list-stored-procedure"
type: docs
weight: 1
description: >
The "postgres-list-stored-procedure" tool retrieves metadata for stored procedures in PostgreSQL, including procedure definitions, owners, languages, and descriptions.
aliases:
- /resources/tools/postgres-list-stored-procedure
---
## About
The `postgres-list-stored-procedure` tool queries PostgreSQL system catalogs (`pg_proc`, `pg_namespace`, `pg_roles`, and `pg_language`) to retrieve comprehensive metadata about stored procedures in the database. It filters for procedures (kind = 'p') and provides the full procedure definition along with ownership and language information.
Compatible sources:
- [alloydb-postgres](../../sources/alloydb-pg.md)
- [cloud-sql-postgres](../../sources/cloud-sql-pg.md)
- [postgres](../../sources/postgres.md)
The tool returns a JSON array where each element represents a stored procedure with its schema, name, owner, language, complete definition, and optional description. Results are sorted by schema name and procedure name, with a default limit of 20 procedures.
## Parameters
| parameter | type | required | default | description |
|--------------|---------|----------|---------|-------------|
| role_name | string | false | null | Optional: The owner name to filter stored procedures by (supports partial matching) |
| schema_name | string | false | null | Optional: The schema name to filter stored procedures by (supports partial matching) |
| limit | integer | false | 20 | Optional: The maximum number of stored procedures to return |
## Example
```yaml
tools:
list_stored_procedure:
kind: postgres-list-stored-procedure
source: postgres-source
description: "Retrieves stored procedure metadata including definitions and owners."
```
### Example Requests
**List all stored procedures (default limit 20):**
```json
{}
```
**Filter by specific owner (role):**
```json
{
"role_name": "app_user"
}
```
**Filter by schema:**
```json
{
"schema_name": "public"
}
```
**Filter by owner and schema with custom limit:**
```json
{
"role_name": "postgres",
"schema_name": "public",
"limit": 50
}
```
**Filter by partial schema name:**
```json
{
"schema_name": "audit"
}
```
### Example Response
```json
[
{
"schema_name": "public",
"name": "process_payment",
"owner": "postgres",
"language": "plpgsql",
"definition": "CREATE OR REPLACE PROCEDURE public.process_payment(p_order_id integer, p_amount numeric)\n LANGUAGE plpgsql\nAS $procedure$\nBEGIN\n UPDATE orders SET status = 'paid', amount = p_amount WHERE id = p_order_id;\n INSERT INTO payment_log (order_id, amount, timestamp) VALUES (p_order_id, p_amount, now());\n COMMIT;\nEND\n$procedure$",
"description": "Processes payment for an order and logs the transaction"
},
{
"schema_name": "public",
"name": "cleanup_old_records",
"owner": "postgres",
"language": "plpgsql",
"definition": "CREATE OR REPLACE PROCEDURE public.cleanup_old_records(p_days_old integer)\n LANGUAGE plpgsql\nAS $procedure$\nDECLARE\n v_deleted integer;\nBEGIN\n DELETE FROM audit_logs WHERE created_at < now() - (p_days_old || ' days')::interval;\n GET DIAGNOSTICS v_deleted = ROW_COUNT;\n RAISE NOTICE 'Deleted % records', v_deleted;\nEND\n$procedure$",
"description": "Removes audit log records older than specified days"
},
{
"schema_name": "audit",
"name": "audit_table_changes",
"owner": "app_user",
"language": "plpgsql",
"definition": "CREATE OR REPLACE PROCEDURE audit.audit_table_changes()\n LANGUAGE plpgsql\nAS $procedure$\nBEGIN\n INSERT INTO audit.change_log (table_name, operation, changed_at) VALUES (TG_TABLE_NAME, TG_OP, now());\nEND\n$procedure$",
"description": null
}
]
```
## Output Fields Reference
| field | type | description |
|-------------|---------|-------------|
| schema_name | string | Name of the schema containing the stored procedure. |
| name | string | Name of the stored procedure. |
| owner | string | PostgreSQL role/user who owns the stored procedure. |
| language | string | Programming language in which the procedure is written (e.g., plpgsql, sql, c). |
| definition | string | Complete SQL definition of the stored procedure, including the CREATE PROCEDURE statement. |
| description | string | Optional description or comment for the procedure (may be null if no comment is set). |
## Use Cases
- **Code review and auditing**: Export procedure definitions for version control or compliance audits.
- **Documentation generation**: Automatically extract procedure metadata and descriptions for documentation.
- **Permission auditing**: Identify procedures owned by specific users or in specific schemas.
- **Migration planning**: Retrieve all procedure definitions when planning database migrations.
- **Dependency analysis**: Review procedure definitions to understand dependencies and call chains.
- **Security assessment**: Audit which roles own and can modify stored procedures.
## Performance Considerations
- The tool filters at the database level using LIKE pattern matching, so partial matches are supported.
- Procedure definitions can be large; consider using the `limit` parameter for large databases with many procedures.
- Results are ordered by schema name and procedure name for consistent output.
- The default limit of 20 procedures is suitable for most use cases; increase as needed.
## Notes
- Only stored **procedures** are returned; functions and other callable objects are excluded via the `prokind = 'p'` filter.
- Filtering uses `LIKE` pattern matching, so filter values support partial matches (e.g., `role_name: "app"` will match "app_user", "app_admin", etc.).
- The `definition` field contains the complete, runnable CREATE PROCEDURE statement.
- The `description` field is populated from comments set via PostgreSQL's COMMENT command and may be null.

1
go.mod
View File

@@ -59,6 +59,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
golang.org/x/oauth2 v0.33.0
google.golang.org/api v0.256.0
google.golang.org/genai v1.37.0
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8
google.golang.org/protobuf v1.36.10
modernc.org/sqlite v1.40.0

2
go.sum
View File

@@ -1869,6 +1869,8 @@ google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genai v1.37.0 h1:dgp71k1wQ+/+APdZrN3LFgAGnVnr5IdTF1Oj0Dg+BQc=
google.golang.org/genai v1.37.0/go.mod h1:A3kkl0nyBjyFlNjgxIwKq70julKbIxpSxqKO5gw/gmk=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=

View File

@@ -0,0 +1,29 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package embeddingmodels
import "context"
// EmbeddingModelConfig is the interface for configuring embedding models.
type EmbeddingModelConfig interface {
EmbeddingModelConfigKind() string
Initialize(context.Context) (EmbeddingModel, error)
}
type EmbeddingModel interface {
EmbeddingModelKind() string
ToConfig() EmbeddingModelConfig
EmbedParameters(context.Context, []string) ([][]float32, error)
}

View File

@@ -0,0 +1,122 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gemini
import (
"context"
"fmt"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/util"
"google.golang.org/genai"
)
const EmbeddingModelKind string = "gemini"
// validate interface
var _ embeddingmodels.EmbeddingModelConfig = Config{}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Model string `yaml:"model" validate:"required"`
ApiKey string `yaml:"apiKey"`
Dimension int32 `yaml:"dimension"`
}
// Returns the embedding model kind
func (cfg Config) EmbeddingModelConfigKind() string {
return EmbeddingModelKind
}
// Initialize a Gemini embedding model
func (cfg Config) Initialize(ctx context.Context) (embeddingmodels.EmbeddingModel, error) {
// Get client configs
configs := &genai.ClientConfig{}
if cfg.ApiKey != "" {
configs.APIKey = cfg.ApiKey
}
// Create new Gemini API client
client, err := genai.NewClient(ctx, configs)
if err != nil {
return nil, fmt.Errorf("unable to create Gemini API client")
}
m := &EmbeddingModel{
Config: cfg,
Client: client,
}
return m, nil
}
var _ embeddingmodels.EmbeddingModel = EmbeddingModel{}
type EmbeddingModel struct {
Client *genai.Client
Config
}
// Returns the embedding model kind
func (m EmbeddingModel) EmbeddingModelKind() string {
return EmbeddingModelKind
}
func (m EmbeddingModel) ToConfig() embeddingmodels.EmbeddingModelConfig {
return m.Config
}
func (m EmbeddingModel) EmbedParameters(ctx context.Context, parameters []string) ([][]float32, error) {
logger, err := util.LoggerFromContext(ctx)
if err != nil {
return nil, fmt.Errorf("unable to get logger from ctx: %s", err)
}
contents := convertStringsToContents(parameters)
embedConfig := &genai.EmbedContentConfig{
TaskType: "SEMANTIC_SIMILARITY",
}
if m.Dimension > 0 {
embedConfig.OutputDimensionality = genai.Ptr(m.Dimension)
}
result, err := m.Client.Models.EmbedContent(ctx, m.Model, contents, embedConfig)
if err != nil {
logger.ErrorContext(ctx, "Error calling EmbedContent for model %s: %v", m.Model, err)
return nil, err
}
embeddings := make([][]float32, 0, len(result.Embeddings))
for _, embedding := range result.Embeddings {
embeddings = append(embeddings, embedding.Values)
}
logger.InfoContext(ctx, "Successfully embedded %d text parameters using model %s", len(parameters), m.Model)
return embeddings, nil
}
// convertStringsToContents takes a slice of strings and converts it into a slice of *genai.Content objects.
func convertStringsToContents(texts []string) []*genai.Content {
contents := make([]*genai.Content, 0, len(texts))
for _, text := range texts {
content := genai.NewContentFromText(text, "")
contents = append(contents, content)
}
return contents
}

View File

@@ -0,0 +1,130 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gemini_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
)
func TestParseFromYamlGemini(t *testing.T) {
tcs := []struct {
desc string
in string
want server.EmbeddingModelConfigs
}{
{
desc: "basic example",
in: `
embeddingModels:
my-gemini-model:
kind: gemini
model: text-embedding-004
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"my-gemini-model": gemini.Config{
Name: "my-gemini-model",
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
},
},
},
{
desc: "full example with optional fields",
in: `
embeddingModels:
complex-gemini:
kind: gemini
model: text-embedding-004
apiKey: "test-api-key"
dimension: 768
`,
want: map[string]embeddingmodels.EmbeddingModelConfig{
"complex-gemini": gemini.Config{
Name: "complex-gemini",
Kind: gemini.EmbeddingModelKind,
Model: "text-embedding-004",
ApiKey: "test-api-key",
Dimension: 768,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
// Parse contents
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if !cmp.Equal(tc.want, got.Models) {
t.Fatalf("incorrect parse: %v", cmp.Diff(tc.want, got.Models))
}
})
}
}
func TestFailParseFromYamlGemini(t *testing.T) {
tcs := []struct {
desc string
in string
err string
}{
{
desc: "missing required model field",
in: `
embeddingModels:
bad-model:
kind: gemini
`,
// Removed the specific model name from the prefix to match your output
err: "unable to parse as \"gemini\": Key: 'Config.Model' Error:Field validation for 'Model' failed on the 'required' tag",
},
{
desc: "unknown field",
in: `
embeddingModels:
bad-field:
kind: gemini
model: text-embedding-004
invalid_param: true
`,
// Updated to match the specific line-starting format of your error output
err: "unable to parse as \"gemini\": [1:1] unknown field \"invalid_param\"\n> 1 | invalid_param: true\n ^\n 2 | kind: gemini\n 3 | model: text-embedding-004",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Models server.EmbeddingModelConfigs `yaml:"embeddingModels"`
}{}
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
if err == nil {
t.Fatalf("expect parsing to fail")
}
if err.Error() != tc.err {
t.Fatalf("unexpected error:\ngot: %q\nwant: %q", err.Error(), tc.err)
}
})
}
}

View File

@@ -224,6 +224,10 @@ tools:
kind: postgres-list-roles
source: alloydb-pg-source
list_stored_procedure:
kind: postgres-list-stored-procedure
source: alloydb-pg-source
toolsets:
alloydb_postgres_database_tools:
- execute_sql
@@ -254,3 +258,4 @@ toolsets:
- list_database_stats
- list_roles
- list_table_stats
- list_stored_procedure

View File

@@ -19,8 +19,8 @@ sources:
region: ${CLOUD_SQL_MYSQL_REGION}
instance: ${CLOUD_SQL_MYSQL_INSTANCE}
database: ${CLOUD_SQL_MYSQL_DATABASE}
user: ${CLOUD_SQL_MYSQL_USER}
password: ${CLOUD_SQL_MYSQL_PASSWORD}
user: ${CLOUD_SQL_MYSQL_USER:}
password: ${CLOUD_SQL_MYSQL_PASSWORD:}
ipType: ${CLOUD_SQL_MYSQL_IP_TYPE:PUBLIC}
tools:
execute_sql:

View File

@@ -226,6 +226,10 @@ tools:
kind: postgres-list-roles
source: cloudsql-pg-source
list_stored_procedure:
kind: postgres-list-stored-procedure
source: cloudsql-pg-source
toolsets:
cloud_sql_postgres_database_tools:
- execute_sql
@@ -256,3 +260,4 @@ toolsets:
- list_database_stats
- list_roles
- list_table_stats
- list_stored_procedure

View File

@@ -225,6 +225,10 @@ tools:
kind: postgres-list-roles
source: postgresql-source
list_stored_procedure:
kind: postgres-list-stored-procedure
source: postgresql-source
toolsets:
postgres_database_tools:
- execute_sql
@@ -255,3 +259,4 @@ toolsets:
- list_database_stats
- list_roles
- list_table_stats
- list_stored_procedure

View File

@@ -276,7 +276,7 @@ func setUpServer(t *testing.T, router string, tools map[string]tools.Tool, tools
sseManager := newSseManager(ctx)
resourceManager := resources.NewResourceManager(nil, nil, tools, toolsets, prompts, promptsets)
resourceManager := resources.NewResourceManager(nil, nil, nil, tools, toolsets, prompts, promptsets)
server := Server{
version: fakeVersionString,

View File

@@ -21,6 +21,8 @@ import (
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/auth/google"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels/gemini"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -38,6 +40,8 @@ type ServerConfig struct {
SourceConfigs SourceConfigs
// AuthServiceConfigs defines what sources of authentication are available for tools.
AuthServiceConfigs AuthServiceConfigs
// EmbeddingModelConfigs defines a models used to embed parameters.
EmbeddingModelConfigs EmbeddingModelConfigs
// ToolConfigs defines what tools are available.
ToolConfigs ToolConfigs
// ToolsetConfigs defines what tools are available.
@@ -205,6 +209,50 @@ func (c *AuthServiceConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(i
return nil
}
// EmbeddingModelConfigs is a type used to allow unmarshal of the embedding model config map
type EmbeddingModelConfigs map[string]embeddingmodels.EmbeddingModelConfig
// validate interface
var _ yaml.InterfaceUnmarshalerContext = &EmbeddingModelConfigs{}
func (c *EmbeddingModelConfigs) UnmarshalYAML(ctx context.Context, unmarshal func(interface{}) error) error {
*c = make(EmbeddingModelConfigs)
// Parse the 'kind' fields for each embedding model
var raw map[string]util.DelayedUnmarshaler
if err := unmarshal(&raw); err != nil {
return err
}
for name, u := range raw {
// Unmarshal to a general type that ensure it capture all fields
var v map[string]any
if err := u.Unmarshal(&v); err != nil {
return fmt.Errorf("unable to unmarshal embedding model %q: %w", name, err)
}
kind, ok := v["kind"]
if !ok {
return fmt.Errorf("missing 'kind' field for embedding model %q", name)
}
dec, err := util.NewStrictDecoder(v)
if err != nil {
return fmt.Errorf("error creating decoder: %w", err)
}
switch kind {
case gemini.EmbeddingModelKind:
actual := gemini.Config{Name: name}
if err := dec.DecodeContext(ctx, &actual); err != nil {
return fmt.Errorf("unable to parse as %q: %w", kind, err)
}
(*c)[name] = actual
default:
return fmt.Errorf("%q is not a valid kind of auth source", kind)
}
}
return nil
}
// ToolConfigs is a type used to allow unmarshal of the tool configs
type ToolConfigs map[string]tools.ToolConfig

View File

@@ -1107,7 +1107,7 @@ func TestStdioSession(t *testing.T) {
sseManager := newSseManager(ctx)
resourceManager := resources.NewResourceManager(nil, nil, toolsMap, toolsets, promptsMap, promptsets)
resourceManager := resources.NewResourceManager(nil, nil, nil, toolsMap, toolsets, promptsMap, promptsets)
server := &Server{
version: fakeVersionString,

View File

@@ -18,6 +18,7 @@ import (
"sync"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/tools"
@@ -25,30 +26,33 @@ import (
// ResourceManager contains available resources for the server. Should be initialized with NewResourceManager().
type ResourceManager struct {
mu sync.RWMutex
sources map[string]sources.Source
authServices map[string]auth.AuthService
tools map[string]tools.Tool
toolsets map[string]tools.Toolset
prompts map[string]prompts.Prompt
promptsets map[string]prompts.Promptset
mu sync.RWMutex
sources map[string]sources.Source
authServices map[string]auth.AuthService
embeddingModels map[string]embeddingmodels.EmbeddingModel
tools map[string]tools.Tool
toolsets map[string]tools.Toolset
prompts map[string]prompts.Prompt
promptsets map[string]prompts.Promptset
}
func NewResourceManager(
sourcesMap map[string]sources.Source,
authServicesMap map[string]auth.AuthService,
embeddingModelsMap map[string]embeddingmodels.EmbeddingModel,
toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset,
promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset,
) *ResourceManager {
resourceMgr := &ResourceManager{
mu: sync.RWMutex{},
sources: sourcesMap,
authServices: authServicesMap,
tools: toolsMap,
toolsets: toolsetsMap,
prompts: promptsMap,
promptsets: promptsetsMap,
mu: sync.RWMutex{},
sources: sourcesMap,
authServices: authServicesMap,
embeddingModels: embeddingModelsMap,
tools: toolsMap,
toolsets: toolsetsMap,
prompts: promptsMap,
promptsets: promptsetsMap,
}
return resourceMgr
@@ -68,6 +72,13 @@ func (r *ResourceManager) GetAuthService(authServiceName string) (auth.AuthServi
return authService, ok
}
func (r *ResourceManager) GetEmbeddingModel(embeddingModelName string) (embeddingmodels.EmbeddingModel, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
model, ok := r.embeddingModels[embeddingModelName]
return model, ok
}
func (r *ResourceManager) GetTool(toolName string) (tools.Tool, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -96,11 +107,12 @@ func (r *ResourceManager) GetPromptset(promptsetName string) (prompts.Promptset,
return promptset, ok
}
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
func (r *ResourceManager) SetResources(sourcesMap map[string]sources.Source, authServicesMap map[string]auth.AuthService, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel, toolsMap map[string]tools.Tool, toolsetsMap map[string]tools.Toolset, promptsMap map[string]prompts.Prompt, promptsetsMap map[string]prompts.Promptset) {
r.mu.Lock()
defer r.mu.Unlock()
r.sources = sourcesMap
r.authServices = authServicesMap
r.embeddingModels = embeddingModelsMap
r.tools = toolsMap
r.toolsets = toolsetsMap
r.prompts = promptsMap
@@ -117,6 +129,16 @@ func (r *ResourceManager) GetAuthServiceMap() map[string]auth.AuthService {
return copiedMap
}
func (r *ResourceManager) GetEmbeddingModelMap() map[string]embeddingmodels.EmbeddingModel {
r.mu.RLock()
defer r.mu.RUnlock()
copiedMap := make(map[string]embeddingmodels.EmbeddingModel, len(r.embeddingModels))
for k, v := range r.embeddingModels {
copiedMap[k] = v
}
return copiedMap
}
func (r *ResourceManager) GetToolsMap() map[string]tools.Tool {
r.mu.RLock()
defer r.mu.RUnlock()

View File

@@ -19,6 +19,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/resources"
"github.com/googleapis/genai-toolbox/internal/sources"
@@ -36,6 +37,7 @@ func TestUpdateServer(t *testing.T) {
},
}
newAuth := map[string]auth.AuthService{"example-auth": nil}
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
newTools := map[string]tools.Tool{"example-tool": nil}
newToolsets := map[string]tools.Toolset{
"example-toolset": {
@@ -54,7 +56,7 @@ func TestUpdateServer(t *testing.T) {
Prompts: []*prompts.Prompt{},
},
}
resMgr := resources.NewResourceManager(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
resMgr := resources.NewResourceManager(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
gotSource, _ := resMgr.GetSource("example-source")
if diff := cmp.Diff(gotSource, newSources["example-source"]); diff != "" {
@@ -95,7 +97,7 @@ func TestUpdateServer(t *testing.T) {
},
}
resMgr.SetResources(updateSource, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
resMgr.SetResources(updateSource, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
gotSource, _ = resMgr.GetSource("example-source2")
if diff := cmp.Diff(gotSource, updateSource["example-source2"]); diff != "" {
t.Errorf("error updating server, sources (-want +got):\n%s", diff)

View File

@@ -30,6 +30,7 @@ import (
"github.com/go-chi/cors"
"github.com/go-chi/httplog/v2"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server/resources"
@@ -56,6 +57,7 @@ type Server struct {
func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
map[string]sources.Source,
map[string]auth.AuthService,
map[string]embeddingmodels.EmbeddingModel,
map[string]tools.Tool,
map[string]tools.Toolset,
map[string]prompts.Prompt,
@@ -91,7 +93,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return s, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
sourcesMap[name] = s
}
@@ -119,7 +121,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return a, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
authServicesMap[name] = a
}
@@ -129,6 +131,34 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d authServices: %s", len(authServicesMap), strings.Join(authServiceNames, ", ")))
// Initialize and validate embedding models from configs.
embeddingModelsMap := make(map[string]embeddingmodels.EmbeddingModel)
for name, ec := range cfg.EmbeddingModelConfigs {
em, err := func() (embeddingmodels.EmbeddingModel, error) {
_, span := instrumentation.Tracer.Start(
ctx,
"toolbox/server/embeddingmodel/init",
trace.WithAttributes(attribute.String("model_kind", ec.EmbeddingModelConfigKind())),
trace.WithAttributes(attribute.String("model_name", name)),
)
defer span.End()
em, err := ec.Initialize(ctx)
if err != nil {
return nil, fmt.Errorf("unable to initialize embedding model %q: %w", name, err)
}
return em, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, nil, err
}
embeddingModelsMap[name] = em
}
embeddingModelNames := make([]string, 0, len(embeddingModelsMap))
for name := range embeddingModelsMap {
embeddingModelNames = append(embeddingModelNames, name)
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d embeddingModels: %s", len(embeddingModelsMap), strings.Join(embeddingModelNames, ", ")))
// initialize and validate the tools from configs
toolsMap := make(map[string]tools.Tool)
for name, tc := range cfg.ToolConfigs {
@@ -147,7 +177,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return t, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
toolsMap[name] = t
}
@@ -184,7 +214,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return t, err
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
toolsetsMap[name] = t
}
@@ -216,7 +246,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return p, nil
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
promptsMap[name] = p
}
@@ -253,7 +283,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
return p, err
}()
if err != nil {
return nil, nil, nil, nil, nil, nil, err
return nil, nil, nil, nil, nil, nil, nil, err
}
promptsetsMap[name] = p
}
@@ -267,7 +297,7 @@ func InitializeConfigs(ctx context.Context, cfg ServerConfig) (
}
l.InfoContext(ctx, fmt.Sprintf("Initialized %d promptsets: %s", len(promptsetsMap), strings.Join(promptsetNames, ", ")))
return sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
return sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, nil
}
// NewServer returns a Server object based on provided Config.
@@ -320,7 +350,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
httpLogger := httplog.NewLogger("httplog", httpOpts)
r.Use(httplog.RequestLogger(httpLogger))
sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap, err := InitializeConfigs(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("unable to initialize configs: %w", err)
}
@@ -330,7 +360,7 @@ func NewServer(ctx context.Context, cfg ServerConfig) (*Server, error) {
sseManager := newSseManager(ctx)
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
resourceManager := resources.NewResourceManager(sourcesMap, authServicesMap, embeddingModelsMap, toolsMap, toolsetsMap, promptsMap, promptsetsMap)
s := &Server{
version: cfg.Version,

View File

@@ -25,6 +25,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/auth"
"github.com/googleapis/genai-toolbox/internal/embeddingmodels"
"github.com/googleapis/genai-toolbox/internal/log"
"github.com/googleapis/genai-toolbox/internal/prompts"
"github.com/googleapis/genai-toolbox/internal/server"
@@ -144,6 +145,7 @@ func TestUpdateServer(t *testing.T) {
},
}
newAuth := map[string]auth.AuthService{"example-auth": nil}
newEmbeddingModels := map[string]embeddingmodels.EmbeddingModel{"example-model": nil}
newTools := map[string]tools.Tool{"example-tool": nil}
newToolsets := map[string]tools.Toolset{
"example-toolset": {
@@ -162,7 +164,7 @@ func TestUpdateServer(t *testing.T) {
Prompts: []*prompts.Prompt{},
},
}
s.ResourceMgr.SetResources(newSources, newAuth, newTools, newToolsets, newPrompts, newPromptsets)
s.ResourceMgr.SetResources(newSources, newAuth, newEmbeddingModels, newTools, newToolsets, newPrompts, newPromptsets)
if err != nil {
t.Errorf("error updating server: %s", err)
}

View File

@@ -333,7 +333,7 @@ func TestInvoke(t *testing.T) {
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
}
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil)
resourceMgr := resources.NewResourceManager(srcs, nil, nil, nil, nil, nil, nil)
// Invoke the tool
result, err := tool.Invoke(ctx, resourceMgr, params, "") // No accessToken needed for ADC client

View File

@@ -0,0 +1,209 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgresliststoredprocedure
import (
"context"
"fmt"
yaml "github.com/goccy/go-yaml"
"github.com/googleapis/genai-toolbox/internal/sources"
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
"github.com/googleapis/genai-toolbox/internal/tools"
"github.com/googleapis/genai-toolbox/internal/util/parameters"
"github.com/jackc/pgx/v5/pgxpool"
)
const kind string = "postgres-list-stored-procedure"
const listStoredProcedure = `
SELECT
n.nspname AS schema_name,
p.proname AS name,
r.rolname AS owner,
l.lanname AS language,
pg_catalog.pg_get_functiondef(p.oid) AS definition,
pg_catalog.obj_description(p.oid, 'pg_proc') AS description
FROM pg_catalog.pg_proc p
JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
JOIN pg_catalog.pg_roles r ON r.oid = p.proowner
JOIN pg_catalog.pg_language l ON l.oid = p.prolang
WHERE
p.prokind = 'p' AND
($1::text IS NULL OR r.rolname LIKE '%' || $1::text || '%') AND
($2::text IS NULL OR n.nspname LIKE '%' || $2::text || '%')
ORDER BY n.nspname, p.proname
LIMIT
COALESCE($3::int, 20);
`
func init() {
if !tools.Register(kind, newConfig) {
panic(fmt.Sprintf("tool kind %q already registered", kind))
}
}
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
actual := Config{Name: name}
if err := decoder.DecodeContext(ctx, &actual); err != nil {
return nil, err
}
return actual, nil
}
type compatibleSource interface {
PostgresPool() *pgxpool.Pool
}
// validate compatible sources are still compatible
var _ compatibleSource = &alloydbpg.Source{}
var _ compatibleSource = &cloudsqlpg.Source{}
var _ compatibleSource = &postgres.Source{}
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
type Config struct {
Name string `yaml:"name" validate:"required"`
Kind string `yaml:"kind" validate:"required"`
Source string `yaml:"source" validate:"required"`
Description string `yaml:"description"`
AuthRequired []string `yaml:"authRequired"`
}
// validate interface
var _ tools.ToolConfig = Config{}
func (cfg Config) ToolConfigKind() string {
return kind
}
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
// verify source exists
rawS, ok := srcs[cfg.Source]
if !ok {
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
}
// verify the source is compatible
s, ok := rawS.(compatibleSource)
if !ok {
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
}
allParameters := parameters.Parameters{
parameters.NewStringParameterWithRequired("role_name", "Optional: The owner name to filter the stored procedures by. Defaults to NULL.", false),
parameters.NewStringParameterWithRequired("schema_name", "Optional: The schema name to filter the stored procedures by. Defaults to NULL.", false),
parameters.NewIntParameterWithDefault("limit", 20, "Optional: The maximum number of stored procedures to return. Defaults to 20."),
}
paramManifest := allParameters.Manifest()
if cfg.Description == "" {
cfg.Description = "Retrieves stored procedure metadata returning schema name, procedure name, procedure owner, language, definition, and description, filtered by optional role name (procedure owner), schema name, and limit (default 20)."
}
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
// finish tool setup
return Tool{
name: cfg.Name,
kind: cfg.Kind,
authRequired: cfg.AuthRequired,
allParams: allParameters,
pool: s.PostgresPool(),
manifest: tools.Manifest{
Description: cfg.Description,
Parameters: paramManifest,
AuthRequired: cfg.AuthRequired,
},
mcpManifest: mcpManifest,
}, nil
}
// validate interface
var _ tools.Tool = Tool{}
type Tool struct {
Config
name string `yaml:"name"`
kind string `yaml:"kind"`
authRequired []string `yaml:"authRequired"`
allParams parameters.Parameters `yaml:"allParams"`
pool *pgxpool.Pool
manifest tools.Manifest
mcpManifest tools.McpManifest
}
func (t Tool) ToConfig() tools.ToolConfig {
return t.Config
}
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
paramsMap := params.AsMap()
newParams, err := parameters.GetParams(t.allParams, paramsMap)
if err != nil {
return nil, fmt.Errorf("unable to extract standard params %w", err)
}
sliceParams := newParams.AsSlice()
results, err := t.pool.Query(ctx, listStoredProcedure, sliceParams...)
if err != nil {
return nil, fmt.Errorf("unable to execute query: %w", err)
}
defer results.Close()
fields := results.FieldDescriptions()
var out []map[string]any
for results.Next() {
values, err := results.Values()
if err != nil {
return nil, fmt.Errorf("unable to parse row: %w", err)
}
rowMap := make(map[string]any)
for i, field := range fields {
rowMap[string(field.Name)] = values[i]
}
out = append(out, rowMap)
}
return out, nil
}
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
return parameters.ParseParams(t.allParams, data, claims)
}
func (t Tool) Manifest() tools.Manifest {
return t.manifest
}
func (t Tool) McpManifest() tools.McpManifest {
return t.mcpManifest
}
func (t Tool) Authorized(verifiedAuthServices []string) bool {
return tools.IsAuthorized(t.authRequired, verifiedAuthServices)
}
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) {
return false, nil
}
func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) {
return "Authorization", nil
}

View File

@@ -0,0 +1,95 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgresliststoredprocedure_test
import (
"testing"
yaml "github.com/goccy/go-yaml"
"github.com/google/go-cmp/cmp"
"github.com/googleapis/genai-toolbox/internal/server"
"github.com/googleapis/genai-toolbox/internal/testutils"
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresliststoredprocedure"
)
func TestParseFromYamlPostgresListStoredProcedure(t *testing.T) {
ctx, err := testutils.ContextWithNewLogger()
if err != nil {
t.Fatalf("unexpected error: %s", err)
}
tcs := []struct {
desc string
in string
want server.ToolConfigs
}{
{
desc: "basic example",
in: `
tools:
example_tool:
kind: postgres-list-stored-procedure
source: my-postgres-instance
description: some description
authRequired:
- my-google-auth-service
- other-auth-service
`,
want: server.ToolConfigs{
"example_tool": postgresliststoredprocedure.Config{
Name: "example_tool",
Kind: "postgres-list-stored-procedure",
Source: "my-postgres-instance",
Description: "some description",
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
},
},
},
{
desc: "basic example",
in: `
tools:
example_tool:
kind: postgres-list-stored-procedure
source: my-postgres-instance
description: some description
`,
want: server.ToolConfigs{
"example_tool": postgresliststoredprocedure.Config{
Name: "example_tool",
Kind: "postgres-list-stored-procedure",
Source: "my-postgres-instance",
Description: "some description",
AuthRequired: []string{},
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
got := struct {
Tools server.ToolConfigs `yaml:"tools"`
}{}
// Parse contents
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
if err != nil {
t.Fatalf("unable to unmarshal: %s", err)
}
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
t.Fatalf("incorrect parse: diff %v", diff)
}
})
}
}

View File

@@ -201,6 +201,7 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
tests.RunPostgresListPgSettingsTest(t, ctx, pool)
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
}
// Test connection with different IP type

View File

@@ -185,6 +185,7 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
tests.RunPostgresListPgSettingsTest(t, ctx, pool)
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
}
// Test connection with different IP type

View File

@@ -213,6 +213,7 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
PostgresListPGSettingsToolKind = "postgres-list-pg-settings"
PostgresListDatabaseStatsToolKind = "postgres-list-database-stats"
PostgresListRolesToolKind = "postgres-list-roles"
PostgresListStoredProcedureToolKind = "postgres-list-stored-procedure"
)
tools, ok := config["tools"].(map[string]any)
@@ -310,6 +311,11 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
"kind": PostgresListRolesToolKind,
"source": "my-instance",
}
tools["list_stored_procedure"] = map[string]any{
"kind": PostgresListStoredProcedureToolKind,
"source": "my-instance",
}
config["tools"] = tools
return config
}

View File

@@ -164,4 +164,5 @@ func TestPostgres(t *testing.T) {
tests.RunPostgresListPgSettingsTest(t, ctx, pool)
tests.RunPostgresListDatabaseStatsTest(t, ctx, pool)
tests.RunPostgresListRolesTest(t, ctx, pool)
tests.RunPostgresListStoredProcedureTest(t, ctx, pool)
}

View File

@@ -4390,6 +4390,257 @@ func RunPostgresListTableStatsTest(t *testing.T, ctx context.Context, pool *pgxp
}
}
// RunPostgresListStoredProcedureTest runs tests for the postgres list-stored-procedure tool
func RunPostgresListStoredProcedureTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
type storedProcedureDetails struct {
SchemaName string `json:"schema_name"`
Name string `json:"name"`
Owner string `json:"owner"`
Language string `json:"language"`
Definition string `json:"definition"`
Description any `json:"description"`
}
// Create test schema
testSchemaName := "test_proc_schema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createSchemaStmt := fmt.Sprintf("CREATE SCHEMA %s", testSchemaName)
if _, err := pool.Exec(ctx, createSchemaStmt); err != nil {
t.Fatalf("unable to create test schema: %v", err)
}
defer func() {
dropSchemaStmt := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", testSchemaName)
if _, err := pool.Exec(ctx, dropSchemaStmt); err != nil {
t.Logf("warning: unable to drop test schema: %v", err)
}
}()
// Create test procedures
proc1Name := "test_proc_1_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createProc1Stmt := fmt.Sprintf(`
CREATE PROCEDURE %s.%s(p_count INT)
LANGUAGE plpgsql
AS $$
BEGIN
INSERT INTO test_table VALUES (p_count);
COMMIT;
END;
$$
`, testSchemaName, proc1Name)
if _, err := pool.Exec(ctx, createProc1Stmt); err != nil {
t.Fatalf("unable to create test procedure 1: %v", err)
}
// Add a comment/description to the procedure
commentStmt := fmt.Sprintf("COMMENT ON PROCEDURE %s.%s(INT) IS 'Test procedure that inserts a record'", testSchemaName, proc1Name)
if _, err := pool.Exec(ctx, commentStmt); err != nil {
t.Logf("warning: unable to add comment to procedure: %v", err)
}
// Create a second test procedure
proc2Name := "test_proc_2_" + strings.ReplaceAll(uuid.New().String(), "-", "")
createProc2Stmt := fmt.Sprintf(`
CREATE PROCEDURE %s.%s()
LANGUAGE plpgsql
AS $$
DECLARE
v_count INT;
BEGIN
SELECT COUNT(*) INTO v_count FROM test_table;
RAISE NOTICE 'Total records: %%', v_count;
END;
$$
`, testSchemaName, proc2Name)
if _, err := pool.Exec(ctx, createProc2Stmt); err != nil {
t.Fatalf("unable to create test procedure 2: %v", err)
}
invokeTcs := []struct {
name string
requestBody io.Reader
wantStatusCode int
shouldHaveData bool
expectedCount int
filterByRole string
filterBySchema string
}{
{
name: "list stored procedures with no arguments (default limit 20)",
requestBody: bytes.NewBufferString(`{}`),
wantStatusCode: http.StatusOK,
shouldHaveData: false, // may or may not have data depending on what's in the database
},
{
name: "list stored procedures filtering by specific schema",
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s"}`, testSchemaName)),
wantStatusCode: http.StatusOK,
shouldHaveData: true,
expectedCount: 2,
filterBySchema: testSchemaName,
},
{
name: "list stored procedures filtering by procedure owner (postgres)",
requestBody: bytes.NewBufferString(`{"role_name": "postgres"}`),
wantStatusCode: http.StatusOK,
shouldHaveData: false, // might have procedures owned by postgres
},
{
name: "list stored procedures with custom limit",
requestBody: bytes.NewBufferString(`{"limit": 5}`),
wantStatusCode: http.StatusOK,
shouldHaveData: false,
},
{
name: "list stored procedures filtering by schema and role",
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "role_name": "postgres"}`, testSchemaName)),
wantStatusCode: http.StatusOK,
shouldHaveData: true,
expectedCount: 2,
filterBySchema: testSchemaName,
filterByRole: "postgres",
},
{
name: "list stored procedures with non-existent schema",
requestBody: bytes.NewBufferString(`{"schema_name": "non_existent_schema_xyz"}`),
wantStatusCode: http.StatusOK,
shouldHaveData: false,
},
{
name: "list stored procedures with non-existent role",
requestBody: bytes.NewBufferString(`{"role_name": "non_existent_role_xyz"}`),
wantStatusCode: http.StatusOK,
shouldHaveData: false,
},
{
name: "list stored procedures with partial schema name match",
requestBody: bytes.NewBufferString(`{"schema_name": "test_proc"}`),
wantStatusCode: http.StatusOK,
shouldHaveData: true,
expectedCount: 2,
},
{
name: "list stored procedures with limit 1",
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "limit": 1}`, testSchemaName)),
wantStatusCode: http.StatusOK,
shouldHaveData: true,
expectedCount: 1,
filterBySchema: testSchemaName,
},
}
for _, tc := range invokeTcs {
t.Run(tc.name, func(t *testing.T) {
const api = "http://127.0.0.1:5000/api/tool/list_stored_procedure/invoke"
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
if resp.StatusCode != tc.wantStatusCode {
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
}
if tc.wantStatusCode != http.StatusOK {
return
}
var bodyWrapper struct {
Result json.RawMessage `json:"result"`
}
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
t.Fatalf("error decoding response wrapper: %v", err)
}
var resultString string
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
resultString = string(bodyWrapper.Result)
}
var got []storedProcedureDetails
if resultString != "null" {
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
t.Fatalf("failed to unmarshal result: %v, result string: %s", err, resultString)
}
}
// Verify expected data presence
if tc.shouldHaveData {
if len(got) == 0 {
t.Fatalf("expected data but got empty result")
}
// If filtering by schema, verify all results are from that schema
if tc.filterBySchema != "" {
for _, proc := range got {
if proc.SchemaName != tc.filterBySchema && !strings.Contains(proc.SchemaName, tc.filterBySchema) {
t.Errorf("procedure schema %s does not match filter %s", proc.SchemaName, tc.filterBySchema)
}
}
}
// If filtering by role, verify all results are owned by that role
if tc.filterByRole != "" {
for _, proc := range got {
if proc.Owner != tc.filterByRole {
t.Errorf("procedure owner %s does not match filter %s", proc.Owner, tc.filterByRole)
}
}
}
// Verify expected count if specified
if tc.expectedCount > 0 && len(got) != tc.expectedCount {
t.Errorf("expected %d procedures but got %d", tc.expectedCount, len(got))
}
}
// Verify result structure and data types
for _, proc := range got {
// Verify all required fields are present and non-empty
if proc.SchemaName == "" {
t.Errorf("schema_name should not be empty")
}
if proc.Name == "" {
t.Errorf("procedure name should not be empty")
}
if proc.Owner == "" {
t.Errorf("owner should not be empty")
}
if proc.Language == "" {
t.Errorf("language should not be empty")
}
if proc.Definition == "" {
t.Errorf("definition should not be empty")
}
// Verify definition contains CREATE PROCEDURE
if !strings.Contains(proc.Definition, "CREATE PROCEDURE") {
t.Logf("warning: definition may not be a valid CREATE PROCEDURE statement: %s", proc.Definition)
}
// Verify language is a valid PostgreSQL language
validLanguages := []string{"plpgsql", "sql", "c", "internal", "plperl", "pltcl", "plpython"}
found := false
for _, lang := range validLanguages {
if proc.Language == lang {
found = true
break
}
}
if !found {
t.Logf("warning: language %s may not be a standard PostgreSQL language", proc.Language)
}
}
// Verify results are sorted by schema_name and name
if len(got) > 1 {
for i := 0; i < len(got)-1; i++ {
currentKey := fmt.Sprintf("%s.%s", got[i].SchemaName, got[i].Name)
nextKey := fmt.Sprintf("%s.%s", got[i+1].SchemaName, got[i+1].Name)
if currentKey > nextKey {
t.Logf("warning: results may not be sorted by schema_name and name")
}
}
}
})
}
}
// RunRequest is a helper function to send HTTP requests and return the response
func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
// Send request