mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-01-10 07:58:12 -05:00
feat: add list-query-stats and get-column-cardinality functions (#1976)
Adds the following tools for Postgre (1) get_column_cardinality - to fetch the estimates of the distinct column counts in the table for particular column or all columns (2) list-query-stats - to obtain the query level statistics in a database. This tool requires pg_stat_statements extension as a prerequisite. <img width="2428" height="1368" alt="image" src="https://github.com/user-attachments/assets/5d9b22f0-6b09-4abe-8411-b1139387e259" /> <img width="3774" height="1010" alt="image" src="https://github.com/user-attachments/assets/b1d9fdf1-8a3b-4afe-ab98-63226a7e3705" /> PR Checklist [Y] Make sure you reviewed [CONTRIBUTING.md](https://github.com/googleapis/genai-toolbox/blob/main/CONTRIBUTING.md) [Y] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/genai-toolbox/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea [Y] Ensure the tests and linter pass [Y] Code coverage does not decrease (if any source code was changed) [Y] Appropriate docs were updated (if necessary) [Y] Make sure to add ! if this involve a breaking change 🛠️ Fixes https://github.com/googleapis/genai-toolbox/issues/1691 --------- Co-authored-by: Averi Kitsch <akitsch@google.com>
This commit is contained in:
@@ -180,11 +180,13 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresdatabaseoverview"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresexecutesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistactivequeries"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistavailableextensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistindexes"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistinstalledextensions"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistlocks"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistschemas"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistsequences"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslisttables"
|
||||
|
||||
@@ -1478,7 +1478,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"alloydb_postgres_database_tools": tools.ToolsetConfig{
|
||||
Name: "alloydb_postgres_database_tools",
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats"},
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1508,7 +1508,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"cloud_sql_postgres_database_tools": tools.ToolsetConfig{
|
||||
Name: "cloud_sql_postgres_database_tools",
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats"},
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1608,7 +1608,7 @@ func TestPrebuiltTools(t *testing.T) {
|
||||
wantToolset: server.ToolsetConfigs{
|
||||
"postgres_database_tools": tools.ToolsetConfig{
|
||||
Name: "postgres_database_tools",
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats"},
|
||||
ToolNames: []string{"execute_sql", "list_tables", "list_active_queries", "list_available_extensions", "list_installed_extensions", "list_autovacuum_configurations", "list_memory_configurations", "list_top_bloated_tables", "list_replication_slots", "list_invalid_indexes", "get_query_plan", "list_views", "list_schemas", "database_overview", "list_triggers", "list_indexes", "list_sequences", "long_running_transactions", "list_locks", "replication_stats", "list_query_stats", "get_column_cardinality"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -71,6 +71,12 @@ cluster][alloydb-free-trial].
|
||||
- [`postgres-replication-stats`](../tools/postgres/postgres-replication-stats.md)
|
||||
List replication stats in a PostgreSQL database.
|
||||
|
||||
- [`postgres-list-query-stats`](../tools/postgres/postgres-list-query-stats.md)
|
||||
List query statistics in a PostgreSQL database.
|
||||
|
||||
- [`postgres-get-column-cardinality`](../tools/postgres/postgres-get-column-cardinality.md)
|
||||
List cardinality of columns in a table in a PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [AlloyDB using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/alloydb_pg_mcp/)
|
||||
|
||||
@@ -67,6 +67,12 @@ to a database by following these instructions][csql-pg-quickstart].
|
||||
- [`postgres-replication-stats`](../tools/postgres/postgres-replication-stats.md)
|
||||
List replication stats in a PostgreSQL database.
|
||||
|
||||
- [`postgres-list-query-stats`](../tools/postgres/postgres-list-query-stats.md)
|
||||
List query statistics in a PostgreSQL database.
|
||||
|
||||
- [`postgres-get-column-cardinality`](../tools/postgres/postgres-get-column-cardinality.md)
|
||||
List cardinality of columns in a table in a PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [Cloud SQL for Postgres using
|
||||
|
||||
@@ -62,6 +62,12 @@ reputation for reliability, feature robustness, and performance.
|
||||
- [`postgres-replication-stats`](../tools/postgres/postgres-replication-stats.md)
|
||||
List replication stats in a PostgreSQL database.
|
||||
|
||||
- [`postgres-list-query-stats`](../tools/postgres/postgres-list-query-stats.md)
|
||||
List query statistics in a PostgreSQL database.
|
||||
|
||||
- [`postgres-get-column-cardinality`](../tools/postgres/postgres-get-column-cardinality.md)
|
||||
List cardinality of columns in a table in a PostgreSQL database.
|
||||
|
||||
### Pre-built Configurations
|
||||
|
||||
- [PostgreSQL using MCP](https://googleapis.github.io/genai-toolbox/how-to/connect-ide/postgres_mcp/)
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
---
|
||||
title: "postgres-get-column-cardinality"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The "postgres-get-column-cardinality" tool estimates the number of unique values in one or all columns of a Postgres database table.
|
||||
aliases:
|
||||
- /resources/tools/postgres-get-column-cardinality
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `postgres-get-column-cardinality` tool estimates the number of unique values
|
||||
(cardinality) for one or all columns in a specific PostgreSQL table by using the
|
||||
database's internal statistics. It's compatible with any of the following sources:
|
||||
|
||||
- [alloydb-postgres](../../sources/alloydb-pg.md)
|
||||
- [cloud-sql-postgres](../../sources/cloud-sql-pg.md)
|
||||
- [postgres](../../sources/postgres.md)
|
||||
|
||||
`postgres-get-column-cardinality` returns detailed information as JSON about column
|
||||
cardinality values, ordered by estimated cardinality in descending order. The tool takes
|
||||
the following input parameters:
|
||||
|
||||
- `schema_name` (required): The schema name in which the table is present.
|
||||
- `table_name` (required): The table name in which the column is present.
|
||||
- `column_name` (optional): The column name for which the cardinality is to be found.
|
||||
If not provided, cardinality for all columns will be returned. Default: `""`.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
get_column_cardinality:
|
||||
kind: postgres-get-column-cardinality
|
||||
source: postgres-source
|
||||
description: Estimates the number of unique values (cardinality) quickly for one or all columns in a specific PostgreSQL table by using the database's internal statistics, returning the results in descending order of estimated cardinality. Please run ANALYZE on the table before using this tool to get accurate results. The tool returns the column_name and the estimated_cardinality. If the column_name is not provided, the tool returns all columns along with their estimated cardinality.
|
||||
```
|
||||
|
||||
The response is a json array with the following elements:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"column_name": "name of the column",
|
||||
"estimated_cardinality": "estimated number of unique values in the column"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
For accurate results, it's recommended to run `ANALYZE` on the table before using this
|
||||
tool. The `ANALYZE` command updates the database statistics that this tool relies on
|
||||
to estimate cardinality.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:--------:|:------------:|------------------------------------------------------|
|
||||
| kind | string | true | Must be "postgres-get-column-cardinality". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -0,0 +1,71 @@
|
||||
---
|
||||
title: "postgres-list-query-stats"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
The "postgres-list-query-stats" tool lists query statistics from a Postgres database.
|
||||
aliases:
|
||||
- /resources/tools/postgres-list-query-stats
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `postgres-list-query-stats` tool retrieves query statistics from the
|
||||
`pg_stat_statements` extension in a PostgreSQL database. It provides detailed
|
||||
performance metrics for executed queries. It's compatible with any of the following
|
||||
sources:
|
||||
|
||||
- [alloydb-postgres](../../sources/alloydb-pg.md)
|
||||
- [cloud-sql-postgres](../../sources/cloud-sql-pg.md)
|
||||
- [postgres](../../sources/postgres.md)
|
||||
|
||||
`postgres-list-query-stats` lists detailed query statistics as JSON, ordered by
|
||||
total execution time in descending order. The tool takes the following input parameters:
|
||||
|
||||
- `database_name` (optional): The database name to filter query stats for. The input is
|
||||
used within a LIKE clause. Default: `""` (all databases).
|
||||
- `limit` (optional): The maximum number of results to return. Default: `50`.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
list_query_stats:
|
||||
kind: postgres-list-query-stats
|
||||
source: postgres-source
|
||||
description: List query statistics from pg_stat_statements, showing performance metrics for queries including execution counts, timing information, and resource usage. Results are ordered by total execution time descending.
|
||||
```
|
||||
|
||||
The response is a json array with the following elements:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"datname": "database name",
|
||||
"query": "the SQL query text",
|
||||
"calls": "number of times the query was executed",
|
||||
"total_exec_time": "total execution time in milliseconds",
|
||||
"min_exec_time": "minimum execution time in milliseconds",
|
||||
"max_exec_time": "maximum execution time in milliseconds",
|
||||
"mean_exec_time": "mean execution time in milliseconds",
|
||||
"rows": "total number of rows retrieved or affected",
|
||||
"shared_blks_hit": "number of shared block cache hits",
|
||||
"shared_blks_read": "number of shared block disk reads"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
This tool requires the `pg_stat_statements` extension to be installed and enabled
|
||||
on the PostgreSQL database. The `pg_stat_statements` extension tracks execution
|
||||
statistics for all SQL statements executed by the server, which is useful for
|
||||
identifying slow queries and understanding query performance patterns.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|-------------|:--------:|:------------:|------------------------------------------------------|
|
||||
| kind | string | true | Must be "postgres-list-query-stats". |
|
||||
| source | string | true | Name of the source the SQL should execute on. |
|
||||
| description | string | true | Description of the tool that is passed to the LLM. |
|
||||
@@ -192,6 +192,14 @@ tools:
|
||||
kind: postgres-list-triggers
|
||||
source: alloydb-pg-source
|
||||
|
||||
list_query_stats:
|
||||
kind: postgres-list-query-stats
|
||||
source: alloydb-pg-source
|
||||
|
||||
get_column_cardinality:
|
||||
kind: postgres-get-column-cardinality
|
||||
source: alloydb-pg-source
|
||||
|
||||
toolsets:
|
||||
alloydb_postgres_database_tools:
|
||||
- execute_sql
|
||||
@@ -214,4 +222,5 @@ toolsets:
|
||||
- long_running_transactions
|
||||
- list_locks
|
||||
- replication_stats
|
||||
|
||||
- list_query_stats
|
||||
- get_column_cardinality
|
||||
|
||||
@@ -194,6 +194,14 @@ tools:
|
||||
kind: postgres-list-sequences
|
||||
source: cloudsql-pg-source
|
||||
|
||||
list_query_stats:
|
||||
kind: postgres-list-query-stats
|
||||
source: cloudsql-pg-source
|
||||
|
||||
get_column_cardinality:
|
||||
kind: postgres-get-column-cardinality
|
||||
source: cloudsql-pg-source
|
||||
|
||||
toolsets:
|
||||
cloud_sql_postgres_database_tools:
|
||||
- execute_sql
|
||||
@@ -216,3 +224,5 @@ toolsets:
|
||||
- long_running_transactions
|
||||
- list_locks
|
||||
- replication_stats
|
||||
- list_query_stats
|
||||
- get_column_cardinality
|
||||
|
||||
@@ -193,6 +193,14 @@ tools:
|
||||
kind: postgres-list-sequences
|
||||
source: postgresql-source
|
||||
|
||||
list_query_stats:
|
||||
kind: postgres-list-query-stats
|
||||
source: postgresql-source
|
||||
|
||||
get_column_cardinality:
|
||||
kind: postgres-get-column-cardinality
|
||||
source: postgresql-source
|
||||
|
||||
toolsets:
|
||||
postgres_database_tools:
|
||||
- execute_sql
|
||||
@@ -215,3 +223,6 @@ toolsets:
|
||||
- long_running_transactions
|
||||
- list_locks
|
||||
- replication_stats
|
||||
- list_query_stats
|
||||
- get_column_cardinality
|
||||
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgresgetcolumncardinality
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const kind string = "postgres-get-column-cardinality"
|
||||
|
||||
const getColumnCardinality = `
|
||||
SELECT
|
||||
s.attname AS column_name,
|
||||
ROUND(
|
||||
CASE
|
||||
WHEN s.n_distinct < 0 THEN ABS(s.n_distinct) * c.reltuples
|
||||
ELSE s.n_distinct
|
||||
END
|
||||
) AS estimated_cardinality
|
||||
FROM pg_stats s
|
||||
JOIN pg_class c ON s.tablename = c.relname
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE s.schemaname = $1 -- Parameter for schema
|
||||
AND s.tablename = $2 -- Parameter for table
|
||||
AND n.nspname = $1 -- (same parameter as schema)
|
||||
AND s.attname = COALESCE($3, s.attname)
|
||||
ORDER BY estimated_cardinality DESC;
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
var _ compatibleSource = &cloudsqlpg.Source{}
|
||||
var _ compatibleSource = &postgres.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithDefault("schema_name", "public", "Optional: The schema name in which the table is present."),
|
||||
parameters.NewStringParameterWithRequired("table_name", "Required: The table name in which the column is present.", true),
|
||||
parameters.NewStringParameterWithRequired("column_name", "Optional: The column name for which the cardinality is to be found. If not provided, cardinality for all columns will be returned.", false),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
if cfg.Description == "" {
|
||||
cfg.Description = "Estimates the number of unique values (cardinality) quickly for one or all columns in a specific PostgreSQL table by using the database's internal statistics, returning the results in descending order of estimated cardinality. Please run ANALYZE on the table before using this tool to get accurate results. The tool returns the column_name and the estimated_cardinality. If the column_name is not provided, the tool returns all columns along with their estimated cardinality."
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
name: cfg.Name,
|
||||
kind: cfg.Kind,
|
||||
authRequired: cfg.AuthRequired,
|
||||
allParams: allParameters,
|
||||
pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: paramManifest,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
name string `yaml:"name"`
|
||||
kind string `yaml:"kind"`
|
||||
authRequired []string `yaml:"authRequired"`
|
||||
allParams parameters.Parameters `yaml:"allParams"`
|
||||
pool *pgxpool.Pool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
newParams, err := parameters.GetParams(t.allParams, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
|
||||
results, err := t.pool.Query(ctx, getColumnCardinality, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
var out []map[string]any
|
||||
|
||||
for results.Next() {
|
||||
values, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
rowMap := make(map[string]any)
|
||||
for i, field := range fields {
|
||||
rowMap[string(field.Name)] = values[i]
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return err.Error(), fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.allParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.authRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgresgetcolumncardinality_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgresgetcolumncardinality"
|
||||
)
|
||||
|
||||
func TestParseFromYamlPostgresGetColumnCardinality(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-get-column-cardinality
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgresgetcolumncardinality.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-get-column-cardinality",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-get-column-cardinality
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgresgetcolumncardinality.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-get-column-cardinality",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgreslistquerystats
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/alloydbpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/postgres"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const kind string = "postgres-list-query-stats"
|
||||
|
||||
const listQueryStats = `
|
||||
SELECT
|
||||
d.datname,
|
||||
s.query,
|
||||
s.calls,
|
||||
s.total_exec_time,
|
||||
s.min_exec_time,
|
||||
s.max_exec_time,
|
||||
s.mean_exec_time,
|
||||
s.rows,
|
||||
s.shared_blks_hit,
|
||||
s.shared_blks_read
|
||||
FROM
|
||||
pg_stat_statements s
|
||||
JOIN pg_database d ON d.oid = s.dbid
|
||||
WHERE d.datname <> 'cloudsqladmin' AND ($1::text IS NULL OR d.datname LIKE '%' || $1::text || '%')
|
||||
ORDER BY total_exec_time DESC
|
||||
LIMIT COALESCE($2::int, 50);
|
||||
`
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type compatibleSource interface {
|
||||
PostgresPool() *pgxpool.Pool
|
||||
}
|
||||
|
||||
// validate compatible sources are still compatible
|
||||
var _ compatibleSource = &alloydbpg.Source{}
|
||||
var _ compatibleSource = &cloudsqlpg.Source{}
|
||||
var _ compatibleSource = &postgres.Source{}
|
||||
|
||||
var compatibleSources = [...]string{alloydbpg.SourceKind, cloudsqlpg.SourceKind, postgres.SourceKind}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(compatibleSource)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources)
|
||||
}
|
||||
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithDefault("database_name", "", "Optional: The database name to list query stats for."),
|
||||
parameters.NewIntParameterWithDefault("limit", 50, "Optional: The maximum number of results to return. Defaults to 50."),
|
||||
}
|
||||
paramManifest := allParameters.Manifest()
|
||||
|
||||
if cfg.Description == "" {
|
||||
cfg.Description = "Lists performance statistics for executed queries ordered by total time, filtering by database name pattern if provided. This tool requires the pg_stat_statements extension to be installed. The tool returns the database name, query text, execution count, timing metrics (total, min, max, mean), rows affected, and buffer cache I/O statistics (hits and reads)."
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
// finish tool setup
|
||||
return Tool{
|
||||
name: cfg.Name,
|
||||
kind: cfg.Kind,
|
||||
authRequired: cfg.AuthRequired,
|
||||
allParams: allParameters,
|
||||
pool: s.PostgresPool(),
|
||||
manifest: tools.Manifest{
|
||||
Description: cfg.Description,
|
||||
Parameters: paramManifest,
|
||||
AuthRequired: cfg.AuthRequired,
|
||||
},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
name string `yaml:"name"`
|
||||
kind string `yaml:"kind"`
|
||||
authRequired []string `yaml:"authRequired"`
|
||||
allParams parameters.Parameters `yaml:"allParams"`
|
||||
pool *pgxpool.Pool
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
func (t Tool) Invoke(ctx context.Context, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
|
||||
newParams, err := parameters.GetParams(t.allParams, paramsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract standard params %w", err)
|
||||
}
|
||||
sliceParams := newParams.AsSlice()
|
||||
|
||||
results, err := t.pool.Query(ctx, listQueryStats, sliceParams...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
defer results.Close()
|
||||
|
||||
fields := results.FieldDescriptions()
|
||||
var out []map[string]any
|
||||
|
||||
for results.Next() {
|
||||
values, err := results.Values()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse row: %w", err)
|
||||
}
|
||||
rowMap := make(map[string]any)
|
||||
for i, field := range fields {
|
||||
rowMap[string(field.Name)] = values[i]
|
||||
}
|
||||
out = append(out, rowMap)
|
||||
}
|
||||
|
||||
if err := results.Err(); err != nil {
|
||||
return err.Error(), fmt.Errorf("unable to execute query: %w", err)
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.allParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.authRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgreslistquerystats_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/postgres/postgreslistquerystats"
|
||||
)
|
||||
|
||||
func TestParseFromYamlPostgresListQueryStats(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-list-query-stats
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
- other-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgreslistquerystats.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-list-query-stats",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{"my-google-auth-service", "other-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
example_tool:
|
||||
kind: postgres-list-query-stats
|
||||
source: my-postgres-instance
|
||||
description: some description
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"example_tool": postgreslistquerystats.Config{
|
||||
Name: "example_tool",
|
||||
Kind: "postgres-list-query-stats",
|
||||
Source: "my-postgres-instance",
|
||||
Description: "some description",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
@@ -193,6 +193,8 @@ func TestAlloyDBPgToolEndpoints(t *testing.T) {
|
||||
tests.RunPostgresListLocksTest(t, ctx, pool)
|
||||
tests.RunPostgresReplicationStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresLongRunningTransactionsTest(t, ctx, pool)
|
||||
tests.RunPostgresListQueryStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresGetColumnCardinalityTest(t, ctx, pool)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -177,6 +177,8 @@ func TestCloudSQLPgSimpleToolEndpoints(t *testing.T) {
|
||||
tests.RunPostgresListLocksTest(t, ctx, pool)
|
||||
tests.RunPostgresReplicationStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresLongRunningTransactionsTest(t, ctx, pool)
|
||||
tests.RunPostgresListQueryStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresGetColumnCardinalityTest(t, ctx, pool)
|
||||
}
|
||||
|
||||
// Test connection with different IP type
|
||||
|
||||
@@ -205,6 +205,8 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
|
||||
PostgresLongRunningTransactionsToolKind = "postgres-long-running-transactions"
|
||||
PostgresListLocksToolKind = "postgres-list-locks"
|
||||
PostgresReplicationStatsToolKind = "postgres-replication-stats"
|
||||
PostgresListQueryStatsToolKind = "postgres-list-query-stats"
|
||||
PostgresGetColumnCardinalityToolKind = "postgres-get-column-cardinality"
|
||||
)
|
||||
|
||||
tools, ok := config["tools"].(map[string]any)
|
||||
@@ -275,6 +277,16 @@ func AddPostgresPrebuiltConfig(t *testing.T, config map[string]any) map[string]a
|
||||
"kind": PostgresReplicationStatsToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
tools["list_query_stats"] = map[string]any{
|
||||
"kind": PostgresListQueryStatsToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
|
||||
tools["get_column_cardinality"] = map[string]any{
|
||||
"kind": PostgresGetColumnCardinalityToolKind,
|
||||
"source": "my-instance",
|
||||
}
|
||||
|
||||
config["tools"] = tools
|
||||
return config
|
||||
}
|
||||
|
||||
@@ -156,4 +156,6 @@ func TestPostgres(t *testing.T) {
|
||||
tests.RunPostgresLongRunningTransactionsTest(t, ctx, pool)
|
||||
tests.RunPostgresListLocksTest(t, ctx, pool)
|
||||
tests.RunPostgresReplicationStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresListQueryStatsTest(t, ctx, pool)
|
||||
tests.RunPostgresGetColumnCardinalityTest(t, ctx, pool)
|
||||
}
|
||||
|
||||
274
tests/tool.go
274
tests/tool.go
@@ -1078,7 +1078,7 @@ func setupPostgresSchemas(t *testing.T, ctx context.Context, pool *pgxpool.Pool,
|
||||
}
|
||||
|
||||
return func() {
|
||||
dropSchemaStmt := fmt.Sprintf("DROP SCHEMA %s", schemaName)
|
||||
dropSchemaStmt := fmt.Sprintf("DROP SCHEMA %s CASCADE", schemaName)
|
||||
_, err := pool.Exec(ctx, dropSchemaStmt)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to drop schema: %v", err)
|
||||
@@ -3163,6 +3163,278 @@ func RunPostgresReplicationStatsTest(t *testing.T, ctx context.Context, pool *pg
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresGetColumnCardinalityTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
schemaName := "testschema_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
tableName := "table1_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
cleanup := setupPostgresSchemas(t, ctx, pool, schemaName)
|
||||
defer cleanup()
|
||||
|
||||
// Create table with multiple columns
|
||||
createTableStmt := fmt.Sprintf(`
|
||||
CREATE TABLE %s.%s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
email VARCHAR(100) UNIQUE,
|
||||
name VARCHAR(50),
|
||||
status VARCHAR(20),
|
||||
created_at TIMESTAMP
|
||||
)
|
||||
`, schemaName, tableName)
|
||||
|
||||
if _, err := pool.Exec(ctx, createTableStmt); err != nil {
|
||||
t.Fatalf("unable to create table: %s", err)
|
||||
}
|
||||
|
||||
// Insert larger sample data to ensure statistics are collected
|
||||
insertStmt := fmt.Sprintf(`
|
||||
INSERT INTO %s.%s (email, name, status, created_at) VALUES
|
||||
('user1@example.com', 'Alice', 'active', NOW()),
|
||||
('user2@example.com', 'Bob', 'inactive', NOW()),
|
||||
('user3@example.com', 'Charlie', 'active', NOW()),
|
||||
('user4@example.com', 'David', 'active', NOW()),
|
||||
('user5@example.com', 'Eve', 'inactive', NOW()),
|
||||
('user6@example.com', 'Frank', 'active', NOW()),
|
||||
('user7@example.com', 'Grace', 'inactive', NOW()),
|
||||
('user8@example.com', 'Henry', 'active', NOW()),
|
||||
('user9@example.com', 'Ivy', 'active', NOW()),
|
||||
('user10@example.com', 'Jack', 'inactive', NOW())
|
||||
`, schemaName, tableName)
|
||||
|
||||
if _, err := pool.Exec(ctx, insertStmt); err != nil {
|
||||
t.Fatalf("unable to insert data: %s", err)
|
||||
}
|
||||
|
||||
// Run ANALYZE to update statistics
|
||||
analyzeStmt := fmt.Sprintf(`ANALYZE %s.%s`, schemaName, tableName)
|
||||
if _, err := pool.Exec(ctx, analyzeStmt); err != nil {
|
||||
t.Fatalf("unable to run ANALYZE: %s", err)
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
shouldHaveData bool // Whether we expect data in the response
|
||||
}{
|
||||
{
|
||||
name: "get cardinality for a specific column",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "column_name": "email"}`, schemaName, tableName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: true,
|
||||
},
|
||||
{
|
||||
name: "get cardinality for all columns",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s"}`, schemaName, tableName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: true,
|
||||
},
|
||||
{
|
||||
name: "get cardinality with non-existent column",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "%s", "table_name": "%s", "column_name": "non_existent"}`, schemaName, tableName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
{
|
||||
name: "get cardinality with non-existent schema",
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"schema_name": "non_existent_schema", "table_name": "%s"}`, tableName)),
|
||||
wantStatusCode: http.StatusOK,
|
||||
shouldHaveData: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/get_column_cardinality/invoke"
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got []map[string]any
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested result string: %v", err)
|
||||
}
|
||||
|
||||
// Verify that we got the expected data presence
|
||||
if tc.shouldHaveData {
|
||||
if len(got) == 0 {
|
||||
t.Logf("warning: expected data but got empty result. This can happen if pg_stats is not populated yet.")
|
||||
return
|
||||
}
|
||||
|
||||
// Verify column names and cardinality values
|
||||
for _, row := range got {
|
||||
columnName, ok := row["column_name"].(string)
|
||||
if !ok {
|
||||
t.Fatalf("column_name is not a string: %v", row["column_name"])
|
||||
}
|
||||
|
||||
// Check that estimated_cardinality is present and is a number
|
||||
cardinality, ok := row["estimated_cardinality"]
|
||||
if !ok {
|
||||
t.Fatalf("estimated_cardinality is missing for column %s", columnName)
|
||||
}
|
||||
|
||||
// Convert to float64 for numeric checks
|
||||
cardinalityFloat, ok := cardinality.(float64)
|
||||
if !ok {
|
||||
t.Fatalf("estimated_cardinality is not a number: %v", cardinality)
|
||||
}
|
||||
|
||||
// Cardinality should be >= 0
|
||||
if cardinalityFloat < 0 {
|
||||
t.Errorf("cardinality for column %s is negative: %v", columnName, cardinalityFloat)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(got) != 0 {
|
||||
t.Errorf("expected no data but got: %v", got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createPostgresExtension(t *testing.T, ctx context.Context, pool *pgxpool.Pool, extensionName string) func() {
|
||||
createExtensionCmd := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
|
||||
_, err := pool.Exec(ctx, createExtensionCmd)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create extension: %v", err)
|
||||
}
|
||||
return func() {
|
||||
dropExtensionCmd := fmt.Sprintf("DROP EXTENSION IF EXISTS %s", extensionName)
|
||||
_, err := pool.Exec(ctx, dropExtensionCmd)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to drop extension: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunPostgresListQueryStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
// Insert a simple query by running a SELECT statement
|
||||
// This will record statistics in pg_stat_statements
|
||||
selectStmt := "SELECT 1 as test_query"
|
||||
if _, err := pool.Exec(ctx, selectStmt); err != nil {
|
||||
t.Logf("warning: unable to execute test query: %s", err)
|
||||
}
|
||||
|
||||
dropExtensionFunc := createPostgresExtension(t, ctx, pool, "pg_stat_statements")
|
||||
defer dropExtensionFunc()
|
||||
|
||||
type queryStatDetails struct {
|
||||
Datname string `json:"datname"`
|
||||
Query string `json:"query"`
|
||||
Calls any `json:"calls"`
|
||||
TotalExecTime any `json:"total_exec_time"`
|
||||
MinExecTime any `json:"min_exec_time"`
|
||||
MaxExecTime any `json:"max_exec_time"`
|
||||
MeanExecTime any `json:"mean_exec_time"`
|
||||
Rows any `json:"rows"`
|
||||
SharedBlksHit any `json:"shared_blks_hit"`
|
||||
SharedBlksRead any `json:"shared_blks_read"`
|
||||
}
|
||||
|
||||
invokeTcs := []struct {
|
||||
name string
|
||||
requestBody io.Reader
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "list query stats with default limit",
|
||||
requestBody: bytes.NewBufferString(`{}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "list query stats with custom limit",
|
||||
requestBody: bytes.NewBufferString(`{"limit": 10}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "list query stats for specific database",
|
||||
requestBody: bytes.NewBufferString(`{"database_name": "postgres"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "list query stats with non-existent database name",
|
||||
requestBody: bytes.NewBufferString(`{"database_name": "non_existent_db_xyz"}`),
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range invokeTcs {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
const api = "http://127.0.0.1:5000/api/tool/list_query_stats/invoke"
|
||||
resp, respBody := RunRequest(t, http.MethodPost, api, tc.requestBody, nil)
|
||||
if resp.StatusCode != tc.wantStatusCode {
|
||||
t.Fatalf("wrong status code: got %d, want %d, body: %s", resp.StatusCode, tc.wantStatusCode, string(respBody))
|
||||
}
|
||||
if tc.wantStatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
|
||||
var bodyWrapper struct {
|
||||
Result json.RawMessage `json:"result"`
|
||||
}
|
||||
if err := json.Unmarshal(respBody, &bodyWrapper); err != nil {
|
||||
t.Fatalf("error decoding response wrapper: %v", err)
|
||||
}
|
||||
|
||||
var resultString string
|
||||
if err := json.Unmarshal(bodyWrapper.Result, &resultString); err != nil {
|
||||
resultString = string(bodyWrapper.Result)
|
||||
}
|
||||
|
||||
var got []map[string]any
|
||||
if err := json.Unmarshal([]byte(resultString), &got); err != nil {
|
||||
t.Fatalf("failed to unmarshal nested result string: %v, resultString: %s", err, resultString)
|
||||
}
|
||||
|
||||
// For databases with pg_stat_statements available, verify response structure
|
||||
if len(got) > 0 {
|
||||
// Verify the response has the expected fields
|
||||
requiredFields := []string{"datname", "query", "calls", "total_exec_time", "min_exec_time", "max_exec_time", "mean_exec_time", "rows", "shared_blks_hit", "shared_blks_read"}
|
||||
for _, field := range requiredFields {
|
||||
if _, ok := got[0][field]; !ok {
|
||||
t.Errorf("missing expected field: %s in result: %v", field, got[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify data types
|
||||
var stat queryStatDetails
|
||||
statData, _ := json.Marshal(got[0])
|
||||
if err := json.Unmarshal(statData, &stat); err != nil {
|
||||
t.Logf("warning: unable to unmarshal query stat: %v", err)
|
||||
}
|
||||
|
||||
// Verify that results are ordered by total_exec_time (descending)
|
||||
if len(got) > 1 {
|
||||
for i := 0; i < len(got)-1; i++ {
|
||||
currentTime, ok1 := got[i]["total_exec_time"].(float64)
|
||||
nextTime, ok2 := got[i+1]["total_exec_time"].(float64)
|
||||
if ok1 && ok2 && currentTime < nextTime {
|
||||
t.Logf("warning: results may not be ordered by total_exec_time descending: %f vs %f", currentTime, nextTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RunRequest is a helper function to send HTTP requests and return the response
|
||||
func RunRequest(t *testing.T, method, url string, body io.Reader, headers map[string]string) (*http.Response, []byte) {
|
||||
// Send request
|
||||
|
||||
Reference in New Issue
Block a user