diff --git a/.ci/integration.cloudbuild.yaml b/.ci/integration.cloudbuild.yaml index bfa16e7156..96bbbd8b9b 100644 --- a/.ci/integration.cloudbuild.yaml +++ b/.ci/integration.cloudbuild.yaml @@ -354,6 +354,30 @@ steps: postgressql \ postgresexecutesql + - id: "cockroachdb" + name: golang:1 + waitFor: ["compile-test-binary"] + entrypoint: /bin/bash + env: + - "GOPATH=/gopath" + - "COCKROACHDB_DATABASE=$_DATABASE_NAME" + - "COCKROACHDB_PORT=$_COCKROACHDB_PORT" + - "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL" + secretEnv: ["COCKROACHDB_USER", "COCKROACHDB_HOST","CLIENT_ID"] + volumes: + - name: "go" + path: "/gopath" + args: + - -c + - | + .ci/test_with_coverage.sh \ + "CockroachDB" \ + cockroachdb \ + cockroachdbsql \ + cockroachdbexecutesql \ + cockroachdblisttables \ + cockroachdblistschemas + - id: "spanner" name: golang:1 waitFor: ["compile-test-binary"] @@ -1129,6 +1153,11 @@ availableSecrets: env: MARIADB_HOST - versionName: projects/$PROJECT_ID/secrets/mongodb_uri/versions/latest env: MONGODB_URI + - versionName: projects/$PROJECT_ID/secrets/cockroachdb_user/versions/latest + env: COCKROACHDB_USER + - versionName: projects/$PROJECT_ID/secrets/cockroachdb_host/versions/latest + env: COCKROACHDB_HOST + options: logging: CLOUD_LOGGING_ONLY @@ -1189,6 +1218,9 @@ substitutions: _SINGLESTORE_PORT: "3308" _SINGLESTORE_DATABASE: "singlestore" _SINGLESTORE_USER: "root" + _COCKROACHDB_HOST: 127.0.0.1 + _COCKROACHDB_PORT: "26257" + _COCKROACHDB_USER: "root" _MARIADB_PORT: "3307" _MARIADB_DATABASE: test_database _SNOWFLAKE_DATABASE: "test" diff --git a/cmd/root.go b/cmd/root.go index 5e59997211..3d62c11dc8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -110,6 +110,10 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlmysql/cloudsqlmysqlcreateinstance" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgcreateinstances" _ "github.com/googleapis/genai-toolbox/internal/tools/cloudsqlpg/cloudsqlpgupgradeprecheck" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" _ "github.com/googleapis/genai-toolbox/internal/tools/couchbase" _ "github.com/googleapis/genai-toolbox/internal/tools/dataform/dataformcompilelocal" _ "github.com/googleapis/genai-toolbox/internal/tools/dataplex/dataplexlookupentry" @@ -256,6 +260,7 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmssql" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlmysql" _ "github.com/googleapis/genai-toolbox/internal/sources/cloudsqlpg" + _ "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" _ "github.com/googleapis/genai-toolbox/internal/sources/couchbase" _ "github.com/googleapis/genai-toolbox/internal/sources/dataplex" _ "github.com/googleapis/genai-toolbox/internal/sources/dgraph" diff --git a/docs/en/resources/sources/cockroachdb.md b/docs/en/resources/sources/cockroachdb.md new file mode 100644 index 0000000000..9ecf884ce5 --- /dev/null +++ b/docs/en/resources/sources/cockroachdb.md @@ -0,0 +1,242 @@ +--- +title: "CockroachDB" +type: docs +weight: 1 +description: > + CockroachDB is a distributed SQL database built for cloud applications. + +--- + +## About + +[CockroachDB][crdb-docs] is a distributed SQL database designed for cloud-native applications. It provides strong consistency, horizontal scalability, and built-in resilience with automatic failover and recovery. CockroachDB uses the PostgreSQL wire protocol, making it compatible with many PostgreSQL tools and drivers while providing unique features like multi-region deployments and distributed transactions. + +**Minimum Version:** CockroachDB v25.1 or later is recommended for full tool compatibility. + +[crdb-docs]: https://www.cockroachlabs.com/docs/ + +## Available Tools + +- [`cockroachdb-sql`](../tools/cockroachdb/cockroachdb-sql.md) + Execute SQL queries as prepared statements in CockroachDB (alias for execute-sql). + +- [`cockroachdb-execute-sql`](../tools/cockroachdb/cockroachdb-execute-sql.md) + Run parameterized SQL statements in CockroachDB. + +- [`cockroachdb-list-schemas`](../tools/cockroachdb/cockroachdb-list-schemas.md) + List schemas in a CockroachDB database. + +- [`cockroachdb-list-tables`](../tools/cockroachdb/cockroachdb-list-tables.md) + List tables in a CockroachDB database. + +## Requirements + +### Database User + +This source uses standard authentication. You will need to [create a CockroachDB user][crdb-users] to login to the database with. For CockroachDB Cloud deployments, SSL/TLS is required. + +[crdb-users]: https://www.cockroachlabs.com/docs/stable/create-user.html + +### SSL/TLS Configuration + +CockroachDB Cloud clusters require SSL/TLS connections. Use the `queryParams` section to configure SSL settings: + +- **For CockroachDB Cloud**: Use `sslmode: require` at minimum +- **For self-hosted with certificates**: Use `sslmode: verify-full` with certificate paths +- **For local development only**: Use `sslmode: disable` (not recommended for production) + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + maxRetries: 5 + retryBaseDelay: 500ms + queryParams: + sslmode: require + application_name: my-app + + # MCP Security Settings (recommended for production) + readOnlyMode: true # Read-only by default (MCP best practice) + enableWriteMode: false # Set to true to allow write operations + maxRowLimit: 1000 # Limit query results + queryTimeoutSec: 30 # Prevent long-running queries + enableTelemetry: true # Enable observability + telemetryVerbose: false # Set true for detailed logs + clusterID: "my-cluster" # Optional identifier + +tools: + list_expenses: + type: cockroachdb-sql + source: my_cockroachdb + description: List all expenses + statement: SELECT id, description, amount, category FROM expenses WHERE user_id = $1 + parameters: + - name: user_id + type: string + description: The user's ID + + describe_expenses: + type: cockroachdb-describe-table + source: my_cockroachdb + description: Describe the expenses table schema + + list_expenses_indexes: + type: cockroachdb-list-indexes + source: my_cockroachdb + description: List indexes on the expenses table +``` + +## Configuration Parameters + +### Required Parameters + +| Parameter | Type | Description | +|-----------|------|-------------| +| `type` | string | Must be `cockroachdb` | +| `host` | string | The hostname or IP address of the CockroachDB cluster | +| `port` | string | The port number (typically "26257") | +| `user` | string | The database user name | +| `database` | string | The database name to connect to | + +### Optional Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `password` | string | "" | The database password (can be empty for certificate-based auth) | +| `maxRetries` | integer | 5 | Maximum number of connection retry attempts | +| `retryBaseDelay` | string | "500ms" | Base delay between retry attempts (exponential backoff) | +| `queryParams` | map | {} | Additional connection parameters (e.g., SSL configuration) | + +### MCP Security Parameters + +CockroachDB integration includes security features following the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) specification: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `readOnlyMode` | boolean | true | Enables read-only mode by default (MCP requirement) | +| `enableWriteMode` | boolean | false | Explicitly enable write operations (INSERT/UPDATE/DELETE/CREATE/DROP) | +| `maxRowLimit` | integer | 1000 | Maximum rows returned per SELECT query (auto-adds LIMIT clause) | +| `queryTimeoutSec` | integer | 30 | Query timeout in seconds to prevent long-running queries | +| `enableTelemetry` | boolean | true | Enable structured logging of tool invocations | +| `telemetryVerbose` | boolean | false | Enable detailed JSON telemetry output | +| `clusterID` | string | "" | Optional cluster identifier for telemetry | + +### Query Parameters + +Common query parameters for CockroachDB connections: + +| Parameter | Values | Description | +|-----------|--------|-------------| +| `sslmode` | `disable`, `require`, `verify-ca`, `verify-full` | SSL/TLS mode (CockroachDB Cloud requires `require` or higher) | +| `sslrootcert` | file path | Path to root certificate for SSL verification | +| `sslcert` | file path | Path to client certificate | +| `sslkey` | file path | Path to client key | +| `application_name` | string | Application name for connection tracking | + +## Best Practices + +### Security and MCP Compliance + +**Read-Only by Default**: The integration follows MCP best practices by defaulting to read-only mode. This prevents accidental data modifications: + +```yaml +sources: + my_cockroachdb: + readOnlyMode: true # Default behavior + enableWriteMode: false # Explicit write opt-in required +``` + +To enable write operations: + +```yaml +sources: + my_cockroachdb: + readOnlyMode: false # Disable read-only protection + enableWriteMode: true # Explicitly allow writes +``` + +**Query Limits**: Automatic row limits prevent excessive data retrieval: +- SELECT queries automatically get `LIMIT 1000` appended (configurable via `maxRowLimit`) +- Queries are terminated after 30 seconds (configurable via `queryTimeoutSec`) + +**Observability**: Structured telemetry provides visibility into tool usage: +- Tool invocations are logged with status, latency, and row counts +- SQL queries are redacted to protect sensitive values +- Set `telemetryVerbose: true` for detailed JSON logs + +### Use UUID Primary Keys + +CockroachDB performs best with UUID primary keys rather than sequential integers to avoid transaction hotspots: + +```sql +CREATE TABLE expenses ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + description TEXT, + amount DECIMAL(10,2) +); +``` + +### Automatic Transaction Retry + +This source uses the official `cockroach-go/v2` library which provides automatic transaction retry for serialization conflicts. For write operations requiring explicit transaction control, tools can use the `ExecuteTxWithRetry` method. + +### Multi-Region Deployments + +CockroachDB supports multi-region deployments with automatic data distribution. Configure your cluster's regions and survival goals separately from the Toolbox configuration. The source will connect to any node in the cluster. + +### Connection Pooling + +The source maintains a connection pool to the CockroachDB cluster. The pool automatically handles: +- Load balancing across cluster nodes +- Connection retry with exponential backoff +- Health checking of connections + +## Troubleshooting + +### SSL/TLS Errors + +If you encounter "server requires encryption" errors: + +1. For CockroachDB Cloud, ensure `sslmode` is set to `require` or higher: + ```yaml + queryParams: + sslmode: require + ``` + +2. For certificate verification, download your cluster's root certificate and configure: + ```yaml + queryParams: + sslmode: verify-full + sslrootcert: /path/to/ca.crt + ``` + +### Connection Timeouts + +If experiencing connection timeouts: + +1. Check network connectivity to the CockroachDB cluster +2. Verify firewall rules allow connections on port 26257 +3. For CockroachDB Cloud, ensure IP allowlisting is configured +4. Increase `maxRetries` or `retryBaseDelay` if needed + +### Transaction Retry Errors + +CockroachDB may encounter serializable transaction conflicts. The integration automatically handles these retries using the cockroach-go library. If you see retry-related errors, check: + +1. Database load and contention +2. Query patterns that might cause conflicts +3. Consider using `SELECT FOR UPDATE` for explicit locking + +## Additional Resources + +- [CockroachDB Documentation](https://www.cockroachlabs.com/docs/) +- [CockroachDB Best Practices](https://www.cockroachlabs.com/docs/stable/performance-best-practices-overview.html) +- [Multi-Region Capabilities](https://www.cockroachlabs.com/docs/stable/multiregion-overview.html) +- [Connection Parameters](https://www.cockroachlabs.com/docs/stable/connection-parameters.html) diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md b/docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md new file mode 100644 index 0000000000..10a78926f4 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-execute-sql.md @@ -0,0 +1,273 @@ +--- +title: "cockroachdb-execute-sql" +type: docs +weight: 1 +description: > + Execute ad-hoc SQL statements against a CockroachDB database. + +--- + +## About + +A `cockroachdb-execute-sql` tool executes ad-hoc SQL statements against a CockroachDB database. This tool is designed for interactive workflows where the SQL query is provided dynamically at runtime, making it ideal for developer assistance and exploratory data analysis. + +The tool takes a single `sql` parameter containing the SQL statement to execute and returns the query results. + +> **Note:** This tool is intended for developer assistant workflows with human-in-the-loop and shouldn't be used for production agents. For production use cases with predefined queries, use [cockroachdb-sql](./cockroachdb-sql.md) instead. + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + execute_sql: + type: cockroachdb-execute-sql + source: my_cockroachdb + description: Execute any SQL statement against the CockroachDB database +``` + +## Usage Examples + +### Simple SELECT Query + +```json +{ + "sql": "SELECT * FROM users LIMIT 10" +} +``` + +### Query with Aggregations + +```json +{ + "sql": "SELECT category, COUNT(*) as count, SUM(amount) as total FROM expenses GROUP BY category ORDER BY total DESC" +} +``` + +### Database Introspection + +```json +{ + "sql": "SHOW TABLES" +} +``` + +```json +{ + "sql": "SHOW COLUMNS FROM expenses" +} +``` + +### Multi-Region Information + +```json +{ + "sql": "SHOW REGIONS FROM DATABASE defaultdb" +} +``` + +```json +{ + "sql": "SHOW ZONE CONFIGURATIONS" +} +``` + +## CockroachDB-Specific Features + +### Check Cluster Version + +```json +{ + "sql": "SELECT version()" +} +``` + +### View Node Status + +```json +{ + "sql": "SELECT node_id, address, locality, is_live FROM crdb_internal.gossip_nodes" +} +``` + +### Check Replication Status + +```json +{ + "sql": "SELECT range_id, start_key, end_key, replicas, lease_holder FROM crdb_internal.ranges LIMIT 10" +} +``` + +### View Table Regions + +```json +{ + "sql": "SHOW REGIONS FROM TABLE expenses" +} +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-execute-sql` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description for the LLM | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `authRequired` | array | List of authentication services required | + +## Parameters + +The tool accepts a single runtime parameter: + +| Parameter | Type | Description | +|-----------|------|-------------| +| `sql` | string | The SQL statement to execute | + +## Best Practices + +### Use for Exploration, Not Production + +This tool is ideal for: +- Interactive database exploration +- Ad-hoc analysis and reporting +- Debugging and troubleshooting +- Schema inspection + +For production use cases, use [cockroachdb-sql](./cockroachdb-sql.md) with parameterized queries. + +### Be Cautious with Data Modification + +While this tool can execute any SQL statement, be careful with: +- `INSERT`, `UPDATE`, `DELETE` statements +- `DROP` or `ALTER` statements +- Schema changes in production + +### Use LIMIT for Large Results + +Always use `LIMIT` clauses when exploring data: + +```sql +SELECT * FROM large_table LIMIT 100 +``` + +### Leverage CockroachDB's SQL Extensions + +CockroachDB supports PostgreSQL syntax plus extensions: + +```sql +-- Show database survival goal +SHOW SURVIVAL GOAL FROM DATABASE defaultdb; + +-- View zone configurations +SHOW ZONE CONFIGURATION FOR TABLE expenses; + +-- Check table localities +SHOW CREATE TABLE expenses; +``` + +## Error Handling + +The tool will return descriptive errors for: +- **Syntax errors**: Invalid SQL syntax +- **Permission errors**: Insufficient user privileges +- **Connection errors**: Network or authentication issues +- **Runtime errors**: Constraint violations, type mismatches, etc. + +## Security Considerations + +### SQL Injection Risk + +Since this tool executes arbitrary SQL, it should only be used with: +- Trusted users in interactive sessions +- Human-in-the-loop workflows +- Development and testing environments + +Never expose this tool directly to end users without proper authorization controls. + +### Use Authentication + +Configure the `authRequired` field to restrict access: + +```yaml +tools: + execute_sql: + type: cockroachdb-execute-sql + source: my_cockroachdb + description: Execute SQL statements + authRequired: + - my-auth-service +``` + +### Read-Only Users + +For safer exploration, create read-only database users: + +```sql +CREATE USER readonly_user; +GRANT SELECT ON DATABASE defaultdb TO readonly_user; +``` + +## Common Use Cases + +### Database Administration + +```sql +-- View database size +SELECT + table_name, + pg_size_pretty(pg_total_relation_size(table_name::regclass)) AS size +FROM information_schema.tables +WHERE table_schema = 'public' +ORDER BY pg_total_relation_size(table_name::regclass) DESC; +``` + +### Performance Analysis + +```sql +-- Find slow queries +SELECT query, count, mean_latency +FROM crdb_internal.statement_statistics +WHERE mean_latency > INTERVAL '1 second' +ORDER BY mean_latency DESC +LIMIT 10; +``` + +### Data Quality Checks + +```sql +-- Find NULL values +SELECT COUNT(*) as null_count +FROM expenses +WHERE description IS NULL OR amount IS NULL; + +-- Find duplicates +SELECT user_id, email, COUNT(*) as count +FROM users +GROUP BY user_id, email +HAVING COUNT(*) > 1; +``` + +## See Also + +- [cockroachdb-sql](./cockroachdb-sql.md) - For parameterized, production-ready queries +- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database +- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference +- [CockroachDB SQL Reference](https://www.cockroachlabs.com/docs/stable/sql-statements.html) - Official SQL documentation diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md b/docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md new file mode 100644 index 0000000000..8a9ee11292 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-list-schemas.md @@ -0,0 +1,305 @@ +--- +title: "cockroachdb-list-schemas" +type: docs +weight: 1 +description: > + List schemas in a CockroachDB database. + +--- + +## About + +The `cockroachdb-list-schemas` tool retrieves a list of schemas (namespaces) in a CockroachDB database. Schemas are used to organize database objects such as tables, views, and functions into logical groups. + +This tool is useful for: +- Understanding database organization +- Discovering available schemas +- Multi-tenant application analysis +- Schema-level access control planning + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + list_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: List all schemas in the database +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-list-schemas` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description for the LLM | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `authRequired` | array | List of authentication services required | + +## Output Structure + +The tool returns a list of schemas with the following information: + +```json +[ + { + "catalog_name": "defaultdb", + "schema_name": "public", + "is_user_defined": true + }, + { + "catalog_name": "defaultdb", + "schema_name": "analytics", + "is_user_defined": true + } +] +``` + +### Fields + +| Field | Type | Description | +|-------|------|-------------| +| `catalog_name` | string | The database (catalog) name | +| `schema_name` | string | The schema name | +| `is_user_defined` | boolean | Whether this is a user-created schema (excludes system schemas) | + +## Usage Example + +```json +{} +``` + +No parameters are required. The tool automatically lists all user-defined schemas. + +## Default Schemas + +CockroachDB includes several standard schemas: + +- **`public`**: The default schema for user objects +- **`pg_catalog`**: PostgreSQL system catalog (excluded from results) +- **`information_schema`**: SQL standard metadata views (excluded from results) +- **`crdb_internal`**: CockroachDB internal metadata (excluded from results) +- **`pg_extension`**: PostgreSQL extension objects (excluded from results) + +The tool filters out system schemas and only returns user-defined schemas. + +## Schema Management in CockroachDB + +### Creating Schemas + +```sql +CREATE SCHEMA analytics; +``` + +### Using Schemas + +```sql +-- Create table in specific schema +CREATE TABLE analytics.revenue ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + amount DECIMAL(10,2), + date DATE +); + +-- Query from specific schema +SELECT * FROM analytics.revenue; +``` + +### Schema Search Path + +The search path determines which schemas are searched for unqualified object names: + +```sql +-- Show current search path +SHOW search_path; + +-- Set search path +SET search_path = analytics, public; +``` + +## Multi-Tenant Applications + +Schemas are commonly used for multi-tenant applications: + +```sql +-- Create schema per tenant +CREATE SCHEMA tenant_acme; +CREATE SCHEMA tenant_globex; + +-- Create same table structure in each schema +CREATE TABLE tenant_acme.orders (...); +CREATE TABLE tenant_globex.orders (...); +``` + +The `cockroachdb-list-schemas` tool helps discover all tenant schemas: + +```yaml +tools: + list_tenants: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + List all tenant schemas in the database. + Each schema represents a separate tenant's data namespace. +``` + +## Best Practices + +### Use Schemas for Organization + +Group related tables into schemas: + +```sql +CREATE SCHEMA sales; +CREATE SCHEMA inventory; +CREATE SCHEMA hr; + +CREATE TABLE sales.orders (...); +CREATE TABLE inventory.products (...); +CREATE TABLE hr.employees (...); +``` + +### Schema Naming Conventions + +Use clear, descriptive schema names: +- Lowercase names +- Use underscores for multi-word names +- Avoid reserved keywords +- Use prefixes for grouped schemas (e.g., `tenant_`, `app_`) + +### Schema-Level Permissions + +Schemas enable fine-grained access control: + +```sql +-- Grant access to specific schema +GRANT USAGE ON SCHEMA analytics TO analyst_role; +GRANT SELECT ON ALL TABLES IN SCHEMA analytics TO analyst_role; + +-- Revoke access +REVOKE ALL ON SCHEMA hr FROM public; +``` + +## Integration with Other Tools + +### Combined with List Tables + +```yaml +tools: + list_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: List all schemas first + + list_tables: + type: cockroachdb-list-tables + source: my_cockroachdb + description: | + List tables in the database. + Use list_schemas first to understand schema organization. +``` + +### Schema Discovery Workflow + +1. Call `cockroachdb-list-schemas` to discover schemas +2. Call `cockroachdb-list-tables` to see tables in each schema +3. Generate queries using fully qualified names: `schema.table` + +## Common Use Cases + +### Discover Database Structure + +```yaml +tools: + discover_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + Discover how the database is organized into schemas. + Use this to understand the logical grouping of tables. +``` + +### Multi-Tenant Analysis + +```yaml +tools: + list_tenant_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + List all tenant schemas (each tenant has their own schema). + Schema names follow the pattern: tenant_ +``` + +### Schema Migration Planning + +```yaml +tools: + audit_schemas: + type: cockroachdb-list-schemas + source: my_cockroachdb + description: | + Audit existing schemas before migration. + Identifies all schemas that need to be migrated. +``` + +## Error Handling + +The tool handles common errors: +- **Connection errors**: Returns connection failure details +- **Permission errors**: Returns error if user lacks USAGE privilege +- **Empty results**: Returns empty array if no user schemas exist + +## Permissions Required + +To list schemas, the user needs: +- `CONNECT` privilege on the database +- No specific schema privileges required for listing + +To query objects within schemas, the user needs: +- `USAGE` privilege on the schema +- Appropriate object privileges (SELECT, INSERT, etc.) + +## CockroachDB-Specific Features + +### System Schemas + +CockroachDB includes PostgreSQL-compatible system schemas plus CockroachDB-specific ones: + +- `crdb_internal.*`: CockroachDB internal metadata and statistics +- `pg_catalog.*`: PostgreSQL system catalog +- `information_schema.*`: SQL standard information schema + +These are automatically filtered from the results. + +### User-Defined Flag + +The `is_user_defined` field helps distinguish: +- `true`: User-created schemas +- `false`: System schemas (already filtered out) + +## See Also + +- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries +- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL +- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference +- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Official documentation diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md b/docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md new file mode 100644 index 0000000000..339dbd2320 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-list-tables.md @@ -0,0 +1,344 @@ +--- +title: "cockroachdb-list-tables" +type: docs +weight: 1 +description: > + List tables in a CockroachDB database with schema details. + +--- + +## About + +The `cockroachdb-list-tables` tool retrieves a list of tables from a CockroachDB database. It provides detailed information about table structure, including columns, constraints, indexes, and foreign key relationships. + +This tool is useful for: +- Database schema discovery +- Understanding table relationships +- Generating context for AI-powered database queries +- Documentation and analysis + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + list_all_tables: + type: cockroachdb-list-tables + source: my_cockroachdb + description: List all user tables in the database with their structure +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-list-tables` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description for the LLM | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `authRequired` | array | List of authentication services required | + +## Parameters + +The tool accepts optional runtime parameters: + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `table_names` | array | all tables | List of specific table names to retrieve | +| `output_format` | string | "detailed" | Output format: "simple" or "detailed" | + +## Output Formats + +### Simple Format + +Returns basic table information: +- Table name +- Row count estimate +- Size information + +```json +{ + "table_names": ["users"], + "output_format": "simple" +} +``` + +### Detailed Format (Default) + +Returns comprehensive table information: +- Table name and schema +- All columns with types and constraints +- Primary keys +- Foreign keys and relationships +- Indexes +- Check constraints +- Table size and row counts + +```json +{ + "table_names": ["users", "orders"], + "output_format": "detailed" +} +``` + +## Usage Examples + +### List All Tables + +```json +{} +``` + +### List Specific Tables + +```json +{ + "table_names": ["users", "orders", "expenses"] +} +``` + +### Simple Output + +```json +{ + "output_format": "simple" +} +``` + +## Output Structure + +### Simple Format Output + +```json +{ + "table_name": "users", + "estimated_rows": 1000, + "size": "128 KB" +} +``` + +### Detailed Format Output + +```json +{ + "table_name": "users", + "schema": "public", + "columns": [ + { + "name": "id", + "type": "UUID", + "nullable": false, + "default": "gen_random_uuid()" + }, + { + "name": "email", + "type": "STRING", + "nullable": false, + "default": null + }, + { + "name": "created_at", + "type": "TIMESTAMP", + "nullable": false, + "default": "now()" + } + ], + "primary_key": ["id"], + "indexes": [ + { + "name": "users_pkey", + "columns": ["id"], + "unique": true, + "primary": true + }, + { + "name": "users_email_idx", + "columns": ["email"], + "unique": true, + "primary": false + } + ], + "foreign_keys": [], + "constraints": [ + { + "name": "users_email_check", + "type": "CHECK", + "definition": "email ~* '^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Z|a-z]{2,}$'" + } + ] +} +``` + +## CockroachDB-Specific Information + +### UUID Primary Keys + +The tool recognizes CockroachDB's recommended UUID primary key pattern: + +```sql +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + ... +); +``` + +### Multi-Region Tables + +For multi-region tables, the output includes locality information: + +```json +{ + "table_name": "users", + "locality": "REGIONAL BY ROW", + "regions": ["us-east-1", "us-west-2", "eu-west-1"] +} +``` + +### Interleaved Tables + +The tool shows parent-child relationships for interleaved tables (legacy feature): + +```json +{ + "table_name": "order_items", + "interleaved_in": "orders" +} +``` + +## Best Practices + +### Use for Schema Discovery + +The tool is ideal for helping AI assistants understand your database structure: + +```yaml +tools: + discover_schema: + type: cockroachdb-list-tables + source: my_cockroachdb + description: | + Use this tool first to understand the database schema before generating queries. + It shows all tables, their columns, data types, and relationships. +``` + +### Filter Large Schemas + +For databases with many tables, specify relevant tables: + +```json +{ + "table_names": ["users", "orders", "products"], + "output_format": "detailed" +} +``` + +### Use Simple Format for Overviews + +When you need just table names and sizes: + +```json +{ + "output_format": "simple" +} +``` + +## Excluded Tables + +The tool automatically excludes system tables and schemas: +- `pg_catalog.*` - PostgreSQL system catalog +- `information_schema.*` - SQL standard information schema +- `crdb_internal.*` - CockroachDB internal tables +- `pg_extension.*` - PostgreSQL extension tables + +Only user-created tables in the public schema (and other user schemas) are returned. + +## Error Handling + +The tool handles common errors: +- **Table not found**: Returns empty result for non-existent tables +- **Permission errors**: Returns error if user lacks SELECT privileges +- **Connection errors**: Returns connection failure details + +## Integration with AI Assistants + +### Prompt Example + +```yaml +tools: + list_tables: + type: cockroachdb-list-tables + source: my_cockroachdb + description: | + Lists all tables in the database with detailed schema information. + Use this tool to understand: + - What tables exist + - What columns each table has + - Data types and constraints + - Relationships between tables (foreign keys) + - Available indexes + + Always call this tool before generating SQL queries to ensure + you use correct table and column names. +``` + +## Common Use Cases + +### Generate Context for Queries + +```json +{} +``` + +This provides comprehensive schema information that helps AI assistants generate accurate SQL queries. + +### Analyze Table Structure + +```json +{ + "table_names": ["users"], + "output_format": "detailed" +} +``` + +Perfect for understanding a specific table's structure, constraints, and relationships. + +### Quick Schema Overview + +```json +{ + "output_format": "simple" +} +``` + +Gets a quick list of tables with basic statistics. + +## Performance Considerations + +- **Simple format** is faster for large databases +- **Detailed format** queries system tables extensively +- Specifying `table_names` reduces query time +- Results are fetched in a single query for efficiency + +## See Also + +- [cockroachdb-sql](./cockroachdb-sql.md) - Execute parameterized queries +- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - Execute ad-hoc SQL +- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference +- [CockroachDB Schema Design](https://www.cockroachlabs.com/docs/stable/schema-design-overview.html) - Best practices diff --git a/docs/en/resources/tools/cockroachdb/cockroachdb-sql.md b/docs/en/resources/tools/cockroachdb/cockroachdb-sql.md new file mode 100644 index 0000000000..aa31edcd52 --- /dev/null +++ b/docs/en/resources/tools/cockroachdb/cockroachdb-sql.md @@ -0,0 +1,291 @@ +--- +title: "cockroachdb-sql" +type: docs +weight: 1 +description: > + Execute parameterized SQL queries in CockroachDB. + +--- + +## About + +The `cockroachdb-sql` tool allows you to execute parameterized SQL queries against a CockroachDB database. This tool supports prepared statements with parameter binding, template parameters for dynamic query construction, and automatic transaction retry for resilience against serialization conflicts. + +## Example + +```yaml +sources: + my_cockroachdb: + type: cockroachdb + host: your-cluster.cockroachlabs.cloud + port: "26257" + user: myuser + password: mypassword + database: defaultdb + queryParams: + sslmode: require + +tools: + get_user_orders: + type: cockroachdb-sql + source: my_cockroachdb + description: Get all orders for a specific user + statement: | + SELECT o.id, o.order_date, o.total_amount, o.status + FROM orders o + WHERE o.user_id = $1 + ORDER BY o.order_date DESC + parameters: + - name: user_id + type: string + description: The UUID of the user +``` + +## Configuration + +### Required Fields + +| Field | Type | Description | +|-------|------|-------------| +| `type` | string | Must be `cockroachdb-sql` | +| `source` | string | Name of the CockroachDB source to use | +| `description` | string | Human-readable description of what the tool does | +| `statement` | string | The SQL query to execute | + +### Optional Fields + +| Field | Type | Description | +|-------|------|-------------| +| `parameters` | array | List of parameter definitions for the query | +| `templateParameters` | array | List of template parameters for dynamic query construction | +| `authRequired` | array | List of authentication services required | + +## Parameters + +Parameters allow you to safely pass values into your SQL queries using prepared statements. CockroachDB uses PostgreSQL-style parameter placeholders: `$1`, `$2`, etc. + +### Parameter Types + +- `string`: Text values +- `number`: Numeric values (integers or decimals) +- `boolean`: True/false values +- `array`: Array of values + +### Example with Multiple Parameters + +```yaml +tools: + filter_expenses: + type: cockroachdb-sql + source: my_cockroachdb + description: Filter expenses by category and date range + statement: | + SELECT id, description, amount, category, expense_date + FROM expenses + WHERE user_id = $1 + AND category = $2 + AND expense_date >= $3 + AND expense_date <= $4 + ORDER BY expense_date DESC + parameters: + - name: user_id + type: string + description: The user's UUID + - name: category + type: string + description: Expense category (e.g., "Food", "Transport") + - name: start_date + type: string + description: Start date in YYYY-MM-DD format + - name: end_date + type: string + description: End date in YYYY-MM-DD format +``` + +## Template Parameters + +Template parameters enable dynamic query construction by replacing placeholders in the SQL statement before parameter binding. This is useful for dynamic table names, column names, or query structure. + +### Example with Template Parameters + +```yaml +tools: + get_column_data: + type: cockroachdb-sql + source: my_cockroachdb + description: Get data from a specific column + statement: | + SELECT {{column_name}} + FROM {{table_name}} + WHERE user_id = $1 + LIMIT 100 + templateParameters: + - name: table_name + type: string + description: The table to query + - name: column_name + type: string + description: The column to retrieve + parameters: + - name: user_id + type: string + description: The user's UUID +``` + +## Best Practices + +### Use UUID Primary Keys + +CockroachDB performs best with UUID primary keys to avoid transaction hotspots: + +```sql +CREATE TABLE orders ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL, + order_date TIMESTAMP DEFAULT now(), + total_amount DECIMAL(10,2) +); +``` + +### Use Indexes for Performance + +Create indexes on frequently queried columns: + +```sql +CREATE INDEX idx_orders_user_id ON orders(user_id); +CREATE INDEX idx_orders_date ON orders(order_date DESC); +``` + +### Use JOINs Efficiently + +CockroachDB supports standard SQL JOINs. Keep joins efficient by: +- Adding appropriate indexes +- Using UUIDs for foreign keys +- Limiting result sets with WHERE clauses + +```yaml +tools: + get_user_with_orders: + type: cockroachdb-sql + source: my_cockroachdb + description: Get user details with their recent orders + statement: | + SELECT u.name, u.email, o.id as order_id, o.order_date, o.total_amount + FROM users u + LEFT JOIN orders o ON u.id = o.user_id + WHERE u.id = $1 + ORDER BY o.order_date DESC + LIMIT 10 + parameters: + - name: user_id + type: string + description: The user's UUID +``` + +### Handle NULL Values + +Use COALESCE or NULL checks when dealing with nullable columns: + +```sql +SELECT id, description, COALESCE(notes, 'No notes') as notes +FROM expenses +WHERE user_id = $1 +``` + +## Error Handling + +The tool automatically handles: +- **Connection errors**: Retried with exponential backoff +- **Serialization conflicts**: Automatically retried using cockroach-go library +- **Invalid parameters**: Returns descriptive error messages +- **SQL syntax errors**: Returns database error details + +## Advanced Usage + +### Aggregations + +```yaml +tools: + expense_summary: + type: cockroachdb-sql + source: my_cockroachdb + description: Get expense summary by category for a user + statement: | + SELECT + category, + COUNT(*) as count, + SUM(amount) as total_amount, + AVG(amount) as avg_amount + FROM expenses + WHERE user_id = $1 + AND expense_date >= $2 + GROUP BY category + ORDER BY total_amount DESC + parameters: + - name: user_id + type: string + description: The user's UUID + - name: start_date + type: string + description: Start date in YYYY-MM-DD format +``` + +### Window Functions + +```yaml +tools: + running_total: + type: cockroachdb-sql + source: my_cockroachdb + description: Get running total of expenses + statement: | + SELECT + expense_date, + amount, + SUM(amount) OVER (ORDER BY expense_date) as running_total + FROM expenses + WHERE user_id = $1 + ORDER BY expense_date + parameters: + - name: user_id + type: string + description: The user's UUID +``` + +### Common Table Expressions (CTEs) + +```yaml +tools: + top_spenders: + type: cockroachdb-sql + source: my_cockroachdb + description: Find top spending users + statement: | + WITH user_totals AS ( + SELECT + user_id, + SUM(amount) as total_spent + FROM expenses + WHERE expense_date >= $1 + GROUP BY user_id + ) + SELECT + u.name, + u.email, + ut.total_spent + FROM user_totals ut + JOIN users u ON ut.user_id = u.id + ORDER BY ut.total_spent DESC + LIMIT 10 + parameters: + - name: start_date + type: string + description: Start date in YYYY-MM-DD format +``` + +## See Also + +- [cockroachdb-execute-sql](./cockroachdb-execute-sql.md) - For ad-hoc SQL execution +- [cockroachdb-list-tables](./cockroachdb-list-tables.md) - List tables in the database +- [cockroachdb-list-schemas](./cockroachdb-list-schemas.md) - List database schemas +- [CockroachDB Source](../../sources/cockroachdb.md) - Source configuration reference diff --git a/go.mod b/go.mod index f74a7dc2c6..9fd4617e1b 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.30.0 github.com/apache/cassandra-gocql-driver/v2 v2.0.0 github.com/cenkalti/backoff/v5 v5.0.3 + github.com/cockroachdb/cockroach-go/v2 v2.4.2 github.com/couchbase/gocb/v2 v2.11.1 github.com/couchbase/tools-common/http v1.0.9 github.com/elastic/elastic-transport-go/v8 v8.8.0 diff --git a/go.sum b/go.sum index c3f59c72d4..8ebebe11d5 100644 --- a/go.sum +++ b/go.sum @@ -800,6 +800,8 @@ github.com/cncf/xds/go v0.0.0-20230105202645-06c439db220b/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20230607035331-e9ce68804cb4/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443 h1:aQ3y1lwWyqYPiWZThqv1aFbZMiM9vblcSArJRf2Irls= github.com/cncf/xds/go v0.0.0-20250501225837-2ac532fd4443/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cockroachdb/cockroach-go/v2 v2.4.2 h1:QB0ozDWQUUJ0GP8Zw63X/qHefPTCpLvtfCs6TLrPgyE= +github.com/cockroachdb/cockroach-go/v2 v2.4.2/go.mod h1:9U179XbCx4qFWtNhc7BiWLPfuyMVQ7qdAhfrwLz1vH0= github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4= github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= @@ -960,6 +962,8 @@ github.com/godror/godror v0.49.6 h1:ts4ZGw8uLJ42e1D7aXmVuSrld0/lzUzmIUjuUuQOgGM= github.com/godror/godror v0.49.6/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8= github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw= github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= diff --git a/internal/sources/cockroachdb/cockroachdb.go b/internal/sources/cockroachdb/cockroachdb.go new file mode 100644 index 0000000000..5a90fcf53e --- /dev/null +++ b/internal/sources/cockroachdb/cockroachdb.go @@ -0,0 +1,430 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "math" + "net/url" + "regexp" + "strings" + "time" + + crdbpgx "github.com/cockroachdb/cockroach-go/v2/crdb/crdbpgxv5" + "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel/trace" +) + +const SourceKind string = "cockroachdb" +const SourceType string = "cockroachdb" + +var _ sources.SourceConfig = Config{} + +func init() { + if !sources.Register(SourceKind, newConfig) { + panic(fmt.Sprintf("source kind %q already registered", SourceKind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) { + // MCP compliance: Read-only by default, require explicit opt-in for writes + actual := Config{ + Name: name, + MaxRetries: 5, + RetryBaseDelay: "500ms", + ReadOnlyMode: true, // MCP requirement: read-only by default + EnableWriteMode: false, // Must be explicitly enabled + MaxRowLimit: 1000, // MCP requirement: limit query results + QueryTimeoutSec: 30, // MCP requirement: prevent long-running queries + EnableTelemetry: true, // MCP requirement: observability + TelemetryVerbose: false, + } + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + + // Security validation: If EnableWriteMode is true, ReadOnlyMode should be false + if actual.EnableWriteMode { + actual.ReadOnlyMode = false + } + + return actual, nil +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Host string `yaml:"host" validate:"required"` + Port string `yaml:"port" validate:"required"` + User string `yaml:"user" validate:"required"` + Password string `yaml:"password"` + Database string `yaml:"database" validate:"required"` + QueryParams map[string]string `yaml:"queryParams"` + MaxRetries int `yaml:"maxRetries"` + RetryBaseDelay string `yaml:"retryBaseDelay"` + + // MCP Security Features + ReadOnlyMode bool `yaml:"readOnlyMode"` // Default: true (enforced in Initialize) + EnableWriteMode bool `yaml:"enableWriteMode"` // Explicit opt-in for write operations + MaxRowLimit int `yaml:"maxRowLimit"` // Default: 1000 + QueryTimeoutSec int `yaml:"queryTimeoutSec"` // Default: 30 + + // Observability + EnableTelemetry bool `yaml:"enableTelemetry"` // Default: true + TelemetryVerbose bool `yaml:"telemetryVerbose"` // Default: false + ClusterID string `yaml:"clusterID"` // Optional cluster identifier for telemetry +} + +func (r Config) SourceConfigKind() string { + return SourceKind +} + +func (r Config) SourceConfigType() string { + return SourceType +} + +func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) { + retryBaseDelay, err := time.ParseDuration(r.RetryBaseDelay) + if err != nil { + return nil, fmt.Errorf("invalid retryBaseDelay: %w", err) + } + + pool, err := initCockroachDBConnectionPoolWithRetry(ctx, tracer, r.Name, r.Host, r.Port, r.User, r.Password, r.Database, r.QueryParams, r.MaxRetries, retryBaseDelay) + if err != nil { + return nil, fmt.Errorf("unable to create pool: %w", err) + } + + s := &Source{ + Config: r, + Pool: pool, + } + return s, nil +} + +var _ sources.Source = &Source{} + +type Source struct { + Config + Pool *pgxpool.Pool +} + +func (s *Source) SourceKind() string { + return SourceKind +} + +func (s *Source) SourceType() string { + return SourceType +} + +func (s *Source) ToConfig() sources.SourceConfig { + return s.Config +} + +func (s *Source) CockroachDBPool() *pgxpool.Pool { + return s.Pool +} + +func (s *Source) PostgresPool() *pgxpool.Pool { + return s.Pool +} + +// ExecuteTxWithRetry executes a function within a transaction with automatic retry logic +// using the official CockroachDB retry mechanism from cockroach-go/v2 +func (s *Source) ExecuteTxWithRetry(ctx context.Context, fn func(pgx.Tx) error) error { + return crdbpgx.ExecuteTx(ctx, s.Pool, pgx.TxOptions{}, fn) +} + +// Query executes a query using the connection pool with MCP security enforcement. +// For read-only queries, connection-level retry is sufficient. +// For write operations requiring transaction retry, use ExecuteTxWithRetry directly. +// Note: Callers should manage context timeouts as needed. +func (s *Source) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + // MCP Security Check 1: Enforce write operation restrictions + if err := s.CanExecuteWrite(sql); err != nil { + return nil, err + } + + // MCP Security Check 2: Apply query limits (row limit) + modifiedSQL, err := s.ApplyQueryLimits(sql) + if err != nil { + return nil, err + } + + return s.Pool.Query(ctx, modifiedSQL, args...) +} + +// ============================================================================ +// MCP Security & Observability Features +// ============================================================================ + +// TelemetryEvent represents a structured telemetry event for MCP tool calls +type TelemetryEvent struct { + Timestamp time.Time `json:"timestamp"` + ToolName string `json:"tool_name"` + ClusterID string `json:"cluster_id"` + Database string `json:"database"` + User string `json:"user"` + SQLRedacted string `json:"sql_redacted"` // Query with values redacted + Status string `json:"status"` // "success" | "failure" + ErrorCode string `json:"error_code,omitempty"` + ErrorMsg string `json:"error_msg,omitempty"` + LatencyMs int64 `json:"latency_ms"` + RowsAffected int64 `json:"rows_affected,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// StructuredError represents an MCP-compliant error with error codes +type StructuredError struct { + Code string `json:"error_code"` + Message string `json:"message"` + Details map[string]any `json:"details,omitempty"` +} + +func (e *StructuredError) Error() string { + return fmt.Sprintf("[%s] %s", e.Code, e.Message) +} + +// MCP Error Codes +const ( + ErrCodeUnauthorized = "CRDB_UNAUTHORIZED" + ErrCodeReadOnlyViolation = "CRDB_READONLY_VIOLATION" + ErrCodeQueryTimeout = "CRDB_QUERY_TIMEOUT" + ErrCodeRowLimitExceeded = "CRDB_ROW_LIMIT_EXCEEDED" + ErrCodeInvalidSQL = "CRDB_INVALID_SQL" + ErrCodeConnectionFailed = "CRDB_CONNECTION_FAILED" + ErrCodeWriteModeRequired = "CRDB_WRITE_MODE_REQUIRED" + ErrCodeQueryExecutionFailed = "CRDB_QUERY_EXECUTION_FAILED" +) + +// SQLStatementType represents the type of SQL statement +type SQLStatementType int + +const ( + SQLTypeUnknown SQLStatementType = iota + SQLTypeSelect + SQLTypeInsert + SQLTypeUpdate + SQLTypeDelete + SQLTypeDDL // CREATE, ALTER, DROP + SQLTypeTruncate + SQLTypeExplain + SQLTypeShow + SQLTypeSet +) + +// ClassifySQL analyzes a SQL statement and returns its type +func ClassifySQL(sql string) SQLStatementType { + // Normalize: trim and convert to uppercase for analysis + normalized := strings.TrimSpace(strings.ToUpper(sql)) + + if normalized == "" { + return SQLTypeUnknown + } + + // Remove comments + normalized = regexp.MustCompile(`--.*`).ReplaceAllString(normalized, "") + normalized = regexp.MustCompile(`/\*.*?\*/`).ReplaceAllString(normalized, "") + normalized = strings.TrimSpace(normalized) + + // Check statement type + switch { + case strings.HasPrefix(normalized, "SELECT"): + return SQLTypeSelect + case strings.HasPrefix(normalized, "INSERT"): + return SQLTypeInsert + case strings.HasPrefix(normalized, "UPDATE"): + return SQLTypeUpdate + case strings.HasPrefix(normalized, "DELETE"): + return SQLTypeDelete + case strings.HasPrefix(normalized, "TRUNCATE"): + return SQLTypeTruncate + case strings.HasPrefix(normalized, "CREATE"): + return SQLTypeDDL + case strings.HasPrefix(normalized, "ALTER"): + return SQLTypeDDL + case strings.HasPrefix(normalized, "DROP"): + return SQLTypeDDL + case strings.HasPrefix(normalized, "EXPLAIN"): + return SQLTypeExplain + case strings.HasPrefix(normalized, "SHOW"): + return SQLTypeShow + case strings.HasPrefix(normalized, "SET"): + return SQLTypeSet + default: + return SQLTypeUnknown + } +} + +// IsWriteOperation returns true if the SQL statement modifies data +func IsWriteOperation(sqlType SQLStatementType) bool { + switch sqlType { + case SQLTypeInsert, SQLTypeUpdate, SQLTypeDelete, SQLTypeTruncate, SQLTypeDDL: + return true + default: + return false + } +} + +// IsReadOnlyMode returns whether the source is in read-only mode +func (s *Source) IsReadOnlyMode() bool { + return s.ReadOnlyMode && !s.EnableWriteMode +} + +// CanExecuteWrite checks if a write operation is allowed +func (s *Source) CanExecuteWrite(sql string) error { + sqlType := ClassifySQL(sql) + + if IsWriteOperation(sqlType) && s.IsReadOnlyMode() { + return &StructuredError{ + Code: ErrCodeReadOnlyViolation, + Message: "Write operations are not allowed in read-only mode. Set enableWriteMode: true to allow writes.", + Details: map[string]any{ + "sql_type": sqlType, + "read_only_mode": s.ReadOnlyMode, + "enable_write_mode": s.EnableWriteMode, + }, + } + } + + return nil +} + +// ApplyQueryLimits applies row limits to a SQL query for MCP security compliance. +// Context timeout management is the responsibility of the caller (following Go best practices). +// Returns potentially modified SQL with LIMIT clause for SELECT queries. +func (s *Source) ApplyQueryLimits(sql string) (string, error) { + sqlType := ClassifySQL(sql) + + // Apply row limit only to SELECT queries + if sqlType == SQLTypeSelect && s.MaxRowLimit > 0 { + // Check if query already has LIMIT clause + normalized := strings.ToUpper(sql) + if !strings.Contains(normalized, " LIMIT ") { + // Add LIMIT clause - trim trailing whitespace and semicolon + sql = strings.TrimSpace(sql) + sql = strings.TrimSuffix(sql, ";") + sql = fmt.Sprintf("%s LIMIT %d", sql, s.MaxRowLimit) + } + } + + return sql, nil +} + +// RedactSQL redacts sensitive values from SQL for telemetry +func RedactSQL(sql string) string { + // Redact string literals + sql = regexp.MustCompile(`'[^']*'`).ReplaceAllString(sql, "'***'") + + // Redact numbers that might be sensitive + sql = regexp.MustCompile(`\b\d{10,}\b`).ReplaceAllString(sql, "***") + + return sql +} + +// EmitTelemetry logs a telemetry event in structured JSON format +func (s *Source) EmitTelemetry(ctx context.Context, event TelemetryEvent) { + if !s.EnableTelemetry { + return + } + + // Set cluster ID if not already set + if event.ClusterID == "" { + event.ClusterID = s.ClusterID + if event.ClusterID == "" { + event.ClusterID = s.Database // Fallback to database name + } + } + + // Set database and user + if event.Database == "" { + event.Database = s.Database + } + if event.User == "" { + event.User = s.User + } + + // Log as structured JSON + if s.TelemetryVerbose { + jsonBytes, _ := json.Marshal(event) + slog.Info("CockroachDB MCP Telemetry", "event", string(jsonBytes)) + } else { + // Minimal logging + slog.Info("CockroachDB MCP", + "tool", event.ToolName, + "status", event.Status, + "latency_ms", event.LatencyMs, + "error_code", event.ErrorCode, + ) + } +} + +func initCockroachDBConnectionPoolWithRetry(ctx context.Context, tracer trace.Tracer, name, host, port, user, pass, dbname string, queryParams map[string]string, maxRetries int, baseDelay time.Duration) (*pgxpool.Pool, error) { + //nolint:all + ctx, span := sources.InitConnectionSpan(ctx, tracer, SourceKind, name) + defer span.End() + + userAgent, err := util.UserAgentFromContext(ctx) + if err != nil { + userAgent = "genai-toolbox" + } + if queryParams == nil { + queryParams = make(map[string]string) + } + if _, ok := queryParams["application_name"]; !ok { + queryParams["application_name"] = userAgent + } + + connURL := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(user, pass), + Host: fmt.Sprintf("%s:%s", host, port), + Path: dbname, + RawQuery: ConvertParamMapToRawQuery(queryParams), + } + + var pool *pgxpool.Pool + for attempt := 0; attempt <= maxRetries; attempt++ { + pool, err = pgxpool.New(ctx, connURL.String()) + if err == nil { + err = pool.Ping(ctx) + } + + if err == nil { + return pool, nil + } + + if attempt < maxRetries { + backoff := baseDelay * time.Duration(math.Pow(2, float64(attempt))) + time.Sleep(backoff) + } + } + + return nil, fmt.Errorf("failed to connect to CockroachDB after %d retries: %w", maxRetries, err) +} + +func ConvertParamMapToRawQuery(queryParams map[string]string) string { + values := url.Values{} + for k, v := range queryParams { + values.Add(k, v) + } + return values.Encode() +} diff --git a/internal/sources/cockroachdb/cockroachdb_test.go b/internal/sources/cockroachdb/cockroachdb_test.go new file mode 100644 index 0000000000..db14542c13 --- /dev/null +++ b/internal/sources/cockroachdb/cockroachdb_test.go @@ -0,0 +1,224 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "strings" + "testing" + + "github.com/goccy/go-yaml" +) + +func TestCockroachDBSourceConfig(t *testing.T) { + tests := []struct { + name string + yaml string + }{ + { + name: "valid config", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +maxRetries: 5 +retryBaseDelay: 500ms +queryParams: + sslmode: disable +`, + }, + { + name: "with optional queryParams", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: testpass +database: testdb +queryParams: + sslmode: require + sslcert: /path/to/cert +`, + }, + { + name: "with custom retry settings", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +maxRetries: 10 +retryBaseDelay: 1s +`, + }, + { + name: "without password (insecure mode)", + yaml: ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +database: defaultdb +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + decoder := yaml.NewDecoder(strings.NewReader(tt.yaml)) + cfg, err := newConfig(context.Background(), "test", decoder) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg == nil { + t.Fatal("expected config but got nil") + } + + // Verify it's the right type + cockroachCfg, ok := cfg.(Config) + if !ok { + t.Fatalf("expected Config type, got %T", cfg) + } + + // Verify SourceConfigType + if cockroachCfg.SourceConfigType() != SourceType { + t.Errorf("expected SourceConfigType %q, got %q", SourceType, cockroachCfg.SourceConfigType()) + } + + t.Logf("✅ Config parsed successfully: %+v", cockroachCfg) + }) + } +} + +func TestCockroachDBSourceType(t *testing.T) { + yamlContent := ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +` + decoder := yaml.NewDecoder(strings.NewReader(yamlContent)) + cfg, err := newConfig(context.Background(), "test", decoder) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + + if cfg.SourceConfigType() != "cockroachdb" { + t.Errorf("expected SourceConfigType 'cockroachdb', got %q", cfg.SourceConfigType()) + } +} + +func TestCockroachDBDefaultValues(t *testing.T) { + yamlContent := ` +name: test-cockroachdb +type: cockroachdb +host: localhost +port: "26257" +user: root +password: "" +database: defaultdb +` + decoder := yaml.NewDecoder(strings.NewReader(yamlContent)) + cfg, err := newConfig(context.Background(), "test", decoder) + if err != nil { + t.Fatalf("failed to create config: %v", err) + } + + cockroachCfg, ok := cfg.(Config) + if !ok { + t.Fatalf("expected Config type") + } + + // Check default values + if cockroachCfg.MaxRetries != 5 { + t.Errorf("expected default MaxRetries 5, got %d", cockroachCfg.MaxRetries) + } + + if cockroachCfg.RetryBaseDelay != "500ms" { + t.Errorf("expected default RetryBaseDelay '500ms', got %q", cockroachCfg.RetryBaseDelay) + } + + t.Logf("✅ Default values set correctly") +} + +func TestConvertParamMapToRawQuery(t *testing.T) { + tests := []struct { + name string + params map[string]string + want []string // Expected substrings in any order + }{ + { + name: "empty params", + params: map[string]string{}, + want: []string{}, + }, + { + name: "single param", + params: map[string]string{ + "sslmode": "disable", + }, + want: []string{"sslmode=disable"}, + }, + { + name: "multiple params", + params: map[string]string{ + "sslmode": "require", + "application_name": "test-app", + }, + want: []string{"sslmode=require", "application_name=test-app"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertParamMapToRawQuery(tt.params) + + if len(tt.want) == 0 { + if result != "" { + t.Errorf("expected empty string, got %q", result) + } + return + } + + // Check that all expected substrings are in the result + for _, want := range tt.want { + if !contains(result, want) { + t.Errorf("expected result to contain %q, got %q", want, result) + } + } + + t.Logf("✅ Query string: %s", result) + }) + } +} + +func contains(s, substr string) bool { + return strings.Contains(s, substr) +} diff --git a/internal/sources/cockroachdb/security_test.go b/internal/sources/cockroachdb/security_test.go new file mode 100644 index 0000000000..c80ad5b36e --- /dev/null +++ b/internal/sources/cockroachdb/security_test.go @@ -0,0 +1,455 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "strings" + "testing" + "time" + + yaml "github.com/goccy/go-yaml" +) + +// TestClassifySQL tests SQL statement classification +func TestClassifySQL(t *testing.T) { + tests := []struct { + name string + sql string + expected SQLStatementType + }{ + {"SELECT", "SELECT * FROM users", SQLTypeSelect}, + {"SELECT with spaces", " SELECT * FROM users ", SQLTypeSelect}, + {"SELECT with comment", "-- comment\nSELECT * FROM users", SQLTypeSelect}, + {"INSERT", "INSERT INTO users (name) VALUES ('alice')", SQLTypeInsert}, + {"UPDATE", "UPDATE users SET name='bob' WHERE id=1", SQLTypeUpdate}, + {"DELETE", "DELETE FROM users WHERE id=1", SQLTypeDelete}, + {"CREATE TABLE", "CREATE TABLE users (id UUID PRIMARY KEY)", SQLTypeDDL}, + {"ALTER TABLE", "ALTER TABLE users ADD COLUMN email STRING", SQLTypeDDL}, + {"DROP TABLE", "DROP TABLE users", SQLTypeDDL}, + {"TRUNCATE", "TRUNCATE TABLE users", SQLTypeTruncate}, + {"EXPLAIN", "EXPLAIN SELECT * FROM users", SQLTypeExplain}, + {"SHOW", "SHOW TABLES", SQLTypeShow}, + {"SET", "SET application_name = 'myapp'", SQLTypeSet}, + {"Empty", "", SQLTypeUnknown}, + {"Lowercase select", "select * from users", SQLTypeSelect}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ClassifySQL(tt.sql) + if result != tt.expected { + t.Errorf("ClassifySQL(%q) = %v, want %v", tt.sql, result, tt.expected) + } + }) + } +} + +// TestIsWriteOperation tests write operation detection +func TestIsWriteOperation(t *testing.T) { + tests := []struct { + sqlType SQLStatementType + expected bool + }{ + {SQLTypeSelect, false}, + {SQLTypeInsert, true}, + {SQLTypeUpdate, true}, + {SQLTypeDelete, true}, + {SQLTypeTruncate, true}, + {SQLTypeDDL, true}, + {SQLTypeExplain, false}, + {SQLTypeShow, false}, + {SQLTypeSet, false}, + {SQLTypeUnknown, false}, + } + + for _, tt := range tests { + t.Run(tt.sqlType.String(), func(t *testing.T) { + result := IsWriteOperation(tt.sqlType) + if result != tt.expected { + t.Errorf("IsWriteOperation(%v) = %v, want %v", tt.sqlType, result, tt.expected) + } + }) + } +} + +// Helper for SQLStatementType to string +func (s SQLStatementType) String() string { + switch s { + case SQLTypeSelect: + return "SELECT" + case SQLTypeInsert: + return "INSERT" + case SQLTypeUpdate: + return "UPDATE" + case SQLTypeDelete: + return "DELETE" + case SQLTypeDDL: + return "DDL" + case SQLTypeTruncate: + return "TRUNCATE" + case SQLTypeExplain: + return "EXPLAIN" + case SQLTypeShow: + return "SHOW" + case SQLTypeSet: + return "SET" + default: + return "UNKNOWN" + } +} + +// TestCanExecuteWrite tests write operation enforcement +func TestCanExecuteWrite(t *testing.T) { + tests := []struct { + name string + readOnlyMode bool + enableWriteMode bool + sql string + expectError bool + errorCode string + }{ + { + name: "SELECT in read-only mode", + readOnlyMode: true, + enableWriteMode: false, + sql: "SELECT * FROM users", + expectError: false, + }, + { + name: "INSERT in read-only mode", + readOnlyMode: true, + enableWriteMode: false, + sql: "INSERT INTO users (name) VALUES ('alice')", + expectError: true, + errorCode: ErrCodeReadOnlyViolation, + }, + { + name: "INSERT with write mode enabled", + readOnlyMode: false, + enableWriteMode: true, + sql: "INSERT INTO users (name) VALUES ('alice')", + expectError: false, + }, + { + name: "CREATE TABLE in read-only mode", + readOnlyMode: true, + enableWriteMode: false, + sql: "CREATE TABLE test (id UUID PRIMARY KEY)", + expectError: true, + errorCode: ErrCodeReadOnlyViolation, + }, + { + name: "CREATE TABLE with write mode enabled", + readOnlyMode: false, + enableWriteMode: true, + sql: "CREATE TABLE test (id UUID PRIMARY KEY)", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &Source{ + Config: Config{ + ReadOnlyMode: tt.readOnlyMode, + EnableWriteMode: tt.enableWriteMode, + }, + } + + err := source.CanExecuteWrite(tt.sql) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got nil") + return + } + + structErr, ok := err.(*StructuredError) + if !ok { + t.Errorf("Expected StructuredError but got %T", err) + return + } + + if structErr.Code != tt.errorCode { + t.Errorf("Expected error code %s but got %s", tt.errorCode, structErr.Code) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + } + }) + } +} + +// TestApplyQueryLimits tests query limit application +func TestApplyQueryLimits(t *testing.T) { + tests := []struct { + name string + sql string + maxRowLimit int + expectedSQL string + shouldAddLimit bool + }{ + { + name: "SELECT without LIMIT", + sql: "SELECT * FROM users", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 100", + shouldAddLimit: true, + }, + { + name: "SELECT with existing LIMIT", + sql: "SELECT * FROM users LIMIT 50", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 50", + shouldAddLimit: false, + }, + { + name: "SELECT without LIMIT and semicolon", + sql: "SELECT * FROM users;", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 100", + shouldAddLimit: true, + }, + { + name: "SELECT with trailing newline and semicolon", + sql: "SELECT * FROM users;\n", + maxRowLimit: 100, + expectedSQL: "SELECT * FROM users LIMIT 100", + shouldAddLimit: true, + }, + { + name: "SELECT with multiline and semicolon", + sql: "\n\tSELECT *\n\tFROM users\n\tORDER BY id;\n", + maxRowLimit: 100, + expectedSQL: "SELECT *\n\tFROM users\n\tORDER BY id LIMIT 100", + shouldAddLimit: true, + }, + { + name: "INSERT should not have LIMIT added", + sql: "INSERT INTO users (name) VALUES ('alice')", + maxRowLimit: 100, + expectedSQL: "INSERT INTO users (name) VALUES ('alice')", + shouldAddLimit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &Source{ + Config: Config{ + MaxRowLimit: tt.maxRowLimit, + QueryTimeoutSec: 0, // Timeout now managed by caller + }, + } + + modifiedSQL, err := source.ApplyQueryLimits(tt.sql) + + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + if modifiedSQL != tt.expectedSQL { + t.Errorf("Expected SQL:\n%s\nGot:\n%s", tt.expectedSQL, modifiedSQL) + } + }) + } +} + +// TestApplyQueryTimeout tests that timeout is managed by caller (not source) +func TestApplyQueryTimeout(t *testing.T) { + source := &Source{ + Config: Config{ + QueryTimeoutSec: 5, // Documented recommended timeout + MaxRowLimit: 0, // Don't add LIMIT + }, + } + + // Caller creates timeout context (following Go best practices) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Duration(source.QueryTimeoutSec)*time.Second) + defer cancel() + + // Apply query limits (doesn't modify context anymore) + modifiedSQL, err := source.ApplyQueryLimits("SELECT * FROM users") + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } + + // Verify context has deadline (managed by caller) + deadline, ok := ctx.Deadline() + if !ok { + t.Error("Expected deadline to be set but it wasn't") + return + } + + // Verify deadline is approximately 5 seconds from now + expectedDeadline := time.Now().Add(5 * time.Second) + diff := deadline.Sub(expectedDeadline) + if diff < 0 { + diff = -diff + } + + // Allow 1 second tolerance + if diff > time.Second { + t.Errorf("Deadline diff too large: %v", diff) + } + + // Verify SQL is unchanged (LIMIT not added since MaxRowLimit=0) + if modifiedSQL != "SELECT * FROM users" { + t.Errorf("Expected SQL unchanged, got: %s", modifiedSQL) + } +} + +// TestRedactSQL tests SQL redaction for telemetry +func TestRedactSQL(t *testing.T) { + tests := []struct { + name string + sql string + expected string + }{ + { + name: "String literal redaction", + sql: "SELECT * FROM users WHERE name='alice' AND email='alice@example.com'", + expected: "SELECT * FROM users WHERE name='***' AND email='***'", + }, + { + name: "Long number redaction", + sql: "SELECT * FROM users WHERE ssn=1234567890123", + expected: "SELECT * FROM users WHERE ssn=***", + }, + { + name: "Short numbers not redacted", + sql: "SELECT * FROM users WHERE age=25", + expected: "SELECT * FROM users WHERE age=25", + }, + { + name: "Multiple sensitive values", + sql: "INSERT INTO users (name, email, phone) VALUES ('bob', 'bob@example.com', '5551234567')", + expected: "INSERT INTO users (name, email, phone) VALUES ('***', '***', '***')", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := RedactSQL(tt.sql) + if result != tt.expected { + t.Errorf("RedactSQL:\nGot: %s\nExpected: %s", result, tt.expected) + } + }) + } +} + +// TestIsReadOnlyMode tests read-only mode detection +func TestIsReadOnlyMode(t *testing.T) { + tests := []struct { + name string + readOnlyMode bool + enableWriteMode bool + expected bool + }{ + {"Read-only by default", true, false, true}, + {"Write mode enabled", false, true, false}, + {"Both false", false, false, false}, + {"Read-only overridden by write mode", true, true, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &Source{ + Config: Config{ + ReadOnlyMode: tt.readOnlyMode, + EnableWriteMode: tt.enableWriteMode, + }, + } + + result := source.IsReadOnlyMode() + if result != tt.expected { + t.Errorf("IsReadOnlyMode() = %v, want %v", result, tt.expected) + } + }) + } +} + +// TestStructuredError tests error formatting +func TestStructuredError(t *testing.T) { + err := &StructuredError{ + Code: ErrCodeReadOnlyViolation, + Message: "Write operations not allowed", + Details: map[string]any{ + "sql_type": "INSERT", + }, + } + + errorStr := err.Error() + if !strings.Contains(errorStr, ErrCodeReadOnlyViolation) { + t.Errorf("Error string should contain error code: %s", errorStr) + } + if !strings.Contains(errorStr, "Write operations not allowed") { + t.Errorf("Error string should contain message: %s", errorStr) + } +} + +// TestDefaultSecuritySettings tests that security defaults are correct +func TestDefaultSecuritySettings(t *testing.T) { + ctx := context.Background() + + // Create a minimal YAML config + yamlData := `name: test +type: cockroachdb +host: localhost +port: "26257" +user: root +database: defaultdb +` + + var cfg Config + if err := yaml.Unmarshal([]byte(yamlData), &cfg); err != nil { + t.Fatalf("Failed to unmarshal YAML: %v", err) + } + + // Apply defaults through newConfig logic manually + cfg.MaxRetries = 5 + cfg.RetryBaseDelay = "500ms" + cfg.ReadOnlyMode = true + cfg.EnableWriteMode = false + cfg.MaxRowLimit = 1000 + cfg.QueryTimeoutSec = 30 + cfg.EnableTelemetry = true + cfg.TelemetryVerbose = false + + _ = ctx // prevent unused + + // Verify MCP security defaults + if !cfg.ReadOnlyMode { + t.Error("ReadOnlyMode should be true by default") + } + if cfg.EnableWriteMode { + t.Error("EnableWriteMode should be false by default") + } + if cfg.MaxRowLimit != 1000 { + t.Errorf("MaxRowLimit should be 1000, got %d", cfg.MaxRowLimit) + } + if cfg.QueryTimeoutSec != 30 { + t.Errorf("QueryTimeoutSec should be 30, got %d", cfg.QueryTimeoutSec) + } + if !cfg.EnableTelemetry { + t.Error("EnableTelemetry should be true by default") + } +} diff --git a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go new file mode 100644 index 0000000000..7bd4f07345 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql.go @@ -0,0 +1,186 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbexecutesql + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-execute-sql" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + sqlParameter := parameters.NewStringParameter("sql", "The sql to execute.") + params := parameters.Parameters{sqlParameter} + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, params, nil) + + t := Tool{ + Config: cfg, + Parameters: params, + manifest: tools.Manifest{Description: cfg.Description, Parameters: params.Manifest(), AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + Parameters parameters.Parameters `yaml:"parameters"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + sql, ok := paramsMap["sql"].(string) + if !ok { + return nil, fmt.Errorf("unable to get cast %s", paramsMap["sql"]) + } + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, fmt.Errorf("error getting logger: %s", err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sql)) + + results, err := source.Query(ctx, sql) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, f := range fields { + row.Add(f.Name, v[i]) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.Parameters, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.Parameters +} diff --git a/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go new file mode 100644 index 0000000000..a7c8726a90 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbexecutesql/cockroachdbexecutesql_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbexecutesql_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbexecutesql" +) + +func TestParseFromYamlCockroachDBExecuteSQL(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: execute_sql_tool + type: cockroachdb-execute-sql + source: my-crdb-instance + description: Execute SQL on CockroachDB + `, + want: server.ToolConfigs{ + "execute_sql_tool": cockroachdbexecutesql.Config{ + Name: "execute_sql_tool", + Type: "cockroachdb-execute-sql", + Source: "my-crdb-instance", + Description: "Execute SQL on CockroachDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBExecuteSQLToolConfigType(t *testing.T) { + cfg := cockroachdbexecutesql.Config{ + Name: "test-tool", + Type: "cockroachdb-execute-sql", + Source: "test-source", + Description: "test description", + } + + if cfg.ToolConfigType() != "cockroachdb-execute-sql" { + t.Errorf("expected ToolConfigType 'cockroachdb-execute-sql', got %q", cfg.ToolConfigType()) + } +} diff --git a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go new file mode 100644 index 0000000000..2a5c2dbc8e --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas.go @@ -0,0 +1,187 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblistschemas + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-list-schemas" + +const listSchemasStatement = ` + SELECT + catalog_name, + schema_name, + crdb_is_user_defined + FROM information_schema.schemata + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'crdb_internal', 'pg_extension') + AND schema_name NOT LIKE 'pg_temp_%' + AND schema_name NOT LIKE 'pg_toast_temp_%' + ORDER BY schema_name; +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var _ compatibleSource = &cockroachdb.Source{} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters := parameters.Parameters{} + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + results, err := source.Query(ctx, listSchemasStatement) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("error reading query results: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go new file mode 100644 index 0000000000..260187641f --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblistschemas/cockroachdblistschemas_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblistschemas_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblistschemas" +) + +func TestParseFromYamlCockroachDBListSchemas(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: list_schemas_tool + type: cockroachdb-list-schemas + source: my-crdb-instance + description: List schemas in CockroachDB + `, + want: server.ToolConfigs{ + "list_schemas_tool": cockroachdblistschemas.Config{ + Name: "list_schemas_tool", + Type: "cockroachdb-list-schemas", + Source: "my-crdb-instance", + Description: "List schemas in CockroachDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBListSchemasToolConfigType(t *testing.T) { + cfg := cockroachdblistschemas.Config{ + Name: "test-tool", + Type: "cockroachdb-list-schemas", + Source: "test-source", + Description: "test description", + } + + if cfg.ToolConfigType() != "cockroachdb-list-schemas" { + t.Errorf("expected ToolConfigType 'cockroachdb-list-schemas', got %q", cfg.ToolConfigType()) + } +} diff --git a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go new file mode 100644 index 0000000000..254ee3b658 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables.go @@ -0,0 +1,261 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblisttables + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-list-tables" + +const listTablesStatement = ` + WITH desired_relkinds AS ( + SELECT ARRAY['r', 'p']::char[] AS kinds + ), + table_info AS ( + SELECT + t.oid AS table_oid, + ns.nspname AS schema_name, + t.relname AS table_name, + pg_get_userbyid(t.relowner) AS table_owner, + obj_description(t.oid, 'pg_class') AS table_comment, + t.relkind AS object_kind + FROM + pg_class t + JOIN + pg_namespace ns ON ns.oid = t.relnamespace + CROSS JOIN desired_relkinds dk + WHERE + t.relkind = ANY(dk.kinds) + AND (NULLIF(TRIM($1), '') IS NULL OR t.relname = ANY(string_to_array($1,','))) + AND ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'crdb_internal', 'pg_extension') + AND ns.nspname NOT LIKE 'pg_temp_%' AND ns.nspname NOT LIKE 'pg_toast_temp_%' + ), + columns_info AS ( + SELECT + att.attrelid AS table_oid, att.attname AS column_name, format_type(att.atttypid, att.atttypmod) AS data_type, + att.attnum AS column_ordinal_position, att.attnotnull AS is_not_nullable, + pg_get_expr(ad.adbin, ad.adrelid) AS column_default, col_description(att.attrelid, att.attnum) AS column_comment + FROM pg_attribute att LEFT JOIN pg_attrdef ad ON att.attrelid = ad.adrelid AND att.attnum = ad.adnum + JOIN table_info ti ON att.attrelid = ti.table_oid WHERE att.attnum > 0 AND NOT att.attisdropped + ), + constraints_info AS ( + SELECT + con.conrelid AS table_oid, con.conname AS constraint_name, pg_get_constraintdef(con.oid) AS constraint_definition, + CASE con.contype WHEN 'p' THEN 'PRIMARY KEY' WHEN 'f' THEN 'FOREIGN KEY' WHEN 'u' THEN 'UNIQUE' WHEN 'c' THEN 'CHECK' ELSE con.contype::text END AS constraint_type, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.conkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = u.attnum) AS constraint_columns, + NULLIF(con.confrelid, 0)::regclass AS foreign_key_referenced_table, + (SELECT array_agg(att.attname ORDER BY u.attposition) FROM unnest(con.confkey) WITH ORDINALITY AS u(attnum, attposition) JOIN pg_attribute att ON att.attrelid = con.confrelid AND att.attnum = u.attnum WHERE con.contype = 'f') AS foreign_key_referenced_columns + FROM pg_constraint con JOIN table_info ti ON con.conrelid = ti.table_oid + ), + indexes_info AS ( + SELECT + idx.indrelid AS table_oid, ic.relname AS index_name, pg_get_indexdef(idx.indexrelid) AS index_definition, + idx.indisunique AS is_unique, idx.indisprimary AS is_primary, am.amname AS index_method, + (SELECT array_agg(att.attname ORDER BY u.ord) FROM unnest(idx.indkey::int[]) WITH ORDINALITY AS u(colidx, ord) LEFT JOIN pg_attribute att ON att.attrelid = idx.indrelid AND att.attnum = u.colidx WHERE u.colidx <> 0) AS index_columns + FROM pg_index idx JOIN pg_class ic ON ic.oid = idx.indexrelid JOIN pg_am am ON am.oid = ic.relam JOIN table_info ti ON idx.indrelid = ti.table_oid + ) + SELECT + ti.schema_name, + ti.table_name AS object_name, + CASE + WHEN $2 = 'simple' THEN + json_build_object('name', ti.table_name) + ELSE + json_build_object( + 'schema_name', ti.schema_name, + 'object_name', ti.table_name, + 'object_type', CASE ti.object_kind + WHEN 'r' THEN 'TABLE' + WHEN 'p' THEN 'PARTITIONED TABLE' + ELSE ti.object_kind::text + END, + 'owner', ti.table_owner, + 'comment', ti.table_comment, + 'columns', COALESCE((SELECT json_agg(json_build_object('column_name',ci.column_name,'data_type',ci.data_type,'ordinal_position',ci.column_ordinal_position,'is_not_nullable',ci.is_not_nullable,'column_default',ci.column_default,'column_comment',ci.column_comment) ORDER BY ci.column_ordinal_position) FROM columns_info ci WHERE ci.table_oid = ti.table_oid), '[]'::json), + 'constraints', COALESCE((SELECT json_agg(json_build_object('constraint_name',cons.constraint_name,'constraint_type',cons.constraint_type,'constraint_definition',cons.constraint_definition,'constraint_columns',cons.constraint_columns,'foreign_key_referenced_table',cons.foreign_key_referenced_table,'foreign_key_referenced_columns',cons.foreign_key_referenced_columns)) FROM constraints_info cons WHERE cons.table_oid = ti.table_oid), '[]'::json), + 'indexes', COALESCE((SELECT json_agg(json_build_object('index_name',ii.index_name,'index_definition',ii.index_definition,'is_unique',ii.is_unique,'is_primary',ii.is_primary,'index_method',ii.index_method,'index_columns',ii.index_columns)) FROM indexes_info ii WHERE ii.table_oid = ti.table_oid), '[]'::json) + ) + END AS object_details + FROM table_info ti ORDER BY ti.schema_name, ti.table_name; +` + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var _ compatibleSource = &cockroachdb.Source{} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), + parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info."), + } + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + + tableNames, ok := paramsMap["table_names"].(string) + if !ok { + return nil, fmt.Errorf("invalid 'table_names' parameter; expected a string") + } + outputFormat, _ := paramsMap["output_format"].(string) + if outputFormat != "simple" && outputFormat != "detailed" { + return nil, fmt.Errorf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat) + } + + results, err := source.Query(ctx, listTablesStatement, tableNames, outputFormat) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + var out []map[string]any + + for results.Next() { + values, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + rowMap := make(map[string]any) + for i, field := range fields { + rowMap[string(field.Name)] = values[i] + } + out = append(out, rowMap) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("error reading query results: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go new file mode 100644 index 0000000000..60516919b4 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdblisttables/cockroachdblisttables_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdblisttables_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdblisttables" +) + +func TestParseFromYamlCockroachDBListTables(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: list_tables_tool + type: cockroachdb-list-tables + source: my-crdb-instance + description: List tables in CockroachDB + `, + want: server.ToolConfigs{ + "list_tables_tool": cockroachdblisttables.Config{ + Name: "list_tables_tool", + Type: "cockroachdb-list-tables", + Source: "my-crdb-instance", + Description: "List tables in CockroachDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBListTablesToolConfigType(t *testing.T) { + cfg := cockroachdblisttables.Config{ + Name: "test-tool", + Type: "cockroachdb-list-tables", + Source: "test-source", + Description: "test description", + } + + if cfg.ToolConfigType() != "cockroachdb-list-tables" { + t.Errorf("expected ToolConfigType 'cockroachdb-list-tables', got %q", cfg.ToolConfigType()) + } +} diff --git a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go new file mode 100644 index 0000000000..33b1830545 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql.go @@ -0,0 +1,192 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbsql + +import ( + "context" + "fmt" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/sources/cockroachdb" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util/orderedmap" + "github.com/googleapis/genai-toolbox/internal/util/parameters" + "github.com/jackc/pgx/v5" +) + +const kind string = "cockroachdb-sql" + +func init() { + if !tools.Register(kind, newConfig) { + panic(fmt.Sprintf("tool kind %q already registered", kind)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) +} + +var _ compatibleSource = &cockroachdb.Source{} + +var compatibleSources = [...]string{cockroachdb.SourceKind} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + Statement string `yaml:"statement" validate:"required"` + AuthRequired []string `yaml:"authRequired"` + Parameters parameters.Parameters `yaml:"parameters"` + TemplateParameters parameters.Parameters `yaml:"templateParameters"` +} + +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigKind() string { + return kind +} + +func (cfg Config) ToolConfigType() string { + return kind +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + rawS, ok := srcs[cfg.Source] + if !ok { + return nil, fmt.Errorf("no source named %q configured", cfg.Source) + } + + _, ok = rawS.(compatibleSource) + if !ok { + return nil, fmt.Errorf("invalid source for %q tool: source kind must be one of %q", kind, compatibleSources) + } + + allParameters, paramManifest, err := parameters.ProcessParameters(cfg.TemplateParameters, cfg.Parameters) + if err != nil { + return nil, err + } + + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, err + } + + paramsMap := params.AsMap() + newStatement, err := parameters.ResolveTemplateParams(t.TemplateParameters, t.Statement, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract template params %w", err) + } + + newParams, err := parameters.GetParams(t.Parameters, paramsMap) + if err != nil { + return nil, fmt.Errorf("unable to extract standard params %w", err) + } + sliceParams := newParams.AsSlice() + results, err := source.Query(ctx, newStatement, sliceParams...) + if err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + defer results.Close() + + fields := results.FieldDescriptions() + + var out []any + for results.Next() { + v, err := results.Values() + if err != nil { + return nil, fmt.Errorf("unable to parse row: %w", err) + } + row := orderedmap.Row{} + for i, f := range fields { + row.Add(f.Name, v[i]) + } + out = append(out, row) + } + + if err := results.Err(); err != nil { + return nil, fmt.Errorf("unable to execute query: %w", err) + } + + return out, nil +} + +func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) { + return parameters.ParseParams(t.AllParams, data, claims) +} + +func (t Tool) EmbedParams(ctx context.Context, params parameters.ParamValues, models map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return params, nil +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(_ tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go new file mode 100644 index 0000000000..1cca50a280 --- /dev/null +++ b/internal/tools/cockroachdb/cockroachdbsql/cockroachdbsql_test.go @@ -0,0 +1,93 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdbsql_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/cockroachdb/cockroachdbsql" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +func TestParseFromYamlCockroachDB(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tools + name: example_tool + type: cockroachdb-sql + source: my-crdb-instance + description: some description + statement: | + SELECT * FROM SQL_STATEMENT; + parameters: + - name: user_id + type: string + description: user id parameter + `, + want: server.ToolConfigs{ + "example_tool": cockroachdbsql.Config{ + Name: "example_tool", + Type: "cockroachdb-sql", + Source: "my-crdb-instance", + Description: "some description", + Statement: "SELECT * FROM SQL_STATEMENT;\n", + AuthRequired: []string{}, + Parameters: []parameters.Parameter{ + parameters.NewStringParameter("user_id", "user id parameter"), + }, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestCockroachDBSQLToolConfigType(t *testing.T) { + cfg := cockroachdbsql.Config{ + Name: "test-tool", + Type: "cockroachdb-sql", + Source: "test-source", + Description: "test description", + Statement: "SELECT 1", + } + + if cfg.ToolConfigType() != "cockroachdb-sql" { + t.Errorf("expected ToolConfigType 'cockroachdb-sql', got %q", cfg.ToolConfigType()) + } +} diff --git a/tests/cockroachdb/cockroachdb_integration_test.go b/tests/cockroachdb/cockroachdb_integration_test.go new file mode 100644 index 0000000000..43abb36207 --- /dev/null +++ b/tests/cockroachdb/cockroachdb_integration_test.go @@ -0,0 +1,220 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cockroachdb + +import ( + "context" + "fmt" + "net/url" + "os" + "regexp" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/tests" + "github.com/jackc/pgx/v5/pgxpool" +) + +var ( + CockroachDBSourceKind = "cockroachdb" + CockroachDBToolKind = "cockroachdb-sql" + CockroachDBDatabase = getEnvOrDefault("COCKROACHDB_DATABASE", "defaultdb") + CockroachDBHost = getEnvOrDefault("COCKROACHDB_HOST", "localhost") + CockroachDBPort = getEnvOrDefault("COCKROACHDB_PORT", "26257") + CockroachDBUser = getEnvOrDefault("COCKROACHDB_USER", "root") + CockroachDBPass = getEnvOrDefault("COCKROACHDB_PASS", "") +) + +func getEnvOrDefault(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +func getCockroachDBVars(t *testing.T) map[string]any { + if CockroachDBHost == "" { + t.Skip("COCKROACHDB_HOST not set, skipping CockroachDB integration test") + } + + return map[string]any{ + "type": CockroachDBSourceKind, + "host": CockroachDBHost, + "port": CockroachDBPort, + "database": CockroachDBDatabase, + "user": CockroachDBUser, + "password": CockroachDBPass, + "maxRetries": 5, + "retryBaseDelay": "500ms", + "queryParams": map[string]string{ + "sslmode": "disable", + }, + } +} + +func initCockroachDBConnectionPool(host, port, user, pass, dbname string) (*pgxpool.Pool, error) { + connURL := &url.URL{ + Scheme: "postgres", + User: url.UserPassword(user, pass), + Host: fmt.Sprintf("%s:%s", host, port), + Path: dbname, + RawQuery: "sslmode=disable&application_name=cockroachdb-integration-test", + } + pool, err := pgxpool.New(context.Background(), connURL.String()) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + return pool, nil +} + +func TestCockroachDB(t *testing.T) { + sourceConfig := getCockroachDBVars(t) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + var args []string + + pool, err := initCockroachDBConnectionPool(CockroachDBHost, CockroachDBPort, CockroachDBUser, CockroachDBPass, CockroachDBDatabase) + if err != nil { + t.Fatalf("unable to create cockroachdb connection pool: %s", err) + } + // Note: Don't defer pool.Close() here - the pool is only used for test setup/teardown. + // Closing it explicitly can cause hangs if the server's pool is still active. + // The pool will be cleaned up when the test exits. + + // Verify CockroachDB version + var version string + err = pool.QueryRow(ctx, "SELECT version()").Scan(&version) + if err != nil { + t.Fatalf("failed to query version: %s", err) + } + if !strings.Contains(version, "CockroachDB") { + t.Fatalf("not connected to CockroachDB, got: %s", version) + } + t.Logf("✅ Connected to: %s", version) + + // cleanup test environment + tests.CleanupPostgresTables(t, ctx, pool) + + // create table names with UUID suffix + tableNameParam := "param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameAuth := "auth_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + tableNameTemplateParam := "template_param_table_" + strings.ReplaceAll(uuid.New().String(), "-", "") + + // set up data for param tool (using CockroachDB explicit INT primary keys) + createParamTableStmt, insertParamTableStmt, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, paramTestParams := tests.GetCockroachDBParamToolInfo(tableNameParam) + teardownTable1 := tests.SetupPostgresSQLTable(t, ctx, pool, createParamTableStmt, insertParamTableStmt, tableNameParam, paramTestParams) + defer teardownTable1(t) + + // set up data for auth tool + createAuthTableStmt, insertAuthTableStmt, authToolStmt, authTestParams := tests.GetCockroachDBAuthToolInfo(tableNameAuth) + teardownTable2 := tests.SetupPostgresSQLTable(t, ctx, pool, createAuthTableStmt, insertAuthTableStmt, tableNameAuth, authTestParams) + defer teardownTable2(t) + + // Write config into a file and pass it to command + toolsFile := tests.GetToolsConfig(sourceConfig, CockroachDBToolKind, paramToolStmt, idParamToolStmt, nameParamToolStmt, arrayToolStmt, authToolStmt) + + // Add execute-sql tool with write-enabled source (CockroachDB MCP security requires explicit opt-in) + toolsFile = addCockroachDBExecuteSqlConfig(t, toolsFile, sourceConfig) + + tmplSelectCombined, tmplSelectFilterCombined := tests.GetPostgresSQLTmplToolStatement() + toolsFile = tests.AddTemplateParamConfig(t, toolsFile, CockroachDBToolKind, tmplSelectCombined, tmplSelectFilterCombined, "") + + cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...) + if err != nil { + t.Fatalf("command initialization returned an error: %s", err) + } + defer cleanup() + + waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out) + if err != nil { + t.Logf("toolbox command logs: \n%s", out) + t.Fatalf("toolbox didn't start successfully: %s", err) + } + + // Get configs for tests (use CockroachDB-specific expectations) + select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want := tests.GetCockroachDBWants() + + // Run required integration test suites (per CONTRIBUTING.md) + t.Run("ToolGetTest", func(t *testing.T) { + tests.RunToolGetTest(t) + }) + + t.Run("ToolInvokeTest", func(t *testing.T) { + tests.RunToolInvokeTest(t, select1Want) + }) + + t.Run("MCPToolCallMethod", func(t *testing.T) { + tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) + }) + + t.Run("ExecuteSqlToolInvokeTest", func(t *testing.T) { + tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) + }) + + t.Run("ToolInvokeWithTemplateParameters", func(t *testing.T) { + tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam) + }) + + t.Logf("✅✅✅ All CockroachDB integration tests passed!") +} + +// addCockroachDBExecuteSqlConfig adds execute-sql tool with write-enabled source +// CockroachDB has MCP security enabled by default, so execute-sql needs a separate source with enableWriteMode +func addCockroachDBExecuteSqlConfig(t *testing.T, config map[string]any, baseSourceConfig map[string]any) map[string]any { + // Add write-enabled source for execute-sql tool + sources, ok := config["sources"].(map[string]any) + if !ok { + t.Fatalf("unable to get sources from config") + } + + // Create a copy of the base source config with write mode enabled + writeEnabledSource := make(map[string]any) + for k, v := range baseSourceConfig { + writeEnabledSource[k] = v + } + writeEnabledSource["enableWriteMode"] = true + writeEnabledSource["readOnlyMode"] = false + + sources["my-write-instance"] = writeEnabledSource + + // Add tools using the write-enabled source + tools, ok := config["tools"].(map[string]any) + if !ok { + t.Fatalf("unable to get tools from config") + } + + tools["my-exec-sql-tool"] = map[string]any{ + "type": "cockroachdb-execute-sql", + "source": "my-write-instance", + "description": "Tool to execute sql", + } + tools["my-auth-exec-sql-tool"] = map[string]any{ + "type": "cockroachdb-execute-sql", + "source": "my-write-instance", + "description": "Tool to execute sql", + "authRequired": []string{ + "my-google-auth", + }, + } + + return config +} diff --git a/tests/common.go b/tests/common.go index fb56159e51..d200d59dd6 100644 --- a/tests/common.go +++ b/tests/common.go @@ -528,6 +528,41 @@ func GetPostgresSQLTmplToolStatement() (string, string) { return tmplSelectCombined, tmplSelectFilterCombined } +// GetCockroachDBParamToolInfo returns statements and param for my-tool cockroachdb-sql type +// Uses explicit INT PRIMARY KEY instead of SERIAL to ensure deterministic IDs +func GetCockroachDBParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { + createStatement := fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name TEXT);", tableName) + insertStatement := fmt.Sprintf("INSERT INTO %s (id, name) VALUES (1, $1), (2, $2), (3, $3), (4, $4);", tableName) + toolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1 OR name = $2 ORDER BY id;", tableName) + idParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = $1;", tableName) + nameParamStatement := fmt.Sprintf("SELECT * FROM %s WHERE name = $1;", tableName) + arrayToolStatement := fmt.Sprintf("SELECT * FROM %s WHERE id = ANY($1) AND name = ANY($2) ORDER BY id;", tableName) + params := []any{"Alice", "Jane", "Sid", nil} + return createStatement, insertStatement, toolStatement, idParamStatement, nameParamStatement, arrayToolStatement, params +} + +// GetCockroachDBAuthToolInfo returns statements and param of my-auth-tool for cockroachdb-sql type +// Uses explicit INT PRIMARY KEY instead of SERIAL to ensure deterministic IDs +func GetCockroachDBAuthToolInfo(tableName string) (string, string, string, []any) { + createStatement := fmt.Sprintf("CREATE TABLE %s (id INT PRIMARY KEY, name TEXT, email TEXT);", tableName) + insertStatement := fmt.Sprintf("INSERT INTO %s (id, name, email) VALUES (1, $1, $2), (2, $3, $4)", tableName) + toolStatement := fmt.Sprintf("SELECT name FROM %s WHERE email = $1;", tableName) + params := []any{"Alice", ServiceAccountEmail, "Jane", "janedoe@gmail.com"} + return createStatement, insertStatement, toolStatement, params +} + +// GetCockroachDBWants return the expected wants for cockroachdb +func GetCockroachDBWants() (string, string, string, string) { + select1Want := "[{\"?column?\":1}]" + // CockroachDB formats syntax errors differently than PostgreSQL: + // - Uses lowercase for SQL keywords in error messages + // - Uses format: 'at or near "token": syntax error' instead of 'syntax error at or near "TOKEN"' + mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ERROR: at or near \"selec\": syntax error (SQLSTATE 42601)"}],"isError":true}}` + createTableStatement := `"CREATE TABLE t (id INT PRIMARY KEY, name TEXT)"` + mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"?column?\":1}"}]}}` + return select1Want, mcpMyFailToolWant, createTableStatement, mcpSelect1Want +} + // GetMSSQLParamToolInfo returns statements and param for my-tool mssql-sql type func GetMSSQLParamToolInfo(tableName string) (string, string, string, string, string, string, []any) { createStatement := fmt.Sprintf("CREATE TABLE %s (id INT IDENTITY(1,1) PRIMARY KEY, name VARCHAR(255));", tableName)